From 56d5b232d835d70db657c24d71129f56f85ceb1c Mon Sep 17 00:00:00 2001 From: Friedrich Lindenberg Date: Sat, 23 May 2015 16:15:17 +0200 Subject: [PATCH] Include support for custom result row types, fixes #41. --- CHANGELOG.md | 3 ++ dataset/__init__.py | 12 ++++--- dataset/persistence/database.py | 9 +++-- dataset/persistence/table.py | 3 +- dataset/persistence/util.py | 11 +++--- docs/quickstart.rst | 11 ++++++ test/test_freeze.py | 2 +- test/test_freeze_app.py | 2 +- test/test_persistence.py | 59 +++++++++++++++++++++++++++++++++ 9 files changed, 97 insertions(+), 15 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd9b296..86a4a63 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,9 @@ *The changelog has only been started with version 0.3.12, previous changes must be reconstructed from revision history.* +* 0.5.7: dataset Databases can now have customized row types. This allows, + for example, information to be retrieved in attribute-accessible dict + subclasses, such as stuf. * 0.5.4: Context manager for transactions, thanks to @victorkashirin. * 0.5.1: Fix a regression where empty queries would raise an exception. * 0.5: Improve overall code quality and testing, including Travis CI. diff --git a/dataset/__init__.py b/dataset/__init__.py index 7ebc9bb..dbc465f 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -2,6 +2,7 @@ import os import warnings from dataset.persistence.database import Database from dataset.persistence.table import Table +from dataset.persistence.util import row_type from dataset.freeze.app import freeze # shut up useless SA warning: @@ -14,20 +15,20 @@ __all__ = ['Database', 'Table', 'freeze', 'connect'] def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None, - reflect_views=True): + reflect_views=True, row_type=row_type): """ Opens a new connection to a database. *url* can be any valid `SQLAlchemy engine URL`_. If *url* is not defined it will try to use *DATABASE_URL* from environment variable. Returns an - instance of :py:class:`Database `. Set *reflectMetadata* + instance of :py:class:`Database `. Set *reflect_metadata* to False if you don't want the entire database schema to be pre-loaded. This significantly speeds up connecting to large databases with lots of tables. *reflect_views* can be set to False if you don't want views to be loaded. Additionally, *engine_kwargs* will be directly passed to SQLAlchemy, e.g. set *engine_kwargs={'pool_recycle': 3600}* will avoid `DB - connection timeout`_. - + connection timeout`_. Set *row_type* to an alternate dict-like class to + change the type of container rows are stored in. :: db = dataset.connect('sqlite:///factbook.db') @@ -38,4 +39,5 @@ def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None, url = os.environ.get('DATABASE_URL', 'sqlite://') return Database(url, schema=schema, reflect_metadata=reflect_metadata, - engine_kwargs=engine_kwargs, reflect_views=reflect_views) + engine_kwargs=engine_kwargs, reflect_views=reflect_views, + row_type=row_type) diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index f8b38de..12ea7e3 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -16,7 +16,7 @@ from alembic.migration import MigrationContext from alembic.operations import Operations from dataset.persistence.table import Table -from dataset.persistence.util import ResultIter +from dataset.persistence.util import ResultIter, row_type from dataset.util import DatasetException log = logging.getLogger(__name__) @@ -24,7 +24,8 @@ log = logging.getLogger(__name__) class Database(object): def __init__(self, url, schema=None, reflect_metadata=True, - engine_kwargs=None, reflect_views=True): + engine_kwargs=None, reflect_views=True, + row_type=row_type): if engine_kwargs is None: engine_kwargs = {} @@ -47,6 +48,7 @@ class Database(object): self.metadata.bind = self.engine if reflect_metadata: self.metadata.reflect(self.engine, views=reflect_views) + self.row_type = row_type self._tables = {} @property @@ -275,7 +277,8 @@ class Database(object): """ if isinstance(query, six.string_types): query = text(query) - return ResultIter(self.executable.execute(query, **kw)) + return ResultIter(self.executable.execute(query, **kw), + row_type=self.row_type) def __repr__(self): return '' % self.url diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index a00c887..ff8f631 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -371,7 +371,8 @@ class Table(object): break queries.append(self.table.select(whereclause=args, limit=qlimit, offset=qoffset, order_by=order_by)) - return ResultIter((self.database.executable.execute(q) for q in queries)) + return ResultIter((self.database.executable.execute(q) for q in queries), + row_type=self.database.row_type) def count(self, **_filter): """ diff --git a/dataset/persistence/util.py b/dataset/persistence/util.py index 4da08a9..5a57552 100644 --- a/dataset/persistence/util.py +++ b/dataset/persistence/util.py @@ -9,6 +9,8 @@ except ImportError: # pragma: no cover from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean from six import string_types +row_type = OrderedDict + def guess_type(sample): if isinstance(sample, bool): @@ -22,10 +24,10 @@ def guess_type(sample): return UnicodeText -def convert_row(row): +def convert_row(row_type, row): if row is None: return None - return OrderedDict(row.items()) + return row_type(row.items()) def normalize_column_name(name): @@ -41,7 +43,8 @@ class ResultIter(object): """ SQLAlchemy ResultProxies are not iterable to get a list of dictionaries. This is to wrap them. """ - def __init__(self, result_proxies): + def __init__(self, result_proxies, row_type=row_type): + self.row_type = row_type if not isgenerator(result_proxies): result_proxies = iter((result_proxies, )) self.result_proxies = result_proxies @@ -61,7 +64,7 @@ class ResultIter(object): if not self._next_rp(): raise StopIteration try: - return convert_row(next(self._iter)) + return convert_row(self.row_type, next(self._iter)) except StopIteration: self._iter = None return self.__next__() diff --git a/docs/quickstart.rst b/docs/quickstart.rst index f5f0cc4..f3d656d 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -144,6 +144,17 @@ with unique values in one or more columns:: # Get one user per country db['user'].distinct('country') +Finally, you can use the ``row_type`` parameter to choose the data type in which +results will be returned:: + + import dataset + from stuf import stuf + + db = dataset.connect('sqlite:///mydatabase.db', row_type=stuf) + +Now contents will be returned in ``stuf`` objects (basically, ``dict`` +objects whose elements can be acessed as attributes (``item.name``) as well as +by index (``item['name']``). Running custom SQL queries -------------------------- diff --git a/test/test_freeze.py b/test/test_freeze.py index 3cb3ea5..9605a29 100644 --- a/test/test_freeze.py +++ b/test/test_freeze.py @@ -65,7 +65,7 @@ class FreezeTestCase(unittest.TestCase): class SerializerTestCase(unittest.TestCase): - def test_Serializer(self): + def test_serializer(self): from dataset.freeze.format.common import Serializer from dataset.freeze.config import Export from dataset.util import FreezeException diff --git a/test/test_freeze_app.py b/test/test_freeze_app.py index 2135385..20826f6 100644 --- a/test/test_freeze_app.py +++ b/test/test_freeze_app.py @@ -17,7 +17,7 @@ from dataset.freeze.app import create_parser, freeze_with_config, freeze_export from .sample_data import TEST_DATA -class FreezeTestCase(TestCase): +class FreezeAppTestCase(TestCase): """ Base TestCase class, sets up a CLI parser """ diff --git a/test/test_persistence.py b/test/test_persistence.py index 8ed09d0..6c8cf4c 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -345,3 +345,62 @@ class TableTestCase(unittest.TestCase): m = self.tbl.find(place='not in data') l = list(m) # exhaust iterator assert len(l) == 0 + + +class Constructor(dict): + """ Very simple low-functionality extension to ``dict`` to + provide attribute access to dictionary contents""" + def __getattr__(self, name): + return self[name] + + +class RowTypeTestCase(unittest.TestCase): + + def setUp(self): + self.db = connect('sqlite:///:memory:', row_type=Constructor) + self.tbl = self.db['weather'] + for row in TEST_DATA: + self.tbl.insert(row) + + def tearDown(self): + for table in self.db.tables: + self.db[table].drop() + + def test_find_one(self): + self.tbl.insert({ + 'date': datetime(2011, 1, 2), + 'temperature': -10, + 'place': 'Berlin'} + ) + d = self.tbl.find_one(place='Berlin') + assert d['temperature'] == -10, d + assert d.temperature == -10, d + d = self.tbl.find_one(place='Atlantis') + assert d is None, d + + def test_find(self): + ds = list(self.tbl.find(place=TEST_CITY_1)) + assert len(ds) == 3, ds + for item in ds: + assert isinstance(item, Constructor), item + ds = list(self.tbl.find(place=TEST_CITY_1, _limit=2)) + assert len(ds) == 2, ds + for item in ds: + assert isinstance(item, Constructor), item + + def test_distinct(self): + x = list(self.tbl.distinct('place')) + assert len(x) == 2, x + for item in x: + assert isinstance(item, Constructor), item + x = list(self.tbl.distinct('place', 'date')) + assert len(x) == 6, x + for item in x: + assert isinstance(item, Constructor), item + + def test_iter(self): + c = 0 + for row in self.tbl: + c += 1 + assert isinstance(row, Constructor), row + assert c == len(self.tbl)