diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index f4c04a0..1d29de8 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -2,6 +2,7 @@ import logging from hashlib import sha1 from sqlalchemy.sql import and_, expression +from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.schema import Column, Index from sqlalchemy import alias, func from dataset.persistence.util import guess_type, normalize_column_name @@ -182,7 +183,7 @@ class Table(object): else: return self.insert(row, ensure=ensure, types=types) - def delete(self, **_filter): + def delete(self, *_clauses, **_filter): """ Delete rows from the table. Keyword arguments can be used to add column-based filters. The filter criterion will always be equality: @@ -195,7 +196,7 @@ class Table(object): """ self._check_dropped() if _filter: - q = self._args_to_clause(_filter) + q = self._args_to_clause(_filter, clauses=_clauses) stmt = self.table.delete(q) else: stmt = self.table.delete() @@ -218,10 +219,10 @@ class Table(object): _type, self.table.name)) self.create_column(column, _type) - def _args_to_clause(self, args, ensure=True): + def _args_to_clause(self, args, ensure=True, clauses=()): if ensure: self._ensure_columns(args) - clauses = [] + clauses = list(clauses) for k, v in args.items(): if not self._has_column(k): clauses.append(func.sum(1) == 2) @@ -309,7 +310,7 @@ class Table(object): self.indexes[name] = idx return idx - def find_one(self, **kwargs): + def find_one(self, *args, **kwargs): """ Works just like :py:meth:`find() ` but returns one result, or None. :: @@ -317,7 +318,7 @@ class Table(object): row = table.find_one(country='United States') """ kwargs['_limit'] = 1 - iterator = self.find(**kwargs) + iterator = self.find(*args, **kwargs) try: return next(iterator) except StopIteration: @@ -329,8 +330,7 @@ class Table(object): else: return self.table.c[order_by].asc() - def find(self, _limit=None, _offset=0, _step=5000, - order_by='id', return_count=False, **_filter): + def find(self, *_clauses, **kwargs): """ Performs a simple search on the table. Simply pass keyword arguments as ``filter``. :: @@ -353,13 +353,20 @@ class Table(object): For more complex queries, please use :py:meth:`db.query() ` instead.""" + _limit = kwargs.pop('_limit', None) + _offset = kwargs.pop('_offset', 0) + _step = kwargs.pop('_step', 5000) + order_by = kwargs.pop('order_by', 'id') + return_count = kwargs.pop('return_count', False) + _filter = kwargs + self._check_dropped() if not isinstance(order_by, (list, tuple)): order_by = [order_by] order_by = [o for o in order_by if (o.startswith('-') and o[1:] or o) in self.table.columns] order_by = [self._args_to_order_by(o) for o in order_by] - args = self._args_to_clause(_filter, ensure=False) + args = self._args_to_clause(_filter, ensure=False, clauses=_clauses) # query total number of rows first count_query = alias(self.table.select(whereclause=args, limit=_limit, offset=_offset), @@ -380,11 +387,11 @@ class Table(object): return ResultIter(self.database.executable.execute(query), row_type=self.database.row_type, step=_step) - def count(self, **_filter): + def count(self, *args, **kwargs): """ Return the count of results for the given filter set (same filter options as with ``find()``). """ - return self.find(return_count=True, **_filter) + return self.find(*args, return_count=True, **kwargs) def __len__(self): """ @@ -392,7 +399,7 @@ class Table(object): """ return self.count() - def distinct(self, *columns, **_filter): + def distinct(self, *args, **_filter): """ Returns all rows of a table, but removes rows in with duplicate values in ``columns``. Interally this creates a `DISTINCT statement `_. @@ -407,8 +414,13 @@ class Table(object): """ self._check_dropped() qargs = [] + columns = [] try: - columns = [self.table.c[c] for c in columns] + for c in args: + if isinstance(c, ClauseElement): + qargs.append(c) + else: + columns.append(self.table.c[c]) for col, val in _filter.items(): qargs.append(self.table.c[col] == val) except KeyError: diff --git a/docs/quickstart.rst b/docs/quickstart.rst index f3d656d..d6a9370 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -138,6 +138,9 @@ We can search for specific entries using :py:meth:`find() ` # Get a specific user john = table.find_one(name='John Doe') + # Find by comparison + elderly_users = table.find(table.table.columns.age >= 70) + Using :py:meth:`distinct() ` we can grab a set of rows with unique values in one or more columns:: diff --git a/test/test_persistence.py b/test/test_persistence.py index b864d24..db2836c 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -267,6 +267,8 @@ class TableTestCase(unittest.TestCase): assert ds[0]['temperature'] == -1, ds ds = list(self.tbl.find(order_by=['-temperature'])) assert ds[0]['temperature'] == 8, ds + ds = list(self.tbl.find(self.tbl.table.columns.temperature > 4)) + assert len(ds) == 3, ds def test_offset(self): ds = list(self.tbl.find(place=TEST_CITY_1, _offset=1)) @@ -279,6 +281,10 @@ class TableTestCase(unittest.TestCase): assert len(x) == 2, x x = list(self.tbl.distinct('place', 'date')) assert len(x) == 6, x + x = list(self.tbl.distinct( + 'place', 'date', + self.tbl.table.columns.date >= datetime(2011, 1, 2, 0, 0))) + assert len(x) == 4, x def test_get_items(self): x = list(self.tbl['place'])