cleanup inspection code

This commit is contained in:
Friedrich Lindenberg 2017-09-04 22:26:51 +02:00
parent df0d79d75c
commit 7e614c0933
2 changed files with 38 additions and 37 deletions

View File

@ -69,6 +69,11 @@ class Database(object):
ctx = MigrationContext.configure(self.executable)
return Operations(ctx)
@property
def inspect(self):
"""Get a SQLAlchemy inspector."""
return Inspector.from_engine(self.executable)
@property
def metadata(self):
"""Return a SQLAlchemy schema cache object."""
@ -120,8 +125,7 @@ class Database(object):
@property
def tables(self):
"""Get a listing of all tables that exist in the database."""
inspector = Inspector.from_engine(self.executable)
return inspector.get_table_names(schema=self.schema)
return self.inspect.get_table_names(schema=self.schema)
def __contains__(self, table_name):
"""Check if the given table name exists in the database."""

View File

@ -4,7 +4,6 @@ from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.schema import Table as SQLATable
from sqlalchemy.exc import NoSuchTableError
@ -25,7 +24,7 @@ class Table(object):
def __init__(self, database, table_name, primary_id=None,
primary_type=None, auto_create=False):
"""Initialise the table from database schema."""
self.database = database
self.db = database
self.name = normalize_table_name(table_name)
self._table = None
self._indexes = []
@ -38,7 +37,7 @@ class Table(object):
"""Check to see if the table currently exists in the database."""
if self._table is not None:
return True
return self.name in self.database
return self.name in self.db
@property
def table(self):
@ -77,7 +76,7 @@ class Table(object):
Returns the inserted row's primary key.
"""
row = self._sync_columns(row, ensure, types=types)
res = self.database.executable.execute(self.table.insert(row))
res = self.db.executable.execute(self.table.insert(row))
if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0]
return True
@ -155,7 +154,7 @@ class Table(object):
if not len(row):
return self.count(clause)
stmt = self.table.update(whereclause=clause, values=row)
rp = self.database.executable.execute(stmt)
rp = self.db.executable.execute(stmt)
if rp.supports_sane_rowcount():
return rp.rowcount
if return_count:
@ -194,16 +193,16 @@ class Table(object):
return False
clause = self._args_to_clause(filters, clauses=clauses)
stmt = self.table.delete(whereclause=clause)
rp = self.database.executable.execute(stmt)
rp = self.db.executable.execute(stmt)
return rp.rowcount > 0
def _reflect_table(self):
"""Load the tables definition from the database."""
with self.database.lock:
with self.db.lock:
try:
self._table = SQLATable(self.name,
self.database.metadata,
schema=self.database.schema,
self.db.metadata,
schema=self.db.schema,
autoload=True)
except NoSuchTableError:
pass
@ -218,10 +217,10 @@ class Table(object):
if not self._auto_create:
raise DatasetException("Table does not exist: %s" % self.name)
# Keep the lock scope small because this is run very often.
with self.database.lock:
with self.db.lock:
self._table = SQLATable(self.name,
self.database.metadata,
schema=self.database.schema)
self.db.metadata,
schema=self.db.schema)
if self._primary_id is not False:
# This can go wrong on DBMS like MySQL and SQLite where
# tables cannot have no columns.
@ -234,12 +233,11 @@ class Table(object):
self._table.append_column(column)
for column in columns:
self._table.append_column(column)
self._table.create(self.database.executable, checkfirst=True)
self._table.create(self.db.executable, checkfirst=True)
elif len(columns):
with self.database.lock:
with self.db.lock:
for column in columns:
self.database.op.add_column(self.name, column,
self.database.schema)
self.db.op.add_column(self.name, column, self.db.schema)
self._reflect_table()
def _sync_columns(self, row, ensure, types=None):
@ -260,7 +258,7 @@ class Table(object):
if ensure and name not in columns:
_type = types.get(name)
if _type is None:
_type = self.database.types.guess(value)
_type = self.db.types.guess(value)
sync_columns.append(Column(name, _type))
columns.append(name)
if name in columns:
@ -270,7 +268,7 @@ class Table(object):
def _check_ensure(self, ensure):
if ensure is None:
return self.database.ensure_schema
return self.db.ensure_schema
return ensure
def _args_to_clause(self, args, clauses=()):
@ -313,17 +311,17 @@ class Table(object):
table.create_column('created_at', db.types.datetime)
"""
# TODO: create the table if it does not exist.
with self.database.lock:
with self.db.lock:
name = normalize_column_name(name)
if self.has_column(name):
log.debug("Column exists: %s" % name)
return
log.debug("Create column: %s on %s", name, self.name)
self.database.op.add_column(
self.db.op.add_column(
self.name,
Column(name, type),
self.database.schema
self.db.schema
)
self._reflect_table()
@ -339,7 +337,7 @@ class Table(object):
If a column of the same name already exists, no action is taken, even
if it is not of the type we would have created.
"""
type_ = self.database.types.guess(value)
type_ = self.db.types.guess(value)
self.create_column(name, type_)
def drop_column(self, name):
@ -347,15 +345,15 @@ class Table(object):
::
table.drop_column('created_at')
"""
if self.database.engine.dialect.name == 'sqlite':
if self.db.engine.dialect.name == 'sqlite':
raise RuntimeError("SQLite does not support dropping columns.")
name = normalize_column_name(name)
with self.database.lock:
with self.db.lock:
if not self.exists or not self.has_column(name):
log.debug("Column does not exist: %s", name)
return
self.database.op.drop_column(
self.db.op.drop_column(
self.table.name,
name,
self.table.schema
@ -367,9 +365,9 @@ class Table(object):
Deletes both the schema and all the contents within it.
"""
with self.database.lock:
with self.db.lock:
if self.exists:
self.table.drop(self.database.executable, checkfirst=True)
self.table.drop(self.db.executable, checkfirst=True)
self._table = None
def has_index(self, columns):
@ -382,8 +380,7 @@ class Table(object):
for column in columns:
if not self.has_column(column):
raise DatasetException("Column does not exist: %s" % column)
inspector = Inspector.from_engine(self.database.executable)
indexes = inspector.get_indexes(self.name, schema=self.database.schema)
indexes = self.db.inspect.get_indexes(self.name, schema=self.db.schema)
for index in indexes:
if columns == set(index.get('column_names', [])):
self._indexes.append(columns)
@ -399,7 +396,7 @@ class Table(object):
table.create_index(['name', 'country'])
"""
columns = [normalize_column_name(c) for c in ensure_tuple(columns)]
with self.database.lock:
with self.db.lock:
if not self.exists:
raise DatasetException("Table has not been created yet.")
@ -407,7 +404,7 @@ class Table(object):
name = name or index_name(self.name, columns)
columns = [self.table.c[c] for c in columns]
idx = Index(name, *columns, **kw)
idx.create(self.database.executable)
idx.create(self.db.executable)
def find(self, *_clauses, **kwargs):
"""Perform a simple search on the table.
@ -452,8 +449,8 @@ class Table(object):
offset=_offset)
if len(order_by):
query = query.order_by(*order_by)
return ResultIter(self.database.executable.execute(query),
row_type=self.database.row_type,
return ResultIter(self.db.executable.execute(query),
row_type=self.db.row_type,
step=_step)
def find_one(self, *args, **kwargs):
@ -488,7 +485,7 @@ class Table(object):
args = self._args_to_clause(kwargs, clauses=_clauses)
query = select([func.count()], whereclause=args)
query = query.select_from(self.table)
rp = self.database.executable.execute(query)
rp = self.db.executable.execute(query)
return rp.fetchone()[0]
def __len__(self):
@ -531,7 +528,7 @@ class Table(object):
distinct=True,
whereclause=and_(*filters),
order_by=[c.asc() for c in columns])
return self.database.query(q)
return self.db.query(q)
# Legacy methods for running find queries.
all = find