Split into three code files.

This commit is contained in:
Friedrich Lindenberg 2011-12-22 19:03:54 +01:00
parent f9be8771ff
commit a21c993538
4 changed files with 143 additions and 121 deletions

View File

@ -1,123 +1,8 @@
import logging
from datetime import datetime
from sqlalchemy import create_engine
from sqlalchemy import Integer, UnicodeText, Float, DateTime
from sqlalchemy.schema import Table, MetaData, Column
from sqlalchemy.sql import and_, expression
from migrate.versioning.util import construct_engine
log = logging.getLogger(__name__)
def connect(url):
""" Create an engine for the given database URL. """
engine = create_engine(url)
engine = construct_engine(engine)
meta = MetaData()
meta.bind = engine
engine._metadata = meta
return engine
def create_table(engine, table_name):
log.debug("Creating table: %s on %r" % (table_name, engine))
table = Table(table_name, engine._metadata)
col = Column('id', Integer, primary_key=True)
table.append_column(col)
table.create(engine)
return table
def load_table(engine, table_name):
return Table(table_name, engine._metadata, autoload=True)
def get_table(engine, table_name):
if engine.has_table(table_name):
return load_table(engine, table_name)
else:
return create_table(engine, table_name)
def _guess_type(sample):
if isinstance(sample, int):
return Integer
elif isinstance(sample, float):
return Float
elif isinstance(sample, datetime):
return DateTime
return UnicodeText
def _ensure_columns(engine, table, row, types={}):
columns = set(row.keys()) - set(table.columns.keys())
for column in columns:
if column in types:
_type = types[column]
else:
_type = _guess_type(row[column])
log.debug("Creating column: %s (%s) on %r" % (column,
_type, table.name))
create_column(engine, table, column, _type)
def _args_to_clause(table, args):
clauses = []
for k, v in args.items():
clauses.append(table.c[k] == v)
return and_(*clauses)
def create_column(engine, table, name, type):
if name not in table.columns.keys():
col = Column(name, type)
col.create(table, connection=engine)
def add_row(engine, table, row, ensure=True, types={}):
""" Add a row (type: dict). If ``ensure`` is set, any of
the keys of the row are not table columns, they will be type
guessed and created. """
if ensure:
_ensure_columns(engine, table, row, types=types)
engine.execute(table.insert(row))
def update_row(engine, table, row, unique, ensure=True, types={}):
if not len(unique):
return False
clause = dict([(u, row.get(u)) for u in unique])
if ensure:
_ensure_columns(engine, table, row, types=types)
try:
stmt = table.update(_args_to_clause(table, clause), row)
rp = engine.execute(stmt)
return rp.rowcount > 0
except KeyError, ke:
log.warn("UPDATE: filter column does not exist: %s" % ke)
return False
def upsert(engine, table, row, unique, ensure=True, types={}):
if not update_row(engine, table, row, unique, ensure=ensure, types=types):
add_row(engine, table, row, ensure=ensure, types=types)
def update(engine, table, criteria, values, ensure=True, types={}):
if ensure:
_ensure_columns(engine, table, values, types=types)
q = table.update().values(values)
for column, value in criteria.items():
q = q.where(table.c[column]==value)
engine.execute(q)
def resultiter(rp):
""" SQLAlchemy ResultProxies are not iterable to get a
list of dictionaries. This is to wrap them. """
keys = rp.keys()
while True:
row = rp.fetchone()
if row is None:
break
yield dict(zip(keys, row))
def distinct(engine, table, *columns):
columns = [table.c[c] for c in columns]
q = expression.select(columns, distinct=True)
return resultiter(engine.execute(q))
def all(engine, table):
q = table.select()
return resultiter(engine.execute(q))
from sqlaload.schema import connect
from sqlaload.schema import create_table, load_table, get_table
from sqlaload.schema import create_column
from sqlaload.write import add_row, update_row
from sqlaload.write import upsert, update
from sqlaload.query import distinct, resultiter, all

28
sqlaload/query.py Normal file
View File

@ -0,0 +1,28 @@
import logging
from sqlalchemy.sql import expression
log = logging.getLogger(__name__)
def resultiter(rp):
""" SQLAlchemy ResultProxies are not iterable to get a
list of dictionaries. This is to wrap them. """
keys = rp.keys()
while True:
row = rp.fetchone()
if row is None:
break
yield dict(zip(keys, row))
def distinct(engine, table, *columns):
columns = [table.c[c] for c in columns]
q = expression.select(columns, distinct=True)
return resultiter(engine.execute(q))
def all(engine, table):
q = table.select()
return resultiter(engine.execute(q))

69
sqlaload/schema.py Normal file
View File

@ -0,0 +1,69 @@
import logging
from datetime import datetime
from sqlalchemy import create_engine
from sqlalchemy import Integer, UnicodeText, Float, DateTime
from sqlalchemy.schema import Table, MetaData, Column
from sqlalchemy.sql import and_, expression
from migrate.versioning.util import construct_engine
log = logging.getLogger(__name__)
def connect(url):
""" Create an engine for the given database URL. """
engine = create_engine(url)
engine = construct_engine(engine)
meta = MetaData()
meta.bind = engine
engine._metadata = meta
return engine
def create_table(engine, table_name):
log.debug("Creating table: %s on %r" % (table_name, engine))
table = Table(table_name, engine._metadata)
col = Column('id', Integer, primary_key=True)
table.append_column(col)
table.create(engine)
return table
def load_table(engine, table_name):
log.debug("Loading table: %s on %r" % (table_name, engine))
return Table(table_name, engine._metadata, autoload=True)
def get_table(engine, table_name):
if engine.has_table(table_name):
return load_table(engine, table_name)
else:
return create_table(engine, table_name)
def _guess_type(sample):
if isinstance(sample, int):
return Integer
elif isinstance(sample, float):
return Float
elif isinstance(sample, datetime):
return DateTime
return UnicodeText
def _ensure_columns(engine, table, row, types={}):
columns = set(row.keys()) - set(table.columns.keys())
for column in columns:
if column in types:
_type = types[column]
else:
_type = _guess_type(row[column])
log.debug("Creating column: %s (%s) on %r" % (column,
_type, table.name))
create_column(engine, table, column, _type)
def _args_to_clause(table, args):
clauses = []
for k, v in args.items():
clauses.append(table.c[k] == v)
return and_(*clauses)
def create_column(engine, table, name, type):
if name not in table.columns.keys():
col = Column(name, type)
col.create(table, connection=engine)

40
sqlaload/write.py Normal file
View File

@ -0,0 +1,40 @@
import logging
from sqlaload.schema import _ensure_columns, _args_to_clause
log = logging.getLogger(__name__)
def add_row(engine, table, row, ensure=True, types={}):
""" Add a row (type: dict). If ``ensure`` is set, any of
the keys of the row are not table columns, they will be type
guessed and created. """
if ensure:
_ensure_columns(engine, table, row, types=types)
engine.execute(table.insert(row))
def update_row(engine, table, row, unique, ensure=True, types={}):
if not len(unique):
return False
clause = dict([(u, row.get(u)) for u in unique])
if ensure:
_ensure_columns(engine, table, row, types=types)
try:
stmt = table.update(_args_to_clause(table, clause), row)
rp = engine.execute(stmt)
return rp.rowcount > 0
except KeyError, ke:
log.warn("UPDATE: filter column does not exist: %s" % ke)
return False
def upsert(engine, table, row, unique, ensure=True, types={}):
if not update_row(engine, table, row, unique, ensure=ensure, types=types):
add_row(engine, table, row, ensure=ensure, types=types)
def update(engine, table, criteria, values, ensure=True, types={}):
if ensure:
_ensure_columns(engine, table, values, types=types)
q = table.update().values(values)
for column, value in criteria.items():
q = q.where(table.c[column]==value)
engine.execute(q)