Merge branch 'master' into improve-testing

Conflicts:
	test/test_persistence.py
This commit is contained in:
Stefan Wehrmeyer 2014-01-25 20:52:22 +01:00
commit 6c8f83b7c2
6 changed files with 55 additions and 10 deletions

View File

@ -3,6 +3,9 @@
*The changelog has only been started with version 0.3.12, previous *The changelog has only been started with version 0.3.12, previous
changes must be reconstructed from revision history.* changes must be reconstructed from revision history.*
* 0.4: Python 3 support and switch to alembic for migrations.
* 0.3.15: Fixes to update and insertion of data, thanks to @cli248
and @abhinav-upadhyay.
* 0.3.14: dataset went viral somehow. Thanks to @gtsafas for * 0.3.14: dataset went viral somehow. Thanks to @gtsafas for
refactorings, @alasdairnicol for fixing the Freezfile example in refactorings, @alasdairnicol for fixing the Freezfile example in
the documentation. @diegoguimaraes fixed the behaviour of insert to the documentation. @diegoguimaraes fixed the behaviour of insert to

View File

@ -4,6 +4,7 @@ import warnings
warnings.filterwarnings( warnings.filterwarnings(
'ignore', 'Unicode type received non-unicode bind param value.') 'ignore', 'Unicode type received non-unicode bind param value.')
from dataset.persistence.util import sqlite_datetime_fix
from dataset.persistence.database import Database from dataset.persistence.database import Database
from dataset.persistence.table import Table from dataset.persistence.table import Table
from dataset.freeze.app import freeze from dataset.freeze.app import freeze
@ -27,4 +28,8 @@ def connect(url=None, schema=None, reflectMetadata=True):
""" """
if url is None: if url is None:
url = os.environ.get('DATABASE_URL', url) url = os.environ.get('DATABASE_URL', url)
if url.startswith("sqlite://"):
sqlite_datetime_fix()
return Database(url, schema=schema, reflectMetadata=reflectMetadata) return Database(url, schema=schema, reflectMetadata=reflectMetadata)

View File

@ -1,9 +1,13 @@
import logging import logging
from itertools import count from itertools import count
try:
from collections import OrderedDict
except ImportError:
from ordereddict import OrderedDict # Python < 2.7 drop-in
from sqlalchemy.sql import and_, expression from sqlalchemy.sql import and_, expression
from sqlalchemy.schema import Column, Index from sqlalchemy.schema import Column, Index
from sqlalchemy import alias
from dataset.persistence.util import guess_type from dataset.persistence.util import guess_type
from dataset.persistence.util import ResultIter from dataset.persistence.util import ResultIter
from dataset.util import DatasetException from dataset.util import DatasetException
@ -68,7 +72,8 @@ class Table(object):
if ensure: if ensure:
self._ensure_columns(row, types=types) self._ensure_columns(row, types=types)
res = self.database.executable.execute(self.table.insert(row)) res = self.database.executable.execute(self.table.insert(row))
return res.inserted_primary_key[0] if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0]
def insert_many(self, rows, chunk_size=1000, ensure=True, types={}): def insert_many(self, rows, chunk_size=1000, ensure=True, types={}):
""" """
@ -283,7 +288,7 @@ class Table(object):
rp = self.database.executable.execute(query) rp = self.database.executable.execute(query)
data = rp.fetchone() data = rp.fetchone()
if data is not None: if data is not None:
return dict(zip(rp.keys(), data)) return OrderedDict(zip(rp.keys(), data))
def _args_to_order_by(self, order_by): def _args_to_order_by(self, order_by):
if order_by[0] == '-': if order_by[0] == '-':
@ -328,7 +333,7 @@ class Table(object):
args = self._args_to_clause(_filter) args = self._args_to_clause(_filter)
# query total number of rows first # query total number of rows first
count_query = self.table.count(whereclause=args, limit=_limit, offset=_offset) count_query = alias(self.table.select(whereclause=args, limit=_limit, offset=_offset), name='count_query_alias').count()
rp = self.database.executable.execute(count_query) rp = self.database.executable.execute(count_query)
total_row_count = rp.fetchone()[0] total_row_count = rp.fetchone()[0]
@ -348,8 +353,6 @@ class Table(object):
qlimit = min(_limit - (_step * i), _step) qlimit = min(_limit - (_step * i), _step)
if qlimit <= 0: if qlimit <= 0:
break break
if qoffset > total_row_count:
break
queries.append(self.table.select(whereclause=args, limit=qlimit, queries.append(self.table.select(whereclause=args, limit=qlimit,
offset=qoffset, order_by=order_by)) offset=qoffset, order_by=order_by))
return ResultIter((self.database.executable.execute(q) for q in queries)) return ResultIter((self.database.executable.execute(q) for q in queries))

View File

@ -1,7 +1,11 @@
from datetime import datetime from datetime import datetime, timedelta
from inspect import isgenerator from inspect import isgenerator
try:
from collections import OrderedDict
except ImportError:
from ordereddict import OrderedDict # Python < 2.7 drop-in
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean, types, Table, event
def guess_type(sample): def guess_type(sample):
@ -46,9 +50,32 @@ class ResultIter(object):
else: else:
# stop here # stop here
raise StopIteration raise StopIteration
return dict(zip(self.keys, row)) return OrderedDict(zip(self.keys, row))
next = __next__ next = __next__
def __iter__(self): def __iter__(self):
return self return self
def sqlite_datetime_fix():
class SQLiteDateTimeType(types.TypeDecorator):
impl = types.Integer
epoch = datetime(1970, 1, 1, 0, 0, 0)
def process_bind_param(self, value, dialect):
return (value / 1000 - self.epoch).total_seconds()
def process_result_value(self, value, dialect):
return self.epoch + timedelta(seconds=value / 1000)
def is_sqlite(inspector):
return inspector.engine.dialect.name == "sqlite"
def is_datetime(column_info):
return isinstance(column_info['type'], types.DateTime)
@event.listens_for(Table, "column_reflect")
def setup_epoch(inspector, table, column_info):
if is_sqlite(inspector) and is_datetime(column_info):
column_info['type'] = SQLiteDateTimeType()

View File

@ -8,7 +8,7 @@ if sys.version_info <= (2, 6):
setup( setup(
name='dataset', name='dataset',
version='0.3.14', version='0.4.0',
description="Toolkit for Python-based data processing.", description="Toolkit for Python-based data processing.",
long_description="", long_description="",
classifiers=[ classifiers=[

View File

@ -8,6 +8,7 @@ from dataset import connect
from dataset.util import DatasetException from dataset.util import DatasetException
from .sample_data import TEST_DATA, TEST_CITY_1 from .sample_data import TEST_DATA, TEST_CITY_1
from sqlalchemy.exc import IntegrityError
class DatabaseTestCase(unittest.TestCase): class DatabaseTestCase(unittest.TestCase):
@ -247,3 +248,9 @@ class TableTestCase(unittest.TestCase):
assert 'foo' in tbl.table.c, tbl.table.c assert 'foo' in tbl.table.c, tbl.table.c
assert FLOAT == type(tbl.table.c['foo'].type), tbl.table.c['foo'].type assert FLOAT == type(tbl.table.c['foo'].type), tbl.table.c['foo'].type
assert 'foo' in tbl.columns, tbl.columns assert 'foo' in tbl.columns, tbl.columns
def test_key_order(self):
res = self.db.query('SELECT temperature, place FROM weather LIMIT 1')
keys = list(res.next().keys())
assert keys[0] == 'temperature'
assert keys[1] == 'place'