2013-04-01 19:06:03 +02:00
|
|
|
import logging
|
|
|
|
|
from itertools import count
|
|
|
|
|
|
|
|
|
|
from sqlalchemy.sql import and_, expression
|
|
|
|
|
from sqlalchemy.schema import Column, Index
|
|
|
|
|
|
2013-04-01 19:28:22 +02:00
|
|
|
from dataset.persistence.util import guess_type
|
2013-04-01 19:06:03 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
2013-04-01 18:05:41 +02:00
|
|
|
|
|
|
|
|
|
|
|
|
|
class Table(object):
|
|
|
|
|
|
|
|
|
|
def __init__(self, database, table):
|
2013-04-01 19:06:03 +02:00
|
|
|
self.indexes = {}
|
2013-04-01 18:05:41 +02:00
|
|
|
self.database = database
|
|
|
|
|
self.table = table
|
|
|
|
|
|
|
|
|
|
def drop(self):
|
|
|
|
|
with self.database.lock:
|
|
|
|
|
self.database.tables.pop(self.table.name, None)
|
|
|
|
|
self.table.drop(engine)
|
|
|
|
|
|
2013-04-01 19:06:03 +02:00
|
|
|
def insert(self, row, ensure=True, types={}):
|
|
|
|
|
""" Add a row (type: dict). If ``ensure`` is set, any of
|
|
|
|
|
the keys of the row are not table columns, they will be type
|
|
|
|
|
guessed and created. """
|
|
|
|
|
if ensure:
|
|
|
|
|
self._ensure_columns(row, types=types)
|
|
|
|
|
self.database.engine.execute(self.table.insert(row))
|
|
|
|
|
|
|
|
|
|
def update(self, row, unique, ensure=True, types={}):
|
|
|
|
|
if not len(unique):
|
|
|
|
|
return False
|
|
|
|
|
clause = [(u, row.get(u)) for u in unique]
|
|
|
|
|
if ensure:
|
|
|
|
|
self._ensure_columns(row, types=types)
|
|
|
|
|
try:
|
|
|
|
|
filters = self._args_to_clause(dict(clause))
|
|
|
|
|
stmt = self.table.update(filters, row)
|
|
|
|
|
rp = self.database.engine.execute(stmt)
|
|
|
|
|
return rp.rowcount > 0
|
|
|
|
|
except KeyError, ke:
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
def upsert(self, row, unique, ensure=True, types={}):
|
|
|
|
|
if ensure:
|
|
|
|
|
self.create_index(unique)
|
|
|
|
|
|
|
|
|
|
if not self.update(row, unique, ensure=ensure, types=types):
|
|
|
|
|
self.insert(row, ensure=ensure, types=types)
|
|
|
|
|
|
|
|
|
|
def delete(self, **kw):
|
|
|
|
|
q = self._args_to_clause(kw)
|
2013-04-01 22:03:01 +02:00
|
|
|
stmt = self.table.delete(q)
|
|
|
|
|
self.database.engine.execute(stmt)
|
2013-04-01 19:06:03 +02:00
|
|
|
|
|
|
|
|
def _ensure_columns(self, row, types={}):
|
2013-04-01 19:46:17 +02:00
|
|
|
for column in set(row.keys()) - set(self.table.columns.keys()):
|
2013-04-01 19:06:03 +02:00
|
|
|
if column in types:
|
|
|
|
|
_type = types[column]
|
|
|
|
|
else:
|
|
|
|
|
_type = guess_type(row[column])
|
|
|
|
|
log.debug("Creating column: %s (%s) on %r" % (column,
|
|
|
|
|
_type, self.table.name))
|
|
|
|
|
self.create_column(column, _type)
|
|
|
|
|
|
|
|
|
|
def _args_to_clause(self, args):
|
2013-04-01 22:03:01 +02:00
|
|
|
self._ensure_columns(args)
|
2013-04-01 19:06:03 +02:00
|
|
|
clauses = []
|
|
|
|
|
for k, v in args.items():
|
|
|
|
|
clauses.append(self.table.c[k] == v)
|
|
|
|
|
return and_(*clauses)
|
|
|
|
|
|
|
|
|
|
def create_column(self, name, type):
|
|
|
|
|
with self.database.lock:
|
|
|
|
|
if name not in self.table.columns.keys():
|
|
|
|
|
col = Column(name, type)
|
|
|
|
|
col.create(self.table,
|
|
|
|
|
connection=self.database.engine)
|
|
|
|
|
|
|
|
|
|
def create_index(self, columns, name=None):
|
|
|
|
|
with self.database.lock:
|
|
|
|
|
if not name:
|
|
|
|
|
sig = abs(hash('||'.join(columns)))
|
|
|
|
|
name = 'ix_%s_%s' % (self.table.name, sig)
|
|
|
|
|
if name in self.indexes:
|
|
|
|
|
return self.indexes[name]
|
|
|
|
|
try:
|
|
|
|
|
columns = [self.table.c[c] for c in columns]
|
|
|
|
|
idx = Index(name, *columns)
|
|
|
|
|
idx.create(self.database.engine)
|
|
|
|
|
except:
|
|
|
|
|
idx = None
|
|
|
|
|
self.indexes[name] = idx
|
|
|
|
|
return idx
|
|
|
|
|
|
2013-04-01 22:03:01 +02:00
|
|
|
def find_one(self, **kw):
|
|
|
|
|
res = list(self.find(_limit=1, **kw))
|
2013-04-01 19:06:03 +02:00
|
|
|
if not len(res):
|
|
|
|
|
return None
|
|
|
|
|
return res[0]
|
|
|
|
|
|
2013-04-01 22:03:01 +02:00
|
|
|
def find(self, _limit=None, _step=5000, _offset=0,
|
2013-04-01 19:06:03 +02:00
|
|
|
order_by='id', **kw):
|
2013-04-01 22:03:01 +02:00
|
|
|
order_by = [self.table.c[order_by].asc()]
|
2013-04-01 19:06:03 +02:00
|
|
|
args = self._args_to_clause(kw)
|
|
|
|
|
|
|
|
|
|
for i in count():
|
|
|
|
|
qoffset = _offset + (_step * i)
|
|
|
|
|
qlimit = _step
|
|
|
|
|
if _limit is not None:
|
|
|
|
|
qlimit = min(_limit-(_step*i), _step)
|
|
|
|
|
if qlimit <= 0:
|
|
|
|
|
break
|
|
|
|
|
q = self.table.select(whereclause=args, limit=qlimit,
|
|
|
|
|
offset=qoffset, order_by=order_by)
|
2013-04-01 19:28:22 +02:00
|
|
|
rows = list(self.database.query(q))
|
2013-04-01 19:06:03 +02:00
|
|
|
if not len(rows):
|
|
|
|
|
return
|
|
|
|
|
for row in rows:
|
|
|
|
|
yield row
|
|
|
|
|
|
2013-04-01 19:56:14 +02:00
|
|
|
def __len__(self):
|
|
|
|
|
d = self.database.query(self.table.count()).next()
|
|
|
|
|
return d.values().pop()
|
|
|
|
|
|
2013-04-01 19:06:03 +02:00
|
|
|
def distinct(self, *columns, **kw):
|
|
|
|
|
qargs = []
|
|
|
|
|
try:
|
|
|
|
|
columns = [self.table.c[c] for c in columns]
|
|
|
|
|
for col, val in kw.items():
|
|
|
|
|
qargs.append(self.table.c[col]==val)
|
|
|
|
|
except KeyError:
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
q = expression.select(columns, distinct=True,
|
|
|
|
|
whereclause=and_(*qargs),
|
|
|
|
|
order_by=[c.asc() for c in columns])
|
2013-04-01 19:28:22 +02:00
|
|
|
return self.database.query(q)
|
2013-04-01 19:06:03 +02:00
|
|
|
|
|
|
|
|
def all(self):
|
|
|
|
|
return self.find()
|
2013-04-01 18:05:41 +02:00
|
|
|
|