Re-write locking to support transactions.

This commit is contained in:
Friedrich Lindenberg 2013-05-13 21:21:25 +02:00
parent 33908b2699
commit db71b6d631
2 changed files with 55 additions and 26 deletions

View File

@ -39,6 +39,15 @@ class Database(object):
return self.local.connection
return self.engine
def _acquire(self):
self.lock.acquire()
def _release(self):
if not hasattr(self.local, 'tx'):
self.lock.release()
else:
self.local.must_release = True
def begin(self):
""" Enter a transaction explicitly. No data will be written
until the transaction has been committed. """
@ -52,12 +61,18 @@ class Database(object):
since the transaction was begun permanent. """
self.local.tx.commit()
del self.local.tx
if not hasattr(self.local, 'must_release'):
self.lock.release()
del self.local.must_release
def rollback(self):
""" Roll back the current transaction, discarding all statements
executed since the transaction was begun. """
self.local.tx.rollback()
del self.local.tx
if not hasattr(self.local, 'must_release'):
self.lock.release()
del self.local.must_release
@property
def tables(self):
@ -79,7 +94,8 @@ class Database(object):
table = db.create_table('population')
"""
with self.lock:
self._acquire()
try:
log.debug("Creating table: %s on %r" % (table_name, self.engine))
table = SQLATable(table_name, self.metadata)
col = Column('id', Integer, primary_key=True)
@ -87,6 +103,8 @@ class Database(object):
table.create(self.engine)
self._tables[table_name] = table
return Table(self, table)
finally:
self._release()
def load_table(self, table_name):
"""
@ -100,11 +118,14 @@ class Database(object):
table = db.load_table('population')
"""
with self.lock:
self._acquire()
try:
log.debug("Loading table: %s on %r" % (table_name, self))
table = SQLATable(table_name, self.metadata, autoload=True)
self._tables[table_name] = table
return Table(self, table)
finally:
self._release()
def get_table(self, table_name):
"""
@ -118,13 +139,16 @@ class Database(object):
# you can also use the short-hand syntax:
table = db['population']
"""
with self.lock:
if table_name in self._tables:
return Table(self, self._tables[table_name])
self._acquire()
try:
if self.engine.has_table(table_name):
return self.load_table(table_name)
else:
return self.create_table(table_name)
finally:
self._release()
def __getitem__(self, table_name):
return self.get_table(table_name)

View File

@ -39,10 +39,10 @@ class Table(object):
dropping the table. If you want to re-create the table, make
sure to get a fresh instance from the :py:class:`Database <dataset.Database>`.
"""
self.database._acquire()
self._is_dropped = True
with self.database.lock:
self.database._tables.pop(self.table.name, None)
self.table.drop(self.database.executable)
self.table.drop(self.database.engine)
def _check_dropped(self):
if self._is_dropped:
@ -198,11 +198,14 @@ class Table(object):
table.create_column('created_at', sqlalchemy.DateTime)
"""
self._check_dropped()
with self.database.lock:
self.database._acquire()
try:
if name not in self.table.columns.keys():
col = Column(name, type)
col.create(self.table,
connection=self.database.engine)
finally:
self.database._release()
def create_index(self, columns, name=None):
"""
@ -212,18 +215,20 @@ class Table(object):
table.create_index(['name', 'country'])
"""
self._check_dropped()
with self.database.lock:
if not name:
sig = abs(hash('||'.join(columns)))
name = 'ix_%s_%s' % (self.table.name, sig)
if name in self.indexes:
return self.indexes[name]
try:
self.database._acquire()
columns = [self.table.c[c] for c in columns]
idx = Index(name, *columns)
idx.create(self.database.engine)
except:
idx = None
finally:
self.database._release()
self.indexes[name] = idx
return idx