From b7bef6ddc40c17ffffc382fd3a3d3785e42e4019 Mon Sep 17 00:00:00 2001 From: Gregor Aisch Date: Mon, 1 Apr 2013 14:04:23 +0200 Subject: [PATCH] added object-oriented API --- sqlaload/__init__.py | 2 ++ sqlaload/db.py | 82 ++++++++++++++++++++++++++++++++++++++++++++ test.py | 38 ++++++++++++++------ 3 files changed, 112 insertions(+), 10 deletions(-) create mode 100644 sqlaload/db.py diff --git a/sqlaload/__init__.py b/sqlaload/__init__.py index 48f9b06..cdac3ed 100644 --- a/sqlaload/__init__.py +++ b/sqlaload/__init__.py @@ -4,7 +4,9 @@ from sqlaload.schema import create_column from sqlaload.write import add_row, update_row from sqlaload.write import upsert, update, delete from sqlaload.query import distinct, resultiter, all, find_one, find, query +from sqlaload.db import create # shut up useless SA warning: import warnings warnings.filterwarnings('ignore', 'Unicode type received non-unicode bind param value.') + diff --git a/sqlaload/db.py b/sqlaload/db.py new file mode 100644 index 0000000..5cdcf17 --- /dev/null +++ b/sqlaload/db.py @@ -0,0 +1,82 @@ + + +from sqlaload.schema import connect, get_table, create_table, load_table, drop_table, create_column, create_index +from sqlalchemy.engine import Engine +from sqlaload.write import add_row, update_row +from sqlaload.write import upsert, update, delete +from sqlaload.query import distinct, all, find_one, find, query + + +class DB(object): + def __init__(self, engine): + if isinstance(engine, (str, unicode)): + self.engine = connect(engine) + elif isinstance(engine, Engine): + self.engine = engine + else: + raise Exception('unknown engine format') + + def create_table(self, table_name): + return Table(self.engine, create_table(self.engine, table_name)) + + def load_table(self, table_name): + return Table(self.engine, load_table(self.engine, table_name)) + + def get_table(self, table_name): + return Table(self.engine, get_table(self.engine, table_name)) + + def drop_table(self, table_name): + drop_table(self.engine, table_name) + + def query(self, sql): + return query(self.engine, sql) + + +class Table(object): + def __init__(self, engine, table): + self.engine = engine + self.table = table + + # sqlaload.write + + def add_row(self, *args, **kwargs): + return add_row(self.engine, self.table, *args, **kwargs) + + def update_row(self, *args, **kwargs): + return update_row(self.engine, self.table, *args, **kwargs) + + def upsert(self, *args, **kwargs): + return upsert(self.engine, self.table, *args, **kwargs) + + def update(self, *args, **kwargs): + return update(self.engine, self.table, *args, **kwargs) + + def delete(self, *args, **kwargs): + return delete(self.engine, self.table, *args, **kwargs) + + # sqlaload.query + + def find_one(self, *args, **kwargs): + return find_one(self.engine, self.table, *args, **kwargs) + + def find(self, *args, **kwargs): + return find(self.engine, self.table, *args, **kwargs) + + def all(self): + return all(self.engine, self.table) + + def distinct(self, *args, **kwargs): + return distinct(self.engine, self.table, *args, **kwargs) + + # sqlaload.schema + + def create_column(self, name, type): + return create_column(self.engine, self.table, name, type) + + def create_index(self, columns, name=None): + return create_index(self.engine, self.table, columns, name=name) + + +def create(engine): + """ returns a DB object """ + return DB(engine) diff --git a/test.py b/test.py index 8421052..907c1c4 100644 --- a/test.py +++ b/test.py @@ -34,17 +34,18 @@ TEST_DATA = [ 'temperature': 5, 'place': 'Berkeley' } - ] +] + class TableTestCase(unittest.TestCase): - + def setUp(self): self.engine = connect('sqlite:///:memory:') def test_create_table(self): table = create_table(self.engine, 'foo') assert table.exists() - assert len(table.columns)==1, table.columns + assert len(table.columns) == 1, table.columns class QueryTestCase(unittest.TestCase): @@ -54,22 +55,39 @@ class QueryTestCase(unittest.TestCase): self.table = create_table(self.engine, 'weather') for entry in TEST_DATA: upsert(self.engine, self.table, entry, ['place', 'date'], - ensure=True) - + ensure=True) + def test_all(self): x = list(all(self.engine, self.table)) - assert len(x)==len(TEST_DATA), len(x) - + assert len(x) == len(TEST_DATA), len(x) + def test_distinct(self): x = list(distinct(self.engine, self.table, 'place')) - assert len(x)==2, x + assert len(x) == 2, x p = [i['place'] for i in x] assert 'Berkeley' in p, p assert 'Galway' in p, p + +class ObjectAPITestCase(unittest.TestCase): + + def setUp(self): + db = create('sqlite:///:memory:') + self.table = db.get_table('weather2') + for entry in TEST_DATA: + self.table.upsert(entry, ['place', 'date'], ensure=True) + def test_all(self): - x = list(all(self.engine, self.table)) - assert len(x)==len(TEST_DATA), len(x) + x = list(self.table.all()) + assert len(x) == len(TEST_DATA), len(x) + + def test_distinct(self): + x = list(self.table.distinct('place')) + assert len(x) == 2, x + p = [i['place'] for i in x] + assert 'Berkeley' in p, p + assert 'Galway' in p, p + if __name__ == '__main__': unittest.main()