Re-write locking to support transactions.
This commit is contained in:
parent
33908b2699
commit
db71b6d631
@ -39,6 +39,15 @@ class Database(object):
|
|||||||
return self.local.connection
|
return self.local.connection
|
||||||
return self.engine
|
return self.engine
|
||||||
|
|
||||||
|
def _acquire(self):
|
||||||
|
self.lock.acquire()
|
||||||
|
|
||||||
|
def _release(self):
|
||||||
|
if not hasattr(self.local, 'tx'):
|
||||||
|
self.lock.release()
|
||||||
|
else:
|
||||||
|
self.local.must_release = True
|
||||||
|
|
||||||
def begin(self):
|
def begin(self):
|
||||||
""" Enter a transaction explicitly. No data will be written
|
""" Enter a transaction explicitly. No data will be written
|
||||||
until the transaction has been committed. """
|
until the transaction has been committed. """
|
||||||
@ -52,12 +61,18 @@ class Database(object):
|
|||||||
since the transaction was begun permanent. """
|
since the transaction was begun permanent. """
|
||||||
self.local.tx.commit()
|
self.local.tx.commit()
|
||||||
del self.local.tx
|
del self.local.tx
|
||||||
|
if not hasattr(self.local, 'must_release'):
|
||||||
|
self.lock.release()
|
||||||
|
del self.local.must_release
|
||||||
|
|
||||||
def rollback(self):
|
def rollback(self):
|
||||||
""" Roll back the current transaction, discarding all statements
|
""" Roll back the current transaction, discarding all statements
|
||||||
executed since the transaction was begun. """
|
executed since the transaction was begun. """
|
||||||
self.local.tx.rollback()
|
self.local.tx.rollback()
|
||||||
del self.local.tx
|
del self.local.tx
|
||||||
|
if not hasattr(self.local, 'must_release'):
|
||||||
|
self.lock.release()
|
||||||
|
del self.local.must_release
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tables(self):
|
def tables(self):
|
||||||
@ -79,7 +94,8 @@ class Database(object):
|
|||||||
|
|
||||||
table = db.create_table('population')
|
table = db.create_table('population')
|
||||||
"""
|
"""
|
||||||
with self.lock:
|
self._acquire()
|
||||||
|
try:
|
||||||
log.debug("Creating table: %s on %r" % (table_name, self.engine))
|
log.debug("Creating table: %s on %r" % (table_name, self.engine))
|
||||||
table = SQLATable(table_name, self.metadata)
|
table = SQLATable(table_name, self.metadata)
|
||||||
col = Column('id', Integer, primary_key=True)
|
col = Column('id', Integer, primary_key=True)
|
||||||
@ -87,6 +103,8 @@ class Database(object):
|
|||||||
table.create(self.engine)
|
table.create(self.engine)
|
||||||
self._tables[table_name] = table
|
self._tables[table_name] = table
|
||||||
return Table(self, table)
|
return Table(self, table)
|
||||||
|
finally:
|
||||||
|
self._release()
|
||||||
|
|
||||||
def load_table(self, table_name):
|
def load_table(self, table_name):
|
||||||
"""
|
"""
|
||||||
@ -100,11 +118,14 @@ class Database(object):
|
|||||||
|
|
||||||
table = db.load_table('population')
|
table = db.load_table('population')
|
||||||
"""
|
"""
|
||||||
with self.lock:
|
self._acquire()
|
||||||
|
try:
|
||||||
log.debug("Loading table: %s on %r" % (table_name, self))
|
log.debug("Loading table: %s on %r" % (table_name, self))
|
||||||
table = SQLATable(table_name, self.metadata, autoload=True)
|
table = SQLATable(table_name, self.metadata, autoload=True)
|
||||||
self._tables[table_name] = table
|
self._tables[table_name] = table
|
||||||
return Table(self, table)
|
return Table(self, table)
|
||||||
|
finally:
|
||||||
|
self._release()
|
||||||
|
|
||||||
def get_table(self, table_name):
|
def get_table(self, table_name):
|
||||||
"""
|
"""
|
||||||
@ -118,13 +139,16 @@ class Database(object):
|
|||||||
# you can also use the short-hand syntax:
|
# you can also use the short-hand syntax:
|
||||||
table = db['population']
|
table = db['population']
|
||||||
"""
|
"""
|
||||||
with self.lock:
|
|
||||||
if table_name in self._tables:
|
if table_name in self._tables:
|
||||||
return Table(self, self._tables[table_name])
|
return Table(self, self._tables[table_name])
|
||||||
|
self._acquire()
|
||||||
|
try:
|
||||||
if self.engine.has_table(table_name):
|
if self.engine.has_table(table_name):
|
||||||
return self.load_table(table_name)
|
return self.load_table(table_name)
|
||||||
else:
|
else:
|
||||||
return self.create_table(table_name)
|
return self.create_table(table_name)
|
||||||
|
finally:
|
||||||
|
self._release()
|
||||||
|
|
||||||
def __getitem__(self, table_name):
|
def __getitem__(self, table_name):
|
||||||
return self.get_table(table_name)
|
return self.get_table(table_name)
|
||||||
|
|||||||
@ -39,10 +39,10 @@ class Table(object):
|
|||||||
dropping the table. If you want to re-create the table, make
|
dropping the table. If you want to re-create the table, make
|
||||||
sure to get a fresh instance from the :py:class:`Database <dataset.Database>`.
|
sure to get a fresh instance from the :py:class:`Database <dataset.Database>`.
|
||||||
"""
|
"""
|
||||||
|
self.database._acquire()
|
||||||
self._is_dropped = True
|
self._is_dropped = True
|
||||||
with self.database.lock:
|
|
||||||
self.database._tables.pop(self.table.name, None)
|
self.database._tables.pop(self.table.name, None)
|
||||||
self.table.drop(self.database.executable)
|
self.table.drop(self.database.engine)
|
||||||
|
|
||||||
def _check_dropped(self):
|
def _check_dropped(self):
|
||||||
if self._is_dropped:
|
if self._is_dropped:
|
||||||
@ -198,11 +198,14 @@ class Table(object):
|
|||||||
table.create_column('created_at', sqlalchemy.DateTime)
|
table.create_column('created_at', sqlalchemy.DateTime)
|
||||||
"""
|
"""
|
||||||
self._check_dropped()
|
self._check_dropped()
|
||||||
with self.database.lock:
|
self.database._acquire()
|
||||||
|
try:
|
||||||
if name not in self.table.columns.keys():
|
if name not in self.table.columns.keys():
|
||||||
col = Column(name, type)
|
col = Column(name, type)
|
||||||
col.create(self.table,
|
col.create(self.table,
|
||||||
connection=self.database.engine)
|
connection=self.database.engine)
|
||||||
|
finally:
|
||||||
|
self.database._release()
|
||||||
|
|
||||||
def create_index(self, columns, name=None):
|
def create_index(self, columns, name=None):
|
||||||
"""
|
"""
|
||||||
@ -212,18 +215,20 @@ class Table(object):
|
|||||||
table.create_index(['name', 'country'])
|
table.create_index(['name', 'country'])
|
||||||
"""
|
"""
|
||||||
self._check_dropped()
|
self._check_dropped()
|
||||||
with self.database.lock:
|
|
||||||
if not name:
|
if not name:
|
||||||
sig = abs(hash('||'.join(columns)))
|
sig = abs(hash('||'.join(columns)))
|
||||||
name = 'ix_%s_%s' % (self.table.name, sig)
|
name = 'ix_%s_%s' % (self.table.name, sig)
|
||||||
if name in self.indexes:
|
if name in self.indexes:
|
||||||
return self.indexes[name]
|
return self.indexes[name]
|
||||||
try:
|
try:
|
||||||
|
self.database._acquire()
|
||||||
columns = [self.table.c[c] for c in columns]
|
columns = [self.table.c[c] for c in columns]
|
||||||
idx = Index(name, *columns)
|
idx = Index(name, *columns)
|
||||||
idx.create(self.database.engine)
|
idx.create(self.database.engine)
|
||||||
except:
|
except:
|
||||||
idx = None
|
idx = None
|
||||||
|
finally:
|
||||||
|
self.database._release()
|
||||||
self.indexes[name] = idx
|
self.indexes[name] = idx
|
||||||
return idx
|
return idx
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user