diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index 649a17c..b5e8266 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -1,12 +1,12 @@ import logging -from hashlib import sha1 from sqlalchemy.sql import and_, expression from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.schema import Column, Index from sqlalchemy import func, select, false +from sqlalchemy.engine.reflection import Inspector -from dataset.persistence.util import normalize_column_name +from dataset.persistence.util import normalize_column_name, index_name from dataset.persistence.util import ResultIter from dataset.util import DatasetException @@ -20,16 +20,23 @@ class Table(object): def __init__(self, database, table): """Initialise the table from database schema.""" - self.indexes = dict((i.name, i) for i in table.indexes) self.database = database self.name = table.name self.table = table self._is_dropped = False + self._indexes = [] + + @property + def exists(self): + """Check to see if the table currently exists in the database.""" + if self.table is not None: + return True + return self.name in self.database @property def columns(self): """Get a listing of all columns that exist in the table.""" - if self.table is None: + if not self.exists: return [] return self.table.columns.keys() @@ -53,13 +60,6 @@ class Table(object): if self.table is None: raise DatasetException('The table has been dropped.') - def _prune_row(self, row): - """Remove keys from row not in column set.""" - # normalize keys - row = {normalize_column_name(k): v for k, v in row.items()} - # filter out keys not in column set - return {k: row[k] for k in row if k in self.columns} - def insert(self, row, ensure=None, types=None): """ Add a row (type: dict) by inserting it into the table. @@ -250,6 +250,8 @@ class Table(object): If no arguments are given, all records are deleted. """ self._check_dropped() + if not self.exists: + return if _filter or _clauses: q = self._args_to_clause(_filter, clauses=_clauses) stmt = self.table.delete(q) @@ -275,6 +277,13 @@ class Table(object): self.table.name)) self.create_column(column, _type) + def _prune_row(self, row): + """Remove keys from row not in column set.""" + # normalize keys + row = {normalize_column_name(k): v for k, v in row.items()} + # filter out keys not in column set + return {k: row[k] for k in row if k in self.columns} + def _args_to_clause(self, args, ensure=None, clauses=()): ensure = self.database.ensure_schema if ensure is None else ensure if ensure: @@ -325,15 +334,15 @@ class Table(object): self.create_column(name, type_) def drop_column(self, name): - """ - Drop the column ``name``. + """Drop the column ``name``. :: - table.drop_column('created_at') """ if self.database.engine.dialect.name == 'sqlite': raise NotImplementedError("SQLite does not support dropping columns.") + if not self.exists: + return self._check_dropped() if name not in self.columns: log.debug("Column does not exist: %s", name) @@ -344,7 +353,20 @@ class Table(object): name, self.table.schema ) - self.table = self.database.update_table(self.table.name) + self.table = self.database._reflect_table(self.table.name) + + def has_index(self, columns): + """Check if an index exists to cover the given `columns`.""" + columns = set([normalize_column_name(c) for c in columns]) + if columns in self._indexes: + return True + inspector = Inspector.from_engine(self.database.executable) + indexes = inspector.get_indexes(self.name, schema=self.database.schema) + for index in indexes: + if columns == set(index.get('column_names', [])): + self._indexes.append(columns) + return True + return False def create_index(self, columns, name=None, **kw): """ @@ -356,30 +378,13 @@ class Table(object): table.create_index(['name', 'country']) """ self._check_dropped() + columns = [normalize_column_name(c) for c in columns] with self.database.lock: - if not name: - sig = '||'.join(columns) - - # This is a work-around for a bug in <=0.6.1 which would - # create indexes based on hash() rather than a proper hash. - key = abs(hash(sig)) - name = 'ix_%s_%s' % (self.table.name, key) - if name in self.indexes: - return self.indexes[name] - - key = sha1(sig.encode('utf-8')).hexdigest()[:16] - name = 'ix_%s_%s' % (self.table.name, key) - - if name in self.indexes: - return self.indexes[name] - try: + if not self.has_index(columns): + name = name or index_name(self.name, columns) columns = [self.table.c[c] for c in columns] idx = Index(name, *columns, **kw) idx.create(self.database.executable) - except: - idx = None - self.indexes[name] = idx - return idx def _args_to_order_by(self, order_by): if not isinstance(order_by, (list, tuple)): @@ -451,6 +456,8 @@ class Table(object): row = table.find_one(country='United States') """ + if not self.exists: + return None kwargs['_limit'] = 1 kwargs['_step'] = None resiter = self.find(*args, **kwargs) @@ -465,6 +472,8 @@ class Table(object): # NOTE: this does not have support for limit and offset since I can't # see how this is useful. Still, there might be compatibility issues # with people using these flags. Let's see how it goes. + if not self.exists: + return 0 self._check_dropped() args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses) query = select([func.count()], whereclause=args) diff --git a/dataset/persistence/util.py b/dataset/persistence/util.py index 7fcd3bb..545429b 100644 --- a/dataset/persistence/util.py +++ b/dataset/persistence/util.py @@ -9,6 +9,7 @@ except ImportError: # pragma: no cover from ordereddict import OrderedDict from six import string_types +from hashlib import sha1 row_type = OrderedDict @@ -70,3 +71,10 @@ def safe_url(url): pwd = ':%s@' % parsed.password url = url.replace(pwd, ':*****@') return url + + +def index_name(table, columns): + """Generate an artificial index name.""" + sig = '||'.join(columns) + key = sha1(sig.encode('utf-8')).hexdigest()[:16] + return 'ix_%s_%s' % (table, key)