Moved everything into the OO api, not really tested …

This commit is contained in:
Friedrich Lindenberg 2013-04-01 19:06:03 +02:00
parent 58bf7bf27f
commit d3fe91bc73
7 changed files with 137 additions and 285 deletions

View File

@ -1,12 +1,10 @@
from dataset.schema import connect
from dataset.schema import create_table, load_table, get_table, drop_table
from dataset.schema import create_column
from dataset.write import add_row, update_row
from dataset.write import upsert, update, delete
from dataset.query import distinct, resultiter, all, find_one, find, query
from dataset.db import create
# shut up useless SA warning: # shut up useless SA warning:
import warnings import warnings
warnings.filterwarnings('ignore', 'Unicode type received non-unicode bind param value.') warnings.filterwarnings('ignore', 'Unicode type received non-unicode bind param value.')
from dataset.persistence.database import Database
from dataset.persistence.table import Table
def connect(url):
return Database(url)

View File

@ -1,82 +0,0 @@
from sqlaload.schema import connect, get_table, create_table, load_table, drop_table, create_column, create_index
from sqlalchemy.engine import Engine
from sqlaload.write import add_row, update_row
from sqlaload.write import upsert, update, delete
from sqlaload.query import distinct, all, find_one, find, query
class DB(object):
def __init__(self, engine):
if isinstance(engine, (str, unicode)):
self.engine = connect(engine)
elif isinstance(engine, Engine):
self.engine = engine
else:
raise Exception('unknown engine format')
def create_table(self, table_name):
return Table(self.engine, create_table(self.engine, table_name))
def load_table(self, table_name):
return Table(self.engine, load_table(self.engine, table_name))
def get_table(self, table_name):
return Table(self.engine, get_table(self.engine, table_name))
def drop_table(self, table_name):
drop_table(self.engine, table_name)
def query(self, sql):
return query(self.engine, sql)
class Table(object):
def __init__(self, engine, table):
self.engine = engine
self.table = table
# sqlaload.write
def add_row(self, *args, **kwargs):
return add_row(self.engine, self.table, *args, **kwargs)
def update_row(self, *args, **kwargs):
return update_row(self.engine, self.table, *args, **kwargs)
def upsert(self, *args, **kwargs):
return upsert(self.engine, self.table, *args, **kwargs)
def update(self, *args, **kwargs):
return update(self.engine, self.table, *args, **kwargs)
def delete(self, *args, **kwargs):
return delete(self.engine, self.table, *args, **kwargs)
# sqlaload.query
def find_one(self, *args, **kwargs):
return find_one(self.engine, self.table, *args, **kwargs)
def find(self, *args, **kwargs):
return find(self.engine, self.table, *args, **kwargs)
def all(self):
return all(self.engine, self.table)
def distinct(self, *args, **kwargs):
return distinct(self.engine, self.table, *args, **kwargs)
# sqlaload.schema
def create_column(self, name, type):
return create_column(self.engine, self.table, name, type)
def create_index(self, columns, name=None):
return create_index(self.engine, self.table, columns, name=name)
def create(engine):
""" returns a DB object """
return DB(engine)

View File

@ -9,6 +9,7 @@ from sqlalchemy.schema import Table as SQLATable
from sqlalchemy import Integer from sqlalchemy import Integer
from dataset.persistence.table import Table from dataset.persistence.table import Table
from dataset.persistence.util import resultiter
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -27,7 +28,6 @@ class Database(object):
self.metadata = MetaData() self.metadata = MetaData()
self.metadata.bind = self.engine self.metadata.bind = self.engine
self.tables = {} self.tables = {}
self.indexes = {}
def create_table(table_name): def create_table(table_name):
with self.lock: with self.lock:
@ -58,9 +58,9 @@ class Database(object):
def __getitem__(self, table_name): def __getitem__(self, table_name):
return self.get_table(table_name) return self.get_table(table_name)
@classmethod def query(self, query):
def connect(self, url): for res in resultiter(self.engine.execute(query)):
return Database(url) yield res
def __repr__(self): def __repr__(self):
return '<Database(%s)>' % self.url return '<Database(%s)>' % self.url

View File

@ -1,8 +1,19 @@
import logging
from itertools import count
from sqlalchemy.sql import and_, expression
from sqlalchemy.schema import Column, Index
from dataset.persistence.util import guess_type, resultiter
log = logging.getLogger(__name__)
class Table(object): class Table(object):
def __init__(self, database, table): def __init__(self, database, table):
self.indexes = {}
self.database = database self.database = database
self.table = table self.table = table
@ -11,4 +22,120 @@ class Table(object):
self.database.tables.pop(self.table.name, None) self.database.tables.pop(self.table.name, None)
self.table.drop(engine) self.table.drop(engine)
def insert(self, row, ensure=True, types={}):
""" Add a row (type: dict). If ``ensure`` is set, any of
the keys of the row are not table columns, they will be type
guessed and created. """
if ensure:
self._ensure_columns(row, types=types)
self.database.engine.execute(self.table.insert(row))
def update(self, row, unique, ensure=True, types={}):
if not len(unique):
return False
clause = [(u, row.get(u)) for u in unique]
if ensure:
self._ensure_columns(row, types=types)
try:
filters = self._args_to_clause(dict(clause))
stmt = self.table.update(filters, row)
rp = self.database.engine.execute(stmt)
return rp.rowcount > 0
except KeyError, ke:
return False
def upsert(self, row, unique, ensure=True, types={}):
if ensure:
self.create_index(unique)
if not self.update(row, unique, ensure=ensure, types=types):
self.insert(row, ensure=ensure, types=types)
def delete(self, **kw):
q = self._args_to_clause(kw)
self.database.engine.execute(q)
def _ensure_columns(self, row, types={}):
for column in set(row.keys()) - set(self.columns.keys()):
if column in types:
_type = types[column]
else:
_type = guess_type(row[column])
log.debug("Creating column: %s (%s) on %r" % (column,
_type, self.table.name))
self.create_column(column, _type)
def _args_to_clause(self, args):
self._ensure_columns(kw)
clauses = []
for k, v in args.items():
clauses.append(self.table.c[k] == v)
return and_(*clauses)
def create_column(self, name, type):
with self.database.lock:
if name not in self.table.columns.keys():
col = Column(name, type)
col.create(self.table,
connection=self.database.engine)
def create_index(self, columns, name=None):
with self.database.lock:
if not name:
sig = abs(hash('||'.join(columns)))
name = 'ix_%s_%s' % (self.table.name, sig)
if name in self.indexes:
return self.indexes[name]
try:
columns = [self.table.c[c] for c in columns]
idx = Index(name, *columns)
idx.create(self.database.engine)
except:
idx = None
self.indexes[name] = idx
return idx
def find_one(**kw):
res = list(self.find(self.database.engine,
self.table, _limit=1, **kw))
if not len(res):
return None
return res[0]
def find(engine, _limit=None, _step=5000, _offset=0,
order_by='id', **kw):
order_by = [table.c[order_by].asc()]
args = self._args_to_clause(kw)
for i in count():
qoffset = _offset + (_step * i)
qlimit = _step
if _limit is not None:
qlimit = min(_limit-(_step*i), _step)
if qlimit <= 0:
break
q = self.table.select(whereclause=args, limit=qlimit,
offset=qoffset, order_by=order_by)
rows = list(resultiter(self.database.engine.execute(q)))
if not len(rows):
return
for row in rows:
yield row
def distinct(self, *columns, **kw):
qargs = []
try:
columns = [self.table.c[c] for c in columns]
for col, val in kw.items():
qargs.append(self.table.c[col]==val)
except KeyError:
return []
q = expression.select(columns, distinct=True,
whereclause=and_(*qargs),
order_by=[c.asc() for c in columns])
return resultiter(self.database.engine.execute(q))
def all(self):
return self.find()

View File

@ -1,71 +0,0 @@
import logging
from itertools import count
from sqlalchemy.sql import expression, and_
from dataset.schema import _ensure_columns, get_table
from dataset.persistence.util import resultiter
log = logging.getLogger(__name__)
def find_one(engine, table, **kw):
table = get_table(engine, table)
res = list(find(engine, table, _limit=1, **kw))
if not len(res):
return None
return res[0]
def find(engine, table, _limit=None, _step=5000, _offset=0,
order_by='id', **kw):
table = get_table(engine, table)
_ensure_columns(engine, table, kw)
order_by = [table.c[order_by].asc()]
qargs = []
try:
for col, val in kw.items():
qargs.append(table.c[col]==val)
except KeyError:
return
for i in count():
qoffset = _offset + (_step * i)
qlimit = _step
if _limit is not None:
qlimit = min(_limit-(_step*i), _step)
if qlimit <= 0:
break
q = table.select(whereclause=and_(*qargs), limit=qlimit,
offset=qoffset, order_by=order_by)
#print q
rows = list(resultiter(engine.execute(q)))
if not len(rows):
return
for row in rows:
yield row
def query(engine, query):
for res in resultiter(engine.execute(query)):
yield res
def distinct(engine, table, *columns, **kw):
table = get_table(engine, table)
qargs = []
try:
columns = [table.c[c] for c in columns]
for col, val in kw.items():
qargs.append(table.c[col]==val)
except KeyError:
return []
q = expression.select(columns, distinct=True,
whereclause=and_(*qargs),
order_by=[c.asc() for c in columns])
return list(resultiter(engine.execute(q)))
def all(engine, table):
return find(engine, table)

View File

@ -1,56 +0,0 @@
import logging
from datetime import datetime
from collections import defaultdict
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean
from sqlalchemy.schema import Table, MetaData, Column, Index
from sqlalchemy.sql import and_, expression
log = logging.getLogger(__name__)
def _ensure_columns(engine, table, row, types={}):
columns = set(row.keys()) - set(table.columns.keys())
for column in columns:
if column in types:
_type = types[column]
else:
_type = _guess_type(row[column])
log.debug("Creating column: %s (%s) on %r" % (column,
_type, table.name))
create_column(engine, table, column, _type)
def _args_to_clause(table, args):
clauses = []
for k, v in args.items():
clauses.append(table.c[k] == v)
return and_(*clauses)
def create_column(engine, table, name, type):
table = get_table(engine, table)
with lock:
if name not in table.columns.keys():
col = Column(name, type)
col.create(table, connection=engine)
def create_index(engine, table, columns, name=None):
table = get_table(engine, table)
with lock:
if not name:
sig = abs(hash('||'.join(columns)))
name = 'ix_%s_%s' % (table.name, sig)
if name in engine._indexes:
return engine._indexes[name]
try:
columns = [table.c[c] for c in columns]
idx = Index(name, *columns)
idx.create(engine)
except:
idx = None
engine._indexes[name] = idx
return idx

View File

@ -1,64 +0,0 @@
import logging
from dataset.schema import _ensure_columns, _args_to_clause
from dataset.schema import create_index, get_table
log = logging.getLogger(__name__)
def add_row(engine, table, row, ensure=True, types={}):
""" Add a row (type: dict). If ``ensure`` is set, any of
the keys of the row are not table columns, they will be type
guessed and created. """
table = get_table(engine, table)
if ensure:
_ensure_columns(engine, table, row, types=types)
engine.execute(table.insert(row))
def update_row(engine, table, row, unique, ensure=True, types={}):
if not len(unique):
return False
table = get_table(engine, table)
clause = dict([(u, row.get(u)) for u in unique])
if ensure:
_ensure_columns(engine, table, row, types=types)
try:
stmt = table.update(_args_to_clause(table, clause), row)
rp = engine.execute(stmt)
return rp.rowcount > 0
except KeyError, ke:
log.warn("UPDATE: '%s' filter column does not exist: %s", table.name, ke)
return False
def upsert(engine, table, row, unique, ensure=True, types={}):
table = get_table(engine, table)
if ensure:
create_index(engine, table, unique)
if not update_row(engine, table, row, unique, ensure=ensure, types=types):
add_row(engine, table, row, ensure=ensure, types=types)
def update(engine, table, criteria, values, ensure=True, types={}):
table = get_table(engine, table)
if ensure:
_ensure_columns(engine, table, values, types=types)
q = table.update().values(values)
for column, value in criteria.items():
q = q.where(table.c[column]==value)
engine.execute(q)
def delete(engine, table, **kw):
table = get_table(engine, table)
_ensure_columns(engine, table, kw)
qargs = []
try:
for col, val in kw.items():
qargs.append(table.c[col]==val)
except KeyError:
return
q = table.delete()
for k, v in kw.items():
q= q.where(table.c[k]==v)
engine.execute(q)