From db71b6d6318d4bb2b34f2380aba39ddf61bc7b28 Mon Sep 17 00:00:00 2001 From: Friedrich Lindenberg Date: Mon, 13 May 2013 21:21:25 +0200 Subject: [PATCH] Re-write locking to support transactions. --- dataset/persistence/database.py | 38 +++++++++++++++++++++++------ dataset/persistence/table.py | 43 ++++++++++++++++++--------------- 2 files changed, 55 insertions(+), 26 deletions(-) diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 78773ec..068155e 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -33,14 +33,23 @@ class Database(object): @property def executable(self): - """ The current connection or engine against which statements + """ The current connection or engine against which statements will be executed. """ if hasattr(self.local, 'connection'): return self.local.connection return self.engine + def _acquire(self): + self.lock.acquire() + + def _release(self): + if not hasattr(self.local, 'tx'): + self.lock.release() + else: + self.local.must_release = True + def begin(self): - """ Enter a transaction explicitly. No data will be written + """ Enter a transaction explicitly. No data will be written until the transaction has been committed. """ if not hasattr(self.local, 'connection'): self.local.connection = self.engine.connect() @@ -52,12 +61,18 @@ class Database(object): since the transaction was begun permanent. """ self.local.tx.commit() del self.local.tx + if not hasattr(self.local, 'must_release'): + self.lock.release() + del self.local.must_release def rollback(self): """ Roll back the current transaction, discarding all statements executed since the transaction was begun. """ self.local.tx.rollback() del self.local.tx + if not hasattr(self.local, 'must_release'): + self.lock.release() + del self.local.must_release @property def tables(self): @@ -79,7 +94,8 @@ class Database(object): table = db.create_table('population') """ - with self.lock: + self._acquire() + try: log.debug("Creating table: %s on %r" % (table_name, self.engine)) table = SQLATable(table_name, self.metadata) col = Column('id', Integer, primary_key=True) @@ -87,6 +103,8 @@ class Database(object): table.create(self.engine) self._tables[table_name] = table return Table(self, table) + finally: + self._release() def load_table(self, table_name): """ @@ -100,11 +118,14 @@ class Database(object): table = db.load_table('population') """ - with self.lock: + self._acquire() + try: log.debug("Loading table: %s on %r" % (table_name, self)) table = SQLATable(table_name, self.metadata, autoload=True) self._tables[table_name] = table return Table(self, table) + finally: + self._release() def get_table(self, table_name): """ @@ -118,13 +139,16 @@ class Database(object): # you can also use the short-hand syntax: table = db['population'] """ - with self.lock: - if table_name in self._tables: - return Table(self, self._tables[table_name]) + if table_name in self._tables: + return Table(self, self._tables[table_name]) + self._acquire() + try: if self.engine.has_table(table_name): return self.load_table(table_name) else: return self.create_table(table_name) + finally: + self._release() def __getitem__(self, table_name): return self.get_table(table_name) diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index ffa29d4..6bd8062 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -39,10 +39,10 @@ class Table(object): dropping the table. If you want to re-create the table, make sure to get a fresh instance from the :py:class:`Database `. """ + self.database._acquire() self._is_dropped = True - with self.database.lock: - self.database._tables.pop(self.table.name, None) - self.table.drop(self.database.executable) + self.database._tables.pop(self.table.name, None) + self.table.drop(self.database.engine) def _check_dropped(self): if self._is_dropped: @@ -198,11 +198,14 @@ class Table(object): table.create_column('created_at', sqlalchemy.DateTime) """ self._check_dropped() - with self.database.lock: + self.database._acquire() + try: if name not in self.table.columns.keys(): col = Column(name, type) col.create(self.table, - connection=self.database.engine) + connection=self.database.engine) + finally: + self.database._release() def create_index(self, columns, name=None): """ @@ -212,20 +215,22 @@ class Table(object): table.create_index(['name', 'country']) """ self._check_dropped() - with self.database.lock: - if not name: - sig = abs(hash('||'.join(columns))) - name = 'ix_%s_%s' % (self.table.name, sig) - if name in self.indexes: - return self.indexes[name] - try: - columns = [self.table.c[c] for c in columns] - idx = Index(name, *columns) - idx.create(self.database.engine) - except: - idx = None - self.indexes[name] = idx - return idx + if not name: + sig = abs(hash('||'.join(columns))) + name = 'ix_%s_%s' % (self.table.name, sig) + if name in self.indexes: + return self.indexes[name] + try: + self.database._acquire() + columns = [self.table.c[c] for c in columns] + idx = Index(name, *columns) + idx.create(self.database.engine) + except: + idx = None + finally: + self.database._release() + self.indexes[name] = idx + return idx def find_one(self, **_filter): """