Merge branch 'master' into create_column_by_example
This commit is contained in:
commit
edc41e4d82
1
.gitignore
vendored
1
.gitignore
vendored
@ -2,6 +2,7 @@
|
||||
*.egg-info
|
||||
*.egg
|
||||
dist/*
|
||||
.tox/*
|
||||
build/*
|
||||
.DS_Store
|
||||
.watchr
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
---
|
||||
language: python
|
||||
python:
|
||||
- '3.6'
|
||||
- '3.5'
|
||||
- '3.4'
|
||||
- '3.3'
|
||||
@ -19,7 +20,7 @@ before_script:
|
||||
- sh -c "if [ '$DATABASE_URL' = 'mysql+pymysql://travis@127.0.0.1/dataset?charset=utf8' ]; then mysql -e 'create database IF NOT EXISTS dataset DEFAULT CHARACTER SET utf8 DEFAULT COLLATE utf8_general_ci;'; fi"
|
||||
script:
|
||||
- flake8 --ignore=E501,E123,E124,E126,E127,E128 dataset test
|
||||
- nosetests
|
||||
- nosetests -v
|
||||
cache:
|
||||
directories:
|
||||
- $HOME/.cache/pip
|
||||
|
||||
@ -161,6 +161,7 @@ def main(): # pragma: no cover
|
||||
except FreezeException as fe:
|
||||
log.error(fe)
|
||||
|
||||
|
||||
if __name__ == '__main__': # pragma: no cover
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
main()
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
import json
|
||||
from datetime import datetime, date
|
||||
from collections import defaultdict
|
||||
from collections import defaultdict, OrderedDict
|
||||
from decimal import Decimal
|
||||
|
||||
from six import PY3
|
||||
@ -29,10 +29,10 @@ class JSONSerializer(Serializer):
|
||||
if self.mode == 'item':
|
||||
result = result[0]
|
||||
if self.export.get_bool('wrap', True):
|
||||
result = {
|
||||
'count': len(result),
|
||||
'results': result
|
||||
}
|
||||
result = OrderedDict([
|
||||
('count', len(result)),
|
||||
('results', result),
|
||||
])
|
||||
meta = self.export.get('meta', {})
|
||||
if meta is not None:
|
||||
result['meta'] = meta
|
||||
|
||||
@ -41,6 +41,7 @@ class Database(object):
|
||||
|
||||
self.lock = threading.RLock()
|
||||
self.local = threading.local()
|
||||
|
||||
if len(parsed_url.query):
|
||||
query = parse_qs(parsed_url.query)
|
||||
if schema is None:
|
||||
@ -56,6 +57,7 @@ class Database(object):
|
||||
self._tables = {}
|
||||
self.metadata = MetaData(schema=schema)
|
||||
self.metadata.bind = self.engine
|
||||
|
||||
if reflect_metadata:
|
||||
self.metadata.reflect(self.engine, views=reflect_views)
|
||||
for table_name in self.metadata.tables.keys():
|
||||
@ -63,39 +65,22 @@ class Database(object):
|
||||
|
||||
@property
|
||||
def executable(self):
|
||||
"""Connection or engine against which statements will be executed."""
|
||||
if hasattr(self.local, 'connection'):
|
||||
return self.local.connection
|
||||
return self.engine
|
||||
"""Connection against which statements will be executed."""
|
||||
if not hasattr(self.local, 'conn'):
|
||||
self.local.conn = self.engine.connect()
|
||||
return self.local.conn
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
"""Get an alembic operations context."""
|
||||
ctx = MigrationContext.configure(self.engine)
|
||||
ctx = MigrationContext.configure(self.executable)
|
||||
return Operations(ctx)
|
||||
|
||||
def _acquire(self):
|
||||
self.lock.acquire()
|
||||
|
||||
def _release(self):
|
||||
if not hasattr(self.local, 'tx'):
|
||||
self.lock.release()
|
||||
else:
|
||||
self.local.lock_count[-1] += 1
|
||||
|
||||
def _release_internal(self):
|
||||
for index in range(self.local.lock_count[-1]):
|
||||
self.lock.release()
|
||||
del self.local.lock_count[-1]
|
||||
|
||||
def _dispose_transaction(self):
|
||||
self._release_internal()
|
||||
self.local.tx.remove(self.local.tx[-1])
|
||||
if not self.local.tx:
|
||||
del self.local.tx
|
||||
del self.local.lock_count
|
||||
self.local.connection.close()
|
||||
del self.local.connection
|
||||
self.lock.release()
|
||||
|
||||
def begin(self):
|
||||
"""
|
||||
@ -105,13 +90,9 @@ class Database(object):
|
||||
**NOTICE:** Schema modification operations, such as the creation
|
||||
of tables or columns will not be part of the transactional context.
|
||||
"""
|
||||
if not hasattr(self.local, 'connection'):
|
||||
self.local.connection = self.engine.connect()
|
||||
if not hasattr(self.local, 'tx'):
|
||||
self.local.tx = []
|
||||
self.local.lock_count = []
|
||||
self.local.tx.append(self.local.connection.begin())
|
||||
self.local.lock_count.append(0)
|
||||
self.local.tx.append(self.executable.begin())
|
||||
|
||||
def commit(self):
|
||||
"""
|
||||
@ -120,8 +101,8 @@ class Database(object):
|
||||
Make all statements executed since the transaction was begun permanent.
|
||||
"""
|
||||
if hasattr(self.local, 'tx') and self.local.tx:
|
||||
self.local.tx[-1].commit()
|
||||
self._dispose_transaction()
|
||||
tx = self.local.tx.pop()
|
||||
tx.commit()
|
||||
|
||||
def rollback(self):
|
||||
"""
|
||||
@ -130,8 +111,8 @@ class Database(object):
|
||||
Discard all statements executed since the transaction was begun.
|
||||
"""
|
||||
if hasattr(self.local, 'tx') and self.local.tx:
|
||||
self.local.tx[-1].rollback()
|
||||
self._dispose_transaction()
|
||||
tx = self.local.tx.pop()
|
||||
tx.rollback()
|
||||
|
||||
def __enter__(self):
|
||||
"""Start a transaction."""
|
||||
|
||||
@ -4,7 +4,7 @@ from hashlib import sha1
|
||||
from sqlalchemy.sql import and_, expression
|
||||
from sqlalchemy.sql.expression import ClauseElement
|
||||
from sqlalchemy.schema import Column, Index
|
||||
from sqlalchemy import alias, false
|
||||
from sqlalchemy import func, select, false
|
||||
from dataset.persistence.util import guess_type, normalize_column_name
|
||||
from dataset.persistence.util import ResultIter
|
||||
from dataset.util import DatasetException
|
||||
@ -249,7 +249,7 @@ class Table(object):
|
||||
If no arguments are given, all records are deleted.
|
||||
"""
|
||||
self._check_dropped()
|
||||
if _filter:
|
||||
if _filter or _clauses:
|
||||
q = self._args_to_clause(_filter, clauses=_clauses)
|
||||
stmt = self.table.delete(q)
|
||||
else:
|
||||
@ -298,15 +298,17 @@ class Table(object):
|
||||
"""
|
||||
self._check_dropped()
|
||||
self.database._acquire()
|
||||
|
||||
try:
|
||||
if normalize_column_name(name) not in self._normalized_columns:
|
||||
self.database.op.add_column(
|
||||
self.table.name,
|
||||
Column(name, type),
|
||||
self.table.schema
|
||||
)
|
||||
self.table = self.database.update_table(self.table.name)
|
||||
if normalize_column_name(name) in self._normalized_columns:
|
||||
log.warn("Column with similar name exists: %s" % name)
|
||||
return
|
||||
|
||||
self.database.op.add_column(
|
||||
self.table.name,
|
||||
Column(name, type),
|
||||
self.table.schema
|
||||
)
|
||||
self.table = self.database.update_table(self.table.name)
|
||||
finally:
|
||||
self.database._release()
|
||||
|
||||
@ -399,10 +401,20 @@ class Table(object):
|
||||
return None
|
||||
|
||||
def _args_to_order_by(self, order_by):
|
||||
if order_by[0] == '-':
|
||||
return self.table.c[order_by[1:]].desc()
|
||||
else:
|
||||
return self.table.c[order_by].asc()
|
||||
if not isinstance(order_by, (list, tuple)):
|
||||
order_by = [order_by]
|
||||
orderings = []
|
||||
for ordering in order_by:
|
||||
if ordering is None:
|
||||
continue
|
||||
column = ordering.lstrip('-')
|
||||
if column not in self.table.columns:
|
||||
continue
|
||||
if ordering.startswith('-'):
|
||||
orderings.append(self.table.c[column].desc())
|
||||
else:
|
||||
orderings.append(self.table.c[column].asc())
|
||||
return orderings
|
||||
|
||||
def find(self, *_clauses, **kwargs):
|
||||
"""
|
||||
@ -433,43 +445,33 @@ class Table(object):
|
||||
_limit = kwargs.pop('_limit', None)
|
||||
_offset = kwargs.pop('_offset', 0)
|
||||
_step = kwargs.pop('_step', 5000)
|
||||
order_by = kwargs.pop('order_by', 'id')
|
||||
return_count = kwargs.pop('return_count', False)
|
||||
return_query = kwargs.pop('return_query', False)
|
||||
_filter = kwargs
|
||||
order_by = kwargs.pop('order_by', None)
|
||||
|
||||
self._check_dropped()
|
||||
if not isinstance(order_by, (list, tuple)):
|
||||
order_by = [order_by]
|
||||
order_by = [o for o in order_by if (o.startswith('-') and o[1:] or o) in self.table.columns]
|
||||
order_by = [self._args_to_order_by(o) for o in order_by]
|
||||
order_by = self._args_to_order_by(order_by)
|
||||
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
|
||||
|
||||
args = self._args_to_clause(_filter, ensure=False, clauses=_clauses)
|
||||
|
||||
# query total number of rows first
|
||||
count_query = alias(self.table.select(whereclause=args, limit=_limit, offset=_offset),
|
||||
name='count_query_alias').count()
|
||||
rp = self.database.executable.execute(count_query)
|
||||
total_row_count = rp.fetchone()[0]
|
||||
if return_count:
|
||||
return total_row_count
|
||||
|
||||
if _limit is None:
|
||||
_limit = total_row_count
|
||||
|
||||
if _step is None or _step is False or _step == 0:
|
||||
_step = total_row_count
|
||||
if _step is False or _step == 0:
|
||||
_step = None
|
||||
|
||||
query = self.table.select(whereclause=args, limit=_limit,
|
||||
offset=_offset, order_by=order_by)
|
||||
if return_query:
|
||||
return query
|
||||
offset=_offset)
|
||||
if len(order_by):
|
||||
query = query.order_by(*order_by)
|
||||
return ResultIter(self.database.executable.execute(query),
|
||||
row_type=self.database.row_type, step=_step)
|
||||
|
||||
def count(self, *args, **kwargs):
|
||||
def count(self, *_clauses, **kwargs):
|
||||
"""Return the count of results for the given filter set."""
|
||||
return self.find(*args, return_count=True, **kwargs)
|
||||
# NOTE: this does not have support for limit and offset since I can't
|
||||
# see how this is useful. Still, there might be compatibility issues
|
||||
# with people using these flags. Let's see how it goes.
|
||||
self._check_dropped()
|
||||
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses)
|
||||
query = select([func.count()], whereclause=args)
|
||||
query = query.select_from(self.table)
|
||||
rp = self.database.executable.execute(query)
|
||||
return rp.fetchone()[0]
|
||||
|
||||
def __len__(self):
|
||||
"""Return the number of rows in the table."""
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from datetime import datetime
|
||||
from datetime import datetime, date
|
||||
try:
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
@ -9,7 +9,7 @@ try:
|
||||
except ImportError: # pragma: no cover
|
||||
from ordereddict import OrderedDict
|
||||
|
||||
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean
|
||||
from sqlalchemy import Integer, UnicodeText, Float, Date, DateTime, Boolean
|
||||
from six import string_types
|
||||
|
||||
row_type = OrderedDict
|
||||
@ -24,6 +24,8 @@ def guess_type(sample):
|
||||
return Float
|
||||
elif isinstance(sample, datetime):
|
||||
return DateTime
|
||||
elif isinstance(sample, date):
|
||||
return Date
|
||||
return UnicodeText
|
||||
|
||||
|
||||
@ -42,39 +44,30 @@ def normalize_column_name(name):
|
||||
return name
|
||||
|
||||
|
||||
def iter_result_proxy(rp, step=None):
|
||||
"""Iterate over the ResultProxy."""
|
||||
while True:
|
||||
if step is None:
|
||||
chunk = rp.fetchall()
|
||||
else:
|
||||
chunk = rp.fetchmany(step)
|
||||
if not chunk:
|
||||
break
|
||||
for row in chunk:
|
||||
yield row
|
||||
|
||||
|
||||
class ResultIter(object):
|
||||
""" SQLAlchemy ResultProxies are not iterable to get a
|
||||
list of dictionaries. This is to wrap them. """
|
||||
|
||||
def __init__(self, result_proxy, row_type=row_type, step=None):
|
||||
self.result_proxy = result_proxy
|
||||
self.row_type = row_type
|
||||
self.step = step
|
||||
self.keys = list(result_proxy.keys())
|
||||
self._iter = None
|
||||
|
||||
def _next_chunk(self):
|
||||
if self.result_proxy.closed:
|
||||
return False
|
||||
if not self.step:
|
||||
chunk = self.result_proxy.fetchall()
|
||||
else:
|
||||
chunk = self.result_proxy.fetchmany(self.step)
|
||||
if chunk:
|
||||
self._iter = iter(chunk)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
self._iter = iter_result_proxy(result_proxy, step=step)
|
||||
|
||||
def __next__(self):
|
||||
if self._iter is None:
|
||||
if not self._next_chunk():
|
||||
raise StopIteration
|
||||
try:
|
||||
return convert_row(self.row_type, next(self._iter))
|
||||
except StopIteration:
|
||||
self._iter = None
|
||||
return self.__next__()
|
||||
return convert_row(self.row_type, next(self._iter))
|
||||
|
||||
next = __next__
|
||||
|
||||
|
||||
11
setup.py
11
setup.py
@ -8,7 +8,7 @@ if sys.version_info[:2] <= (2, 6):
|
||||
|
||||
setup(
|
||||
name='dataset',
|
||||
version='0.7.0',
|
||||
version='0.8.0',
|
||||
description="Toolkit for Python-based data processing.",
|
||||
long_description="",
|
||||
classifiers=[
|
||||
@ -16,15 +16,14 @@ setup(
|
||||
"Intended Audience :: Developers",
|
||||
"License :: OSI Approved :: MIT License",
|
||||
"Operating System :: OS Independent",
|
||||
'Programming Language :: Python :: 2.6',
|
||||
'Programming Language :: Python :: 2.7',
|
||||
'Programming Language :: Python :: 3.3'
|
||||
'Programming Language :: Python :: 3.4'
|
||||
'Programming Language :: Python :: 3.3',
|
||||
'Programming Language :: Python :: 3.4',
|
||||
'Programming Language :: Python :: 3.5'
|
||||
],
|
||||
keywords='sql sqlalchemy etl loading utility',
|
||||
author='Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer',
|
||||
author_email='info@okfn.org',
|
||||
author_email='friedrich@pudo.org',
|
||||
url='http://github.com/pudo/dataset',
|
||||
license='MIT',
|
||||
packages=find_packages(exclude=['ez_setup', 'examples', 'test']),
|
||||
@ -34,7 +33,7 @@ setup(
|
||||
install_requires=[
|
||||
'sqlalchemy >= 0.9.1',
|
||||
'alembic >= 0.6.2',
|
||||
'normality >= 0.2.2',
|
||||
'normality >= 0.3.9',
|
||||
"PyYAML >= 3.10",
|
||||
"six >= 1.7.3"
|
||||
] + py26_dependency,
|
||||
|
||||
@ -8,7 +8,7 @@ except ImportError: # pragma: no cover
|
||||
from ordereddict import OrderedDict # Python < 2.7 drop-in
|
||||
|
||||
from sqlalchemy import FLOAT, INTEGER, TEXT
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
|
||||
|
||||
from dataset import connect
|
||||
from dataset.util import DatasetException
|
||||
@ -252,7 +252,13 @@ class TableTestCase(unittest.TestCase):
|
||||
'temperature': -10,
|
||||
'place': 'Berlin'}
|
||||
)
|
||||
original_count = len(self.tbl)
|
||||
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
|
||||
# Test bad use of API
|
||||
with self.assertRaises(ArgumentError):
|
||||
self.tbl.delete({'place': 'Berlin'})
|
||||
assert len(self.tbl) == original_count, len(self.tbl)
|
||||
|
||||
assert self.tbl.delete(place='Berlin') is True, 'should return 1'
|
||||
assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
|
||||
assert self.tbl.delete() is True, 'should return non zero'
|
||||
|
||||
Loading…
Reference in New Issue
Block a user