Make table instances singleton-ish.

This commit is contained in:
Friedrich Lindenberg 2017-09-02 22:35:29 +02:00
parent bf17deeb7f
commit 37d7f47d39
3 changed files with 95 additions and 107 deletions

View File

@ -55,7 +55,7 @@ class Database(object):
self.url = url
self.row_type = row_type
self.ensure_schema = ensure_schema
self.reflect_views = reflect_views
self._tables = {}
@property
def executable(self):
@ -75,15 +75,8 @@ class Database(object):
"""Return a SQLAlchemy schema cache object."""
return MetaData(schema=self.schema, bind=self.executable)
def _acquire(self):
self.lock.acquire()
def _release(self):
self.lock.release()
def begin(self):
"""
Enter a transaction explicitly.
"""Enter a transaction explicitly.
No data will be written until the transaction has been committed.
"""
@ -92,8 +85,7 @@ class Database(object):
self.local.tx.append(self.executable.begin())
def commit(self):
"""
Commit the current transaction.
"""Commit the current transaction.
Make all statements executed since the transaction was begun permanent.
"""
@ -133,9 +125,12 @@ class Database(object):
inspector = Inspector.from_engine(self.executable)
return inspector.get_table_names(schema=self.schema)
def __contains__(self, member):
def __contains__(self, table_name):
"""Check if the given table name exists in the database."""
return member in self.tables
try:
return self._valid_table_name(table_name) in self.tables
except ValueError:
return False
def _valid_table_name(self, table_name):
"""Check if the table name is obviously invalid."""
@ -146,15 +141,12 @@ class Database(object):
def create_table(self, table_name, primary_id=None, primary_type=None):
"""Create a new table.
The new table will automatically have an `id` column unless specified via
optional parameter primary_id, which will be used as the primary key of the
table. Automatic id is set to be an auto-incrementing integer, while the
type of custom primary_id can be a
String or an Integer as specified with primary_type flag. The default
length of String is 255. The caller can specify the length.
The caller will be responsible for the uniqueness of manual primary_id.
This custom id feature is only available via direct create_table call.
Either loads a table or creates it if it doesn't exist yet. You can
define the name and type of the primary key field, if a new table is to
be created. The default is to create an auto-incrementing integer,
`id`. You can also set the primary key to be a string or big integer.
The caller will be responsible for the uniqueness of `primary_id` if
it is defined as a text type.
Returns a :py:class:`Table <dataset.Table>` instance.
::
@ -177,11 +169,9 @@ class Database(object):
assert not isinstance(primary_type, six.string_types), \
'Text-based primary_type support is dropped, use db.types.'
table_name = self._valid_table_name(table_name)
table = self._reflect_table(table_name)
if table is not None:
return Table(self, table)
self._acquire()
try:
with self.lock:
if table_name in self:
return self.load_table(table_name)
log.debug("Creating table: %s" % (table_name))
table = SQLATable(table_name, self.metadata, schema=self.schema)
if primary_id is not False:
@ -194,9 +184,8 @@ class Database(object):
autoincrement=autoincrement)
table.append_column(col)
table.create(self.executable, checkfirst=True)
return Table(self, table)
finally:
self._release()
self._tables[table_name] = Table(self, table)
return self._tables[table_name]
def load_table(self, table_name):
"""Load a table.
@ -211,38 +200,37 @@ class Database(object):
table = db.load_table('population')
"""
table_name = self._valid_table_name(table_name)
if table_name in self._tables:
return self._tables.get(table_name)
log.debug("Loading table: %s", table_name)
with self.lock:
table = self._reflect_table(table_name)
if table is not None:
return Table(self, table)
self._tables[table_name] = Table(self, table)
return self._tables[table_name]
def _reflect_table(self, table_name):
"""Reload a table schema from the database."""
table_name = self._valid_table_name(table_name)
try:
return SQLATable(table_name,
table = SQLATable(table_name,
self.metadata,
schema=self.schema,
autoload=True,
autoload_with=self.executable)
return table
except NoSuchTableError:
return None
def get_table(self, table_name, primary_id=None, primary_type=None):
"""
Smart wrapper around *load_table* and *create_table*.
"""Load or create a table.
Either loads a table or creates it if it doesn't exist yet.
For short-hand to create a table with custom id and type using [], where
table_name, primary_id, and primary_type are specified as a tuple
Returns a :py:class:`Table <dataset.Table>` instance.
This is now the same as `create_table`.
::
table = db.get_table('population')
# you can also use the short-hand syntax:
table = db['population']
"""
return self.create_table(table_name, primary_id, primary_type)
@ -250,7 +238,7 @@ class Database(object):
"""Get a given table."""
return self.get_table(table_name)
def query(self, query, **kw):
def query(self, query, **kwargs):
"""Run a statement on the database directly.
Allows for the execution of arbitrary read/write queries. A query can either be
@ -269,8 +257,9 @@ class Database(object):
"""
if isinstance(query, six.string_types):
query = text(query)
return ResultIter(self.executable.execute(query, **kw),
row_type=self.row_type)
_step = kwargs.pop('_step', 5000)
return ResultIter(self.executable.execute(query, **kwargs),
row_type=self.row_type, step=_step)
def __repr__(self):
"""Text representation contains the URL."""

View File

@ -22,13 +22,16 @@ class Table(object):
"""Initialise the table from database schema."""
self.indexes = dict((i.name, i) for i in table.indexes)
self.database = database
self.name = table.name
self.table = table
self._is_dropped = False
@property
def columns(self):
"""Get a listing of all columns that exist in the table."""
return list(self.table.columns.keys())
if self.table is None:
return []
return self.table.columns.keys()
def drop(self):
"""
@ -40,16 +43,15 @@ class Table(object):
sure to get a fresh instance from the :py:class:`Database <dataset.Database>`.
"""
self._check_dropped()
self.database._acquire()
try:
with self.database.lock:
self.table.drop(self.database.executable, checkfirst=True)
# self.database._tables.pop(self.name, None)
self.table = None
finally:
self.database._release()
def _check_dropped(self):
# self.table = self.database._reflect_table(self.name)
if self.table is None:
raise DatasetException('the table has been dropped. this object should not be used again.')
raise DatasetException('The table has been dropped.')
def _prune_row(self, row):
"""Remove keys from row not in column set."""
@ -165,10 +167,10 @@ class Table(object):
they will be created based on the settings of ``ensure`` and
``types``, matching the behavior of :py:meth:`insert() <dataset.Table.insert>`.
"""
self._check_dropped()
# check whether keys arg is a string and format as a list
if not isinstance(keys, (list, tuple)):
keys = [keys]
self._check_dropped()
if not keys or len(keys) == len(row):
return False
clause = [(u, row.get(u)) for u in keys]
@ -297,8 +299,7 @@ class Table(object):
table.create_column('created_at', db.types.datetime)
"""
self._check_dropped()
self.database._acquire()
try:
with self.database.lock:
name = normalize_column_name(name)
if name in self.columns:
log.debug("Column exists: %s" % name)
@ -310,8 +311,6 @@ class Table(object):
self.table.schema
)
self.table = self.database._reflect_table(self.table.name)
finally:
self.database._release()
def create_column_by_example(self, name, value):
"""
@ -336,19 +335,16 @@ class Table(object):
if self.database.engine.dialect.name == 'sqlite':
raise NotImplementedError("SQLite does not support dropping columns.")
self._check_dropped()
if name not in self.table.columns.keys():
if name not in self.columns:
log.debug("Column does not exist: %s", name)
return
self.database._acquire()
try:
with self.database.lock:
self.database.op.drop_column(
self.table.name,
name,
self.table.schema
)
self.table = self.database.update_table(self.table.name)
finally:
self.database._release()
def create_index(self, columns, name=None, **kw):
"""
@ -360,11 +356,12 @@ class Table(object):
table.create_index(['name', 'country'])
"""
self._check_dropped()
with self.database.lock:
if not name:
sig = '||'.join(columns)
# This is a work-around for a bug in <=0.6.1 which would create
# indexes based on hash() rather than a proper hash.
# This is a work-around for a bug in <=0.6.1 which would
# create indexes based on hash() rather than a proper hash.
key = abs(hash(sig))
name = 'ix_%s_%s' % (self.table.name, key)
if name in self.indexes:
@ -376,36 +373,14 @@ class Table(object):
if name in self.indexes:
return self.indexes[name]
try:
self.database._acquire()
columns = [self.table.c[c] for c in columns]
idx = Index(name, *columns, **kw)
idx.create(self.database.executable)
except:
idx = None
finally:
self.database._release()
self.indexes[name] = idx
return idx
def find_one(self, *args, **kwargs):
"""
Get a single result from the table.
Works just like :py:meth:`find() <dataset.Table.find>` but returns one result, or None.
::
row = table.find_one(country='United States')
"""
kwargs['_limit'] = 1
kwargs['_step'] = None
resiter = self.find(*args, **kwargs)
try:
for row in resiter:
return row
finally:
resiter.close()
def _args_to_order_by(self, order_by):
if not isinstance(order_by, (list, tuple)):
order_by = [order_by]
@ -467,6 +442,24 @@ class Table(object):
return ResultIter(self.database.executable.execute(query),
row_type=self.database.row_type, step=_step)
def find_one(self, *args, **kwargs):
"""Get a single result from the table.
Works just like :py:meth:`find() <dataset.Table.find>` but returns one
result, or None.
::
row = table.find_one(country='United States')
"""
kwargs['_limit'] = 1
kwargs['_step'] = None
resiter = self.find(*args, **kwargs)
try:
for row in resiter:
return row
finally:
resiter.close()
def count(self, *_clauses, **kwargs):
"""Return the count of results for the given filter set."""
# NOTE: this does not have support for limit and offset since I can't

View File

@ -346,6 +346,12 @@ class TableTestCase(unittest.TestCase):
self.db['weather'].drop()
assert 'weather' not in self.db
def test_table_drop_then_create(self):
assert 'weather' in self.db
self.db['weather'].drop()
assert 'weather' not in self.db
self.db['weather'].insert({'foo': 'bar'})
def test_columns(self):
cols = self.tbl.columns
assert len(list(cols)) == 4, 'column count mismatch'