Make table instances singleton-ish.

This commit is contained in:
Friedrich Lindenberg 2017-09-02 22:35:29 +02:00
parent bf17deeb7f
commit 37d7f47d39
3 changed files with 95 additions and 107 deletions

View File

@ -55,7 +55,7 @@ class Database(object):
self.url = url self.url = url
self.row_type = row_type self.row_type = row_type
self.ensure_schema = ensure_schema self.ensure_schema = ensure_schema
self.reflect_views = reflect_views self._tables = {}
@property @property
def executable(self): def executable(self):
@ -75,15 +75,8 @@ class Database(object):
"""Return a SQLAlchemy schema cache object.""" """Return a SQLAlchemy schema cache object."""
return MetaData(schema=self.schema, bind=self.executable) return MetaData(schema=self.schema, bind=self.executable)
def _acquire(self):
self.lock.acquire()
def _release(self):
self.lock.release()
def begin(self): def begin(self):
""" """Enter a transaction explicitly.
Enter a transaction explicitly.
No data will be written until the transaction has been committed. No data will be written until the transaction has been committed.
""" """
@ -92,8 +85,7 @@ class Database(object):
self.local.tx.append(self.executable.begin()) self.local.tx.append(self.executable.begin())
def commit(self): def commit(self):
""" """Commit the current transaction.
Commit the current transaction.
Make all statements executed since the transaction was begun permanent. Make all statements executed since the transaction was begun permanent.
""" """
@ -133,9 +125,12 @@ class Database(object):
inspector = Inspector.from_engine(self.executable) inspector = Inspector.from_engine(self.executable)
return inspector.get_table_names(schema=self.schema) return inspector.get_table_names(schema=self.schema)
def __contains__(self, member): def __contains__(self, table_name):
"""Check if the given table name exists in the database.""" """Check if the given table name exists in the database."""
return member in self.tables try:
return self._valid_table_name(table_name) in self.tables
except ValueError:
return False
def _valid_table_name(self, table_name): def _valid_table_name(self, table_name):
"""Check if the table name is obviously invalid.""" """Check if the table name is obviously invalid."""
@ -146,15 +141,12 @@ class Database(object):
def create_table(self, table_name, primary_id=None, primary_type=None): def create_table(self, table_name, primary_id=None, primary_type=None):
"""Create a new table. """Create a new table.
The new table will automatically have an `id` column unless specified via Either loads a table or creates it if it doesn't exist yet. You can
optional parameter primary_id, which will be used as the primary key of the define the name and type of the primary key field, if a new table is to
table. Automatic id is set to be an auto-incrementing integer, while the be created. The default is to create an auto-incrementing integer,
type of custom primary_id can be a `id`. You can also set the primary key to be a string or big integer.
String or an Integer as specified with primary_type flag. The default The caller will be responsible for the uniqueness of `primary_id` if
length of String is 255. The caller can specify the length. it is defined as a text type.
The caller will be responsible for the uniqueness of manual primary_id.
This custom id feature is only available via direct create_table call.
Returns a :py:class:`Table <dataset.Table>` instance. Returns a :py:class:`Table <dataset.Table>` instance.
:: ::
@ -177,11 +169,9 @@ class Database(object):
assert not isinstance(primary_type, six.string_types), \ assert not isinstance(primary_type, six.string_types), \
'Text-based primary_type support is dropped, use db.types.' 'Text-based primary_type support is dropped, use db.types.'
table_name = self._valid_table_name(table_name) table_name = self._valid_table_name(table_name)
table = self._reflect_table(table_name) with self.lock:
if table is not None: if table_name in self:
return Table(self, table) return self.load_table(table_name)
self._acquire()
try:
log.debug("Creating table: %s" % (table_name)) log.debug("Creating table: %s" % (table_name))
table = SQLATable(table_name, self.metadata, schema=self.schema) table = SQLATable(table_name, self.metadata, schema=self.schema)
if primary_id is not False: if primary_id is not False:
@ -194,9 +184,8 @@ class Database(object):
autoincrement=autoincrement) autoincrement=autoincrement)
table.append_column(col) table.append_column(col)
table.create(self.executable, checkfirst=True) table.create(self.executable, checkfirst=True)
return Table(self, table) self._tables[table_name] = Table(self, table)
finally: return self._tables[table_name]
self._release()
def load_table(self, table_name): def load_table(self, table_name):
"""Load a table. """Load a table.
@ -211,38 +200,37 @@ class Database(object):
table = db.load_table('population') table = db.load_table('population')
""" """
table_name = self._valid_table_name(table_name) table_name = self._valid_table_name(table_name)
if table_name in self._tables:
return self._tables.get(table_name)
log.debug("Loading table: %s", table_name) log.debug("Loading table: %s", table_name)
table = self._reflect_table(table_name) with self.lock:
if table is not None: table = self._reflect_table(table_name)
return Table(self, table) if table is not None:
self._tables[table_name] = Table(self, table)
return self._tables[table_name]
def _reflect_table(self, table_name): def _reflect_table(self, table_name):
"""Reload a table schema from the database.""" """Reload a table schema from the database."""
table_name = self._valid_table_name(table_name) table_name = self._valid_table_name(table_name)
try: try:
return SQLATable(table_name, table = SQLATable(table_name,
self.metadata, self.metadata,
schema=self.schema, schema=self.schema,
autoload=True, autoload=True,
autoload_with=self.executable) autoload_with=self.executable)
return table
except NoSuchTableError: except NoSuchTableError:
return None return None
def get_table(self, table_name, primary_id=None, primary_type=None): def get_table(self, table_name, primary_id=None, primary_type=None):
""" """Load or create a table.
Smart wrapper around *load_table* and *create_table*.
Either loads a table or creates it if it doesn't exist yet. This is now the same as `create_table`.
For short-hand to create a table with custom id and type using [], where
table_name, primary_id, and primary_type are specified as a tuple
Returns a :py:class:`Table <dataset.Table>` instance.
:: ::
table = db.get_table('population') table = db.get_table('population')
# you can also use the short-hand syntax: # you can also use the short-hand syntax:
table = db['population'] table = db['population']
""" """
return self.create_table(table_name, primary_id, primary_type) return self.create_table(table_name, primary_id, primary_type)
@ -250,7 +238,7 @@ class Database(object):
"""Get a given table.""" """Get a given table."""
return self.get_table(table_name) return self.get_table(table_name)
def query(self, query, **kw): def query(self, query, **kwargs):
"""Run a statement on the database directly. """Run a statement on the database directly.
Allows for the execution of arbitrary read/write queries. A query can either be Allows for the execution of arbitrary read/write queries. A query can either be
@ -269,8 +257,9 @@ class Database(object):
""" """
if isinstance(query, six.string_types): if isinstance(query, six.string_types):
query = text(query) query = text(query)
return ResultIter(self.executable.execute(query, **kw), _step = kwargs.pop('_step', 5000)
row_type=self.row_type) return ResultIter(self.executable.execute(query, **kwargs),
row_type=self.row_type, step=_step)
def __repr__(self): def __repr__(self):
"""Text representation contains the URL.""" """Text representation contains the URL."""

View File

@ -22,13 +22,16 @@ class Table(object):
"""Initialise the table from database schema.""" """Initialise the table from database schema."""
self.indexes = dict((i.name, i) for i in table.indexes) self.indexes = dict((i.name, i) for i in table.indexes)
self.database = database self.database = database
self.name = table.name
self.table = table self.table = table
self._is_dropped = False self._is_dropped = False
@property @property
def columns(self): def columns(self):
"""Get a listing of all columns that exist in the table.""" """Get a listing of all columns that exist in the table."""
return list(self.table.columns.keys()) if self.table is None:
return []
return self.table.columns.keys()
def drop(self): def drop(self):
""" """
@ -40,16 +43,15 @@ class Table(object):
sure to get a fresh instance from the :py:class:`Database <dataset.Database>`. sure to get a fresh instance from the :py:class:`Database <dataset.Database>`.
""" """
self._check_dropped() self._check_dropped()
self.database._acquire() with self.database.lock:
try:
self.table.drop(self.database.executable, checkfirst=True) self.table.drop(self.database.executable, checkfirst=True)
# self.database._tables.pop(self.name, None)
self.table = None self.table = None
finally:
self.database._release()
def _check_dropped(self): def _check_dropped(self):
# self.table = self.database._reflect_table(self.name)
if self.table is None: if self.table is None:
raise DatasetException('the table has been dropped. this object should not be used again.') raise DatasetException('The table has been dropped.')
def _prune_row(self, row): def _prune_row(self, row):
"""Remove keys from row not in column set.""" """Remove keys from row not in column set."""
@ -165,10 +167,10 @@ class Table(object):
they will be created based on the settings of ``ensure`` and they will be created based on the settings of ``ensure`` and
``types``, matching the behavior of :py:meth:`insert() <dataset.Table.insert>`. ``types``, matching the behavior of :py:meth:`insert() <dataset.Table.insert>`.
""" """
self._check_dropped()
# check whether keys arg is a string and format as a list # check whether keys arg is a string and format as a list
if not isinstance(keys, (list, tuple)): if not isinstance(keys, (list, tuple)):
keys = [keys] keys = [keys]
self._check_dropped()
if not keys or len(keys) == len(row): if not keys or len(keys) == len(row):
return False return False
clause = [(u, row.get(u)) for u in keys] clause = [(u, row.get(u)) for u in keys]
@ -297,8 +299,7 @@ class Table(object):
table.create_column('created_at', db.types.datetime) table.create_column('created_at', db.types.datetime)
""" """
self._check_dropped() self._check_dropped()
self.database._acquire() with self.database.lock:
try:
name = normalize_column_name(name) name = normalize_column_name(name)
if name in self.columns: if name in self.columns:
log.debug("Column exists: %s" % name) log.debug("Column exists: %s" % name)
@ -310,8 +311,6 @@ class Table(object):
self.table.schema self.table.schema
) )
self.table = self.database._reflect_table(self.table.name) self.table = self.database._reflect_table(self.table.name)
finally:
self.database._release()
def create_column_by_example(self, name, value): def create_column_by_example(self, name, value):
""" """
@ -336,19 +335,16 @@ class Table(object):
if self.database.engine.dialect.name == 'sqlite': if self.database.engine.dialect.name == 'sqlite':
raise NotImplementedError("SQLite does not support dropping columns.") raise NotImplementedError("SQLite does not support dropping columns.")
self._check_dropped() self._check_dropped()
if name not in self.table.columns.keys(): if name not in self.columns:
log.debug("Column does not exist: %s", name) log.debug("Column does not exist: %s", name)
return return
self.database._acquire() with self.database.lock:
try:
self.database.op.drop_column( self.database.op.drop_column(
self.table.name, self.table.name,
name, name,
self.table.schema self.table.schema
) )
self.table = self.database.update_table(self.table.name) self.table = self.database.update_table(self.table.name)
finally:
self.database._release()
def create_index(self, columns, name=None, **kw): def create_index(self, columns, name=None, **kw):
""" """
@ -360,51 +356,30 @@ class Table(object):
table.create_index(['name', 'country']) table.create_index(['name', 'country'])
""" """
self._check_dropped() self._check_dropped()
if not name: with self.database.lock:
sig = '||'.join(columns) if not name:
sig = '||'.join(columns)
# This is a work-around for a bug in <=0.6.1 which would
# create indexes based on hash() rather than a proper hash.
key = abs(hash(sig))
name = 'ix_%s_%s' % (self.table.name, key)
if name in self.indexes:
return self.indexes[name]
key = sha1(sig.encode('utf-8')).hexdigest()[:16]
name = 'ix_%s_%s' % (self.table.name, key)
# This is a work-around for a bug in <=0.6.1 which would create
# indexes based on hash() rather than a proper hash.
key = abs(hash(sig))
name = 'ix_%s_%s' % (self.table.name, key)
if name in self.indexes: if name in self.indexes:
return self.indexes[name] return self.indexes[name]
try:
key = sha1(sig.encode('utf-8')).hexdigest()[:16] columns = [self.table.c[c] for c in columns]
name = 'ix_%s_%s' % (self.table.name, key) idx = Index(name, *columns, **kw)
idx.create(self.database.executable)
if name in self.indexes: except:
return self.indexes[name] idx = None
try: self.indexes[name] = idx
self.database._acquire() return idx
columns = [self.table.c[c] for c in columns]
idx = Index(name, *columns, **kw)
idx.create(self.database.executable)
except:
idx = None
finally:
self.database._release()
self.indexes[name] = idx
return idx
def find_one(self, *args, **kwargs):
"""
Get a single result from the table.
Works just like :py:meth:`find() <dataset.Table.find>` but returns one result, or None.
::
row = table.find_one(country='United States')
"""
kwargs['_limit'] = 1
kwargs['_step'] = None
resiter = self.find(*args, **kwargs)
try:
for row in resiter:
return row
finally:
resiter.close()
def _args_to_order_by(self, order_by): def _args_to_order_by(self, order_by):
if not isinstance(order_by, (list, tuple)): if not isinstance(order_by, (list, tuple)):
@ -467,6 +442,24 @@ class Table(object):
return ResultIter(self.database.executable.execute(query), return ResultIter(self.database.executable.execute(query),
row_type=self.database.row_type, step=_step) row_type=self.database.row_type, step=_step)
def find_one(self, *args, **kwargs):
"""Get a single result from the table.
Works just like :py:meth:`find() <dataset.Table.find>` but returns one
result, or None.
::
row = table.find_one(country='United States')
"""
kwargs['_limit'] = 1
kwargs['_step'] = None
resiter = self.find(*args, **kwargs)
try:
for row in resiter:
return row
finally:
resiter.close()
def count(self, *_clauses, **kwargs): def count(self, *_clauses, **kwargs):
"""Return the count of results for the given filter set.""" """Return the count of results for the given filter set."""
# NOTE: this does not have support for limit and offset since I can't # NOTE: this does not have support for limit and offset since I can't

View File

@ -346,6 +346,12 @@ class TableTestCase(unittest.TestCase):
self.db['weather'].drop() self.db['weather'].drop()
assert 'weather' not in self.db assert 'weather' not in self.db
def test_table_drop_then_create(self):
assert 'weather' in self.db
self.db['weather'].drop()
assert 'weather' not in self.db
self.db['weather'].insert({'foo': 'bar'})
def test_columns(self): def test_columns(self):
cols = self.tbl.columns cols = self.tbl.columns
assert len(list(cols)) == 4, 'column count mismatch' assert len(list(cols)) == 4, 'column count mismatch'