Rewrite data change functions on table.

This commit is contained in:
Friedrich Lindenberg 2017-09-03 10:05:17 +02:00
parent 213a7ce857
commit e30cf24195
5 changed files with 201 additions and 236 deletions

View File

@ -17,7 +17,7 @@ from alembic.migration import MigrationContext
from alembic.operations import Operations from alembic.operations import Operations
from dataset.persistence.table import Table from dataset.persistence.table import Table
from dataset.persistence.util import ResultIter, row_type, safe_url from dataset.persistence.util import ResultIter, row_type, safe_url, QUERY_STEP
from dataset.persistence.types import Types from dataset.persistence.types import Types
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -238,28 +238,34 @@ class Database(object):
"""Get a given table.""" """Get a given table."""
return self.get_table(table_name) return self.get_table(table_name)
def query(self, query, **kwargs): def query(self, query, *args, **kwargs):
"""Run a statement on the database directly. """Run a statement on the database directly.
Allows for the execution of arbitrary read/write queries. A query can either be Allows for the execution of arbitrary read/write queries. A query can
a plain text string, or a `SQLAlchemy expression <http://docs.sqlalchemy.org/en/latest/core/tutorial.html#selecting>`_. either be a plain text string, or a `SQLAlchemy expression
If a plain string is passed in, it will be converted to an expression automatically. <http://docs.sqlalchemy.org/en/latest/core/tutorial.html#selecting>`_.
If a plain string is passed in, it will be converted to an expression
automatically.
Keyword arguments will be used for parameter binding. See the `SQLAlchemy Further positional and keyword arguments will be used for parameter
documentation <http://docs.sqlalchemy.org/en/rel_0_9/core/connections.html#sqlalchemy.engine.Connection.execute>`_ for details. binding. To include a positional argument in your query, use question
marks in the query (i.e. `SELECT * FROM tbl WHERE a = ?`). For keyword
arguments, use a bind parameter (i.e. `SELECT * FROM tbl WHERE a =
:foo`).
The returned iterator will yield each result sequentially. The returned iterator will yield each result sequentially.
:: ::
res = db.query('SELECT user, COUNT(*) c FROM photos GROUP BY user') statement = 'SELECT user, COUNT(*) c FROM photos GROUP BY user'
for row in res: for row in db.query(statement):
print(row['user'], row['c']) print(row['user'], row['c'])
""" """
if isinstance(query, six.string_types): if isinstance(query, six.string_types):
query = text(query) query = text(query)
_step = kwargs.pop('_step', 5000) _step = kwargs.pop('_step', QUERY_STEP)
return ResultIter(self.executable.execute(query, **kwargs), rp = self.executable.execute(query, *args, **kwargs)
row_type=self.row_type, step=_step) return ResultIter(rp, row_type=self.row_type, step=_step)
def __repr__(self): def __repr__(self):
"""Text representation contains the URL.""" """Text representation contains the URL."""

View File

@ -7,7 +7,7 @@ from sqlalchemy import func, select, false
from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.reflection import Inspector
from dataset.persistence.util import normalize_column_name, index_name from dataset.persistence.util import normalize_column_name, index_name
from dataset.persistence.util import ResultIter from dataset.persistence.util import ensure_tuple, ResultIter, QUERY_STEP
from dataset.util import DatasetException from dataset.util import DatasetException
@ -49,17 +49,12 @@ class Table(object):
dropping the table. If you want to re-create the table, make dropping the table. If you want to re-create the table, make
sure to get a fresh instance from the :py:class:`Database <dataset.Database>`. sure to get a fresh instance from the :py:class:`Database <dataset.Database>`.
""" """
self._check_dropped()
with self.database.lock: with self.database.lock:
if not self.exists:
return
self.table.drop(self.database.executable, checkfirst=True) self.table.drop(self.database.executable, checkfirst=True)
# self.database._tables.pop(self.name, None)
self.table = None self.table = None
def _check_dropped(self):
# self.table = self.database._reflect_table(self.name)
if self.table is None:
raise DatasetException('The table has been dropped.')
def insert(self, row, ensure=None, types=None): def insert(self, row, ensure=None, types=None):
""" """
Add a row (type: dict) by inserting it into the table. Add a row (type: dict) by inserting it into the table.
@ -79,15 +74,11 @@ class Table(object):
Returns the inserted row's primary key. Returns the inserted row's primary key.
""" """
self._check_dropped() row = self._sync_columns(row, ensure, types=types)
ensure = self.database.ensure_schema if ensure is None else ensure
if ensure:
self._ensure_columns(row, types=types)
else:
row = self._prune_row(row)
res = self.database.executable.execute(self.table.insert(row)) res = self.database.executable.execute(self.table.insert(row))
if len(res.inserted_primary_key) > 0: if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0] return res.inserted_primary_key[0]
return True
def insert_ignore(self, row, keys, ensure=None, types=None): def insert_ignore(self, row, keys, ensure=None, types=None):
""" """
@ -108,11 +99,13 @@ class Table(object):
data = dict(id=10, title='I am a banana!') data = dict(id=10, title='I am a banana!')
table.insert_ignore(data, ['id']) table.insert_ignore(data, ['id'])
""" """
row, res = self._upsert_pre_check(row, keys, ensure) row = self._sync_columns(row, ensure, types=types)
if res is None: if self._check_ensure(ensure):
return self.insert(row, ensure=ensure, types=types) self.create_index(keys)
else: args, _ = self._keys_to_args(row, keys)
return False if self.count(**args) == 0:
return self.insert(row, ensure=False)
return False
def insert_many(self, rows, chunk_size=1000, ensure=None, types=None): def insert_many(self, rows, chunk_size=1000, ensure=None, types=None):
""" """
@ -129,28 +122,18 @@ class Table(object):
rows = [dict(name='Dolly')] * 10000 rows = [dict(name='Dolly')] * 10000
table.insert_many(rows) table.insert_many(rows)
""" """
ensure = self.database.ensure_schema if ensure is None else ensure
def _process_chunk(chunk):
if ensure:
for row in chunk:
self._ensure_columns(row, types=types)
else:
chunk = [self._prune_row(r) for r in chunk]
self.table.insert().execute(chunk)
self._check_dropped()
chunk = [] chunk = []
for i, row in enumerate(rows, start=1): for row in rows:
row = self._sync_columns(row, ensure, types=types)
chunk.append(row) chunk.append(row)
if i % chunk_size == 0: if len(chunk) == chunk_size:
_process_chunk(chunk) self.table.insert().execute(chunk)
chunk = [] chunk = []
if chunk: if len(chunk):
_process_chunk(chunk) self.table.insert().execute(chunk)
def update(self, row, keys, ensure=None, types=None): def update(self, row, keys, ensure=None, types=None, return_count=False):
""" """
Update a row in the table. Update a row in the table.
@ -167,54 +150,20 @@ class Table(object):
they will be created based on the settings of ``ensure`` and they will be created based on the settings of ``ensure`` and
``types``, matching the behavior of :py:meth:`insert() <dataset.Table.insert>`. ``types``, matching the behavior of :py:meth:`insert() <dataset.Table.insert>`.
""" """
self._check_dropped() row = self._sync_columns(row, ensure, types=types)
# check whether keys arg is a string and format as a list args, row = self._keys_to_args(row, keys)
if not isinstance(keys, (list, tuple)): clause = self._args_to_clause(args)
keys = [keys] if not len(row):
if not keys or len(keys) == len(row): return self.count(clause)
return False stmt = self.table.update(whereclause=clause, values=row)
clause = [(u, row.get(u)) for u in keys] rp = self.database.executable.execute(stmt)
if rp.supports_sane_rowcount():
ensure = self.database.ensure_schema if ensure is None else ensure
if ensure:
self._ensure_columns(row, types=types)
else:
row = self._prune_row(row)
# Don't update the key itself, so remove any keys from the row dict
clean_row = row.copy()
for key in keys:
if key in clean_row.keys():
del clean_row[key]
try:
filters = self._args_to_clause(dict(clause))
stmt = self.table.update(filters, clean_row)
rp = self.database.executable.execute(stmt)
return rp.rowcount return rp.rowcount
except KeyError: if return_count:
return 0 return self.count(clause)
def _upsert_pre_check(self, row, keys, ensure):
# check whether keys arg is a string and format as a list
if not isinstance(keys, (list, tuple)):
keys = [keys]
self._check_dropped()
ensure = self.database.ensure_schema if ensure is None else ensure
if ensure:
self.create_index(keys)
else:
row = self._prune_row(row)
filters = {}
for key in keys:
filters[key] = row.get(key)
return row, self.find_one(**filters)
def upsert(self, row, keys, ensure=None, types=None): def upsert(self, row, keys, ensure=None, types=None):
""" """An UPSERT is a smart combination of insert and update.
An UPSERT is a smart combination of insert and update.
If rows with matching ``keys`` exist they will be updated, otherwise a If rows with matching ``keys`` exist they will be updated, otherwise a
new row is inserted in the table. new row is inserted in the table.
@ -223,22 +172,16 @@ class Table(object):
data = dict(id=10, title='I am a banana!') data = dict(id=10, title='I am a banana!')
table.upsert(data, ['id']) table.upsert(data, ['id'])
""" """
row, res = self._upsert_pre_check(row, keys, ensure) row = self._sync_columns(row, ensure, types=types)
if res is None: if self._check_ensure(ensure):
return self.insert(row, ensure=ensure, types=types) self.create_index(keys)
else: row_count = self.update(row, keys, ensure=False, return_count=True)
row_count = self.update(row, keys, ensure=ensure, types=types) if row_count == 0:
try: return self.insert(row, ensure=False)
result = (row_count > 0, res['id'])[row_count == 1] return True
except KeyError:
result = row_count > 0
return result def delete(self, *clauses, **filters):
"""Delete rows from the table.
def delete(self, *_clauses, **_filter):
"""
Delete rows from the table.
Keyword arguments can be used to add column-based filters. The filter Keyword arguments can be used to add column-based filters. The filter
criterion will always be equality: criterion will always be equality:
@ -249,55 +192,60 @@ class Table(object):
If no arguments are given, all records are deleted. If no arguments are given, all records are deleted.
""" """
self._check_dropped()
if not self.exists: if not self.exists:
return return False
if _filter or _clauses: clause = self._args_to_clause(filters, clauses=clauses)
q = self._args_to_clause(_filter, clauses=_clauses) stmt = self.table.delete(whereclause=clause)
stmt = self.table.delete(q) rp = self.database.executable.execute(stmt)
else: return rp.rowcount > 0
stmt = self.table.delete()
rows = self.database.executable.execute(stmt)
return rows.rowcount > 0
def _has_column(self, column): def has_column(self, column):
"""Check if a column with the given name exists on this table."""
return normalize_column_name(column) in self.columns return normalize_column_name(column) in self.columns
def _ensure_columns(self, row, types=None): def _sync_columns(self, row, ensure, types=None):
# Keep order of inserted columns columns = self.columns
for column in row.keys(): ensure = self._check_ensure(ensure)
if self._has_column(column): types = types or {}
continue types = {normalize_column_name(k): v for (k, v) in types.items()}
if types is not None and column in types: out = {}
_type = types[column] for name, value in row.items():
else: name = normalize_column_name(name)
_type = self.database.types.guess(row[column]) if ensure and name not in columns:
log.debug("Creating column: %s (%s) on %r" % (column, _type = types.get(name)
_type, if _type is None:
self.table.name)) _type = self.database.types.guess(value)
self.create_column(column, _type) log.debug("Create column: %s on %s", name, self.name)
self.create_column(name, _type)
columns.append(name)
if name in columns:
out[name] = value
return out
def _prune_row(self, row): def _check_ensure(self, ensure):
"""Remove keys from row not in column set.""" if ensure is None:
# normalize keys return self.database.ensure_schema
row = {normalize_column_name(k): v for k, v in row.items()} return ensure
# filter out keys not in column set
return {k: row[k] for k in row if k in self.columns}
def _args_to_clause(self, args, ensure=None, clauses=()): def _args_to_clause(self, args, clauses=()):
ensure = self.database.ensure_schema if ensure is None else ensure
if ensure:
self._ensure_columns(args)
clauses = list(clauses) clauses = list(clauses)
for k, v in args.items(): for column, value in args.items():
if not self._has_column(k): if not self.has_column(column):
clauses.append(false()) clauses.append(false())
elif isinstance(v, (list, tuple)): elif isinstance(value, (list, tuple)):
clauses.append(self.table.c[k].in_(v)) clauses.append(self.table.c[column].in_(value))
else: else:
clauses.append(self.table.c[k] == v) clauses.append(self.table.c[column] == value)
return and_(*clauses) return and_(*clauses)
def _keys_to_args(self, row, keys):
keys = ensure_tuple(keys)
keys = [normalize_column_name(k) for k in keys]
# keys = [self.has_column(k) for k in keys]
row = row.copy()
args = {k: row.pop(k) for k in keys if k in row}
return args, row
def create_column(self, name, type): def create_column(self, name, type):
""" """
Explicitly create a new column ``name`` of a specified type. Explicitly create a new column ``name`` of a specified type.
@ -307,10 +255,10 @@ class Table(object):
table.create_column('created_at', db.types.datetime) table.create_column('created_at', db.types.datetime)
""" """
self._check_dropped() # TODO: create the table if it does not exist.
with self.database.lock: with self.database.lock:
name = normalize_column_name(name) name = normalize_column_name(name)
if name in self.columns: if self.has_column(name):
log.debug("Column exists: %s" % name) log.debug("Column exists: %s" % name)
return return
@ -341,10 +289,8 @@ class Table(object):
""" """
if self.database.engine.dialect.name == 'sqlite': if self.database.engine.dialect.name == 'sqlite':
raise NotImplementedError("SQLite does not support dropping columns.") raise NotImplementedError("SQLite does not support dropping columns.")
if not self.exists: name = normalize_column_name(name)
return if not self.exists or not self.has_column(name):
self._check_dropped()
if name not in self.columns:
log.debug("Column does not exist: %s", name) log.debug("Column does not exist: %s", name)
return return
with self.database.lock: with self.database.lock:
@ -357,9 +303,14 @@ class Table(object):
def has_index(self, columns): def has_index(self, columns):
"""Check if an index exists to cover the given `columns`.""" """Check if an index exists to cover the given `columns`."""
if not self.exists:
return False
columns = set([normalize_column_name(c) for c in columns]) columns = set([normalize_column_name(c) for c in columns])
if columns in self._indexes: if columns in self._indexes:
return True return True
for column in columns:
if not self.has_column(column):
raise DatasetException("Column does not exist: %s" % column)
inspector = Inspector.from_engine(self.database.executable) inspector = Inspector.from_engine(self.database.executable)
indexes = inspector.get_indexes(self.name, schema=self.database.schema) indexes = inspector.get_indexes(self.name, schema=self.database.schema)
for index in indexes: for index in indexes:
@ -377,9 +328,11 @@ class Table(object):
table.create_index(['name', 'country']) table.create_index(['name', 'country'])
""" """
self._check_dropped()
columns = [normalize_column_name(c) for c in columns] columns = [normalize_column_name(c) for c in columns]
with self.database.lock: with self.database.lock:
if not self.exists:
# TODO
pass
if not self.has_index(columns): if not self.has_index(columns):
name = name or index_name(self.name, columns) name = name or index_name(self.name, columns)
columns = [self.table.c[c] for c in columns] columns = [self.table.c[c] for c in columns]
@ -387,10 +340,8 @@ class Table(object):
idx.create(self.database.executable) idx.create(self.database.executable)
def _args_to_order_by(self, order_by): def _args_to_order_by(self, order_by):
if not isinstance(order_by, (list, tuple)):
order_by = [order_by]
orderings = [] orderings = []
for ordering in order_by: for ordering in ensure_tuple(order_by):
if ordering is None: if ordering is None:
continue continue
column = ordering.lstrip('-') column = ordering.lstrip('-')
@ -430,22 +381,25 @@ class Table(object):
""" """
_limit = kwargs.pop('_limit', None) _limit = kwargs.pop('_limit', None)
_offset = kwargs.pop('_offset', 0) _offset = kwargs.pop('_offset', 0)
_step = kwargs.pop('_step', 5000) _step = kwargs.pop('_step', QUERY_STEP)
order_by = kwargs.pop('order_by', None) order_by = kwargs.pop('order_by', None)
self._check_dropped() if not self.exists:
return []
order_by = self._args_to_order_by(order_by) order_by = self._args_to_order_by(order_by)
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses) args = self._args_to_clause(kwargs, clauses=_clauses)
if _step is False or _step == 0: if _step is False or _step == 0:
_step = None _step = None
query = self.table.select(whereclause=args, limit=_limit, query = self.table.select(whereclause=args,
limit=_limit,
offset=_offset) offset=_offset)
if len(order_by): if len(order_by):
query = query.order_by(*order_by) query = query.order_by(*order_by)
return ResultIter(self.database.executable.execute(query), return ResultIter(self.database.executable.execute(query),
row_type=self.database.row_type, step=_step) row_type=self.database.row_type,
step=_step)
def find_one(self, *args, **kwargs): def find_one(self, *args, **kwargs):
"""Get a single result from the table. """Get a single result from the table.
@ -458,6 +412,7 @@ class Table(object):
""" """
if not self.exists: if not self.exists:
return None return None
kwargs['_limit'] = 1 kwargs['_limit'] = 1
kwargs['_step'] = None kwargs['_step'] = None
resiter = self.find(*args, **kwargs) resiter = self.find(*args, **kwargs)
@ -474,8 +429,8 @@ class Table(object):
# with people using these flags. Let's see how it goes. # with people using these flags. Let's see how it goes.
if not self.exists: if not self.exists:
return 0 return 0
self._check_dropped()
args = self._args_to_clause(kwargs, ensure=False, clauses=_clauses) args = self._args_to_clause(kwargs, clauses=_clauses)
query = select([func.count()], whereclause=args) query = select([func.count()], whereclause=args)
query = query.select_from(self.table) query = query.select_from(self.table)
rp = self.database.executable.execute(query) rp = self.database.executable.execute(query)
@ -486,10 +441,7 @@ class Table(object):
return self.count() return self.count()
def distinct(self, *args, **_filter): def distinct(self, *args, **_filter):
""" """Return all the unique (distinct) values for the given ``columns``.
Return all rows of a table, but remove rows in with duplicate values in ``columns``.
Interally this creates a `DISTINCT statement <http://www.w3schools.com/sql/sql_distinct.asp>`_.
:: ::
# returns only one row per year, ignoring the rest # returns only one row per year, ignoring the rest
@ -499,62 +451,47 @@ class Table(object):
# you can also combine this with a filter # you can also combine this with a filter
table.distinct('year', country='China') table.distinct('year', country='China')
""" """
self._check_dropped() if not self.exists:
qargs = []
columns = []
try:
for c in args:
if isinstance(c, ClauseElement):
qargs.append(c)
else:
columns.append(self.table.c[c])
for col, val in _filter.items():
qargs.append(self.table.c[col] == val)
except KeyError:
return [] return []
q = expression.select(columns, distinct=True, filters = []
whereclause=and_(*qargs), for column, value in _filter.items():
if not self.has_column(column):
raise DatasetException("No such column: %s" % column)
filters.append(self.table.c[column] == value)
columns = []
for column in args:
if isinstance(column, ClauseElement):
filters.append(column)
else:
if not self.has_column(column):
raise DatasetException("No such column: %s" % column)
columns.append(self.table.c[column])
if not len(columns):
return []
q = expression.select(columns,
distinct=True,
whereclause=and_(*filters),
order_by=[c.asc() for c in columns]) order_by=[c.asc() for c in columns])
return self.database.query(q) return self.database.query(q)
def __getitem__(self, item): # Legacy methods for running find queries.
""" all = find
Get distinct column values.
This is an alias for distinct which allows the table to be queried as using
square bracket syntax.
::
# Same as distinct:
print list(table['year'])
"""
if not isinstance(item, tuple):
item = item,
return self.distinct(*item)
def all(self):
"""
Return all rows of the table as simple dictionaries.
This is simply a shortcut to *find()* called with no arguments.
::
rows = table.all()
"""
return self.find()
def __iter__(self): def __iter__(self):
""" """Return all rows of the table as simple dictionaries.
Return all rows of the table as simple dictionaries.
Allows for iterating over all rows in the table without explicetly Allows for iterating over all rows in the table without explicetly
calling :py:meth:`all() <dataset.Table.all>`. calling :py:meth:`find() <dataset.Table.find>`.
:: ::
for row in table: for row in table:
print(row) print(row)
""" """
return self.all() return self.find()
def __repr__(self): def __repr__(self):
"""Get table representation.""" """Get table representation."""

View File

@ -9,8 +9,10 @@ except ImportError: # pragma: no cover
from ordereddict import OrderedDict from ordereddict import OrderedDict
from six import string_types from six import string_types
from collections import Sequence
from hashlib import sha1 from hashlib import sha1
QUERY_STEP = 1000
row_type = OrderedDict row_type = OrderedDict
@ -78,3 +80,12 @@ def index_name(table, columns):
sig = '||'.join(columns) sig = '||'.join(columns)
key = sha1(sig.encode('utf-8')).hexdigest()[:16] key = sha1(sig.encode('utf-8')).hexdigest()[:16]
return 'ix_%s_%s' % (table, key) return 'ix_%s_%s' % (table, key)
def ensure_tuple(obj):
"""Try and make the given argument into a tuple."""
if obj is None:
return tuple()
if isinstance(obj, Sequence) and not isinstance(obj, string_types):
return tuple(obj)
return obj,

View File

@ -31,7 +31,7 @@ setup(
include_package_data=False, include_package_data=False,
zip_safe=False, zip_safe=False,
install_requires=[ install_requires=[
'sqlalchemy >= 0.9.1', 'sqlalchemy >= 1.1.0',
'alembic >= 0.6.2', 'alembic >= 0.6.2',
'normality >= 0.3.9', 'normality >= 0.3.9',
"PyYAML >= 3.10", "PyYAML >= 3.10",

View File

@ -11,7 +11,6 @@ from sqlalchemy import FLOAT, INTEGER, TEXT
from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
from dataset import connect from dataset import connect
from dataset.util import DatasetException
from .sample_data import TEST_DATA, TEST_CITY_1 from .sample_data import TEST_DATA, TEST_CITY_1
@ -185,13 +184,14 @@ class TableTestCase(unittest.TestCase):
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_insert_ignore_all_key(self): def test_insert_ignore_all_key(self):
for i in range(0, 2): for i in range(0, 4):
self.tbl.insert_ignore({ self.tbl.insert_ignore({
'date': datetime(2011, 1, 2), 'date': datetime(2011, 1, 2),
'temperature': -10, 'temperature': -10,
'place': 'Berlin'}, 'place': 'Berlin'},
['date', 'temperature', 'place'] ['date', 'temperature', 'place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_upsert(self): def test_upsert(self):
self.tbl.upsert({ self.tbl.upsert({
@ -209,7 +209,21 @@ class TableTestCase(unittest.TestCase):
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_upsert_single_column(self):
table = self.db['banana_single_col']
table.upsert({
'color': 'Yellow'},
['color']
)
assert len(table) == 1, len(table)
table.upsert({
'color': 'Yellow'},
['color']
)
assert len(table) == 1, len(table)
def test_upsert_all_key(self): def test_upsert_all_key(self):
assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
for i in range(0, 2): for i in range(0, 2):
self.tbl.upsert({ self.tbl.upsert({
'date': datetime(2011, 1, 2), 'date': datetime(2011, 1, 2),
@ -217,11 +231,13 @@ class TableTestCase(unittest.TestCase):
'place': 'Berlin'}, 'place': 'Berlin'},
['date', 'temperature', 'place'] ['date', 'temperature', 'place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_update_while_iter(self): def test_update_while_iter(self):
for row in self.tbl: for row in self.tbl:
row['foo'] = 'bar' row['foo'] = 'bar'
self.tbl.update(row, ['place', 'date']) self.tbl.update(row, ['place', 'date'])
assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
def test_weird_column_names(self): def test_weird_column_names(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
@ -262,17 +278,19 @@ class TableTestCase(unittest.TestCase):
assert len(self.tbl) == 0, len(self.tbl) assert len(self.tbl) == 0, len(self.tbl)
def test_repr(self): def test_repr(self):
assert repr(self.tbl) == '<Table(weather)>', 'the representation should be <Table(weather)>' assert repr(self.tbl) == '<Table(weather)>', \
'the representation should be <Table(weather)>'
def test_delete_nonexist_entry(self): def test_delete_nonexist_entry(self):
assert self.tbl.delete(place='Berlin') is False, 'entry not exist, should fail to delete' assert self.tbl.delete(place='Berlin') is False, \
'entry not exist, should fail to delete'
def test_find_one(self): def test_find_one(self):
self.tbl.insert({ self.tbl.insert({
'date': datetime(2011, 1, 2), 'date': datetime(2011, 1, 2),
'temperature': -10, 'temperature': -10,
'place': 'Berlin'} 'place': 'Berlin'
) })
d = self.tbl.find_one(place='Berlin') d = self.tbl.find_one(place='Berlin')
assert d['temperature'] == -10, d assert d['temperature'] == -10, d
d = self.tbl.find_one(place='Atlantis') d = self.tbl.find_one(place='Atlantis')
@ -317,29 +335,17 @@ class TableTestCase(unittest.TestCase):
self.tbl.table.columns.date >= datetime(2011, 1, 2, 0, 0))) self.tbl.table.columns.date >= datetime(2011, 1, 2, 0, 0)))
assert len(x) == 4, x assert len(x) == 4, x
def test_get_items(self):
x = list(self.tbl['place'])
y = list(self.tbl.distinct('place'))
assert x == y, (x, y)
x = list(self.tbl['place', 'date'])
y = list(self.tbl.distinct('place', 'date'))
assert x == y, (x, y)
def test_insert_many(self): def test_insert_many(self):
data = TEST_DATA * 100 data = TEST_DATA * 100
self.tbl.insert_many(data, chunk_size=13) self.tbl.insert_many(data, chunk_size=13)
assert len(self.tbl) == len(data) + 6 assert len(self.tbl) == len(data) + 6
def test_drop_warning(self): def test_drop_operations(self):
assert self.tbl.table is not None, 'table shouldn\'t be dropped yet' assert self.tbl.table is not None, 'table shouldn\'t be dropped yet'
self.tbl.drop() self.tbl.drop()
assert self.tbl.table is None, 'table should be dropped now' assert self.tbl.table is None, 'table should be dropped now'
try: assert list(self.tbl.all()) == [], self.tbl.all()
list(self.tbl.all()) assert self.tbl.count() == 0, self.tbl.count()
except DatasetException:
pass
else:
assert False, 'we should not reach else block, no exception raised!'
def test_table_drop(self): def test_table_drop(self):
assert 'weather' in self.db assert 'weather' in self.db
@ -380,26 +386,31 @@ class TableTestCase(unittest.TestCase):
) )
assert res, 'update should return True' assert res, 'update should return True'
m = self.tbl.find_one(place=TEST_CITY_1, date=date) m = self.tbl.find_one(place=TEST_CITY_1, date=date)
assert m['temperature'] == -10, 'new temp. should be -10 but is %d' % m['temperature'] assert m['temperature'] == -10, \
'new temp. should be -10 but is %d' % m['temperature']
def test_create_column(self): def test_create_column(self):
tbl = self.tbl tbl = self.tbl
tbl.create_column('foo', FLOAT) tbl.create_column('foo', FLOAT)
assert 'foo' in tbl.table.c, tbl.table.c assert 'foo' in tbl.table.c, tbl.table.c
assert isinstance(tbl.table.c['foo'].type, FLOAT), tbl.table.c['foo'].type assert isinstance(tbl.table.c['foo'].type, FLOAT), \
tbl.table.c['foo'].type
assert 'foo' in tbl.columns, tbl.columns assert 'foo' in tbl.columns, tbl.columns
def test_ensure_column(self): def test_ensure_column(self):
tbl = self.tbl tbl = self.tbl
tbl.create_column_by_example('foo', 0.1) tbl.create_column_by_example('foo', 0.1)
assert 'foo' in tbl.table.c, tbl.table.c assert 'foo' in tbl.table.c, tbl.table.c
assert isinstance(tbl.table.c['foo'].type, FLOAT), tbl.table.c['bar'].type assert isinstance(tbl.table.c['foo'].type, FLOAT), \
tbl.table.c['bar'].type
tbl.create_column_by_example('bar', 1) tbl.create_column_by_example('bar', 1)
assert 'bar' in tbl.table.c, tbl.table.c assert 'bar' in tbl.table.c, tbl.table.c
assert isinstance(tbl.table.c['bar'].type, INTEGER), tbl.table.c['bar'].type assert isinstance(tbl.table.c['bar'].type, INTEGER), \
tbl.table.c['bar'].type
tbl.create_column_by_example('pippo', 'test') tbl.create_column_by_example('pippo', 'test')
assert 'pippo' in tbl.table.c, tbl.table.c assert 'pippo' in tbl.table.c, tbl.table.c
assert isinstance(tbl.table.c['pippo'].type, TEXT), tbl.table.c['pippo'].type assert isinstance(tbl.table.c['pippo'].type, TEXT), \
tbl.table.c['pippo'].type
def test_key_order(self): def test_key_order(self):
res = self.db.query('SELECT temperature, place FROM weather LIMIT 1') res = self.db.query('SELECT temperature, place FROM weather LIMIT 1')