Use introspection for table indexes.
This commit is contained in:
parent
37d7f47d39
commit
213a7ce857
@ -1,12 +1,12 @@
|
||||
import logging
|
||||
from hashlib import sha1
|
||||
|
||||
from sqlalchemy.sql import and_, expression
|
||||
from sqlalchemy.sql.expression import ClauseElement
|
||||
from sqlalchemy.schema import Column, Index
|
||||
from sqlalchemy import func, select, false
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
|
||||
from dataset.persistence.util import normalize_column_name
|
||||
from dataset.persistence.util import normalize_column_name, index_name
|
||||
from dataset.persistence.util import ResultIter
|
||||
from dataset.util import DatasetException
|
||||
|
||||
@ -20,16 +20,23 @@ class Table(object):
|
||||
|
||||
def __init__(self, database, table):
|
||||
"""Initialise the table from database schema."""
|
||||
self.indexes = dict((i.name, i) for i in table.indexes)
|
||||
self.database = database
|
||||
self.name = table.name
|
||||
self.table = table
|
||||
self._is_dropped = False
|
||||
self._indexes = []
|
||||
|
||||
@property
|
||||
def exists(self):
|
||||
"""Check to see if the table currently exists in the database."""
|
||||
if self.table is not None:
|
||||
return True
|
||||
return self.name in self.database
|
||||
|
||||
@property
|
||||
def columns(self):
|
||||
"""Get a listing of all columns that exist in the table."""
|
||||
if self.table is None:
|
||||
if not self.exists:
|
||||
return []
|
||||
return self.table.columns.keys()
|
||||
|
||||
@ -53,13 +60,6 @@ class Table(object):
|
||||
if self.table is None:
|
||||
raise DatasetException('The table has been dropped.')
|
||||
|
||||
def _prune_row(self, row):
|
||||
"""Remove keys from row not in column set."""
|
||||
# normalize keys
|
||||
row = {normalize_column_name(k): v for k, v in row.items()}
|
||||
# filter out keys not in column set
|
||||
return {k: row[k] for k in row if k in self.columns}
|
||||
|
||||
def insert(self, row, ensure=None, types=None):
|
||||
"""
|
||||
Add a row (type: dict) by inserting it into the table.
|
||||
@ -250,6 +250,8 @@ class Table(object):
|
||||
If no arguments are given, all records are deleted.
|
||||
"""
|
||||
self._check_dropped()
|
||||
if not self.exists:
|
||||
return
|
||||
if _filter or _clauses:
|
||||
q = self._args_to_clause(_filter, clauses=_clauses)
|
||||
stmt = self.table.delete(q)
|
||||
@ -275,6 +277,13 @@ class Table(object):
|
||||
self.table.name))
|
||||
self.create_column(column, _type)
|
||||
|
||||
def _prune_row(self, row):
|
||||
"""Remove keys from row not in column set."""
|
||||
# normalize keys
|
||||
row = {normalize_column_name(k): v for k, v in row.items()}
|
||||
# filter out keys not in column set
|
||||
return {k: row[k] for k in row if k in self.columns}
|
||||
|
||||
def _args_to_clause(self, args, ensure=None, clauses=()):
|
||||
ensure = self.database.ensure_schema if ensure is None else ensure
|
||||
if ensure:
|
||||
@ -325,15 +334,15 @@ class Table(object):
|
||||
self.create_column(name, type_)
|
||||
|
||||
def drop_column(self, name):
|
||||
"""
|
||||
Drop the column ``name``.
|
||||
"""Drop the column ``name``.
|
||||
|
||||
::
|
||||
|
||||
table.drop_column('created_at')
|
||||
"""
|
||||
if self.database.engine.dialect.name == 'sqlite':
|
||||
raise NotImplementedError("SQLite does not support dropping columns.")
|
||||
if not self.exists:
|
||||
return
|
||||
self._check_dropped()
|
||||
if name not in self.columns:
|
||||
log.debug("Column does not exist: %s", name)
|
||||
@ -344,7 +353,20 @@ class Table(object):
|
||||
name,
|
||||
self.table.schema
|
||||
)
|
||||
self.table = self.database.update_table(self.table.name)
|
||||
self.table = self.database._reflect_table(self.table.name)
|
||||
|
||||
def has_index(self, columns):
|
||||
"""Check if an index exists to cover the given `columns`."""
|
||||
columns = set([normalize_column_name(c) for c in columns])
|
||||
if columns in self._indexes:
|
||||
return True
|
||||
inspector = Inspector.from_engine(self.database.executable)
|
||||
indexes = inspector.get_indexes(self.name, schema=self.database.schema)
|
||||
for index in indexes:
|
||||
if columns == set(index.get('column_names', [])):
|
||||
self._indexes.append(columns)
|
||||
return True
|
||||
return False
|
||||
|
||||
def create_index(self, columns, name=None, **kw):
|
||||
"""
|
||||
@ -356,30 +378,13 @@ class Table(object):
|
||||
table.create_index(['name', 'country'])
|
||||
"""
|
||||
self._check_dropped()
|
||||
columns = [normalize_column_name(c) for c in columns]
|
||||
with self.database.lock:
|
||||
if not name:
|
||||
sig = '||'.join(columns)
|
||||
|
||||
# This is a work-around for a bug in <=0.6.1 which would
|
||||
# create indexes based on hash() rather than a proper hash.
|
||||
key = abs(hash(sig))
|
||||
name = 'ix_%s_%s' % (self.table.name, key)
|
||||
if name in self.indexes:
|
||||
return self.indexes[name]
|
||||
|
||||
key = sha1(sig.encode('utf-8')).hexdigest()[:16]
|
||||
name = 'ix_%s_%s' % (self.table.name, key)
|
||||
|
||||
if name in self.indexes:
|
||||
return self.indexes[name]
|
||||
try:
|
||||
if not self.has_index(columns):
|
||||
name = name or index_name(self.name, columns)
|
||||
columns = [self.table.c[c] for c in columns]
|
||||
idx = Index(name, *columns, **kw)
|
||||
idx.create(self.database.executable)
|
||||
except:
|
||||
idx = None
|
||||
self.indexes[name] = idx
|
||||
return idx
|
||||
|
||||
def _args_to_order_by(self, order_by):
|
||||
if not isinstance(order_by, (list, tuple)):
|
||||
@ -451,6 +456,8 @@ class Table(object):
|
||||
|
||||
row = table.find_one(country='United States')
|
||||
"""
|
||||
if not self.exists:
|
||||
return None
|
||||
kwargs['_limit'] = 1
|
||||
kwargs['_step'] = None
|
||||
resiter = self.find(*args, **kwargs)
|
||||
@ -465,6 +472,8 @@ class Table(object):
|
||||
# NOTE: this does not have support for limit and offset since I can't
|
||||
# see how this is useful. Still, there might be compatibility issues
|
||||
# with people using these flags. Let's see how it goes.
|
||||
if not self.exists:
|
||||
return 0
|
||||
self._check_dropped()
|
||||
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
|
||||
query = select([func.count()], whereclause=args)
|
||||
|
||||
@ -9,6 +9,7 @@ except ImportError: # pragma: no cover
|
||||
from ordereddict import OrderedDict
|
||||
|
||||
from six import string_types
|
||||
from hashlib import sha1
|
||||
|
||||
row_type = OrderedDict
|
||||
|
||||
@ -70,3 +71,10 @@ def safe_url(url):
|
||||
pwd = ':%s@' % parsed.password
|
||||
url = url.replace(pwd, ':*****@')
|
||||
return url
|
||||
|
||||
|
||||
def index_name(table, columns):
|
||||
"""Generate an artificial index name."""
|
||||
sig = '||'.join(columns)
|
||||
key = sha1(sig.encode('utf-8')).hexdigest()[:16]
|
||||
return 'ix_%s_%s' % (table, key)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user