Merge branch 'master' into create_column_by_example

This commit is contained in:
Friedrich Lindenberg 2017-09-02 08:22:22 +02:00 committed by GitHub
commit edc41e4d82
9 changed files with 97 additions and 113 deletions

1
.gitignore vendored
View File

@ -2,6 +2,7 @@
*.egg-info *.egg-info
*.egg *.egg
dist/* dist/*
.tox/*
build/* build/*
.DS_Store .DS_Store
.watchr .watchr

View File

@ -1,6 +1,7 @@
--- ---
language: python language: python
python: python:
- '3.6'
- '3.5' - '3.5'
- '3.4' - '3.4'
- '3.3' - '3.3'
@ -19,7 +20,7 @@ before_script:
- sh -c "if [ '$DATABASE_URL' = 'mysql+pymysql://travis@127.0.0.1/dataset?charset=utf8' ]; then mysql -e 'create database IF NOT EXISTS dataset DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;'; fi" - sh -c "if [ '$DATABASE_URL' = 'mysql+pymysql://travis@127.0.0.1/dataset?charset=utf8' ]; then mysql -e 'create database IF NOT EXISTS dataset DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;'; fi"
script: script:
- flake8 --ignore=E501,E123,E124,E126,E127,E128 dataset test - flake8 --ignore=E501,E123,E124,E126,E127,E128 dataset test
- nosetests - nosetests -v
cache: cache:
directories: directories:
- $HOME/.cache/pip - $HOME/.cache/pip

View File

@ -161,6 +161,7 @@ def main(): # pragma: no cover
except FreezeException as fe: except FreezeException as fe:
log.error(fe) log.error(fe)
if __name__ == '__main__': # pragma: no cover if __name__ == '__main__': # pragma: no cover
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
main() main()

View File

@ -1,6 +1,6 @@
import json import json
from datetime import datetime, date from datetime import datetime, date
from collections import defaultdict from collections import defaultdict, OrderedDict
from decimal import Decimal from decimal import Decimal
from six import PY3 from six import PY3
@ -29,10 +29,10 @@ class JSONSerializer(Serializer):
if self.mode == 'item': if self.mode == 'item':
result = result[0] result = result[0]
if self.export.get_bool('wrap', True): if self.export.get_bool('wrap', True):
result = { result = OrderedDict([
'count': len(result), ('count', len(result)),
'results': result ('results', result),
} ])
meta = self.export.get('meta', {}) meta = self.export.get('meta', {})
if meta is not None: if meta is not None:
result['meta'] = meta result['meta'] = meta

View File

@ -41,6 +41,7 @@ class Database(object):
self.lock = threading.RLock() self.lock = threading.RLock()
self.local = threading.local() self.local = threading.local()
if len(parsed_url.query): if len(parsed_url.query):
query = parse_qs(parsed_url.query) query = parse_qs(parsed_url.query)
if schema is None: if schema is None:
@ -56,6 +57,7 @@ class Database(object):
self._tables = {} self._tables = {}
self.metadata = MetaData(schema=schema) self.metadata = MetaData(schema=schema)
self.metadata.bind = self.engine self.metadata.bind = self.engine
if reflect_metadata: if reflect_metadata:
self.metadata.reflect(self.engine, views=reflect_views) self.metadata.reflect(self.engine, views=reflect_views)
for table_name in self.metadata.tables.keys(): for table_name in self.metadata.tables.keys():
@ -63,39 +65,22 @@ class Database(object):
@property @property
def executable(self): def executable(self):
"""Connection or engine against which statements will be executed.""" """Connection against which statements will be executed."""
if hasattr(self.local, 'connection'): if not hasattr(self.local, 'conn'):
return self.local.connection self.local.conn = self.engine.connect()
return self.engine return self.local.conn
@property @property
def op(self): def op(self):
"""Get an alembic operations context.""" """Get an alembic operations context."""
ctx = MigrationContext.configure(self.engine) ctx = MigrationContext.configure(self.executable)
return Operations(ctx) return Operations(ctx)
def _acquire(self): def _acquire(self):
self.lock.acquire() self.lock.acquire()
def _release(self): def _release(self):
if not hasattr(self.local, 'tx'): self.lock.release()
self.lock.release()
else:
self.local.lock_count[-1] += 1
def _release_internal(self):
for index in range(self.local.lock_count[-1]):
self.lock.release()
del self.local.lock_count[-1]
def _dispose_transaction(self):
self._release_internal()
self.local.tx.remove(self.local.tx[-1])
if not self.local.tx:
del self.local.tx
del self.local.lock_count
self.local.connection.close()
del self.local.connection
def begin(self): def begin(self):
""" """
@ -105,13 +90,9 @@ class Database(object):
**NOTICE:** Schema modification operations, such as the creation **NOTICE:** Schema modification operations, such as the creation
of tables or columns will not be part of the transactional context. of tables or columns will not be part of the transactional context.
""" """
if not hasattr(self.local, 'connection'):
self.local.connection = self.engine.connect()
if not hasattr(self.local, 'tx'): if not hasattr(self.local, 'tx'):
self.local.tx = [] self.local.tx = []
self.local.lock_count = [] self.local.tx.append(self.executable.begin())
self.local.tx.append(self.local.connection.begin())
self.local.lock_count.append(0)
def commit(self): def commit(self):
""" """
@ -120,8 +101,8 @@ class Database(object):
Make all statements executed since the transaction was begun permanent. Make all statements executed since the transaction was begun permanent.
""" """
if hasattr(self.local, 'tx') and self.local.tx: if hasattr(self.local, 'tx') and self.local.tx:
self.local.tx[-1].commit() tx = self.local.tx.pop()
self._dispose_transaction() tx.commit()
def rollback(self): def rollback(self):
""" """
@ -130,8 +111,8 @@ class Database(object):
Discard all statements executed since the transaction was begun. Discard all statements executed since the transaction was begun.
""" """
if hasattr(self.local, 'tx') and self.local.tx: if hasattr(self.local, 'tx') and self.local.tx:
self.local.tx[-1].rollback() tx = self.local.tx.pop()
self._dispose_transaction() tx.rollback()
def __enter__(self): def __enter__(self):
"""Start a transaction.""" """Start a transaction."""

View File

@ -4,7 +4,7 @@ from hashlib import sha1
from sqlalchemy.sql import and_, expression from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.schema import Column, Index from sqlalchemy.schema import Column, Index
from sqlalchemy import alias, false from sqlalchemy import func, select, false
from dataset.persistence.util import guess_type, normalize_column_name from dataset.persistence.util import guess_type, normalize_column_name
from dataset.persistence.util import ResultIter from dataset.persistence.util import ResultIter
from dataset.util import DatasetException from dataset.util import DatasetException
@ -249,7 +249,7 @@ class Table(object):
If no arguments are given, all records are deleted. If no arguments are given, all records are deleted.
""" """
self._check_dropped() self._check_dropped()
if _filter: if _filter or _clauses:
q = self._args_to_clause(_filter, clauses=_clauses) q = self._args_to_clause(_filter, clauses=_clauses)
stmt = self.table.delete(q) stmt = self.table.delete(q)
else: else:
@ -298,15 +298,17 @@ class Table(object):
""" """
self._check_dropped() self._check_dropped()
self.database._acquire() self.database._acquire()
try: try:
if normalize_column_name(name) not in self._normalized_columns: if normalize_column_name(name) in self._normalized_columns:
self.database.op.add_column( log.warn("Column with similar name exists: %s" % name)
self.table.name, return
Column(name, type),
self.table.schema self.database.op.add_column(
) self.table.name,
self.table = self.database.update_table(self.table.name) Column(name, type),
self.table.schema
)
self.table = self.database.update_table(self.table.name)
finally: finally:
self.database._release() self.database._release()
@ -399,10 +401,20 @@ class Table(object):
return None return None
def _args_to_order_by(self, order_by): def _args_to_order_by(self, order_by):
if order_by[0] == '-': if not isinstance(order_by, (list, tuple)):
return self.table.c[order_by[1:]].desc() order_by = [order_by]
else: orderings = []
return self.table.c[order_by].asc() for ordering in order_by:
if ordering is None:
continue
column = ordering.lstrip('-')
if column not in self.table.columns:
continue
if ordering.startswith('-'):
orderings.append(self.table.c[column].desc())
else:
orderings.append(self.table.c[column].asc())
return orderings
def find(self, *_clauses, **kwargs): def find(self, *_clauses, **kwargs):
""" """
@ -433,43 +445,33 @@ class Table(object):
_limit = kwargs.pop('_limit', None) _limit = kwargs.pop('_limit', None)
_offset = kwargs.pop('_offset', 0) _offset = kwargs.pop('_offset', 0)
_step = kwargs.pop('_step', 5000) _step = kwargs.pop('_step', 5000)
order_by = kwargs.pop('order_by', 'id') order_by = kwargs.pop('order_by', None)
return_count = kwargs.pop('return_count', False)
return_query = kwargs.pop('return_query', False)
_filter = kwargs
self._check_dropped() self._check_dropped()
if not isinstance(order_by, (list, tuple)): order_by = self._args_to_order_by(order_by)
order_by = [order_by] args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
order_by = [o for o in order_by if (o.startswith('-') and o[1:] or o) in self.table.columns]
order_by = [self._args_to_order_by(o) for o in order_by]
args = self._args_to_clause(_filter, ensure=False, clauses=_clauses) if _step is False or _step == 0:
_step = None
# query total number of rows first
count_query = alias(self.table.select(whereclause=args, limit=_limit, offset=_offset),
name='count_query_alias').count()
rp = self.database.executable.execute(count_query)
total_row_count = rp.fetchone()[0]
if return_count:
return total_row_count
if _limit is None:
_limit = total_row_count
if _step is None or _step is False or _step == 0:
_step = total_row_count
query = self.table.select(whereclause=args, limit=_limit, query = self.table.select(whereclause=args, limit=_limit,
offset=_offset, order_by=order_by) offset=_offset)
if return_query: if len(order_by):
return query query = query.order_by(*order_by)
return ResultIter(self.database.executable.execute(query), return ResultIter(self.database.executable.execute(query),
row_type=self.database.row_type, step=_step) row_type=self.database.row_type, step=_step)
def count(self, *args, **kwargs): def count(self, *_clauses, **kwargs):
"""Return the count of results for the given filter set.""" """Return the count of results for the given filter set."""
return self.find(*args, return_count=True, **kwargs) # NOTE: this does not have support for limit and offset since I can't
# see how this is useful. Still, there might be compatibility issues
# with people using these flags. Let's see how it goes.
self._check_dropped()
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
query = select([func.count()], whereclause=args)
query = query.select_from(self.table)
rp = self.database.executable.execute(query)
return rp.fetchone()[0]
def __len__(self): def __len__(self):
"""Return the number of rows in the table.""" """Return the number of rows in the table."""

View File

@ -1,4 +1,4 @@
from datetime import datetime from datetime import datetime, date
try: try:
from urlparse import urlparse from urlparse import urlparse
except ImportError: except ImportError:
@ -9,7 +9,7 @@ try:
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
from ordereddict import OrderedDict from ordereddict import OrderedDict
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean from sqlalchemy import Integer, UnicodeText, Float, Date, DateTime, Boolean
from six import string_types from six import string_types
row_type = OrderedDict row_type = OrderedDict
@ -24,6 +24,8 @@ def guess_type(sample):
return Float return Float
elif isinstance(sample, datetime): elif isinstance(sample, datetime):
return DateTime return DateTime
elif isinstance(sample, date):
return Date
return UnicodeText return UnicodeText
@ -42,39 +44,30 @@ def normalize_column_name(name):
return name return name
def iter_result_proxy(rp, step=None):
"""Iterate over the ResultProxy."""
while True:
if step is None:
chunk = rp.fetchall()
else:
chunk = rp.fetchmany(step)
if not chunk:
break
for row in chunk:
yield row
class ResultIter(object): class ResultIter(object):
""" SQLAlchemy ResultProxies are not iterable to get a """ SQLAlchemy ResultProxies are not iterable to get a
list of dictionaries. This is to wrap them. """ list of dictionaries. This is to wrap them. """
def __init__(self, result_proxy, row_type=row_type, step=None): def __init__(self, result_proxy, row_type=row_type, step=None):
self.result_proxy = result_proxy
self.row_type = row_type self.row_type = row_type
self.step = step
self.keys = list(result_proxy.keys()) self.keys = list(result_proxy.keys())
self._iter = None self._iter = iter_result_proxy(result_proxy, step=step)
def _next_chunk(self):
if self.result_proxy.closed:
return False
if not self.step:
chunk = self.result_proxy.fetchall()
else:
chunk = self.result_proxy.fetchmany(self.step)
if chunk:
self._iter = iter(chunk)
return True
else:
return False
def __next__(self): def __next__(self):
if self._iter is None: return convert_row(self.row_type, next(self._iter))
if not self._next_chunk():
raise StopIteration
try:
return convert_row(self.row_type, next(self._iter))
except StopIteration:
self._iter = None
return self.__next__()
next = __next__ next = __next__

View File

@ -8,7 +8,7 @@ if sys.version_info[:2] <= (2, 6):
setup( setup(
name='dataset', name='dataset',
version='0.7.0', version='0.8.0',
description="Toolkit for Python-based data processing.", description="Toolkit for Python-based data processing.",
long_description="", long_description="",
classifiers=[ classifiers=[
@ -16,15 +16,14 @@ setup(
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.3' 'Programming Language :: Python :: 3.3',
'Programming Language :: Python :: 3.4' 'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5' 'Programming Language :: Python :: 3.5'
], ],
keywords='sql sqlalchemy etl loading utility', keywords='sql sqlalchemy etl loading utility',
author='Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer', author='Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer',
author_email='info@okfn.org', author_email='friedrich@pudo.org',
url='http://github.com/pudo/dataset', url='http://github.com/pudo/dataset',
license='MIT', license='MIT',
packages=find_packages(exclude=['ez_setup', 'examples', 'test']), packages=find_packages(exclude=['ez_setup', 'examples', 'test']),
@ -34,7 +33,7 @@ setup(
install_requires=[ install_requires=[
'sqlalchemy >= 0.9.1', 'sqlalchemy >= 0.9.1',
'alembic >= 0.6.2', 'alembic >= 0.6.2',
'normality >= 0.2.2', 'normality >= 0.3.9',
"PyYAML >= 3.10", "PyYAML >= 3.10",
"six >= 1.7.3" "six >= 1.7.3"
] + py26_dependency, ] + py26_dependency,

View File

@ -8,7 +8,7 @@ except ImportError: # pragma: no cover
from ordereddict import OrderedDict # Python < 2.7 drop-in from ordereddict import OrderedDict # Python < 2.7 drop-in
from sqlalchemy import FLOAT, INTEGER, TEXT from sqlalchemy import FLOAT, INTEGER, TEXT
from sqlalchemy.exc import IntegrityError, SQLAlchemyError from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
from dataset import connect from dataset import connect
from dataset.util import DatasetException from dataset.util import DatasetException
@ -252,7 +252,13 @@ class TableTestCase(unittest.TestCase):
'temperature': -10, 'temperature': -10,
'place': 'Berlin'} 'place': 'Berlin'}
) )
original_count = len(self.tbl)
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
# Test bad use of API
with self.assertRaises(ArgumentError):
self.tbl.delete({'place': 'Berlin'})
assert len(self.tbl) == original_count, len(self.tbl)
assert self.tbl.delete(place='Berlin') is True, 'should return 1' assert self.tbl.delete(place='Berlin') is True, 'should return 1'
assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
assert self.tbl.delete() is True, 'should return non zero' assert self.tbl.delete() is True, 'should return non zero'