Include support for custom result row types, fixes #41.

This commit is contained in:
Friedrich Lindenberg 2015-05-23 16:15:17 +02:00
parent 8b411636f3
commit 56d5b232d8
9 changed files with 97 additions and 15 deletions

View File

@ -3,6 +3,9 @@
*The changelog has only been started with version 0.3.12, previous *The changelog has only been started with version 0.3.12, previous
changes must be reconstructed from revision history.* changes must be reconstructed from revision history.*
* 0.5.7: dataset Databases can now have customized row types. This allows,
for example, information to be retrieved in attribute-accessible dict
subclasses, such as stuf.
* 0.5.4: Context manager for transactions, thanks to @victorkashirin. * 0.5.4: Context manager for transactions, thanks to @victorkashirin.
* 0.5.1: Fix a regression where empty queries would raise an exception. * 0.5.1: Fix a regression where empty queries would raise an exception.
* 0.5: Improve overall code quality and testing, including Travis CI. * 0.5: Improve overall code quality and testing, including Travis CI.

View File

@ -2,6 +2,7 @@ import os
import warnings import warnings
from dataset.persistence.database import Database from dataset.persistence.database import Database
from dataset.persistence.table import Table from dataset.persistence.table import Table
from dataset.persistence.util import row_type
from dataset.freeze.app import freeze from dataset.freeze.app import freeze
# shut up useless SA warning: # shut up useless SA warning:
@ -14,20 +15,20 @@ __all__ = ['Database', 'Table', 'freeze', 'connect']
def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None, def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None,
reflect_views=True): reflect_views=True, row_type=row_type):
""" """
Opens a new connection to a database. Opens a new connection to a database.
*url* can be any valid `SQLAlchemy engine URL`_. If *url* is not defined *url* can be any valid `SQLAlchemy engine URL`_. If *url* is not defined
it will try to use *DATABASE_URL* from environment variable. Returns an it will try to use *DATABASE_URL* from environment variable. Returns an
instance of :py:class:`Database <dataset.Database>`. Set *reflectMetadata* instance of :py:class:`Database <dataset.Database>`. Set *reflect_metadata*
to False if you don't want the entire database schema to be pre-loaded. to False if you don't want the entire database schema to be pre-loaded.
This significantly speeds up connecting to large databases with lots of This significantly speeds up connecting to large databases with lots of
tables. *reflect_views* can be set to False if you don't want views to be tables. *reflect_views* can be set to False if you don't want views to be
loaded. Additionally, *engine_kwargs* will be directly passed to loaded. Additionally, *engine_kwargs* will be directly passed to
SQLAlchemy, e.g. set *engine_kwargs={'pool_recycle': 3600}* will avoid `DB SQLAlchemy, e.g. set *engine_kwargs={'pool_recycle': 3600}* will avoid `DB
connection timeout`_. connection timeout`_. Set *row_type* to an alternate dict-like class to
change the type of container rows are stored in.
:: ::
db = dataset.connect('sqlite:///factbook.db') db = dataset.connect('sqlite:///factbook.db')
@ -38,4 +39,5 @@ def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None,
url = os.environ.get('DATABASE_URL', 'sqlite://') url = os.environ.get('DATABASE_URL', 'sqlite://')
return Database(url, schema=schema, reflect_metadata=reflect_metadata, return Database(url, schema=schema, reflect_metadata=reflect_metadata,
engine_kwargs=engine_kwargs, reflect_views=reflect_views) engine_kwargs=engine_kwargs, reflect_views=reflect_views,
row_type=row_type)

View File

@ -16,7 +16,7 @@ from alembic.migration import MigrationContext
from alembic.operations import Operations from alembic.operations import Operations
from dataset.persistence.table import Table from dataset.persistence.table import Table
from dataset.persistence.util import ResultIter from dataset.persistence.util import ResultIter, row_type
from dataset.util import DatasetException from dataset.util import DatasetException
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -24,7 +24,8 @@ log = logging.getLogger(__name__)
class Database(object): class Database(object):
def __init__(self, url, schema=None, reflect_metadata=True, def __init__(self, url, schema=None, reflect_metadata=True,
engine_kwargs=None, reflect_views=True): engine_kwargs=None, reflect_views=True,
row_type=row_type):
if engine_kwargs is None: if engine_kwargs is None:
engine_kwargs = {} engine_kwargs = {}
@ -47,6 +48,7 @@ class Database(object):
self.metadata.bind = self.engine self.metadata.bind = self.engine
if reflect_metadata: if reflect_metadata:
self.metadata.reflect(self.engine, views=reflect_views) self.metadata.reflect(self.engine, views=reflect_views)
self.row_type = row_type
self._tables = {} self._tables = {}
@property @property
@ -275,7 +277,8 @@ class Database(object):
""" """
if isinstance(query, six.string_types): if isinstance(query, six.string_types):
query = text(query) query = text(query)
return ResultIter(self.executable.execute(query, **kw)) return ResultIter(self.executable.execute(query, **kw),
row_type=self.row_type)
def __repr__(self): def __repr__(self):
return '<Database(%s)>' % self.url return '<Database(%s)>' % self.url

View File

@ -371,7 +371,8 @@ class Table(object):
break break
queries.append(self.table.select(whereclause=args, limit=qlimit, queries.append(self.table.select(whereclause=args, limit=qlimit,
offset=qoffset, order_by=order_by)) offset=qoffset, order_by=order_by))
return ResultIter((self.database.executable.execute(q) for q in queries)) return ResultIter((self.database.executable.execute(q) for q in queries),
row_type=self.database.row_type)
def count(self, **_filter): def count(self, **_filter):
""" """

View File

@ -9,6 +9,8 @@ except ImportError: # pragma: no cover
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean
from six import string_types from six import string_types
row_type = OrderedDict
def guess_type(sample): def guess_type(sample):
if isinstance(sample, bool): if isinstance(sample, bool):
@ -22,10 +24,10 @@ def guess_type(sample):
return UnicodeText return UnicodeText
def convert_row(row): def convert_row(row_type, row):
if row is None: if row is None:
return None return None
return OrderedDict(row.items()) return row_type(row.items())
def normalize_column_name(name): def normalize_column_name(name):
@ -41,7 +43,8 @@ class ResultIter(object):
""" SQLAlchemy ResultProxies are not iterable to get a """ SQLAlchemy ResultProxies are not iterable to get a
list of dictionaries. This is to wrap them. """ list of dictionaries. This is to wrap them. """
def __init__(self, result_proxies): def __init__(self, result_proxies, row_type=row_type):
self.row_type = row_type
if not isgenerator(result_proxies): if not isgenerator(result_proxies):
result_proxies = iter((result_proxies, )) result_proxies = iter((result_proxies, ))
self.result_proxies = result_proxies self.result_proxies = result_proxies
@ -61,7 +64,7 @@ class ResultIter(object):
if not self._next_rp(): if not self._next_rp():
raise StopIteration raise StopIteration
try: try:
return convert_row(next(self._iter)) return convert_row(self.row_type, next(self._iter))
except StopIteration: except StopIteration:
self._iter = None self._iter = None
return self.__next__() return self.__next__()

View File

@ -144,6 +144,17 @@ with unique values in one or more columns::
# Get one user per country # Get one user per country
db['user'].distinct('country') db['user'].distinct('country')
Finally, you can use the ``row_type`` parameter to choose the data type in which
results will be returned::
import dataset
from stuf import stuf
db = dataset.connect('sqlite:///mydatabase.db', row_type=stuf)
Now contents will be returned in ``stuf`` objects (basically, ``dict``
objects whose elements can be acessed as attributes (``item.name``) as well as
by index (``item['name']``).
Running custom SQL queries Running custom SQL queries
-------------------------- --------------------------

View File

@ -65,7 +65,7 @@ class FreezeTestCase(unittest.TestCase):
class SerializerTestCase(unittest.TestCase): class SerializerTestCase(unittest.TestCase):
def test_Serializer(self): def test_serializer(self):
from dataset.freeze.format.common import Serializer from dataset.freeze.format.common import Serializer
from dataset.freeze.config import Export from dataset.freeze.config import Export
from dataset.util import FreezeException from dataset.util import FreezeException

View File

@ -17,7 +17,7 @@ from dataset.freeze.app import create_parser, freeze_with_config, freeze_export
from .sample_data import TEST_DATA from .sample_data import TEST_DATA
class FreezeTestCase(TestCase): class FreezeAppTestCase(TestCase):
""" """
Base TestCase class, sets up a CLI parser Base TestCase class, sets up a CLI parser
""" """

View File

@ -345,3 +345,62 @@ class TableTestCase(unittest.TestCase):
m = self.tbl.find(place='not in data') m = self.tbl.find(place='not in data')
l = list(m) # exhaust iterator l = list(m) # exhaust iterator
assert len(l) == 0 assert len(l) == 0
class Constructor(dict):
""" Very simple low-functionality extension to ``dict`` to
provide attribute access to dictionary contents"""
def __getattr__(self, name):
return self[name]
class RowTypeTestCase(unittest.TestCase):
def setUp(self):
self.db = connect('sqlite:///:memory:', row_type=Constructor)
self.tbl = self.db['weather']
for row in TEST_DATA:
self.tbl.insert(row)
def tearDown(self):
for table in self.db.tables:
self.db[table].drop()
def test_find_one(self):
self.tbl.insert({
'date': datetime(2011, 1, 2),
'temperature': -10,
'place': 'Berlin'}
)
d = self.tbl.find_one(place='Berlin')
assert d['temperature'] == -10, d
assert d.temperature == -10, d
d = self.tbl.find_one(place='Atlantis')
assert d is None, d
def test_find(self):
ds = list(self.tbl.find(place=TEST_CITY_1))
assert len(ds) == 3, ds
for item in ds:
assert isinstance(item, Constructor), item
ds = list(self.tbl.find(place=TEST_CITY_1, _limit=2))
assert len(ds) == 2, ds
for item in ds:
assert isinstance(item, Constructor), item
def test_distinct(self):
x = list(self.tbl.distinct('place'))
assert len(x) == 2, x
for item in x:
assert isinstance(item, Constructor), item
x = list(self.tbl.distinct('place', 'date'))
assert len(x) == 6, x
for item in x:
assert isinstance(item, Constructor), item
def test_iter(self):
c = 0
for row in self.tbl:
c += 1
assert isinstance(row, Constructor), row
assert c == len(self.tbl)