diff --git a/.gitignore b/.gitignore index b47312f..d396746 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ *.egg-info *.egg dist/* +.tox/* build/* .DS_Store .watchr diff --git a/.travis.yml b/.travis.yml index c29763b..540255a 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,7 @@ --- language: python python: +- '3.6' - '3.5' - '3.4' - '3.3' @@ -19,7 +20,7 @@ before_script: - sh -c "if [ '$DATABASE_URL' = 'mysql+pymysql://travis@127.0.0.1/dataset?charset=utf8' ]; then mysql -e 'create database IF NOT EXISTS dataset DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;'; fi" script: - flake8 --ignore=E501,E123,E124,E126,E127,E128 dataset test -- nosetests +- nosetests -v cache: directories: - $HOME/.cache/pip diff --git a/dataset/freeze/app.py b/dataset/freeze/app.py index 396e26f..141091c 100644 --- a/dataset/freeze/app.py +++ b/dataset/freeze/app.py @@ -161,6 +161,7 @@ def main(): # pragma: no cover except FreezeException as fe: log.error(fe) + if __name__ == '__main__': # pragma: no cover logging.basicConfig(level=logging.DEBUG) main() diff --git a/dataset/freeze/format/fjson.py b/dataset/freeze/format/fjson.py index fd22cd3..8086ccd 100644 --- a/dataset/freeze/format/fjson.py +++ b/dataset/freeze/format/fjson.py @@ -1,6 +1,6 @@ import json from datetime import datetime, date -from collections import defaultdict +from collections import defaultdict, OrderedDict from decimal import Decimal from six import PY3 @@ -29,10 +29,10 @@ class JSONSerializer(Serializer): if self.mode == 'item': result = result[0] if self.export.get_bool('wrap', True): - result = { - 'count': len(result), - 'results': result - } + result = OrderedDict([ + ('count', len(result)), + ('results', result), + ]) meta = self.export.get('meta', {}) if meta is not None: result['meta'] = meta diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 080cc78..d48af22 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -41,6 +41,7 @@ class Database(object): self.lock = threading.RLock() self.local = threading.local() + if len(parsed_url.query): query = parse_qs(parsed_url.query) if schema is None: @@ -56,6 +57,7 @@ class Database(object): self._tables = {} self.metadata = MetaData(schema=schema) self.metadata.bind = self.engine + if reflect_metadata: self.metadata.reflect(self.engine, views=reflect_views) for table_name in self.metadata.tables.keys(): @@ -63,39 +65,22 @@ class Database(object): @property def executable(self): - """Connection or engine against which statements will be executed.""" - if hasattr(self.local, 'connection'): - return self.local.connection - return self.engine + """Connection against which statements will be executed.""" + if not hasattr(self.local, 'conn'): + self.local.conn = self.engine.connect() + return self.local.conn @property def op(self): """Get an alembic operations context.""" - ctx = MigrationContext.configure(self.engine) + ctx = MigrationContext.configure(self.executable) return Operations(ctx) def _acquire(self): self.lock.acquire() def _release(self): - if not hasattr(self.local, 'tx'): - self.lock.release() - else: - self.local.lock_count[-1] += 1 - - def _release_internal(self): - for index in range(self.local.lock_count[-1]): - self.lock.release() - del self.local.lock_count[-1] - - def _dispose_transaction(self): - self._release_internal() - self.local.tx.remove(self.local.tx[-1]) - if not self.local.tx: - del self.local.tx - del self.local.lock_count - self.local.connection.close() - del self.local.connection + self.lock.release() def begin(self): """ @@ -105,13 +90,9 @@ class Database(object): **NOTICE:** Schema modification operations, such as the creation of tables or columns will not be part of the transactional context. """ - if not hasattr(self.local, 'connection'): - self.local.connection = self.engine.connect() if not hasattr(self.local, 'tx'): self.local.tx = [] - self.local.lock_count = [] - self.local.tx.append(self.local.connection.begin()) - self.local.lock_count.append(0) + self.local.tx.append(self.executable.begin()) def commit(self): """ @@ -120,8 +101,8 @@ class Database(object): Make all statements executed since the transaction was begun permanent. """ if hasattr(self.local, 'tx') and self.local.tx: - self.local.tx[-1].commit() - self._dispose_transaction() + tx = self.local.tx.pop() + tx.commit() def rollback(self): """ @@ -130,8 +111,8 @@ class Database(object): Discard all statements executed since the transaction was begun. """ if hasattr(self.local, 'tx') and self.local.tx: - self.local.tx[-1].rollback() - self._dispose_transaction() + tx = self.local.tx.pop() + tx.rollback() def __enter__(self): """Start a transaction.""" diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index 4da10ba..2153b28 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -4,7 +4,7 @@ from hashlib import sha1 from sqlalchemy.sql import and_, expression from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.schema import Column, Index -from sqlalchemy import alias, false +from sqlalchemy import func, select, false from dataset.persistence.util import guess_type, normalize_column_name from dataset.persistence.util import ResultIter from dataset.util import DatasetException @@ -249,7 +249,7 @@ class Table(object): If no arguments are given, all records are deleted. """ self._check_dropped() - if _filter: + if _filter or _clauses: q = self._args_to_clause(_filter, clauses=_clauses) stmt = self.table.delete(q) else: @@ -298,15 +298,17 @@ class Table(object): """ self._check_dropped() self.database._acquire() - try: - if normalize_column_name(name) not in self._normalized_columns: - self.database.op.add_column( - self.table.name, - Column(name, type), - self.table.schema - ) - self.table = self.database.update_table(self.table.name) + if normalize_column_name(name) in self._normalized_columns: + log.warn("Column with similar name exists: %s" % name) + return + + self.database.op.add_column( + self.table.name, + Column(name, type), + self.table.schema + ) + self.table = self.database.update_table(self.table.name) finally: self.database._release() @@ -399,10 +401,20 @@ class Table(object): return None def _args_to_order_by(self, order_by): - if order_by[0] == '-': - return self.table.c[order_by[1:]].desc() - else: - return self.table.c[order_by].asc() + if not isinstance(order_by, (list, tuple)): + order_by = [order_by] + orderings = [] + for ordering in order_by: + if ordering is None: + continue + column = ordering.lstrip('-') + if column not in self.table.columns: + continue + if ordering.startswith('-'): + orderings.append(self.table.c[column].desc()) + else: + orderings.append(self.table.c[column].asc()) + return orderings def find(self, *_clauses, **kwargs): """ @@ -433,43 +445,33 @@ class Table(object): _limit = kwargs.pop('_limit', None) _offset = kwargs.pop('_offset', 0) _step = kwargs.pop('_step', 5000) - order_by = kwargs.pop('order_by', 'id') - return_count = kwargs.pop('return_count', False) - return_query = kwargs.pop('return_query', False) - _filter = kwargs + order_by = kwargs.pop('order_by', None) self._check_dropped() - if not isinstance(order_by, (list, tuple)): - order_by = [order_by] - order_by = [o for o in order_by if (o.startswith('-') and o[1:] or o) in self.table.columns] - order_by = [self._args_to_order_by(o) for o in order_by] + order_by = self._args_to_order_by(order_by) + args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses) - args = self._args_to_clause(_filter, ensure=False, clauses=_clauses) - - # query total number of rows first - count_query = alias(self.table.select(whereclause=args, limit=_limit, offset=_offset), - name='count_query_alias').count() - rp = self.database.executable.execute(count_query) - total_row_count = rp.fetchone()[0] - if return_count: - return total_row_count - - if _limit is None: - _limit = total_row_count - - if _step is None or _step is False or _step == 0: - _step = total_row_count + if _step is False or _step == 0: + _step = None query = self.table.select(whereclause=args, limit=_limit, - offset=_offset, order_by=order_by) - if return_query: - return query + offset=_offset) + if len(order_by): + query = query.order_by(*order_by) return ResultIter(self.database.executable.execute(query), row_type=self.database.row_type, step=_step) - def count(self, *args, **kwargs): + def count(self, *_clauses, **kwargs): """Return the count of results for the given filter set.""" - return self.find(*args, return_count=True, **kwargs) + # NOTE: this does not have support for limit and offset since I can't + # see how this is useful. Still, there might be compatibility issues + # with people using these flags. Let's see how it goes. + self._check_dropped() + args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses) + query = select([func.count()], whereclause=args) + query = query.select_from(self.table) + rp = self.database.executable.execute(query) + return rp.fetchone()[0] def __len__(self): """Return the number of rows in the table.""" diff --git a/dataset/persistence/util.py b/dataset/persistence/util.py index 8f616b6..8c1b380 100644 --- a/dataset/persistence/util.py +++ b/dataset/persistence/util.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import datetime, date try: from urlparse import urlparse except ImportError: @@ -9,7 +9,7 @@ try: except ImportError: # pragma: no cover from ordereddict import OrderedDict -from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean +from sqlalchemy import Integer, UnicodeText, Float, Date, DateTime, Boolean from six import string_types row_type = OrderedDict @@ -24,6 +24,8 @@ def guess_type(sample): return Float elif isinstance(sample, datetime): return DateTime + elif isinstance(sample, date): + return Date return UnicodeText @@ -42,39 +44,30 @@ def normalize_column_name(name): return name +def iter_result_proxy(rp, step=None): + """Iterate over the ResultProxy.""" + while True: + if step is None: + chunk = rp.fetchall() + else: + chunk = rp.fetchmany(step) + if not chunk: + break + for row in chunk: + yield row + + class ResultIter(object): """ SQLAlchemy ResultProxies are not iterable to get a list of dictionaries. This is to wrap them. """ def __init__(self, result_proxy, row_type=row_type, step=None): - self.result_proxy = result_proxy self.row_type = row_type - self.step = step self.keys = list(result_proxy.keys()) - self._iter = None - - def _next_chunk(self): - if self.result_proxy.closed: - return False - if not self.step: - chunk = self.result_proxy.fetchall() - else: - chunk = self.result_proxy.fetchmany(self.step) - if chunk: - self._iter = iter(chunk) - return True - else: - return False + self._iter = iter_result_proxy(result_proxy, step=step) def __next__(self): - if self._iter is None: - if not self._next_chunk(): - raise StopIteration - try: - return convert_row(self.row_type, next(self._iter)) - except StopIteration: - self._iter = None - return self.__next__() + return convert_row(self.row_type, next(self._iter)) next = __next__ diff --git a/setup.py b/setup.py index f1c4664..137314c 100644 --- a/setup.py +++ b/setup.py @@ -8,7 +8,7 @@ if sys.version_info[:2] <= (2, 6): setup( name='dataset', - version='0.7.0', + version='0.8.0', description="Toolkit for Python-based data processing.", long_description="", classifiers=[ @@ -16,15 +16,14 @@ setup( "Intended Audience :: Developers", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", - 'Programming Language :: Python :: 2.6', 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3.3' - 'Programming Language :: Python :: 3.4' + 'Programming Language :: Python :: 3.3', + 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5' ], keywords='sql sqlalchemy etl loading utility', author='Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer', - author_email='info@okfn.org', + author_email='friedrich@pudo.org', url='http://github.com/pudo/dataset', license='MIT', packages=find_packages(exclude=['ez_setup', 'examples', 'test']), @@ -34,7 +33,7 @@ setup( install_requires=[ 'sqlalchemy >= 0.9.1', 'alembic >= 0.6.2', - 'normality >= 0.2.2', + 'normality >= 0.3.9', "PyYAML >= 3.10", "six >= 1.7.3" ] + py26_dependency, diff --git a/test/test_persistence.py b/test/test_persistence.py index 11131d9..1e6ab92 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -8,7 +8,7 @@ except ImportError: # pragma: no cover from ordereddict import OrderedDict # Python < 2.7 drop-in from sqlalchemy import FLOAT, INTEGER, TEXT -from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError from dataset import connect from dataset.util import DatasetException @@ -252,7 +252,13 @@ class TableTestCase(unittest.TestCase): 'temperature': -10, 'place': 'Berlin'} ) + original_count = len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) + # Test bad use of API + with self.assertRaises(ArgumentError): + self.tbl.delete({'place': 'Berlin'}) + assert len(self.tbl) == original_count, len(self.tbl) + assert self.tbl.delete(place='Berlin') is True, 'should return 1' assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert self.tbl.delete() is True, 'should return non zero'