Re-write locking to support transactions.

This commit is contained in:
Friedrich Lindenberg 2013-05-13 21:21:25 +02:00
parent 33908b2699
commit db71b6d631
2 changed files with 55 additions and 26 deletions

View File

@ -33,14 +33,23 @@ class Database(object):
@property @property
def executable(self): def executable(self):
""" The current connection or engine against which statements """ The current connection or engine against which statements
will be executed. """ will be executed. """
if hasattr(self.local, 'connection'): if hasattr(self.local, 'connection'):
return self.local.connection return self.local.connection
return self.engine return self.engine
def _acquire(self):
self.lock.acquire()
def _release(self):
if not hasattr(self.local, 'tx'):
self.lock.release()
else:
self.local.must_release = True
def begin(self): def begin(self):
""" Enter a transaction explicitly. No data will be written """ Enter a transaction explicitly. No data will be written
until the transaction has been committed. """ until the transaction has been committed. """
if not hasattr(self.local, 'connection'): if not hasattr(self.local, 'connection'):
self.local.connection = self.engine.connect() self.local.connection = self.engine.connect()
@ -52,12 +61,18 @@ class Database(object):
since the transaction was begun permanent. """ since the transaction was begun permanent. """
self.local.tx.commit() self.local.tx.commit()
del self.local.tx del self.local.tx
if not hasattr(self.local, 'must_release'):
self.lock.release()
del self.local.must_release
def rollback(self): def rollback(self):
""" Roll back the current transaction, discarding all statements """ Roll back the current transaction, discarding all statements
executed since the transaction was begun. """ executed since the transaction was begun. """
self.local.tx.rollback() self.local.tx.rollback()
del self.local.tx del self.local.tx
if not hasattr(self.local, 'must_release'):
self.lock.release()
del self.local.must_release
@property @property
def tables(self): def tables(self):
@ -79,7 +94,8 @@ class Database(object):
table = db.create_table('population') table = db.create_table('population')
""" """
with self.lock: self._acquire()
try:
log.debug("Creating table: %s on %r" % (table_name, self.engine)) log.debug("Creating table: %s on %r" % (table_name, self.engine))
table = SQLATable(table_name, self.metadata) table = SQLATable(table_name, self.metadata)
col = Column('id', Integer, primary_key=True) col = Column('id', Integer, primary_key=True)
@ -87,6 +103,8 @@ class Database(object):
table.create(self.engine) table.create(self.engine)
self._tables[table_name] = table self._tables[table_name] = table
return Table(self, table) return Table(self, table)
finally:
self._release()
def load_table(self, table_name): def load_table(self, table_name):
""" """
@ -100,11 +118,14 @@ class Database(object):
table = db.load_table('population') table = db.load_table('population')
""" """
with self.lock: self._acquire()
try:
log.debug("Loading table: %s on %r" % (table_name, self)) log.debug("Loading table: %s on %r" % (table_name, self))
table = SQLATable(table_name, self.metadata, autoload=True) table = SQLATable(table_name, self.metadata, autoload=True)
self._tables[table_name] = table self._tables[table_name] = table
return Table(self, table) return Table(self, table)
finally:
self._release()
def get_table(self, table_name): def get_table(self, table_name):
""" """
@ -118,13 +139,16 @@ class Database(object):
# you can also use the short-hand syntax: # you can also use the short-hand syntax:
table = db['population'] table = db['population']
""" """
with self.lock: if table_name in self._tables:
if table_name in self._tables: return Table(self, self._tables[table_name])
return Table(self, self._tables[table_name]) self._acquire()
try:
if self.engine.has_table(table_name): if self.engine.has_table(table_name):
return self.load_table(table_name) return self.load_table(table_name)
else: else:
return self.create_table(table_name) return self.create_table(table_name)
finally:
self._release()
def __getitem__(self, table_name): def __getitem__(self, table_name):
return self.get_table(table_name) return self.get_table(table_name)

View File

@ -39,10 +39,10 @@ class Table(object):
dropping the table. If you want to re-create the table, make dropping the table. If you want to re-create the table, make
sure to get a fresh instance from the :py:class:`Database <dataset.Database>`. sure to get a fresh instance from the :py:class:`Database <dataset.Database>`.
""" """
self.database._acquire()
self._is_dropped = True self._is_dropped = True
with self.database.lock: self.database._tables.pop(self.table.name, None)
self.database._tables.pop(self.table.name, None) self.table.drop(self.database.engine)
self.table.drop(self.database.executable)
def _check_dropped(self): def _check_dropped(self):
if self._is_dropped: if self._is_dropped:
@ -198,11 +198,14 @@ class Table(object):
table.create_column('created_at', sqlalchemy.DateTime) table.create_column('created_at', sqlalchemy.DateTime)
""" """
self._check_dropped() self._check_dropped()
with self.database.lock: self.database._acquire()
try:
if name not in self.table.columns.keys(): if name not in self.table.columns.keys():
col = Column(name, type) col = Column(name, type)
col.create(self.table, col.create(self.table,
connection=self.database.engine) connection=self.database.engine)
finally:
self.database._release()
def create_index(self, columns, name=None): def create_index(self, columns, name=None):
""" """
@ -212,20 +215,22 @@ class Table(object):
table.create_index(['name', 'country']) table.create_index(['name', 'country'])
""" """
self._check_dropped() self._check_dropped()
with self.database.lock: if not name:
if not name: sig = abs(hash('||'.join(columns)))
sig = abs(hash('||'.join(columns))) name = 'ix_%s_%s' % (self.table.name, sig)
name = 'ix_%s_%s' % (self.table.name, sig) if name in self.indexes:
if name in self.indexes: return self.indexes[name]
return self.indexes[name] try:
try: self.database._acquire()
columns = [self.table.c[c] for c in columns] columns = [self.table.c[c] for c in columns]
idx = Index(name, *columns) idx = Index(name, *columns)
idx.create(self.database.engine) idx.create(self.database.engine)
except: except:
idx = None idx = None
self.indexes[name] = idx finally:
return idx self.database._release()
self.indexes[name] = idx
return idx
def find_one(self, **_filter): def find_one(self, **_filter):
""" """