From e30cf241953d72294f8cefb421a2531a0adef492 Mon Sep 17 00:00:00 2001 From: Friedrich Lindenberg Date: Sun, 3 Sep 2017 10:05:17 +0200 Subject: [PATCH] Rewrite data change functions on table. --- dataset/persistence/database.py | 30 +-- dataset/persistence/table.py | 331 +++++++++++++------------------- dataset/persistence/util.py | 11 ++ setup.py | 2 +- test/test_persistence.py | 63 +++--- 5 files changed, 201 insertions(+), 236 deletions(-) diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index ca84604..4a28690 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -17,7 +17,7 @@ from alembic.migration import MigrationContext from alembic.operations import Operations from dataset.persistence.table import Table -from dataset.persistence.util import ResultIter, row_type, safe_url +from dataset.persistence.util import ResultIter, row_type, safe_url, QUERY_STEP from dataset.persistence.types import Types log = logging.getLogger(__name__) @@ -238,28 +238,34 @@ class Database(object): """Get a given table.""" return self.get_table(table_name) - def query(self, query, **kwargs): + def query(self, query, *args, **kwargs): """Run a statement on the database directly. - Allows for the execution of arbitrary read/write queries. A query can either be - a plain text string, or a `SQLAlchemy expression `_. - If a plain string is passed in, it will be converted to an expression automatically. + Allows for the execution of arbitrary read/write queries. A query can + either be a plain text string, or a `SQLAlchemy expression + `_. + If a plain string is passed in, it will be converted to an expression + automatically. - Keyword arguments will be used for parameter binding. See the `SQLAlchemy - documentation `_ for details. + Further positional and keyword arguments will be used for parameter + binding. To include a positional argument in your query, use question + marks in the query (i.e. `SELECT * FROM tbl WHERE a = ?`). For keyword + arguments, use a bind parameter (i.e. `SELECT * FROM tbl WHERE a = + :foo`). The returned iterator will yield each result sequentially. :: - res = db.query('SELECT user, COUNT(*) c FROM photos GROUP BY user') - for row in res: + statement = 'SELECT user, COUNT(*) c FROM photos GROUP BY user' + for row in db.query(statement): print(row['user'], row['c']) + """ if isinstance(query, six.string_types): query = text(query) - _step = kwargs.pop('_step', 5000) - return ResultIter(self.executable.execute(query, **kwargs), - row_type=self.row_type, step=_step) + _step = kwargs.pop('_step', QUERY_STEP) + rp = self.executable.execute(query, *args, **kwargs) + return ResultIter(rp, row_type=self.row_type, step=_step) def __repr__(self): """Text representation contains the URL.""" diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index b5e8266..581c690 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -7,7 +7,7 @@ from sqlalchemy import func, select, false from sqlalchemy.engine.reflection import Inspector from dataset.persistence.util import normalize_column_name, index_name -from dataset.persistence.util import ResultIter +from dataset.persistence.util import ensure_tuple, ResultIter, QUERY_STEP from dataset.util import DatasetException @@ -49,17 +49,12 @@ class Table(object): dropping the table. If you want to re-create the table, make sure to get a fresh instance from the :py:class:`Database `. """ - self._check_dropped() with self.database.lock: + if not self.exists: + return self.table.drop(self.database.executable, checkfirst=True) - # self.database._tables.pop(self.name, None) self.table = None - def _check_dropped(self): - # self.table = self.database._reflect_table(self.name) - if self.table is None: - raise DatasetException('The table has been dropped.') - def insert(self, row, ensure=None, types=None): """ Add a row (type: dict) by inserting it into the table. @@ -79,15 +74,11 @@ class Table(object): Returns the inserted row's primary key. """ - self._check_dropped() - ensure = self.database.ensure_schema if ensure is None else ensure - if ensure: - self._ensure_columns(row, types=types) - else: - row = self._prune_row(row) + row = self._sync_columns(row, ensure, types=types) res = self.database.executable.execute(self.table.insert(row)) if len(res.inserted_primary_key) > 0: return res.inserted_primary_key[0] + return True def insert_ignore(self, row, keys, ensure=None, types=None): """ @@ -108,11 +99,13 @@ class Table(object): data = dict(id=10, title='I am a banana!') table.insert_ignore(data, ['id']) """ - row, res = self._upsert_pre_check(row, keys, ensure) - if res is None: - return self.insert(row, ensure=ensure, types=types) - else: - return False + row = self._sync_columns(row, ensure, types=types) + if self._check_ensure(ensure): + self.create_index(keys) + args, _ = self._keys_to_args(row, keys) + if self.count(**args) == 0: + return self.insert(row, ensure=False) + return False def insert_many(self, rows, chunk_size=1000, ensure=None, types=None): """ @@ -129,28 +122,18 @@ class Table(object): rows = [dict(name='Dolly')] * 10000 table.insert_many(rows) """ - ensure = self.database.ensure_schema if ensure is None else ensure - - def _process_chunk(chunk): - if ensure: - for row in chunk: - self._ensure_columns(row, types=types) - else: - chunk = [self._prune_row(r) for r in chunk] - self.table.insert().execute(chunk) - self._check_dropped() - chunk = [] - for i, row in enumerate(rows, start=1): + for row in rows: + row = self._sync_columns(row, ensure, types=types) chunk.append(row) - if i % chunk_size == 0: - _process_chunk(chunk) + if len(chunk) == chunk_size: + self.table.insert().execute(chunk) chunk = [] - if chunk: - _process_chunk(chunk) + if len(chunk): + self.table.insert().execute(chunk) - def update(self, row, keys, ensure=None, types=None): + def update(self, row, keys, ensure=None, types=None, return_count=False): """ Update a row in the table. @@ -167,54 +150,20 @@ class Table(object): they will be created based on the settings of ``ensure`` and ``types``, matching the behavior of :py:meth:`insert() `. """ - self._check_dropped() - # check whether keys arg is a string and format as a list - if not isinstance(keys, (list, tuple)): - keys = [keys] - if not keys or len(keys) == len(row): - return False - clause = [(u, row.get(u)) for u in keys] - - ensure = self.database.ensure_schema if ensure is None else ensure - if ensure: - self._ensure_columns(row, types=types) - else: - row = self._prune_row(row) - - # Don't update the key itself, so remove any keys from the row dict - clean_row = row.copy() - for key in keys: - if key in clean_row.keys(): - del clean_row[key] - - try: - filters = self._args_to_clause(dict(clause)) - stmt = self.table.update(filters, clean_row) - rp = self.database.executable.execute(stmt) + row = self._sync_columns(row, ensure, types=types) + args, row = self._keys_to_args(row, keys) + clause = self._args_to_clause(args) + if not len(row): + return self.count(clause) + stmt = self.table.update(whereclause=clause, values=row) + rp = self.database.executable.execute(stmt) + if rp.supports_sane_rowcount(): return rp.rowcount - except KeyError: - return 0 - - def _upsert_pre_check(self, row, keys, ensure): - # check whether keys arg is a string and format as a list - if not isinstance(keys, (list, tuple)): - keys = [keys] - self._check_dropped() - - ensure = self.database.ensure_schema if ensure is None else ensure - if ensure: - self.create_index(keys) - else: - row = self._prune_row(row) - - filters = {} - for key in keys: - filters[key] = row.get(key) - return row, self.find_one(**filters) + if return_count: + return self.count(clause) def upsert(self, row, keys, ensure=None, types=None): - """ - An UPSERT is a smart combination of insert and update. + """An UPSERT is a smart combination of insert and update. If rows with matching ``keys`` exist they will be updated, otherwise a new row is inserted in the table. @@ -223,22 +172,16 @@ class Table(object): data = dict(id=10, title='I am a banana!') table.upsert(data, ['id']) """ - row, res = self._upsert_pre_check(row, keys, ensure) - if res is None: - return self.insert(row, ensure=ensure, types=types) - else: - row_count = self.update(row, keys, ensure=ensure, types=types) - try: - result = (row_count > 0, res['id'])[row_count == 1] - except KeyError: - result = row_count > 0 + row = self._sync_columns(row, ensure, types=types) + if self._check_ensure(ensure): + self.create_index(keys) + row_count = self.update(row, keys, ensure=False, return_count=True) + if row_count == 0: + return self.insert(row, ensure=False) + return True - return result - - def delete(self, *_clauses, **_filter): - """ - - Delete rows from the table. + def delete(self, *clauses, **filters): + """Delete rows from the table. Keyword arguments can be used to add column-based filters. The filter criterion will always be equality: @@ -249,55 +192,60 @@ class Table(object): If no arguments are given, all records are deleted. """ - self._check_dropped() if not self.exists: - return - if _filter or _clauses: - q = self._args_to_clause(_filter, clauses=_clauses) - stmt = self.table.delete(q) - else: - stmt = self.table.delete() - rows = self.database.executable.execute(stmt) - return rows.rowcount > 0 + return False + clause = self._args_to_clause(filters, clauses=clauses) + stmt = self.table.delete(whereclause=clause) + rp = self.database.executable.execute(stmt) + return rp.rowcount > 0 - def _has_column(self, column): + def has_column(self, column): + """Check if a column with the given name exists on this table.""" return normalize_column_name(column) in self.columns - def _ensure_columns(self, row, types=None): - # Keep order of inserted columns - for column in row.keys(): - if self._has_column(column): - continue - if types is not None and column in types: - _type = types[column] - else: - _type = self.database.types.guess(row[column]) - log.debug("Creating column: %s (%s) on %r" % (column, - _type, - self.table.name)) - self.create_column(column, _type) + def _sync_columns(self, row, ensure, types=None): + columns = self.columns + ensure = self._check_ensure(ensure) + types = types or {} + types = {normalize_column_name(k): v for (k, v) in types.items()} + out = {} + for name, value in row.items(): + name = normalize_column_name(name) + if ensure and name not in columns: + _type = types.get(name) + if _type is None: + _type = self.database.types.guess(value) + log.debug("Create column: %s on %s", name, self.name) + self.create_column(name, _type) + columns.append(name) + if name in columns: + out[name] = value + return out - def _prune_row(self, row): - """Remove keys from row not in column set.""" - # normalize keys - row = {normalize_column_name(k): v for k, v in row.items()} - # filter out keys not in column set - return {k: row[k] for k in row if k in self.columns} + def _check_ensure(self, ensure): + if ensure is None: + return self.database.ensure_schema + return ensure - def _args_to_clause(self, args, ensure=None, clauses=()): - ensure = self.database.ensure_schema if ensure is None else ensure - if ensure: - self._ensure_columns(args) + def _args_to_clause(self, args, clauses=()): clauses = list(clauses) - for k, v in args.items(): - if not self._has_column(k): + for column, value in args.items(): + if not self.has_column(column): clauses.append(false()) - elif isinstance(v, (list, tuple)): - clauses.append(self.table.c[k].in_(v)) + elif isinstance(value, (list, tuple)): + clauses.append(self.table.c[column].in_(value)) else: - clauses.append(self.table.c[k] == v) + clauses.append(self.table.c[column] == value) return and_(*clauses) + def _keys_to_args(self, row, keys): + keys = ensure_tuple(keys) + keys = [normalize_column_name(k) for k in keys] + # keys = [self.has_column(k) for k in keys] + row = row.copy() + args = {k: row.pop(k) for k in keys if k in row} + return args, row + def create_column(self, name, type): """ Explicitly create a new column ``name`` of a specified type. @@ -307,10 +255,10 @@ class Table(object): table.create_column('created_at', db.types.datetime) """ - self._check_dropped() + # TODO: create the table if it does not exist. with self.database.lock: name = normalize_column_name(name) - if name in self.columns: + if self.has_column(name): log.debug("Column exists: %s" % name) return @@ -341,10 +289,8 @@ class Table(object): """ if self.database.engine.dialect.name == 'sqlite': raise NotImplementedError("SQLite does not support dropping columns.") - if not self.exists: - return - self._check_dropped() - if name not in self.columns: + name = normalize_column_name(name) + if not self.exists or not self.has_column(name): log.debug("Column does not exist: %s", name) return with self.database.lock: @@ -357,9 +303,14 @@ class Table(object): def has_index(self, columns): """Check if an index exists to cover the given `columns`.""" + if not self.exists: + return False columns = set([normalize_column_name(c) for c in columns]) if columns in self._indexes: return True + for column in columns: + if not self.has_column(column): + raise DatasetException("Column does not exist: %s" % column) inspector = Inspector.from_engine(self.database.executable) indexes = inspector.get_indexes(self.name, schema=self.database.schema) for index in indexes: @@ -377,9 +328,11 @@ class Table(object): table.create_index(['name', 'country']) """ - self._check_dropped() columns = [normalize_column_name(c) for c in columns] with self.database.lock: + if not self.exists: + # TODO + pass if not self.has_index(columns): name = name or index_name(self.name, columns) columns = [self.table.c[c] for c in columns] @@ -387,10 +340,8 @@ class Table(object): idx.create(self.database.executable) def _args_to_order_by(self, order_by): - if not isinstance(order_by, (list, tuple)): - order_by = [order_by] orderings = [] - for ordering in order_by: + for ordering in ensure_tuple(order_by): if ordering is None: continue column = ordering.lstrip('-') @@ -430,22 +381,25 @@ class Table(object): """ _limit = kwargs.pop('_limit', None) _offset = kwargs.pop('_offset', 0) - _step = kwargs.pop('_step', 5000) + _step = kwargs.pop('_step', QUERY_STEP) order_by = kwargs.pop('order_by', None) - self._check_dropped() + if not self.exists: + return [] order_by = self._args_to_order_by(order_by) - args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses) + args = self._args_to_clause(kwargs, clauses=_clauses) if _step is False or _step == 0: _step = None - query = self.table.select(whereclause=args, limit=_limit, + query = self.table.select(whereclause=args, + limit=_limit, offset=_offset) if len(order_by): query = query.order_by(*order_by) return ResultIter(self.database.executable.execute(query), - row_type=self.database.row_type, step=_step) + row_type=self.database.row_type, + step=_step) def find_one(self, *args, **kwargs): """Get a single result from the table. @@ -458,6 +412,7 @@ class Table(object): """ if not self.exists: return None + kwargs['_limit'] = 1 kwargs['_step'] = None resiter = self.find(*args, **kwargs) @@ -474,8 +429,8 @@ class Table(object): # with people using these flags. Let's see how it goes. if not self.exists: return 0 - self._check_dropped() - args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses) + + args = self._args_to_clause(kwargs, clauses=_clauses) query = select([func.count()], whereclause=args) query = query.select_from(self.table) rp = self.database.executable.execute(query) @@ -486,10 +441,7 @@ class Table(object): return self.count() def distinct(self, *args, **_filter): - """ - Return all rows of a table, but remove rows in with duplicate values in ``columns``. - - Interally this creates a `DISTINCT statement `_. + """Return all the unique (distinct) values for the given ``columns``. :: # returns only one row per year, ignoring the rest @@ -499,62 +451,47 @@ class Table(object): # you can also combine this with a filter table.distinct('year', country='China') """ - self._check_dropped() - qargs = [] - columns = [] - try: - for c in args: - if isinstance(c, ClauseElement): - qargs.append(c) - else: - columns.append(self.table.c[c]) - for col, val in _filter.items(): - qargs.append(self.table.c[col] == val) - except KeyError: + if not self.exists: return [] - q = expression.select(columns, distinct=True, - whereclause=and_(*qargs), + filters = [] + for column, value in _filter.items(): + if not self.has_column(column): + raise DatasetException("No such column: %s" % column) + filters.append(self.table.c[column] == value) + + columns = [] + for column in args: + if isinstance(column, ClauseElement): + filters.append(column) + else: + if not self.has_column(column): + raise DatasetException("No such column: %s" % column) + columns.append(self.table.c[column]) + + if not len(columns): + return [] + + q = expression.select(columns, + distinct=True, + whereclause=and_(*filters), order_by=[c.asc() for c in columns]) return self.database.query(q) - def __getitem__(self, item): - """ - Get distinct column values. - - This is an alias for distinct which allows the table to be queried as using - square bracket syntax. - :: - # Same as distinct: - print list(table['year']) - """ - if not isinstance(item, tuple): - item = item, - return self.distinct(*item) - - def all(self): - """ - Return all rows of the table as simple dictionaries. - - This is simply a shortcut to *find()* called with no arguments. - :: - - rows = table.all() - """ - return self.find() + # Legacy methods for running find queries. + all = find def __iter__(self): - """ - Return all rows of the table as simple dictionaries. + """Return all rows of the table as simple dictionaries. Allows for iterating over all rows in the table without explicetly - calling :py:meth:`all() `. + calling :py:meth:`find() `. :: for row in table: print(row) """ - return self.all() + return self.find() def __repr__(self): """Get table representation.""" diff --git a/dataset/persistence/util.py b/dataset/persistence/util.py index 545429b..94cf79f 100644 --- a/dataset/persistence/util.py +++ b/dataset/persistence/util.py @@ -9,8 +9,10 @@ except ImportError: # pragma: no cover from ordereddict import OrderedDict from six import string_types +from collections import Sequence from hashlib import sha1 +QUERY_STEP = 1000 row_type = OrderedDict @@ -78,3 +80,12 @@ def index_name(table, columns): sig = '||'.join(columns) key = sha1(sig.encode('utf-8')).hexdigest()[:16] return 'ix_%s_%s' % (table, key) + + +def ensure_tuple(obj): + """Try and make the given argument into a tuple.""" + if obj is None: + return tuple() + if isinstance(obj, Sequence) and not isinstance(obj, string_types): + return tuple(obj) + return obj, diff --git a/setup.py b/setup.py index 137314c..1e1dee0 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ setup( include_package_data=False, zip_safe=False, install_requires=[ - 'sqlalchemy >= 0.9.1', + 'sqlalchemy >= 1.1.0', 'alembic >= 0.6.2', 'normality >= 0.3.9', "PyYAML >= 3.10", diff --git a/test/test_persistence.py b/test/test_persistence.py index 86fd85f..0559bfa 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -11,7 +11,6 @@ from sqlalchemy import FLOAT, INTEGER, TEXT from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError from dataset import connect -from dataset.util import DatasetException from .sample_data import TEST_DATA, TEST_CITY_1 @@ -185,13 +184,14 @@ class TableTestCase(unittest.TestCase): assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) def test_insert_ignore_all_key(self): - for i in range(0, 2): + for i in range(0, 4): self.tbl.insert_ignore({ 'date': datetime(2011, 1, 2), 'temperature': -10, 'place': 'Berlin'}, ['date', 'temperature', 'place'] ) + assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) def test_upsert(self): self.tbl.upsert({ @@ -209,7 +209,21 @@ class TableTestCase(unittest.TestCase): ) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) + def test_upsert_single_column(self): + table = self.db['banana_single_col'] + table.upsert({ + 'color': 'Yellow'}, + ['color'] + ) + assert len(table) == 1, len(table) + table.upsert({ + 'color': 'Yellow'}, + ['color'] + ) + assert len(table) == 1, len(table) + def test_upsert_all_key(self): + assert len(self.tbl) == len(TEST_DATA), len(self.tbl) for i in range(0, 2): self.tbl.upsert({ 'date': datetime(2011, 1, 2), @@ -217,11 +231,13 @@ class TableTestCase(unittest.TestCase): 'place': 'Berlin'}, ['date', 'temperature', 'place'] ) + assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) def test_update_while_iter(self): for row in self.tbl: row['foo'] = 'bar' self.tbl.update(row, ['place', 'date']) + assert len(self.tbl) == len(TEST_DATA), len(self.tbl) def test_weird_column_names(self): with self.assertRaises(ValueError): @@ -262,17 +278,19 @@ class TableTestCase(unittest.TestCase): assert len(self.tbl) == 0, len(self.tbl) def test_repr(self): - assert repr(self.tbl) == '', 'the representation should be ' + assert repr(self.tbl) == '', \ + 'the representation should be ' def test_delete_nonexist_entry(self): - assert self.tbl.delete(place='Berlin') is False, 'entry not exist, should fail to delete' + assert self.tbl.delete(place='Berlin') is False, \ + 'entry not exist, should fail to delete' def test_find_one(self): self.tbl.insert({ 'date': datetime(2011, 1, 2), 'temperature': -10, - 'place': 'Berlin'} - ) + 'place': 'Berlin' + }) d = self.tbl.find_one(place='Berlin') assert d['temperature'] == -10, d d = self.tbl.find_one(place='Atlantis') @@ -317,29 +335,17 @@ class TableTestCase(unittest.TestCase): self.tbl.table.columns.date >= datetime(2011, 1, 2, 0, 0))) assert len(x) == 4, x - def test_get_items(self): - x = list(self.tbl['place']) - y = list(self.tbl.distinct('place')) - assert x == y, (x, y) - x = list(self.tbl['place', 'date']) - y = list(self.tbl.distinct('place', 'date')) - assert x == y, (x, y) - def test_insert_many(self): data = TEST_DATA * 100 self.tbl.insert_many(data, chunk_size=13) assert len(self.tbl) == len(data) + 6 - def test_drop_warning(self): + def test_drop_operations(self): assert self.tbl.table is not None, 'table shouldn\'t be dropped yet' self.tbl.drop() assert self.tbl.table is None, 'table should be dropped now' - try: - list(self.tbl.all()) - except DatasetException: - pass - else: - assert False, 'we should not reach else block, no exception raised!' + assert list(self.tbl.all()) == [], self.tbl.all() + assert self.tbl.count() == 0, self.tbl.count() def test_table_drop(self): assert 'weather' in self.db @@ -380,26 +386,31 @@ class TableTestCase(unittest.TestCase): ) assert res, 'update should return True' m = self.tbl.find_one(place=TEST_CITY_1, date=date) - assert m['temperature'] == -10, 'new temp. should be -10 but is %d' % m['temperature'] + assert m['temperature'] == -10, \ + 'new temp. should be -10 but is %d' % m['temperature'] def test_create_column(self): tbl = self.tbl tbl.create_column('foo', FLOAT) assert 'foo' in tbl.table.c, tbl.table.c - assert isinstance(tbl.table.c['foo'].type, FLOAT), tbl.table.c['foo'].type + assert isinstance(tbl.table.c['foo'].type, FLOAT), \ + tbl.table.c['foo'].type assert 'foo' in tbl.columns, tbl.columns def test_ensure_column(self): tbl = self.tbl tbl.create_column_by_example('foo', 0.1) assert 'foo' in tbl.table.c, tbl.table.c - assert isinstance(tbl.table.c['foo'].type, FLOAT), tbl.table.c['bar'].type + assert isinstance(tbl.table.c['foo'].type, FLOAT), \ + tbl.table.c['bar'].type tbl.create_column_by_example('bar', 1) assert 'bar' in tbl.table.c, tbl.table.c - assert isinstance(tbl.table.c['bar'].type, INTEGER), tbl.table.c['bar'].type + assert isinstance(tbl.table.c['bar'].type, INTEGER), \ + tbl.table.c['bar'].type tbl.create_column_by_example('pippo', 'test') assert 'pippo' in tbl.table.c, tbl.table.c - assert isinstance(tbl.table.c['pippo'].type, TEXT), tbl.table.c['pippo'].type + assert isinstance(tbl.table.c['pippo'].type, TEXT), \ + tbl.table.c['pippo'].type def test_key_order(self): res = self.db.query('SELECT temperature, place FROM weather LIMIT 1')