Use introspection for table indexes.

This commit is contained in:
Friedrich Lindenberg 2017-09-02 23:05:50 +02:00
parent 37d7f47d39
commit 213a7ce857
2 changed files with 52 additions and 35 deletions

View File

@ -1,12 +1,12 @@
import logging
from hashlib import sha1
from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false
from sqlalchemy.engine.reflection import Inspector
from dataset.persistence.util import normalize_column_name
from dataset.persistence.util import normalize_column_name, index_name
from dataset.persistence.util import ResultIter
from dataset.util import DatasetException
@ -20,16 +20,23 @@ class Table(object):
def __init__(self, database, table):
"""Initialise the table from database schema."""
self.indexes = dict((i.name, i) for i in table.indexes)
self.database = database
self.name = table.name
self.table = table
self._is_dropped = False
self._indexes = []
@property
def exists(self):
"""Check to see if the table currently exists in the database."""
if self.table is not None:
return True
return self.name in self.database
@property
def columns(self):
"""Get a listing of all columns that exist in the table."""
if self.table is None:
if not self.exists:
return []
return self.table.columns.keys()
@ -53,13 +60,6 @@ class Table(object):
if self.table is None:
raise DatasetException('The table has been dropped.')
def _prune_row(self, row):
"""Remove keys from row not in column set."""
# normalize keys
row = {normalize_column_name(k): v for k, v in row.items()}
# filter out keys not in column set
return {k: row[k] for k in row if k in self.columns}
def insert(self, row, ensure=None, types=None):
"""
Add a row (type: dict) by inserting it into the table.
@ -250,6 +250,8 @@ class Table(object):
If no arguments are given, all records are deleted.
"""
self._check_dropped()
if not self.exists:
return
if _filter or _clauses:
q = self._args_to_clause(_filter, clauses=_clauses)
stmt = self.table.delete(q)
@ -275,6 +277,13 @@ class Table(object):
self.table.name))
self.create_column(column, _type)
def _prune_row(self, row):
"""Remove keys from row not in column set."""
# normalize keys
row = {normalize_column_name(k): v for k, v in row.items()}
# filter out keys not in column set
return {k: row[k] for k in row if k in self.columns}
def _args_to_clause(self, args, ensure=None, clauses=()):
ensure = self.database.ensure_schema if ensure is None else ensure
if ensure:
@ -325,15 +334,15 @@ class Table(object):
self.create_column(name, type_)
def drop_column(self, name):
"""
Drop the column ``name``.
"""Drop the column ``name``.
::
table.drop_column('created_at')
"""
if self.database.engine.dialect.name == 'sqlite':
raise NotImplementedError("SQLite does not support dropping columns.")
if not self.exists:
return
self._check_dropped()
if name not in self.columns:
log.debug("Column does not exist: %s", name)
@ -344,7 +353,20 @@ class Table(object):
name,
self.table.schema
)
self.table = self.database.update_table(self.table.name)
self.table = self.database._reflect_table(self.table.name)
def has_index(self, columns):
"""Check if an index exists to cover the given `columns`."""
columns = set([normalize_column_name(c) for c in columns])
if columns in self._indexes:
return True
inspector = Inspector.from_engine(self.database.executable)
indexes = inspector.get_indexes(self.name, schema=self.database.schema)
for index in indexes:
if columns == set(index.get('column_names', [])):
self._indexes.append(columns)
return True
return False
def create_index(self, columns, name=None, **kw):
"""
@ -356,30 +378,13 @@ class Table(object):
table.create_index(['name', 'country'])
"""
self._check_dropped()
columns = [normalize_column_name(c) for c in columns]
with self.database.lock:
if not name:
sig = '||'.join(columns)
# This is a work-around for a bug in <=0.6.1 which would
# create indexes based on hash() rather than a proper hash.
key = abs(hash(sig))
name = 'ix_%s_%s' % (self.table.name, key)
if name in self.indexes:
return self.indexes[name]
key = sha1(sig.encode('utf-8')).hexdigest()[:16]
name = 'ix_%s_%s' % (self.table.name, key)
if name in self.indexes:
return self.indexes[name]
try:
if not self.has_index(columns):
name = name or index_name(self.name, columns)
columns = [self.table.c[c] for c in columns]
idx = Index(name, *columns, **kw)
idx.create(self.database.executable)
except:
idx = None
self.indexes[name] = idx
return idx
def _args_to_order_by(self, order_by):
if not isinstance(order_by, (list, tuple)):
@ -451,6 +456,8 @@ class Table(object):
row = table.find_one(country='United States')
"""
if not self.exists:
return None
kwargs['_limit'] = 1
kwargs['_step'] = None
resiter = self.find(*args, **kwargs)
@ -465,6 +472,8 @@ class Table(object):
# NOTE: this does not have support for limit and offset since I can't
# see how this is useful. Still, there might be compatibility issues
# with people using these flags. Let's see how it goes.
if not self.exists:
return 0
self._check_dropped()
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
query = select([func.count()], whereclause=args)

View File

@ -9,6 +9,7 @@ except ImportError: # pragma: no cover
from ordereddict import OrderedDict
from six import string_types
from hashlib import sha1
row_type = OrderedDict
@ -70,3 +71,10 @@ def safe_url(url):
pwd = ':%s@' % parsed.password
url = url.replace(pwd, ':*****@')
return url
def index_name(table, columns):
"""Generate an artificial index name."""
sig = '||'.join(columns)
key = sha1(sig.encode('utf-8')).hexdigest()[:16]
return 'ix_%s_%s' % (table, key)