refactor query code to be simpler

This commit is contained in:
Friedrich Lindenberg 2017-01-29 15:45:05 +01:00
parent fb4512783a
commit 522415a27c
5 changed files with 52 additions and 58 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
*.egg-info *.egg-info
*.egg *.egg
dist/* dist/*
.tox/*
build/* build/*
.DS_Store .DS_Store
.watchr .watchr

View File

@ -41,6 +41,7 @@ class Database(object):
self.lock = threading.RLock() self.lock = threading.RLock()
self.local = threading.local() self.local = threading.local()
if len(parsed_url.query): if len(parsed_url.query):
query = parse_qs(parsed_url.query) query = parse_qs(parsed_url.query)
if schema is None: if schema is None:
@ -56,6 +57,7 @@ class Database(object):
self._tables = {} self._tables = {}
self.metadata = MetaData(schema=schema) self.metadata = MetaData(schema=schema)
self.metadata.bind = self.engine self.metadata.bind = self.engine
if reflect_metadata: if reflect_metadata:
self.metadata.reflect(self.engine, views=reflect_views) self.metadata.reflect(self.engine, views=reflect_views)
for table_name in self.metadata.tables.keys(): for table_name in self.metadata.tables.keys():

View File

@ -4,7 +4,7 @@ from hashlib import sha1
from sqlalchemy.sql import and_, expression from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.schema import Column, Index from sqlalchemy.schema import Column, Index
from sqlalchemy import alias, false from sqlalchemy import func, select, false
from dataset.persistence.util import guess_type, normalize_column_name from dataset.persistence.util import guess_type, normalize_column_name
from dataset.persistence.util import ResultIter from dataset.persistence.util import ResultIter
from dataset.util import DatasetException from dataset.util import DatasetException
@ -388,10 +388,20 @@ class Table(object):
return None return None
def _args_to_order_by(self, order_by): def _args_to_order_by(self, order_by):
if order_by[0] == '-': if not isinstance(order_by, (list, tuple)):
return self.table.c[order_by[1:]].desc() order_by = [order_by]
else: orderings = []
return self.table.c[order_by].asc() for ordering in order_by:
if ordering is None:
continue
column = ordering.lstrip('-')
if column not in self.table.columns:
continue
if ordering.startswith('-'):
orderings.append(self.table.c[column].desc())
else:
orderings.append(self.table.c[column].asc())
return orderings
def find(self, *_clauses, **kwargs): def find(self, *_clauses, **kwargs):
""" """
@ -422,43 +432,33 @@ class Table(object):
_limit = kwargs.pop('_limit', None) _limit = kwargs.pop('_limit', None)
_offset = kwargs.pop('_offset', 0) _offset = kwargs.pop('_offset', 0)
_step = kwargs.pop('_step', 5000) _step = kwargs.pop('_step', 5000)
order_by = kwargs.pop('order_by', 'id') order_by = kwargs.pop('order_by', None)
return_count = kwargs.pop('return_count', False)
return_query = kwargs.pop('return_query', False)
_filter = kwargs
self._check_dropped() self._check_dropped()
if not isinstance(order_by, (list, tuple)): order_by = self._args_to_order_by(order_by)
order_by = [order_by] args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
order_by = [o for o in order_by if (o.startswith('-') and o[1:] or o) in self.table.columns]
order_by = [self._args_to_order_by(o) for o in order_by]
args = self._args_to_clause(_filter, ensure=False, clauses=_clauses) if _step is False or _step == 0:
_step = None
# query total number of rows first
count_query = alias(self.table.select(whereclause=args, limit=_limit, offset=_offset),
name='count_query_alias').count()
rp = self.database.executable.execute(count_query)
total_row_count = rp.fetchone()[0]
if return_count:
return total_row_count
if _limit is None:
_limit = total_row_count
if _step is None or _step is False or _step == 0:
_step = total_row_count
query = self.table.select(whereclause=args, limit=_limit, query = self.table.select(whereclause=args, limit=_limit,
offset=_offset, order_by=order_by) offset=_offset)
if return_query: if len(order_by):
return query query = query.order_by(*order_by)
return ResultIter(self.database.executable.execute(query), return ResultIter(self.database.executable.execute(query),
row_type=self.database.row_type, step=_step) row_type=self.database.row_type, step=_step)
def count(self, *args, **kwargs): def count(self, *_clauses, **kwargs):
"""Return the count of results for the given filter set.""" """Return the count of results for the given filter set."""
return self.find(*args, return_count=True, **kwargs) # NOTE: this does not have support for limit and offset since I can't
# see how this is useful. Still, there might be compatibility issues
# with people using these flags. Let's see how it goes.
self._check_dropped()
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
query = select([func.count()], whereclause=args)
query = query.select_from(self.table)
rp = self.database.executable.execute(query)
return rp.fetchone()[0]
def __len__(self): def __len__(self):
"""Return the number of rows in the table.""" """Return the number of rows in the table."""

View File

@ -44,39 +44,30 @@ def normalize_column_name(name):
return name return name
def iter_result_proxy(rp, step=None):
"""Iterate over the ResultProxy."""
while True:
if step is None:
chunk = rp.fetchall()
else:
chunk = rp.fetchmany(step)
if not chunk:
break
for row in chunk:
yield row
class ResultIter(object): class ResultIter(object):
""" SQLAlchemy ResultProxies are not iterable to get a """ SQLAlchemy ResultProxies are not iterable to get a
list of dictionaries. This is to wrap them. """ list of dictionaries. This is to wrap them. """
def __init__(self, result_proxy, row_type=row_type, step=None): def __init__(self, result_proxy, row_type=row_type, step=None):
self.result_proxy = result_proxy
self.row_type = row_type self.row_type = row_type
self.step = step
self.keys = list(result_proxy.keys()) self.keys = list(result_proxy.keys())
self._iter = None self._iter = iter_result_proxy(result_proxy, step=step)
def _next_chunk(self):
if self.result_proxy.closed:
return False
if not self.step:
chunk = self.result_proxy.fetchall()
else:
chunk = self.result_proxy.fetchmany(self.step)
if chunk:
self._iter = iter(chunk)
return True
else:
return False
def __next__(self): def __next__(self):
if self._iter is None: return convert_row(self.row_type, next(self._iter))
if not self._next_chunk():
raise StopIteration
try:
return convert_row(self.row_type, next(self._iter))
except StopIteration:
self._iter = None
return self.__next__()
next = __next__ next = __next__

View File

@ -23,7 +23,7 @@ setup(
], ],
keywords='sql sqlalchemy etl loading utility', keywords='sql sqlalchemy etl loading utility',
author='Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer', author='Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer',
author_email='info@okfn.org', author_email='friedrich@pudo.org',
url='http://github.com/pudo/dataset', url='http://github.com/pudo/dataset',
license='MIT', license='MIT',
packages=find_packages(exclude=['ez_setup', 'examples', 'test']), packages=find_packages(exclude=['ez_setup', 'examples', 'test']),