diff --git a/dataset/__init__.py b/dataset/__init__.py index 900144f..295f445 100644 --- a/dataset/__init__.py +++ b/dataset/__init__.py @@ -1,12 +1,10 @@ -from dataset.schema import connect -from dataset.schema import create_table, load_table, get_table, drop_table -from dataset.schema import create_column -from dataset.write import add_row, update_row -from dataset.write import upsert, update, delete -from dataset.query import distinct, resultiter, all, find_one, find, query -from dataset.db import create - # shut up useless SA warning: import warnings warnings.filterwarnings('ignore', 'Unicode type received non-unicode bind param value.') +from dataset.persistence.database import Database +from dataset.persistence.table import Table + +def connect(url): + return Database(url) + diff --git a/dataset/db.py b/dataset/db.py deleted file mode 100644 index 5cdcf17..0000000 --- a/dataset/db.py +++ /dev/null @@ -1,82 +0,0 @@ - - -from sqlaload.schema import connect, get_table, create_table, load_table, drop_table, create_column, create_index -from sqlalchemy.engine import Engine -from sqlaload.write import add_row, update_row -from sqlaload.write import upsert, update, delete -from sqlaload.query import distinct, all, find_one, find, query - - -class DB(object): - def __init__(self, engine): - if isinstance(engine, (str, unicode)): - self.engine = connect(engine) - elif isinstance(engine, Engine): - self.engine = engine - else: - raise Exception('unknown engine format') - - def create_table(self, table_name): - return Table(self.engine, create_table(self.engine, table_name)) - - def load_table(self, table_name): - return Table(self.engine, load_table(self.engine, table_name)) - - def get_table(self, table_name): - return Table(self.engine, get_table(self.engine, table_name)) - - def drop_table(self, table_name): - drop_table(self.engine, table_name) - - def query(self, sql): - return query(self.engine, sql) - - -class Table(object): - def __init__(self, engine, table): - self.engine = engine - self.table = table - - # sqlaload.write - - def add_row(self, *args, **kwargs): - return add_row(self.engine, self.table, *args, **kwargs) - - def update_row(self, *args, **kwargs): - return update_row(self.engine, self.table, *args, **kwargs) - - def upsert(self, *args, **kwargs): - return upsert(self.engine, self.table, *args, **kwargs) - - def update(self, *args, **kwargs): - return update(self.engine, self.table, *args, **kwargs) - - def delete(self, *args, **kwargs): - return delete(self.engine, self.table, *args, **kwargs) - - # sqlaload.query - - def find_one(self, *args, **kwargs): - return find_one(self.engine, self.table, *args, **kwargs) - - def find(self, *args, **kwargs): - return find(self.engine, self.table, *args, **kwargs) - - def all(self): - return all(self.engine, self.table) - - def distinct(self, *args, **kwargs): - return distinct(self.engine, self.table, *args, **kwargs) - - # sqlaload.schema - - def create_column(self, name, type): - return create_column(self.engine, self.table, name, type) - - def create_index(self, columns, name=None): - return create_index(self.engine, self.table, columns, name=name) - - -def create(engine): - """ returns a DB object """ - return DB(engine) diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 8f34f3b..9a1588f 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -9,6 +9,7 @@ from sqlalchemy.schema import Table as SQLATable from sqlalchemy import Integer from dataset.persistence.table import Table +from dataset.persistence.util import resultiter log = logging.getLogger(__name__) @@ -27,7 +28,6 @@ class Database(object): self.metadata = MetaData() self.metadata.bind = self.engine self.tables = {} - self.indexes = {} def create_table(table_name): with self.lock: @@ -58,9 +58,9 @@ class Database(object): def __getitem__(self, table_name): return self.get_table(table_name) - @classmethod - def connect(self, url): - return Database(url) + def query(self, query): + for res in resultiter(self.engine.execute(query)): + yield res def __repr__(self): return '' % self.url diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index ded753e..fb4c473 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -1,8 +1,19 @@ +import logging +from itertools import count + +from sqlalchemy.sql import and_, expression +from sqlalchemy.schema import Column, Index + +from dataset.persistence.util import guess_type, resultiter + + +log = logging.getLogger(__name__) class Table(object): def __init__(self, database, table): + self.indexes = {} self.database = database self.table = table @@ -11,4 +22,120 @@ class Table(object): self.database.tables.pop(self.table.name, None) self.table.drop(engine) + def insert(self, row, ensure=True, types={}): + """ Add a row (type: dict). If ``ensure`` is set, any of + the keys of the row are not table columns, they will be type + guessed and created. """ + if ensure: + self._ensure_columns(row, types=types) + self.database.engine.execute(self.table.insert(row)) + + def update(self, row, unique, ensure=True, types={}): + if not len(unique): + return False + clause = [(u, row.get(u)) for u in unique] + if ensure: + self._ensure_columns(row, types=types) + try: + filters = self._args_to_clause(dict(clause)) + stmt = self.table.update(filters, row) + rp = self.database.engine.execute(stmt) + return rp.rowcount > 0 + except KeyError, ke: + return False + + def upsert(self, row, unique, ensure=True, types={}): + if ensure: + self.create_index(unique) + + if not self.update(row, unique, ensure=ensure, types=types): + self.insert(row, ensure=ensure, types=types) + + def delete(self, **kw): + q = self._args_to_clause(kw) + self.database.engine.execute(q) + + def _ensure_columns(self, row, types={}): + for column in set(row.keys()) - set(self.columns.keys()): + if column in types: + _type = types[column] + else: + _type = guess_type(row[column]) + log.debug("Creating column: %s (%s) on %r" % (column, + _type, self.table.name)) + self.create_column(column, _type) + + def _args_to_clause(self, args): + self._ensure_columns(kw) + clauses = [] + for k, v in args.items(): + clauses.append(self.table.c[k] == v) + return and_(*clauses) + + def create_column(self, name, type): + with self.database.lock: + if name not in self.table.columns.keys(): + col = Column(name, type) + col.create(self.table, + connection=self.database.engine) + + def create_index(self, columns, name=None): + with self.database.lock: + if not name: + sig = abs(hash('||'.join(columns))) + name = 'ix_%s_%s' % (self.table.name, sig) + if name in self.indexes: + return self.indexes[name] + try: + columns = [self.table.c[c] for c in columns] + idx = Index(name, *columns) + idx.create(self.database.engine) + except: + idx = None + self.indexes[name] = idx + return idx + + def find_one(**kw): + res = list(self.find(self.database.engine, + self.table, _limit=1, **kw)) + if not len(res): + return None + return res[0] + + def find(engine, _limit=None, _step=5000, _offset=0, + order_by='id', **kw): + order_by = [table.c[order_by].asc()] + args = self._args_to_clause(kw) + + for i in count(): + qoffset = _offset + (_step * i) + qlimit = _step + if _limit is not None: + qlimit = min(_limit-(_step*i), _step) + if qlimit <= 0: + break + q = self.table.select(whereclause=args, limit=qlimit, + offset=qoffset, order_by=order_by) + rows = list(resultiter(self.database.engine.execute(q))) + if not len(rows): + return + for row in rows: + yield row + + def distinct(self, *columns, **kw): + qargs = [] + try: + columns = [self.table.c[c] for c in columns] + for col, val in kw.items(): + qargs.append(self.table.c[col]==val) + except KeyError: + return [] + + q = expression.select(columns, distinct=True, + whereclause=and_(*qargs), + order_by=[c.asc() for c in columns]) + return resultiter(self.database.engine.execute(q)) + + def all(self): + return self.find() diff --git a/dataset/query.py b/dataset/query.py deleted file mode 100644 index c782675..0000000 --- a/dataset/query.py +++ /dev/null @@ -1,71 +0,0 @@ -import logging -from itertools import count - -from sqlalchemy.sql import expression, and_ -from dataset.schema import _ensure_columns, get_table -from dataset.persistence.util import resultiter - -log = logging.getLogger(__name__) - - -def find_one(engine, table, **kw): - table = get_table(engine, table) - res = list(find(engine, table, _limit=1, **kw)) - if not len(res): - return None - return res[0] - -def find(engine, table, _limit=None, _step=5000, _offset=0, - order_by='id', **kw): - table = get_table(engine, table) - _ensure_columns(engine, table, kw) - order_by = [table.c[order_by].asc()] - - qargs = [] - try: - for col, val in kw.items(): - qargs.append(table.c[col]==val) - except KeyError: - return - - for i in count(): - qoffset = _offset + (_step * i) - qlimit = _step - if _limit is not None: - qlimit = min(_limit-(_step*i), _step) - if qlimit <= 0: - break - q = table.select(whereclause=and_(*qargs), limit=qlimit, - offset=qoffset, order_by=order_by) - #print q - rows = list(resultiter(engine.execute(q))) - if not len(rows): - return - for row in rows: - yield row - -def query(engine, query): - for res in resultiter(engine.execute(query)): - yield res - -def distinct(engine, table, *columns, **kw): - table = get_table(engine, table) - qargs = [] - try: - columns = [table.c[c] for c in columns] - for col, val in kw.items(): - qargs.append(table.c[col]==val) - except KeyError: - return [] - - q = expression.select(columns, distinct=True, - whereclause=and_(*qargs), - order_by=[c.asc() for c in columns]) - return list(resultiter(engine.execute(q))) - -def all(engine, table): - return find(engine, table) - - - - diff --git a/dataset/schema.py b/dataset/schema.py deleted file mode 100644 index 0ade8ed..0000000 --- a/dataset/schema.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -from datetime import datetime -from collections import defaultdict - -from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean -from sqlalchemy.schema import Table, MetaData, Column, Index -from sqlalchemy.sql import and_, expression - -log = logging.getLogger(__name__) - - - - - - -def _ensure_columns(engine, table, row, types={}): - columns = set(row.keys()) - set(table.columns.keys()) - for column in columns: - if column in types: - _type = types[column] - else: - _type = _guess_type(row[column]) - log.debug("Creating column: %s (%s) on %r" % (column, - _type, table.name)) - create_column(engine, table, column, _type) - -def _args_to_clause(table, args): - clauses = [] - for k, v in args.items(): - clauses.append(table.c[k] == v) - return and_(*clauses) - -def create_column(engine, table, name, type): - table = get_table(engine, table) - with lock: - if name not in table.columns.keys(): - col = Column(name, type) - col.create(table, connection=engine) - -def create_index(engine, table, columns, name=None): - table = get_table(engine, table) - with lock: - if not name: - sig = abs(hash('||'.join(columns))) - name = 'ix_%s_%s' % (table.name, sig) - if name in engine._indexes: - return engine._indexes[name] - try: - columns = [table.c[c] for c in columns] - idx = Index(name, *columns) - idx.create(engine) - except: - idx = None - engine._indexes[name] = idx - return idx - diff --git a/dataset/write.py b/dataset/write.py deleted file mode 100644 index 8d007e4..0000000 --- a/dataset/write.py +++ /dev/null @@ -1,64 +0,0 @@ -import logging - -from dataset.schema import _ensure_columns, _args_to_clause -from dataset.schema import create_index, get_table - -log = logging.getLogger(__name__) - -def add_row(engine, table, row, ensure=True, types={}): - """ Add a row (type: dict). If ``ensure`` is set, any of - the keys of the row are not table columns, they will be type - guessed and created. """ - table = get_table(engine, table) - if ensure: - _ensure_columns(engine, table, row, types=types) - engine.execute(table.insert(row)) - -def update_row(engine, table, row, unique, ensure=True, types={}): - if not len(unique): - return False - table = get_table(engine, table) - clause = dict([(u, row.get(u)) for u in unique]) - if ensure: - _ensure_columns(engine, table, row, types=types) - try: - stmt = table.update(_args_to_clause(table, clause), row) - rp = engine.execute(stmt) - return rp.rowcount > 0 - except KeyError, ke: - log.warn("UPDATE: '%s' filter column does not exist: %s", table.name, ke) - return False - -def upsert(engine, table, row, unique, ensure=True, types={}): - table = get_table(engine, table) - if ensure: - create_index(engine, table, unique) - - if not update_row(engine, table, row, unique, ensure=ensure, types=types): - add_row(engine, table, row, ensure=ensure, types=types) - -def update(engine, table, criteria, values, ensure=True, types={}): - table = get_table(engine, table) - if ensure: - _ensure_columns(engine, table, values, types=types) - q = table.update().values(values) - for column, value in criteria.items(): - q = q.where(table.c[column]==value) - engine.execute(q) - -def delete(engine, table, **kw): - table = get_table(engine, table) - _ensure_columns(engine, table, kw) - - qargs = [] - try: - for col, val in kw.items(): - qargs.append(table.c[col]==val) - except KeyError: - return - - q = table.delete() - for k, v in kw.items(): - q= q.where(table.c[k]==v) - engine.execute(q) -