Re-write locking to support transactions.
This commit is contained in:
parent
33908b2699
commit
db71b6d631
@ -39,6 +39,15 @@ class Database(object):
|
||||
return self.local.connection
|
||||
return self.engine
|
||||
|
||||
def _acquire(self):
|
||||
self.lock.acquire()
|
||||
|
||||
def _release(self):
|
||||
if not hasattr(self.local, 'tx'):
|
||||
self.lock.release()
|
||||
else:
|
||||
self.local.must_release = True
|
||||
|
||||
def begin(self):
|
||||
""" Enter a transaction explicitly. No data will be written
|
||||
until the transaction has been committed. """
|
||||
@ -52,12 +61,18 @@ class Database(object):
|
||||
since the transaction was begun permanent. """
|
||||
self.local.tx.commit()
|
||||
del self.local.tx
|
||||
if not hasattr(self.local, 'must_release'):
|
||||
self.lock.release()
|
||||
del self.local.must_release
|
||||
|
||||
def rollback(self):
|
||||
""" Roll back the current transaction, discarding all statements
|
||||
executed since the transaction was begun. """
|
||||
self.local.tx.rollback()
|
||||
del self.local.tx
|
||||
if not hasattr(self.local, 'must_release'):
|
||||
self.lock.release()
|
||||
del self.local.must_release
|
||||
|
||||
@property
|
||||
def tables(self):
|
||||
@ -79,7 +94,8 @@ class Database(object):
|
||||
|
||||
table = db.create_table('population')
|
||||
"""
|
||||
with self.lock:
|
||||
self._acquire()
|
||||
try:
|
||||
log.debug("Creating table: %s on %r" % (table_name, self.engine))
|
||||
table = SQLATable(table_name, self.metadata)
|
||||
col = Column('id', Integer, primary_key=True)
|
||||
@ -87,6 +103,8 @@ class Database(object):
|
||||
table.create(self.engine)
|
||||
self._tables[table_name] = table
|
||||
return Table(self, table)
|
||||
finally:
|
||||
self._release()
|
||||
|
||||
def load_table(self, table_name):
|
||||
"""
|
||||
@ -100,11 +118,14 @@ class Database(object):
|
||||
|
||||
table = db.load_table('population')
|
||||
"""
|
||||
with self.lock:
|
||||
self._acquire()
|
||||
try:
|
||||
log.debug("Loading table: %s on %r" % (table_name, self))
|
||||
table = SQLATable(table_name, self.metadata, autoload=True)
|
||||
self._tables[table_name] = table
|
||||
return Table(self, table)
|
||||
finally:
|
||||
self._release()
|
||||
|
||||
def get_table(self, table_name):
|
||||
"""
|
||||
@ -118,13 +139,16 @@ class Database(object):
|
||||
# you can also use the short-hand syntax:
|
||||
table = db['population']
|
||||
"""
|
||||
with self.lock:
|
||||
if table_name in self._tables:
|
||||
return Table(self, self._tables[table_name])
|
||||
self._acquire()
|
||||
try:
|
||||
if self.engine.has_table(table_name):
|
||||
return self.load_table(table_name)
|
||||
else:
|
||||
return self.create_table(table_name)
|
||||
finally:
|
||||
self._release()
|
||||
|
||||
def __getitem__(self, table_name):
|
||||
return self.get_table(table_name)
|
||||
|
||||
@ -39,10 +39,10 @@ class Table(object):
|
||||
dropping the table. If you want to re-create the table, make
|
||||
sure to get a fresh instance from the :py:class:`Database <dataset.Database>`.
|
||||
"""
|
||||
self.database._acquire()
|
||||
self._is_dropped = True
|
||||
with self.database.lock:
|
||||
self.database._tables.pop(self.table.name, None)
|
||||
self.table.drop(self.database.executable)
|
||||
self.table.drop(self.database.engine)
|
||||
|
||||
def _check_dropped(self):
|
||||
if self._is_dropped:
|
||||
@ -198,11 +198,14 @@ class Table(object):
|
||||
table.create_column('created_at', sqlalchemy.DateTime)
|
||||
"""
|
||||
self._check_dropped()
|
||||
with self.database.lock:
|
||||
self.database._acquire()
|
||||
try:
|
||||
if name not in self.table.columns.keys():
|
||||
col = Column(name, type)
|
||||
col.create(self.table,
|
||||
connection=self.database.engine)
|
||||
finally:
|
||||
self.database._release()
|
||||
|
||||
def create_index(self, columns, name=None):
|
||||
"""
|
||||
@ -212,18 +215,20 @@ class Table(object):
|
||||
table.create_index(['name', 'country'])
|
||||
"""
|
||||
self._check_dropped()
|
||||
with self.database.lock:
|
||||
if not name:
|
||||
sig = abs(hash('||'.join(columns)))
|
||||
name = 'ix_%s_%s' % (self.table.name, sig)
|
||||
if name in self.indexes:
|
||||
return self.indexes[name]
|
||||
try:
|
||||
self.database._acquire()
|
||||
columns = [self.table.c[c] for c in columns]
|
||||
idx = Index(name, *columns)
|
||||
idx.create(self.database.engine)
|
||||
except:
|
||||
idx = None
|
||||
finally:
|
||||
self.database._release()
|
||||
self.indexes[name] = idx
|
||||
return idx
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user