added object-oriented API

This commit is contained in:
Gregor Aisch 2013-04-01 14:04:23 +02:00
parent ec7741223b
commit b7bef6ddc4
3 changed files with 112 additions and 10 deletions

View File

@ -4,7 +4,9 @@ from sqlaload.schema import create_column
from sqlaload.write import add_row, update_row from sqlaload.write import add_row, update_row
from sqlaload.write import upsert, update, delete from sqlaload.write import upsert, update, delete
from sqlaload.query import distinct, resultiter, all, find_one, find, query from sqlaload.query import distinct, resultiter, all, find_one, find, query
from sqlaload.db import create
# shut up useless SA warning: # shut up useless SA warning:
import warnings import warnings
warnings.filterwarnings('ignore', 'Unicode type received non-unicode bind param value.') warnings.filterwarnings('ignore', 'Unicode type received non-unicode bind param value.')

82
sqlaload/db.py Normal file
View File

@ -0,0 +1,82 @@
from sqlaload.schema import connect, get_table, create_table, load_table, drop_table, create_column, create_index
from sqlalchemy.engine import Engine
from sqlaload.write import add_row, update_row
from sqlaload.write import upsert, update, delete
from sqlaload.query import distinct, all, find_one, find, query
class DB(object):
def __init__(self, engine):
if isinstance(engine, (str, unicode)):
self.engine = connect(engine)
elif isinstance(engine, Engine):
self.engine = engine
else:
raise Exception('unknown engine format')
def create_table(self, table_name):
return Table(self.engine, create_table(self.engine, table_name))
def load_table(self, table_name):
return Table(self.engine, load_table(self.engine, table_name))
def get_table(self, table_name):
return Table(self.engine, get_table(self.engine, table_name))
def drop_table(self, table_name):
drop_table(self.engine, table_name)
def query(self, sql):
return query(self.engine, sql)
class Table(object):
def __init__(self, engine, table):
self.engine = engine
self.table = table
# sqlaload.write
def add_row(self, *args, **kwargs):
return add_row(self.engine, self.table, *args, **kwargs)
def update_row(self, *args, **kwargs):
return update_row(self.engine, self.table, *args, **kwargs)
def upsert(self, *args, **kwargs):
return upsert(self.engine, self.table, *args, **kwargs)
def update(self, *args, **kwargs):
return update(self.engine, self.table, *args, **kwargs)
def delete(self, *args, **kwargs):
return delete(self.engine, self.table, *args, **kwargs)
# sqlaload.query
def find_one(self, *args, **kwargs):
return find_one(self.engine, self.table, *args, **kwargs)
def find(self, *args, **kwargs):
return find(self.engine, self.table, *args, **kwargs)
def all(self):
return all(self.engine, self.table)
def distinct(self, *args, **kwargs):
return distinct(self.engine, self.table, *args, **kwargs)
# sqlaload.schema
def create_column(self, name, type):
return create_column(self.engine, self.table, name, type)
def create_index(self, columns, name=None):
return create_index(self.engine, self.table, columns, name=name)
def create(engine):
""" returns a DB object """
return DB(engine)

32
test.py
View File

@ -34,7 +34,8 @@ TEST_DATA = [
'temperature': 5, 'temperature': 5,
'place': 'Berkeley' 'place': 'Berkeley'
} }
] ]
class TableTestCase(unittest.TestCase): class TableTestCase(unittest.TestCase):
@ -44,7 +45,7 @@ class TableTestCase(unittest.TestCase):
def test_create_table(self): def test_create_table(self):
table = create_table(self.engine, 'foo') table = create_table(self.engine, 'foo')
assert table.exists() assert table.exists()
assert len(table.columns)==1, table.columns assert len(table.columns) == 1, table.columns
class QueryTestCase(unittest.TestCase): class QueryTestCase(unittest.TestCase):
@ -54,22 +55,39 @@ class QueryTestCase(unittest.TestCase):
self.table = create_table(self.engine, 'weather') self.table = create_table(self.engine, 'weather')
for entry in TEST_DATA: for entry in TEST_DATA:
upsert(self.engine, self.table, entry, ['place', 'date'], upsert(self.engine, self.table, entry, ['place', 'date'],
ensure=True) ensure=True)
def test_all(self): def test_all(self):
x = list(all(self.engine, self.table)) x = list(all(self.engine, self.table))
assert len(x)==len(TEST_DATA), len(x) assert len(x) == len(TEST_DATA), len(x)
def test_distinct(self): def test_distinct(self):
x = list(distinct(self.engine, self.table, 'place')) x = list(distinct(self.engine, self.table, 'place'))
assert len(x)==2, x assert len(x) == 2, x
p = [i['place'] for i in x] p = [i['place'] for i in x]
assert 'Berkeley' in p, p assert 'Berkeley' in p, p
assert 'Galway' in p, p assert 'Galway' in p, p
class ObjectAPITestCase(unittest.TestCase):
def setUp(self):
db = create('sqlite:///:memory:')
self.table = db.get_table('weather2')
for entry in TEST_DATA:
self.table.upsert(entry, ['place', 'date'], ensure=True)
def test_all(self): def test_all(self):
x = list(all(self.engine, self.table)) x = list(self.table.all())
assert len(x)==len(TEST_DATA), len(x) assert len(x) == len(TEST_DATA), len(x)
def test_distinct(self):
x = list(self.table.distinct('place'))
assert len(x) == 2, x
p = [i['place'] for i in x]
assert 'Berkeley' in p, p
assert 'Galway' in p, p
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()