Use introspection for table indexes.

This commit is contained in:
Friedrich Lindenberg 2017-09-02 23:05:50 +02:00
parent 37d7f47d39
commit 213a7ce857
2 changed files with 52 additions and 35 deletions

View File

@ -1,12 +1,12 @@
import logging import logging
from hashlib import sha1
from sqlalchemy.sql import and_, expression from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.schema import Column, Index from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false from sqlalchemy import func, select, false
from sqlalchemy.engine.reflection import Inspector
from dataset.persistence.util import normalize_column_name from dataset.persistence.util import normalize_column_name, index_name
from dataset.persistence.util import ResultIter from dataset.persistence.util import ResultIter
from dataset.util import DatasetException from dataset.util import DatasetException
@ -20,16 +20,23 @@ class Table(object):
def __init__(self, database, table): def __init__(self, database, table):
"""Initialise the table from database schema.""" """Initialise the table from database schema."""
self.indexes = dict((i.name, i) for i in table.indexes)
self.database = database self.database = database
self.name = table.name self.name = table.name
self.table = table self.table = table
self._is_dropped = False self._is_dropped = False
self._indexes = []
@property
def exists(self):
"""Check to see if the table currently exists in the database."""
if self.table is not None:
return True
return self.name in self.database
@property @property
def columns(self): def columns(self):
"""Get a listing of all columns that exist in the table.""" """Get a listing of all columns that exist in the table."""
if self.table is None: if not self.exists:
return [] return []
return self.table.columns.keys() return self.table.columns.keys()
@ -53,13 +60,6 @@ class Table(object):
if self.table is None: if self.table is None:
raise DatasetException('The table has been dropped.') raise DatasetException('The table has been dropped.')
def _prune_row(self, row):
"""Remove keys from row not in column set."""
# normalize keys
row = {normalize_column_name(k): v for k, v in row.items()}
# filter out keys not in column set
return {k: row[k] for k in row if k in self.columns}
def insert(self, row, ensure=None, types=None): def insert(self, row, ensure=None, types=None):
""" """
Add a row (type: dict) by inserting it into the table. Add a row (type: dict) by inserting it into the table.
@ -250,6 +250,8 @@ class Table(object):
If no arguments are given, all records are deleted. If no arguments are given, all records are deleted.
""" """
self._check_dropped() self._check_dropped()
if not self.exists:
return
if _filter or _clauses: if _filter or _clauses:
q = self._args_to_clause(_filter, clauses=_clauses) q = self._args_to_clause(_filter, clauses=_clauses)
stmt = self.table.delete(q) stmt = self.table.delete(q)
@ -275,6 +277,13 @@ class Table(object):
self.table.name)) self.table.name))
self.create_column(column, _type) self.create_column(column, _type)
def _prune_row(self, row):
"""Remove keys from row not in column set."""
# normalize keys
row = {normalize_column_name(k): v for k, v in row.items()}
# filter out keys not in column set
return {k: row[k] for k in row if k in self.columns}
def _args_to_clause(self, args, ensure=None, clauses=()): def _args_to_clause(self, args, ensure=None, clauses=()):
ensure = self.database.ensure_schema if ensure is None else ensure ensure = self.database.ensure_schema if ensure is None else ensure
if ensure: if ensure:
@ -325,15 +334,15 @@ class Table(object):
self.create_column(name, type_) self.create_column(name, type_)
def drop_column(self, name): def drop_column(self, name):
""" """Drop the column ``name``.
Drop the column ``name``.
:: ::
table.drop_column('created_at') table.drop_column('created_at')
""" """
if self.database.engine.dialect.name == 'sqlite': if self.database.engine.dialect.name == 'sqlite':
raise NotImplementedError("SQLite does not support dropping columns.") raise NotImplementedError("SQLite does not support dropping columns.")
if not self.exists:
return
self._check_dropped() self._check_dropped()
if name not in self.columns: if name not in self.columns:
log.debug("Column does not exist: %s", name) log.debug("Column does not exist: %s", name)
@ -344,7 +353,20 @@ class Table(object):
name, name,
self.table.schema self.table.schema
) )
self.table = self.database.update_table(self.table.name) self.table = self.database._reflect_table(self.table.name)
def has_index(self, columns):
"""Check if an index exists to cover the given `columns`."""
columns = set([normalize_column_name(c) for c in columns])
if columns in self._indexes:
return True
inspector = Inspector.from_engine(self.database.executable)
indexes = inspector.get_indexes(self.name, schema=self.database.schema)
for index in indexes:
if columns == set(index.get('column_names', [])):
self._indexes.append(columns)
return True
return False
def create_index(self, columns, name=None, **kw): def create_index(self, columns, name=None, **kw):
""" """
@ -356,30 +378,13 @@ class Table(object):
table.create_index(['name', 'country']) table.create_index(['name', 'country'])
""" """
self._check_dropped() self._check_dropped()
columns = [normalize_column_name(c) for c in columns]
with self.database.lock: with self.database.lock:
if not name: if not self.has_index(columns):
sig = '||'.join(columns) name = name or index_name(self.name, columns)
# This is a work-around for a bug in <=0.6.1 which would
# create indexes based on hash() rather than a proper hash.
key = abs(hash(sig))
name = 'ix_%s_%s' % (self.table.name, key)
if name in self.indexes:
return self.indexes[name]
key = sha1(sig.encode('utf-8')).hexdigest()[:16]
name = 'ix_%s_%s' % (self.table.name, key)
if name in self.indexes:
return self.indexes[name]
try:
columns = [self.table.c[c] for c in columns] columns = [self.table.c[c] for c in columns]
idx = Index(name, *columns, **kw) idx = Index(name, *columns, **kw)
idx.create(self.database.executable) idx.create(self.database.executable)
except:
idx = None
self.indexes[name] = idx
return idx
def _args_to_order_by(self, order_by): def _args_to_order_by(self, order_by):
if not isinstance(order_by, (list, tuple)): if not isinstance(order_by, (list, tuple)):
@ -451,6 +456,8 @@ class Table(object):
row = table.find_one(country='United States') row = table.find_one(country='United States')
""" """
if not self.exists:
return None
kwargs['_limit'] = 1 kwargs['_limit'] = 1
kwargs['_step'] = None kwargs['_step'] = None
resiter = self.find(*args, **kwargs) resiter = self.find(*args, **kwargs)
@ -465,6 +472,8 @@ class Table(object):
# NOTE: this does not have support for limit and offset since I can't # NOTE: this does not have support for limit and offset since I can't
# see how this is useful. Still, there might be compatibility issues # see how this is useful. Still, there might be compatibility issues
# with people using these flags. Let's see how it goes. # with people using these flags. Let's see how it goes.
if not self.exists:
return 0
self._check_dropped() self._check_dropped()
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses) args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
query = select([func.count()], whereclause=args) query = select([func.count()], whereclause=args)

View File

@ -9,6 +9,7 @@ except ImportError: # pragma: no cover
from ordereddict import OrderedDict from ordereddict import OrderedDict
from six import string_types from six import string_types
from hashlib import sha1
row_type = OrderedDict row_type = OrderedDict
@ -70,3 +71,10 @@ def safe_url(url):
pwd = ':%s@' % parsed.password pwd = ':%s@' % parsed.password
url = url.replace(pwd, ':*****@') url = url.replace(pwd, ':*****@')
return url return url
def index_name(table, columns):
"""Generate an artificial index name."""
sig = '||'.join(columns)
key = sha1(sig.encode('utf-8')).hexdigest()[:16]
return 'ix_%s_%s' % (table, key)