Include support for custom result row types, fixes #41.
This commit is contained in:
parent
8b411636f3
commit
56d5b232d8
@ -3,6 +3,9 @@
|
||||
*The changelog has only been started with version 0.3.12, previous
|
||||
changes must be reconstructed from revision history.*
|
||||
|
||||
* 0.5.7: dataset Databases can now have customized row types. This allows,
|
||||
for example, information to be retrieved in attribute-accessible dict
|
||||
subclasses, such as stuf.
|
||||
* 0.5.4: Context manager for transactions, thanks to @victorkashirin.
|
||||
* 0.5.1: Fix a regression where empty queries would raise an exception.
|
||||
* 0.5: Improve overall code quality and testing, including Travis CI.
|
||||
|
||||
@ -2,6 +2,7 @@ import os
|
||||
import warnings
|
||||
from dataset.persistence.database import Database
|
||||
from dataset.persistence.table import Table
|
||||
from dataset.persistence.util import row_type
|
||||
from dataset.freeze.app import freeze
|
||||
|
||||
# shut up useless SA warning:
|
||||
@ -14,20 +15,20 @@ __all__ = ['Database', 'Table', 'freeze', 'connect']
|
||||
|
||||
|
||||
def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None,
|
||||
reflect_views=True):
|
||||
reflect_views=True, row_type=row_type):
|
||||
"""
|
||||
Opens a new connection to a database.
|
||||
|
||||
*url* can be any valid `SQLAlchemy engine URL`_. If *url* is not defined
|
||||
it will try to use *DATABASE_URL* from environment variable. Returns an
|
||||
instance of :py:class:`Database <dataset.Database>`. Set *reflectMetadata*
|
||||
instance of :py:class:`Database <dataset.Database>`. Set *reflect_metadata*
|
||||
to False if you don't want the entire database schema to be pre-loaded.
|
||||
This significantly speeds up connecting to large databases with lots of
|
||||
tables. *reflect_views* can be set to False if you don't want views to be
|
||||
loaded. Additionally, *engine_kwargs* will be directly passed to
|
||||
SQLAlchemy, e.g. set *engine_kwargs={'pool_recycle': 3600}* will avoid `DB
|
||||
connection timeout`_.
|
||||
|
||||
connection timeout`_. Set *row_type* to an alternate dict-like class to
|
||||
change the type of container rows are stored in.
|
||||
::
|
||||
db = dataset.connect('sqlite:///factbook.db')
|
||||
|
||||
@ -38,4 +39,5 @@ def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None,
|
||||
url = os.environ.get('DATABASE_URL', 'sqlite://')
|
||||
|
||||
return Database(url, schema=schema, reflect_metadata=reflect_metadata,
|
||||
engine_kwargs=engine_kwargs, reflect_views=reflect_views)
|
||||
engine_kwargs=engine_kwargs, reflect_views=reflect_views,
|
||||
row_type=row_type)
|
||||
|
||||
@ -16,7 +16,7 @@ from alembic.migration import MigrationContext
|
||||
from alembic.operations import Operations
|
||||
|
||||
from dataset.persistence.table import Table
|
||||
from dataset.persistence.util import ResultIter
|
||||
from dataset.persistence.util import ResultIter, row_type
|
||||
from dataset.util import DatasetException
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
@ -24,7 +24,8 @@ log = logging.getLogger(__name__)
|
||||
|
||||
class Database(object):
|
||||
def __init__(self, url, schema=None, reflect_metadata=True,
|
||||
engine_kwargs=None, reflect_views=True):
|
||||
engine_kwargs=None, reflect_views=True,
|
||||
row_type=row_type):
|
||||
if engine_kwargs is None:
|
||||
engine_kwargs = {}
|
||||
|
||||
@ -47,6 +48,7 @@ class Database(object):
|
||||
self.metadata.bind = self.engine
|
||||
if reflect_metadata:
|
||||
self.metadata.reflect(self.engine, views=reflect_views)
|
||||
self.row_type = row_type
|
||||
self._tables = {}
|
||||
|
||||
@property
|
||||
@ -275,7 +277,8 @@ class Database(object):
|
||||
"""
|
||||
if isinstance(query, six.string_types):
|
||||
query = text(query)
|
||||
return ResultIter(self.executable.execute(query, **kw))
|
||||
return ResultIter(self.executable.execute(query, **kw),
|
||||
row_type=self.row_type)
|
||||
|
||||
def __repr__(self):
|
||||
return '<Database(%s)>' % self.url
|
||||
|
||||
@ -371,7 +371,8 @@ class Table(object):
|
||||
break
|
||||
queries.append(self.table.select(whereclause=args, limit=qlimit,
|
||||
offset=qoffset, order_by=order_by))
|
||||
return ResultIter((self.database.executable.execute(q) for q in queries))
|
||||
return ResultIter((self.database.executable.execute(q) for q in queries),
|
||||
row_type=self.database.row_type)
|
||||
|
||||
def count(self, **_filter):
|
||||
"""
|
||||
|
||||
@ -9,6 +9,8 @@ except ImportError: # pragma: no cover
|
||||
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean
|
||||
from six import string_types
|
||||
|
||||
row_type = OrderedDict
|
||||
|
||||
|
||||
def guess_type(sample):
|
||||
if isinstance(sample, bool):
|
||||
@ -22,10 +24,10 @@ def guess_type(sample):
|
||||
return UnicodeText
|
||||
|
||||
|
||||
def convert_row(row):
|
||||
def convert_row(row_type, row):
|
||||
if row is None:
|
||||
return None
|
||||
return OrderedDict(row.items())
|
||||
return row_type(row.items())
|
||||
|
||||
|
||||
def normalize_column_name(name):
|
||||
@ -41,7 +43,8 @@ class ResultIter(object):
|
||||
""" SQLAlchemy ResultProxies are not iterable to get a
|
||||
list of dictionaries. This is to wrap them. """
|
||||
|
||||
def __init__(self, result_proxies):
|
||||
def __init__(self, result_proxies, row_type=row_type):
|
||||
self.row_type = row_type
|
||||
if not isgenerator(result_proxies):
|
||||
result_proxies = iter((result_proxies, ))
|
||||
self.result_proxies = result_proxies
|
||||
@ -61,7 +64,7 @@ class ResultIter(object):
|
||||
if not self._next_rp():
|
||||
raise StopIteration
|
||||
try:
|
||||
return convert_row(next(self._iter))
|
||||
return convert_row(self.row_type, next(self._iter))
|
||||
except StopIteration:
|
||||
self._iter = None
|
||||
return self.__next__()
|
||||
|
||||
@ -144,6 +144,17 @@ with unique values in one or more columns::
|
||||
# Get one user per country
|
||||
db['user'].distinct('country')
|
||||
|
||||
Finally, you can use the ``row_type`` parameter to choose the data type in which
|
||||
results will be returned::
|
||||
|
||||
import dataset
|
||||
from stuf import stuf
|
||||
|
||||
db = dataset.connect('sqlite:///mydatabase.db', row_type=stuf)
|
||||
|
||||
Now contents will be returned in ``stuf`` objects (basically, ``dict``
|
||||
objects whose elements can be acessed as attributes (``item.name``) as well as
|
||||
by index (``item['name']``).
|
||||
|
||||
Running custom SQL queries
|
||||
--------------------------
|
||||
|
||||
@ -65,7 +65,7 @@ class FreezeTestCase(unittest.TestCase):
|
||||
|
||||
class SerializerTestCase(unittest.TestCase):
|
||||
|
||||
def test_Serializer(self):
|
||||
def test_serializer(self):
|
||||
from dataset.freeze.format.common import Serializer
|
||||
from dataset.freeze.config import Export
|
||||
from dataset.util import FreezeException
|
||||
|
||||
@ -17,7 +17,7 @@ from dataset.freeze.app import create_parser, freeze_with_config, freeze_export
|
||||
from .sample_data import TEST_DATA
|
||||
|
||||
|
||||
class FreezeTestCase(TestCase):
|
||||
class FreezeAppTestCase(TestCase):
|
||||
"""
|
||||
Base TestCase class, sets up a CLI parser
|
||||
"""
|
||||
|
||||
@ -345,3 +345,62 @@ class TableTestCase(unittest.TestCase):
|
||||
m = self.tbl.find(place='not in data')
|
||||
l = list(m) # exhaust iterator
|
||||
assert len(l) == 0
|
||||
|
||||
|
||||
class Constructor(dict):
|
||||
""" Very simple low-functionality extension to ``dict`` to
|
||||
provide attribute access to dictionary contents"""
|
||||
def __getattr__(self, name):
|
||||
return self[name]
|
||||
|
||||
|
||||
class RowTypeTestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self.db = connect('sqlite:///:memory:', row_type=Constructor)
|
||||
self.tbl = self.db['weather']
|
||||
for row in TEST_DATA:
|
||||
self.tbl.insert(row)
|
||||
|
||||
def tearDown(self):
|
||||
for table in self.db.tables:
|
||||
self.db[table].drop()
|
||||
|
||||
def test_find_one(self):
|
||||
self.tbl.insert({
|
||||
'date': datetime(2011, 1, 2),
|
||||
'temperature': -10,
|
||||
'place': 'Berlin'}
|
||||
)
|
||||
d = self.tbl.find_one(place='Berlin')
|
||||
assert d['temperature'] == -10, d
|
||||
assert d.temperature == -10, d
|
||||
d = self.tbl.find_one(place='Atlantis')
|
||||
assert d is None, d
|
||||
|
||||
def test_find(self):
|
||||
ds = list(self.tbl.find(place=TEST_CITY_1))
|
||||
assert len(ds) == 3, ds
|
||||
for item in ds:
|
||||
assert isinstance(item, Constructor), item
|
||||
ds = list(self.tbl.find(place=TEST_CITY_1, _limit=2))
|
||||
assert len(ds) == 2, ds
|
||||
for item in ds:
|
||||
assert isinstance(item, Constructor), item
|
||||
|
||||
def test_distinct(self):
|
||||
x = list(self.tbl.distinct('place'))
|
||||
assert len(x) == 2, x
|
||||
for item in x:
|
||||
assert isinstance(item, Constructor), item
|
||||
x = list(self.tbl.distinct('place', 'date'))
|
||||
assert len(x) == 6, x
|
||||
for item in x:
|
||||
assert isinstance(item, Constructor), item
|
||||
|
||||
def test_iter(self):
|
||||
c = 0
|
||||
for row in self.tbl:
|
||||
c += 1
|
||||
assert isinstance(row, Constructor), row
|
||||
assert c == len(self.tbl)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user