352 lines
12 KiB
Python
Raw Normal View History

import logging
from itertools import count
from sqlalchemy.sql import and_, expression
from sqlalchemy.schema import Column, Index
2013-04-01 19:28:22 +02:00
from dataset.persistence.util import guess_type
2013-04-12 16:42:22 +02:00
from dataset.persistence.util import ResultIter
2013-04-05 00:31:13 +02:00
from dataset.util import DatasetException
log = logging.getLogger(__name__)
2013-04-01 18:05:41 +02:00
class Table(object):
def __init__(self, database, table):
self.indexes = {}
2013-04-01 18:05:41 +02:00
self.database = database
self.table = table
2013-04-05 00:31:13 +02:00
self._is_dropped = False
2013-04-01 18:05:41 +02:00
2013-04-03 00:51:33 +02:00
@property
def columns(self):
2013-04-03 01:48:26 +02:00
"""
Get a listing of all columns that exist in the table.
>>> print 'age' in table.columns
True
"""
2013-04-03 00:51:33 +02:00
return set(self.table.columns.keys())
2013-04-01 18:05:41 +02:00
def drop(self):
2013-04-03 00:51:33 +02:00
"""
Drop the table from the database, deleting both the schema
2013-04-02 00:10:07 +02:00
and all the contents within it.
2013-04-02 11:10:29 +02:00
2013-04-05 00:49:13 +02:00
Note: the object will raise an Exception if you use it after
dropping the table. If you want to re-create the table, make
sure to get a fresh instance from the :py:class:`Database <dataset.Database>`.
2013-04-03 00:51:33 +02:00
"""
2013-04-05 00:31:13 +02:00
self._is_dropped = True
2013-04-01 18:05:41 +02:00
with self.database.lock:
2013-04-05 00:47:28 +02:00
self.database._tables.pop(self.table.name, None)
self.table.drop(self.database.engine)
2013-04-01 18:05:41 +02:00
2013-04-05 00:31:13 +02:00
def _check_dropped(self):
if self._is_dropped:
raise DatasetException('the table has been dropped. this object should not be used again.')
def insert(self, row, ensure=True, types={}):
2013-04-03 00:51:33 +02:00
"""
Add a row (type: dict) by inserting it into the table.
2013-04-02 00:10:07 +02:00
If ``ensure`` is set, any of the keys of the row are not
2013-04-02 11:10:29 +02:00
table columns, they will be created automatically.
2013-04-02 00:10:07 +02:00
During column creation, ``types`` will be checked for a key
2013-04-02 11:10:29 +02:00
matching the name of a column to be created, and the given
2013-04-02 00:10:07 +02:00
SQLAlchemy column type will be used. Otherwise, the type is
2013-04-03 22:27:06 +02:00
guessed from the row value, defaulting to a simple unicode
2013-04-02 13:44:14 +02:00
field.
::
2013-04-03 00:51:33 +02:00
2013-04-04 20:05:27 +02:00
data = dict(title='I am a banana!')
table.insert(data)
2013-04-02 13:44:14 +02:00
"""
2013-04-05 00:31:13 +02:00
self._check_dropped()
if ensure:
self._ensure_columns(row, types=types)
2013-04-06 03:06:51 +02:00
res = self.database.engine.execute(self.table.insert(row))
return res.lastrowid
2013-04-04 19:44:28 +02:00
def insert_many(self, rows, chunk_size=1000, ensure=True, types={}):
2013-04-04 15:43:05 +02:00
"""
Add many rows at a time, which is significantly faster than adding
2013-04-04 19:44:28 +02:00
them one by one. Per default the rows are processed in chunks of
1000 per commit, unless you specify a different ``chunk_size``.
See :py:meth:`insert() <dataset.Table.insert>` for details on
the other parameters.
2013-04-04 15:43:05 +02:00
::
rows = [dict(name='Dolly')] * 10000
table.insert_many(rows)
"""
def _process_chunk(chunk):
if ensure:
for row in chunk:
self._ensure_columns(row, types=types)
self.table.insert().execute(chunk)
2013-04-05 00:31:13 +02:00
self._check_dropped()
2013-04-04 15:43:05 +02:00
chunk = []
i = 0
for row in rows:
chunk.append(row)
i += 1
if i == chunk_size:
_process_chunk(chunk)
chunk = []
i = 0
if i > 0:
_process_chunk(chunk)
2013-04-02 11:10:29 +02:00
def update(self, row, keys, ensure=True, types={}):
2013-04-03 00:51:33 +02:00
"""
Update a row in the table. The update is managed via
2013-04-02 11:10:29 +02:00
the set of column names stated in ``keys``: they will be
2013-04-02 00:10:07 +02:00
used as filters for the data to be updated, using the values
2013-04-02 11:10:29 +02:00
in ``row``.
::
2013-04-02 00:10:07 +02:00
2013-04-02 11:10:29 +02:00
# update all entries with id matching 10, setting their title columns
2013-04-02 00:10:07 +02:00
data = dict(id=10, title='I am a banana!')
table.update(data, ['id'])
2013-04-02 11:10:29 +02:00
If keys in ``row`` update columns not present in the table,
they will be created based on the settings of ``ensure`` and
2013-04-04 19:44:28 +02:00
``types``, matching the behavior of :py:meth:`insert() <dataset.Table.insert>`.
2013-04-02 00:10:07 +02:00
"""
2013-04-05 00:31:13 +02:00
self._check_dropped()
2013-04-02 11:10:29 +02:00
if not len(keys):
return False
2013-04-02 11:10:29 +02:00
clause = [(u, row.get(u)) for u in keys]
if ensure:
self._ensure_columns(row, types=types)
try:
filters = self._args_to_clause(dict(clause))
stmt = self.table.update(filters, row)
rp = self.database.engine.execute(stmt)
return rp.rowcount > 0
2013-04-03 00:56:07 +02:00
except KeyError:
return False
2013-04-02 11:10:29 +02:00
def upsert(self, row, keys, ensure=True, types={}):
2013-04-03 00:51:33 +02:00
"""
An UPSERT is a smart combination of insert and update. If rows with matching ``keys`` exist
2013-04-02 11:10:29 +02:00
they will be updated, otherwise a new row is inserted in the table.
::
2013-04-03 00:51:33 +02:00
2013-04-02 11:10:29 +02:00
data = dict(id=10, title='I am a banana!')
table.upsert(data, ['id'])
"""
2013-04-05 00:31:13 +02:00
self._check_dropped()
if ensure:
2013-04-02 11:10:29 +02:00
self.create_index(keys)
2013-04-02 11:10:29 +02:00
if not self.update(row, keys, ensure=ensure, types=types):
self.insert(row, ensure=ensure, types=types)
2013-04-02 11:10:29 +02:00
def delete(self, **filter):
2013-04-02 00:20:02 +02:00
""" Delete rows from the table. Keyword arguments can be used
to add column-based filters. The filter criterion will always
be equality:
.. code-block:: python
table.delete(place='Berlin')
2013-04-04 20:08:39 +02:00
If no arguments are given, all records are deleted.
2013-04-02 00:20:02 +02:00
"""
2013-04-05 00:31:13 +02:00
self._check_dropped()
2013-04-05 11:54:12 +02:00
if len(filter) > 0:
q = self._args_to_clause(filter)
stmt = self.table.delete(q)
else:
stmt = self.table.delete()
2013-04-01 22:03:01 +02:00
self.database.engine.execute(stmt)
def _ensure_columns(self, row, types={}):
2013-04-01 19:46:17 +02:00
for column in set(row.keys()) - set(self.table.columns.keys()):
if column in types:
_type = types[column]
else:
_type = guess_type(row[column])
2013-04-02 11:10:29 +02:00
log.debug("Creating column: %s (%s) on %r" % (column,
_type, self.table.name))
self.create_column(column, _type)
def _args_to_clause(self, args):
2013-04-01 22:03:01 +02:00
self._ensure_columns(args)
clauses = []
for k, v in args.items():
clauses.append(self.table.c[k] == v)
return and_(*clauses)
def create_column(self, name, type):
2013-04-03 01:48:26 +02:00
"""
2013-04-03 22:27:06 +02:00
Explicitely create a new column ``name`` of a specified type.
``type`` must be a `SQLAlchemy column type <http://docs.sqlalchemy.org/en/rel_0_8/core/types.html>`_.
2013-04-03 01:48:26 +02:00
::
2013-04-04 20:08:39 +02:00
table.create_column('created_at', sqlalchemy.DateTime)
2013-04-03 01:48:26 +02:00
"""
2013-04-05 00:31:13 +02:00
self._check_dropped()
with self.database.lock:
if name not in self.table.columns.keys():
col = Column(name, type)
col.create(self.table,
2013-04-02 11:10:29 +02:00
connection=self.database.engine)
def create_index(self, columns, name=None):
2013-04-03 01:48:26 +02:00
"""
Create an index to speed up queries on a table. If no ``name`` is given a random name is created.
::
table.create_index(['name', 'country'])
"""
2013-04-05 00:31:13 +02:00
self._check_dropped()
with self.database.lock:
if not name:
sig = abs(hash('||'.join(columns)))
name = 'ix_%s_%s' % (self.table.name, sig)
if name in self.indexes:
return self.indexes[name]
try:
columns = [self.table.c[c] for c in columns]
idx = Index(name, *columns)
idx.create(self.database.engine)
except:
idx = None
self.indexes[name] = idx
return idx
2013-04-02 11:10:29 +02:00
def find_one(self, **filter):
2013-04-03 00:51:33 +02:00
"""
Works just like :py:meth:`find() <dataset.Table.find>` but returns only one result.
2013-04-02 11:10:29 +02:00
::
2013-04-03 00:51:33 +02:00
2013-04-02 11:10:29 +02:00
row = table.find_one(country='United States')
"""
2013-04-05 00:31:13 +02:00
self._check_dropped()
2013-04-02 11:10:29 +02:00
res = list(self.find(_limit=1, **filter))
if not len(res):
return None
return res[0]
def _args_to_order_by(self, order_by):
if order_by[0] == '-':
return self.table.c[order_by[1:]].desc()
else:
return self.table.c[order_by].asc()
2013-04-02 23:45:44 +02:00
def find(self, _limit=None, _offset=0, _step=5000,
2013-04-02 11:10:29 +02:00
order_by='id', **filter):
2013-04-03 00:51:33 +02:00
"""
Performs a simple search on the table. Simply pass keyword arguments as ``filter``.
2013-04-02 11:10:29 +02:00
::
2013-04-03 00:51:33 +02:00
2013-04-02 11:10:29 +02:00
results = table.find(country='France')
results = table.find(country='France', year=1980)
2013-04-02 23:45:44 +02:00
2013-04-03 00:51:33 +02:00
Using ``_limit``::
2013-04-02 11:10:29 +02:00
# just return the first 10 rows
results = table.find(country='France', _limit=10)
2013-04-02 13:44:14 +02:00
2013-04-02 23:45:44 +02:00
You can sort the results by single or multiple columns. Append a minus sign
to the column name for descending order::
2013-04-03 00:51:33 +02:00
# sort results by a column 'year'
2013-04-02 11:10:29 +02:00
results = table.find(country='France', order_by='year')
# return all rows sorted by multiple columns (by year in descending order)
results = table.find(order_by=['country', '-year'])
2013-04-02 11:10:29 +02:00
2013-04-03 22:27:06 +02:00
For more complex queries, please use :py:meth:`db.query() <dataset.Database.query>`
instead."""
2013-04-05 00:31:13 +02:00
self._check_dropped()
if isinstance(order_by, (str, unicode)):
order_by = [order_by]
order_by = filter(lambda o: o in self.table.columns, order_by)
order_by = [self._args_to_order_by(o) for o in order_by]
2013-04-02 11:10:29 +02:00
args = self._args_to_clause(filter)
2013-04-12 16:42:22 +02:00
# query total number of rows first
count_query = self.table.count(whereclause=args, limit=_limit, offset=_offset)
rp = self.database.engine.execute(count_query)
total_row_count = rp.fetchone()[0]
if total_row_count > _step and len(order_by) == 0:
_step = total_row_count
log.warn("query cannot be broken into smaller sections because it is unordered")
2013-04-12 16:42:22 +02:00
queries = []
for i in count():
qoffset = _offset + (_step * i)
qlimit = _step
if _limit is not None:
2013-04-02 11:10:29 +02:00
qlimit = min(_limit - (_step * i), _step)
if qlimit <= 0:
break
2013-04-12 16:42:22 +02:00
if qoffset > total_row_count:
break
queries.append(self.table.select(whereclause=args, limit=qlimit,
offset=qoffset, order_by=order_by))
return ResultIter((self.database.engine.execute(q) for q in queries))
2013-04-01 19:56:14 +02:00
def __len__(self):
"""
Returns the number of rows in the table.
"""
2013-04-01 19:56:14 +02:00
d = self.database.query(self.table.count()).next()
return d.values().pop()
2013-04-02 11:10:29 +02:00
def distinct(self, *columns, **filter):
2013-04-03 00:51:33 +02:00
"""
Returns all rows of a table, but removes rows in with duplicate values in ``columns``.
2013-04-02 11:10:29 +02:00
Interally this creates a `DISTINCT statement <http://www.w3schools.com/sql/sql_distinct.asp>`_.
::
2013-04-03 00:51:33 +02:00
2013-04-02 11:10:29 +02:00
# returns only one row per year, ignoring the rest
table.distinct('year')
# works with multiple columns, too
table.distinct('year', 'country')
# you can also combine this with a filter
table.distinct('year', country='China')
"""
2013-04-05 00:31:13 +02:00
self._check_dropped()
qargs = []
try:
columns = [self.table.c[c] for c in columns]
2013-04-02 11:10:29 +02:00
for col, val in filter.items():
qargs.append(self.table.c[col] == val)
except KeyError:
return []
q = expression.select(columns, distinct=True,
2013-04-02 11:10:29 +02:00
whereclause=and_(*qargs),
order_by=[c.asc() for c in columns])
2013-04-01 19:28:22 +02:00
return self.database.query(q)
def all(self):
2013-04-03 00:51:33 +02:00
"""
Returns all rows of the table as simple dictionaries. This is simply a shortcut
2013-04-02 11:10:29 +02:00
to *find()* called with no arguments.
::
2013-04-03 00:51:33 +02:00
2013-04-02 11:10:29 +02:00
rows = table.all()"""
return self.find()
2013-04-01 18:05:41 +02:00
def __iter__(self):
"""
2013-04-03 12:46:10 +02:00
Allows for iterating over all rows in the table without explicetly
calling :py:meth:`all() <dataset.Table.all>`.
::
for row in table:
print row
"""
2013-04-12 16:42:22 +02:00
return self.all()