Merge branch 'python-3' of github.com:stefanw/dataset into stefanw-python-3

This commit is contained in:
Friedrich Lindenberg 2014-01-05 17:52:59 +01:00
commit 3abe9d2c8d
12 changed files with 123 additions and 83 deletions

View File

@ -91,7 +91,7 @@ def freeze_export(export, result=None):
serializer_cls = get_serializer(export) serializer_cls = get_serializer(export)
serializer = serializer_cls(export, query) serializer = serializer_cls(export, query)
serializer.serialize() serializer.serialize()
except ProgrammingError, pe: except ProgrammingError as pe:
raise FreezeException("Invalid query: %s" % pe) raise FreezeException("Invalid query: %s" % pe)
@ -110,7 +110,7 @@ def main():
continue continue
log.info("Running: %s", export.name) log.info("Running: %s", export.name)
freeze_export(export) freeze_export(export)
except FreezeException, fe: except FreezeException as fe:
log.error(fe) log.error(fe)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,6 +1,11 @@
import json import json
import yaml import yaml
try:
str = unicode
except NameError:
pass
from dataset.util import FreezeException from dataset.util import FreezeException
@ -31,11 +36,11 @@ class Configuration(object):
fh = open(file_name, 'rb') fh = open(file_name, 'rb')
try: try:
self.data = loader.load(fh) self.data = loader.load(fh)
except ValueError, ve: except ValueError as ve:
raise FreezeException("Invalid freeze file: %s" % ve) raise FreezeException("Invalid freeze file: %s" % ve)
fh.close() fh.close()
except IOError, ioe: except IOError as ioe:
raise FreezeException(unicode(ioe)) raise FreezeException(str(ioe))
@property @property
def exports(self): def exports(self):
@ -59,7 +64,7 @@ class Export(object):
def get_normalized(self, name, default=None): def get_normalized(self, name, default=None):
value = self.get(name, default=default) value = self.get(name, default=default)
if not value in [None, default]: if not value in [None, default]:
value = unicode(value).lower().strip() value = str(value).lower().strip()
return value return value
def get_bool(self, name, default=False): def get_bool(self, name, default=False):

View File

@ -3,6 +3,11 @@ import logging
import re import re
import locale import locale
try:
str = unicode
except NameError:
pass
from dataset.util import FreezeException from dataset.util import FreezeException
from slugify import slugify from slugify import slugify
@ -11,7 +16,7 @@ TMPL_KEY = re.compile("{{([^}]*)}}")
OPERATIONS = { OPERATIONS = {
'identity': lambda x: x, 'identity': lambda x: x,
'lower': lambda x: unicode(x).lower(), 'lower': lambda x: str(x).lower(),
'slug': slugify 'slug': slugify
} }
@ -39,7 +44,7 @@ class Serializer(object):
op, key = 'identity', m.group(1) op, key = 'identity', m.group(1)
if ':' in key: if ':' in key:
op, key = key.split(':', 1) op, key = key.split(':', 1)
return unicode(OPERATIONS.get(op)(data.get(key, ''))) return str(OPERATIONS.get(op)(data.get(key, '')))
path = TMPL_KEY.sub(repl, self._basepath) path = TMPL_KEY.sub(repl, self._basepath)
enc = locale.getpreferredencoding() enc = locale.getpreferredencoding()
return os.path.realpath(path.encode(enc, 'replace')) return os.path.realpath(path.encode(enc, 'replace'))
@ -77,5 +82,3 @@ class Serializer(object):
self.write(self.file_name(row), row) self.write(self.file_name(row), row)
self.close() self.close()

View File

@ -7,7 +7,7 @@ from dataset.freeze.format.common import Serializer
def value_to_str(value): def value_to_str(value):
if isinstance(value, datetime): if isinstance(value, datetime):
return value.isoformat() return value.isoformat()
if isinstance(value, unicode): if hasattr(value, 'encode'):
return value.encode('utf-8') return value.encode('utf-8')
if value is None: if value is None:
return '' return ''
@ -20,7 +20,7 @@ class CSVSerializer(Serializer):
self.handles = {} self.handles = {}
def write(self, path, result): def write(self, path, result):
keys = result.keys() keys = list(result.keys())
if not path in self.handles: if not path in self.handles:
fh = open(path, 'wb') fh = open(path, 'wb')
writer = csv.writer(fh) writer = csv.writer(fh)
@ -31,7 +31,5 @@ class CSVSerializer(Serializer):
writer.writerow(values) writer.writerow(values)
def close(self): def close(self):
for (writer, fh) in self.handles.values(): for writer, fh in self.handles.values():
fh.close() fh.close()

View File

@ -46,4 +46,3 @@ class JSONSerializer(Serializer):
data) data)
fh.write(data) fh.write(data)
fh.close() fh.close()

View File

@ -7,7 +7,7 @@ class TabsonSerializer(JSONSerializer):
fields = [] fields = []
data = [] data = []
if len(result): if len(result):
keys = result[0].keys() keys = list(result[0].keys())
fields = [{'id': k} for k in keys] fields = [{'id': k} for k in keys]
for row in result: for row in result:
d = [row.get(k) for k in keys] d = [row.get(k) for k in keys]

View File

@ -1,16 +1,23 @@
import logging import logging
import threading import threading
import re import re
from urlparse import parse_qs
try:
from urllib.parse import urlencode
from urllib.parse import parse_qs
except ImportError:
from urllib import urlencode from urllib import urlencode
from urlparse import parse_qs
from sqlalchemy import create_engine from sqlalchemy import create_engine
from migrate.versioning.util import construct_engine
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
from sqlalchemy.schema import MetaData, Column, Index from sqlalchemy.schema import MetaData, Column, Index
from sqlalchemy.schema import Table as SQLATable from sqlalchemy.schema import Table as SQLATable
from sqlalchemy import Integer, Text, String from sqlalchemy import Integer, Text, String
from alembic.migration import MigrationContext
from alembic.operations import Operations
from dataset.persistence.table import Table from dataset.persistence.table import Table
from dataset.persistence.util import ResultIter from dataset.persistence.util import ResultIter
from dataset.util import DatasetException from dataset.util import DatasetException
@ -37,9 +44,8 @@ class Database(object):
if len(query): if len(query):
url = url + '?' + urlencode(query, doseq=True) url = url + '?' + urlencode(query, doseq=True)
self.schema = schema self.schema = schema
engine = create_engine(url, **kw) self.engine = create_engine(url, **kw)
self.url = url self.url = url
self.engine = construct_engine(engine)
self.metadata = MetaData(schema=schema) self.metadata = MetaData(schema=schema)
self.metadata.bind = self.engine self.metadata.bind = self.engine
if reflectMetadata: if reflectMetadata:
@ -54,6 +60,14 @@ class Database(object):
return self.local.connection return self.local.connection
return self.engine return self.engine
@property
def op(self):
if hasattr(self.local, 'connection'):
ctx = MigrationContext.configure(self.local.connection)
else:
ctx = MigrationContext.configure(self.engine)
return Operations(ctx)
def _acquire(self): def _acquire(self):
self.lock.acquire() self.lock.acquire()
@ -98,8 +112,9 @@ class Database(object):
>>> print db.tables >>> print db.tables
set([u'user', u'action']) set([u'user', u'action'])
""" """
return list(set(self.metadata.tables.keys() + return list(
self._tables.keys())) set(self.metadata.tables.keys()) | set(self._tables.keys())
)
def create_table(self, table_name, primary_id='id', primary_type='Integer'): def create_table(self, table_name, primary_id='id', primary_type='Integer'):
""" """
@ -173,6 +188,12 @@ class Database(object):
finally: finally:
self._release() self._release()
def update_table(self, table_name):
self.metadata = MetaData(schema=self.schema)
self.metadata.bind = self.engine
self.metadata.reflect(self.engine)
return SQLATable(table_name, self.metadata)
def get_table(self, table_name, primary_id='id', primary_type='Integer'): def get_table(self, table_name, primary_id='id', primary_type='Integer'):
""" """
Smart wrapper around *load_table* and *create_table*. Either loads a table Smart wrapper around *load_table* and *create_table*. Either loads a table

View File

@ -123,9 +123,8 @@ class Table(object):
``types``, matching the behavior of :py:meth:`insert() <dataset.Table.insert>`. ``types``, matching the behavior of :py:meth:`insert() <dataset.Table.insert>`.
""" """
# check whether keys arg is a string and format as a list # check whether keys arg is a string and format as a list
if isinstance(keys, basestring): if not isinstance(keys, (list, tuple)):
keys = [keys] keys = [keys]
self._check_dropped() self._check_dropped()
if not keys or len(keys)==len(row): if not keys or len(keys)==len(row):
return False return False
@ -158,9 +157,8 @@ class Table(object):
table.upsert(data, ['id']) table.upsert(data, ['id'])
""" """
# check whether keys arg is a string and format as a list # check whether keys arg is a string and format as a list
if isinstance(keys, basestring): if not isinstance(keys, (list, tuple)):
keys = [keys] keys = [keys]
self._check_dropped() self._check_dropped()
if ensure: if ensure:
self.create_index(keys) self.create_index(keys)
@ -225,9 +223,11 @@ class Table(object):
self.database._acquire() self.database._acquire()
try: try:
if name not in self.table.columns.keys(): if name not in self.table.columns.keys():
col = Column(name, type) self.database.op.add_column(
col.create(self.table, self.table.name,
connection=self.database.executable) Column(name, type)
)
self.table = self.database.update_table(self.table.name)
finally: finally:
self.database._release() self.database._release()
@ -242,8 +242,11 @@ class Table(object):
self.database._acquire() self.database._acquire()
try: try:
if name in self.table.columns.keys(): if name in self.table.columns.keys():
col = self.table.columns[name] self.database.op.drop_column(
col.drop() self.table.name,
name
)
self.table = self.database.update_table(self.table.name)
finally: finally:
self.database._release() self.database._release()
@ -322,9 +325,9 @@ class Table(object):
For more complex queries, please use :py:meth:`db.query() <dataset.Database.query>` For more complex queries, please use :py:meth:`db.query() <dataset.Database.query>`
instead.""" instead."""
self._check_dropped() self._check_dropped()
if isinstance(order_by, (str, unicode)): if not isinstance(order_by, (list, tuple)):
order_by = [order_by] order_by = [order_by]
order_by = filter(lambda o: o in self.table.columns, order_by) order_by = [o for o in order_by if o in self.table.columns]
order_by = [self._args_to_order_by(o) for o in order_by] order_by = [self._args_to_order_by(o) for o in order_by]
args = self._args_to_clause(_filter) args = self._args_to_clause(_filter)
@ -360,8 +363,8 @@ class Table(object):
""" """
Returns the number of rows in the table. Returns the number of rows in the table.
""" """
d = self.database.query(self.table.count()).next() d = next(self.database.query(self.table.count()))
return d.values().pop() return list(d.values()).pop()
def distinct(self, *columns, **_filter): def distinct(self, *columns, **_filter):
""" """

View File

@ -35,22 +35,24 @@ class ResultIter(object):
def _next_rp(self): def _next_rp(self):
try: try:
self.rp = self.result_proxies.next() self.rp = next(self.result_proxies)
self.count += self.rp.rowcount self.count += self.rp.rowcount
self.keys = self.rp.keys() self.keys = list(self.rp.keys())
return True return True
except StopIteration: except StopIteration:
return False return False
def next(self): def __next__(self):
row = self.rp.fetchone() row = self.rp.fetchone()
if row is None: if row is None:
if self._next_rp(): if self._next_rp():
return self.next() return next(self)
else: else:
# stop here # stop here
raise StopIteration raise StopIteration
return OrderedDict(zip(self.keys, row)) return OrderedDict(zip(self.keys, row))
next = __next__
def __iter__(self): def __iter__(self):
return self return self

View File

@ -1,5 +1,11 @@
import sys
from setuptools import setup, find_packages from setuptools import setup, find_packages
py26_dependency = []
if sys.version_info <= (2, 6):
py26_dependency = ["argparse >= 1.2.1"]
setup( setup(
name='dataset', name='dataset',
version='0.3.15', version='0.3.15',
@ -10,7 +16,9 @@ setup(
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
"Programming Language :: Python", 'Programming Language :: Python :: 2.6',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3.3'
], ],
keywords='sql sqlalchemy etl loading utility', keywords='sql sqlalchemy etl loading utility',
author='Friedrich Lindenberg, Gregor Aisch', author='Friedrich Lindenberg, Gregor Aisch',
@ -23,11 +31,10 @@ setup(
zip_safe=False, zip_safe=False,
install_requires=[ install_requires=[
'sqlalchemy >= 0.8.1', 'sqlalchemy >= 0.8.1',
'sqlalchemy-migrate >= 0.7', 'alembic >= 0.6.1',
"argparse >= 1.2.1",
'python-slugify >= 0.0.6', 'python-slugify >= 0.0.6',
"PyYAML >= 3.10" "PyYAML >= 3.10"
], ] + py26_dependency,
tests_require=[], tests_require=[],
entry_points={ entry_points={
'console_scripts': [ 'console_scripts': [

View File

@ -1,34 +1,38 @@
# -*- encoding: utf-8 -*-
from datetime import datetime from datetime import datetime
TEST_CITY_1 = u'B€rkeley'
TEST_CITY_2 = u'G€lway'
TEST_DATA = [ TEST_DATA = [
{ {
'date': datetime(2011, 01, 01), 'date': datetime(2011, 1, 1),
'temperature': 1, 'temperature': 1,
'place': 'Galway' 'place': TEST_CITY_2
}, },
{ {
'date': datetime(2011, 01, 02), 'date': datetime(2011, 1, 2),
'temperature': -1, 'temperature': -1,
'place': 'Galway' 'place': TEST_CITY_2
}, },
{ {
'date': datetime(2011, 01, 03), 'date': datetime(2011, 1, 3),
'temperature': 0, 'temperature': 0,
'place': 'Galway' 'place': TEST_CITY_2
}, },
{ {
'date': datetime(2011, 01, 01), 'date': datetime(2011, 1, 1),
'temperature': 6, 'temperature': 6,
'place': 'Berkeley' 'place': TEST_CITY_1
}, },
{ {
'date': datetime(2011, 01, 02), 'date': datetime(2011, 1, 2),
'temperature': 8, 'temperature': 8,
'place': 'Berkeley' 'place': TEST_CITY_1
}, },
{ {
'date': datetime(2011, 01, 03), 'date': datetime(2011, 1, 3),
'temperature': 5, 'temperature': 5,
'place': 'Berkeley' 'place': TEST_CITY_1
} }
] ]

View File

@ -4,7 +4,7 @@ from datetime import datetime
from dataset import connect from dataset import connect
from dataset.util import DatasetException from dataset.util import DatasetException
from sample_data import TEST_DATA from sample_data import TEST_DATA, TEST_CITY_1
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
@ -70,8 +70,7 @@ class DatabaseTestCase(unittest.TestCase):
table.insert({'int_id': 124}) table.insert({'int_id': 124})
assert table.find_one(int_id = 123)['int_id'] == 123 assert table.find_one(int_id = 123)['int_id'] == 123
assert table.find_one(int_id = 124)['int_id'] == 124 assert table.find_one(int_id = 124)['int_id'] == 124
with self.assertRaises(IntegrityError): self.assertRaises(IntegrityError, lambda: table.insert({'int_id': 123}))
table.insert({'int_id': 123})
def test_create_table_shorthand1(self): def test_create_table_shorthand1(self):
pid = "int_id" pid = "int_id"
@ -84,8 +83,7 @@ class DatabaseTestCase(unittest.TestCase):
table.insert({'int_id': 124}) table.insert({'int_id': 124})
assert table.find_one(int_id = 123)['int_id'] == 123 assert table.find_one(int_id = 123)['int_id'] == 123
assert table.find_one(int_id = 124)['int_id'] == 124 assert table.find_one(int_id = 124)['int_id'] == 124
with self.assertRaises(IntegrityError): self.assertRaises(IntegrityError, lambda: table.insert({'int_id': 123}))
table.insert({'int_id': 123})
def test_create_table_shorthand2(self): def test_create_table_shorthand2(self):
pid = "string_id" pid = "string_id"
@ -129,7 +127,7 @@ class TableTestCase(unittest.TestCase):
def test_insert(self): def test_insert(self):
assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
last_id = self.tbl.insert({ last_id = self.tbl.insert({
'date': datetime(2011, 01, 02), 'date': datetime(2011, 1, 2),
'temperature': -10, 'temperature': -10,
'place': 'Berlin'} 'place': 'Berlin'}
) )
@ -138,14 +136,14 @@ class TableTestCase(unittest.TestCase):
def test_upsert(self): def test_upsert(self):
self.tbl.upsert({ self.tbl.upsert({
'date': datetime(2011, 01, 02), 'date': datetime(2011, 1, 2),
'temperature': -10, 'temperature': -10,
'place': 'Berlin'}, 'place': 'Berlin'},
['place'] ['place']
) )
assert len(self.tbl) == len(TEST_DATA)+1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA)+1, len(self.tbl)
self.tbl.upsert({ self.tbl.upsert({
'date': datetime(2011, 01, 02), 'date': datetime(2011, 1, 2),
'temperature': -10, 'temperature': -10,
'place': 'Berlin'}, 'place': 'Berlin'},
['place'] ['place']
@ -155,7 +153,7 @@ class TableTestCase(unittest.TestCase):
def test_upsert_all_key(self): def test_upsert_all_key(self):
for i in range(0,2): for i in range(0,2):
self.tbl.upsert({ self.tbl.upsert({
'date': datetime(2011, 01, 02), 'date': datetime(2011, 1, 2),
'temperature': -10, 'temperature': -10,
'place': 'Berlin'}, 'place': 'Berlin'},
['date', 'temperature', 'place'] ['date', 'temperature', 'place']
@ -163,7 +161,7 @@ class TableTestCase(unittest.TestCase):
def test_delete(self): def test_delete(self):
self.tbl.insert({ self.tbl.insert({
'date': datetime(2011, 01, 02), 'date': datetime(2011, 1, 2),
'temperature': -10, 'temperature': -10,
'place': 'Berlin'} 'place': 'Berlin'}
) )
@ -175,7 +173,7 @@ class TableTestCase(unittest.TestCase):
def test_find_one(self): def test_find_one(self):
self.tbl.insert({ self.tbl.insert({
'date': datetime(2011, 01, 02), 'date': datetime(2011, 1, 2),
'temperature': -10, 'temperature': -10,
'place': 'Berlin'} 'place': 'Berlin'}
) )
@ -185,9 +183,9 @@ class TableTestCase(unittest.TestCase):
assert d is None, d assert d is None, d
def test_find(self): def test_find(self):
ds = list(self.tbl.find(place='Berkeley')) ds = list(self.tbl.find(place=TEST_CITY_1))
assert len(ds) == 3, ds assert len(ds) == 3, ds
ds = list(self.tbl.find(place='Berkeley', _limit=2)) ds = list(self.tbl.find(place=TEST_CITY_1, _limit=2))
assert len(ds) == 2, ds assert len(ds) == 2, ds
def test_distinct(self): def test_distinct(self):
@ -225,15 +223,15 @@ class TableTestCase(unittest.TestCase):
assert c == len(self.tbl) assert c == len(self.tbl)
def test_update(self): def test_update(self):
date = datetime(2011, 01, 02) date = datetime(2011, 1, 2)
res = self.tbl.update({ res = self.tbl.update({
'date': date, 'date': date,
'temperature': -10, 'temperature': -10,
'place': 'Berkeley'}, 'place': TEST_CITY_1},
['place', 'date'] ['place', 'date']
) )
assert res, 'update should return True' assert res, 'update should return True'
m = self.tbl.find_one(place='Berkeley', date=date) m = self.tbl.find_one(place=TEST_CITY_1, date=date)
assert m['temperature'] == -10, 'new temp. should be -10 but is %d' % m['temperature'] assert m['temperature'] == -10, 'new temp. should be -10 but is %d' % m['temperature']
def test_create_column(self): def test_create_column(self):