Use introspection for table indexes.
This commit is contained in:
parent
37d7f47d39
commit
213a7ce857
@ -1,12 +1,12 @@
|
|||||||
import logging
|
import logging
|
||||||
from hashlib import sha1
|
|
||||||
|
|
||||||
from sqlalchemy.sql import and_, expression
|
from sqlalchemy.sql import and_, expression
|
||||||
from sqlalchemy.sql.expression import ClauseElement
|
from sqlalchemy.sql.expression import ClauseElement
|
||||||
from sqlalchemy.schema import Column, Index
|
from sqlalchemy.schema import Column, Index
|
||||||
from sqlalchemy import func, select, false
|
from sqlalchemy import func, select, false
|
||||||
|
from sqlalchemy.engine.reflection import Inspector
|
||||||
|
|
||||||
from dataset.persistence.util import normalize_column_name
|
from dataset.persistence.util import normalize_column_name, index_name
|
||||||
from dataset.persistence.util import ResultIter
|
from dataset.persistence.util import ResultIter
|
||||||
from dataset.util import DatasetException
|
from dataset.util import DatasetException
|
||||||
|
|
||||||
@ -20,16 +20,23 @@ class Table(object):
|
|||||||
|
|
||||||
def __init__(self, database, table):
|
def __init__(self, database, table):
|
||||||
"""Initialise the table from database schema."""
|
"""Initialise the table from database schema."""
|
||||||
self.indexes = dict((i.name, i) for i in table.indexes)
|
|
||||||
self.database = database
|
self.database = database
|
||||||
self.name = table.name
|
self.name = table.name
|
||||||
self.table = table
|
self.table = table
|
||||||
self._is_dropped = False
|
self._is_dropped = False
|
||||||
|
self._indexes = []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def exists(self):
|
||||||
|
"""Check to see if the table currently exists in the database."""
|
||||||
|
if self.table is not None:
|
||||||
|
return True
|
||||||
|
return self.name in self.database
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def columns(self):
|
def columns(self):
|
||||||
"""Get a listing of all columns that exist in the table."""
|
"""Get a listing of all columns that exist in the table."""
|
||||||
if self.table is None:
|
if not self.exists:
|
||||||
return []
|
return []
|
||||||
return self.table.columns.keys()
|
return self.table.columns.keys()
|
||||||
|
|
||||||
@ -53,13 +60,6 @@ class Table(object):
|
|||||||
if self.table is None:
|
if self.table is None:
|
||||||
raise DatasetException('The table has been dropped.')
|
raise DatasetException('The table has been dropped.')
|
||||||
|
|
||||||
def _prune_row(self, row):
|
|
||||||
"""Remove keys from row not in column set."""
|
|
||||||
# normalize keys
|
|
||||||
row = {normalize_column_name(k): v for k, v in row.items()}
|
|
||||||
# filter out keys not in column set
|
|
||||||
return {k: row[k] for k in row if k in self.columns}
|
|
||||||
|
|
||||||
def insert(self, row, ensure=None, types=None):
|
def insert(self, row, ensure=None, types=None):
|
||||||
"""
|
"""
|
||||||
Add a row (type: dict) by inserting it into the table.
|
Add a row (type: dict) by inserting it into the table.
|
||||||
@ -250,6 +250,8 @@ class Table(object):
|
|||||||
If no arguments are given, all records are deleted.
|
If no arguments are given, all records are deleted.
|
||||||
"""
|
"""
|
||||||
self._check_dropped()
|
self._check_dropped()
|
||||||
|
if not self.exists:
|
||||||
|
return
|
||||||
if _filter or _clauses:
|
if _filter or _clauses:
|
||||||
q = self._args_to_clause(_filter, clauses=_clauses)
|
q = self._args_to_clause(_filter, clauses=_clauses)
|
||||||
stmt = self.table.delete(q)
|
stmt = self.table.delete(q)
|
||||||
@ -275,6 +277,13 @@ class Table(object):
|
|||||||
self.table.name))
|
self.table.name))
|
||||||
self.create_column(column, _type)
|
self.create_column(column, _type)
|
||||||
|
|
||||||
|
def _prune_row(self, row):
|
||||||
|
"""Remove keys from row not in column set."""
|
||||||
|
# normalize keys
|
||||||
|
row = {normalize_column_name(k): v for k, v in row.items()}
|
||||||
|
# filter out keys not in column set
|
||||||
|
return {k: row[k] for k in row if k in self.columns}
|
||||||
|
|
||||||
def _args_to_clause(self, args, ensure=None, clauses=()):
|
def _args_to_clause(self, args, ensure=None, clauses=()):
|
||||||
ensure = self.database.ensure_schema if ensure is None else ensure
|
ensure = self.database.ensure_schema if ensure is None else ensure
|
||||||
if ensure:
|
if ensure:
|
||||||
@ -325,15 +334,15 @@ class Table(object):
|
|||||||
self.create_column(name, type_)
|
self.create_column(name, type_)
|
||||||
|
|
||||||
def drop_column(self, name):
|
def drop_column(self, name):
|
||||||
"""
|
"""Drop the column ``name``.
|
||||||
Drop the column ``name``.
|
|
||||||
|
|
||||||
::
|
::
|
||||||
|
|
||||||
table.drop_column('created_at')
|
table.drop_column('created_at')
|
||||||
"""
|
"""
|
||||||
if self.database.engine.dialect.name == 'sqlite':
|
if self.database.engine.dialect.name == 'sqlite':
|
||||||
raise NotImplementedError("SQLite does not support dropping columns.")
|
raise NotImplementedError("SQLite does not support dropping columns.")
|
||||||
|
if not self.exists:
|
||||||
|
return
|
||||||
self._check_dropped()
|
self._check_dropped()
|
||||||
if name not in self.columns:
|
if name not in self.columns:
|
||||||
log.debug("Column does not exist: %s", name)
|
log.debug("Column does not exist: %s", name)
|
||||||
@ -344,7 +353,20 @@ class Table(object):
|
|||||||
name,
|
name,
|
||||||
self.table.schema
|
self.table.schema
|
||||||
)
|
)
|
||||||
self.table = self.database.update_table(self.table.name)
|
self.table = self.database._reflect_table(self.table.name)
|
||||||
|
|
||||||
|
def has_index(self, columns):
|
||||||
|
"""Check if an index exists to cover the given `columns`."""
|
||||||
|
columns = set([normalize_column_name(c) for c in columns])
|
||||||
|
if columns in self._indexes:
|
||||||
|
return True
|
||||||
|
inspector = Inspector.from_engine(self.database.executable)
|
||||||
|
indexes = inspector.get_indexes(self.name, schema=self.database.schema)
|
||||||
|
for index in indexes:
|
||||||
|
if columns == set(index.get('column_names', [])):
|
||||||
|
self._indexes.append(columns)
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def create_index(self, columns, name=None, **kw):
|
def create_index(self, columns, name=None, **kw):
|
||||||
"""
|
"""
|
||||||
@ -356,30 +378,13 @@ class Table(object):
|
|||||||
table.create_index(['name', 'country'])
|
table.create_index(['name', 'country'])
|
||||||
"""
|
"""
|
||||||
self._check_dropped()
|
self._check_dropped()
|
||||||
|
columns = [normalize_column_name(c) for c in columns]
|
||||||
with self.database.lock:
|
with self.database.lock:
|
||||||
if not name:
|
if not self.has_index(columns):
|
||||||
sig = '||'.join(columns)
|
name = name or index_name(self.name, columns)
|
||||||
|
|
||||||
# This is a work-around for a bug in <=0.6.1 which would
|
|
||||||
# create indexes based on hash() rather than a proper hash.
|
|
||||||
key = abs(hash(sig))
|
|
||||||
name = 'ix_%s_%s' % (self.table.name, key)
|
|
||||||
if name in self.indexes:
|
|
||||||
return self.indexes[name]
|
|
||||||
|
|
||||||
key = sha1(sig.encode('utf-8')).hexdigest()[:16]
|
|
||||||
name = 'ix_%s_%s' % (self.table.name, key)
|
|
||||||
|
|
||||||
if name in self.indexes:
|
|
||||||
return self.indexes[name]
|
|
||||||
try:
|
|
||||||
columns = [self.table.c[c] for c in columns]
|
columns = [self.table.c[c] for c in columns]
|
||||||
idx = Index(name, *columns, **kw)
|
idx = Index(name, *columns, **kw)
|
||||||
idx.create(self.database.executable)
|
idx.create(self.database.executable)
|
||||||
except:
|
|
||||||
idx = None
|
|
||||||
self.indexes[name] = idx
|
|
||||||
return idx
|
|
||||||
|
|
||||||
def _args_to_order_by(self, order_by):
|
def _args_to_order_by(self, order_by):
|
||||||
if not isinstance(order_by, (list, tuple)):
|
if not isinstance(order_by, (list, tuple)):
|
||||||
@ -451,6 +456,8 @@ class Table(object):
|
|||||||
|
|
||||||
row = table.find_one(country='United States')
|
row = table.find_one(country='United States')
|
||||||
"""
|
"""
|
||||||
|
if not self.exists:
|
||||||
|
return None
|
||||||
kwargs['_limit'] = 1
|
kwargs['_limit'] = 1
|
||||||
kwargs['_step'] = None
|
kwargs['_step'] = None
|
||||||
resiter = self.find(*args, **kwargs)
|
resiter = self.find(*args, **kwargs)
|
||||||
@ -465,6 +472,8 @@ class Table(object):
|
|||||||
# NOTE: this does not have support for limit and offset since I can't
|
# NOTE: this does not have support for limit and offset since I can't
|
||||||
# see how this is useful. Still, there might be compatibility issues
|
# see how this is useful. Still, there might be compatibility issues
|
||||||
# with people using these flags. Let's see how it goes.
|
# with people using these flags. Let's see how it goes.
|
||||||
|
if not self.exists:
|
||||||
|
return 0
|
||||||
self._check_dropped()
|
self._check_dropped()
|
||||||
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
|
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
|
||||||
query = select([func.count()], whereclause=args)
|
query = select([func.count()], whereclause=args)
|
||||||
|
|||||||
@ -9,6 +9,7 @@ except ImportError: # pragma: no cover
|
|||||||
from ordereddict import OrderedDict
|
from ordereddict import OrderedDict
|
||||||
|
|
||||||
from six import string_types
|
from six import string_types
|
||||||
|
from hashlib import sha1
|
||||||
|
|
||||||
row_type = OrderedDict
|
row_type = OrderedDict
|
||||||
|
|
||||||
@ -70,3 +71,10 @@ def safe_url(url):
|
|||||||
pwd = ':%s@' % parsed.password
|
pwd = ':%s@' % parsed.password
|
||||||
url = url.replace(pwd, ':*****@')
|
url = url.replace(pwd, ':*****@')
|
||||||
return url
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def index_name(table, columns):
|
||||||
|
"""Generate an artificial index name."""
|
||||||
|
sig = '||'.join(columns)
|
||||||
|
key = sha1(sig.encode('utf-8')).hexdigest()[:16]
|
||||||
|
return 'ix_%s_%s' % (table, key)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user