added object-oriented API
This commit is contained in:
parent
ec7741223b
commit
b7bef6ddc4
@ -4,7 +4,9 @@ from sqlaload.schema import create_column
|
||||
from sqlaload.write import add_row, update_row
|
||||
from sqlaload.write import upsert, update, delete
|
||||
from sqlaload.query import distinct, resultiter, all, find_one, find, query
|
||||
from sqlaload.db import create
|
||||
|
||||
# shut up useless SA warning:
|
||||
import warnings
|
||||
warnings.filterwarnings('ignore', 'Unicode type received non-unicode bind param value.')
|
||||
|
||||
|
||||
82
sqlaload/db.py
Normal file
82
sqlaload/db.py
Normal file
@ -0,0 +1,82 @@
|
||||
|
||||
|
||||
from sqlaload.schema import connect, get_table, create_table, load_table, drop_table, create_column, create_index
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlaload.write import add_row, update_row
|
||||
from sqlaload.write import upsert, update, delete
|
||||
from sqlaload.query import distinct, all, find_one, find, query
|
||||
|
||||
|
||||
class DB(object):
|
||||
def __init__(self, engine):
|
||||
if isinstance(engine, (str, unicode)):
|
||||
self.engine = connect(engine)
|
||||
elif isinstance(engine, Engine):
|
||||
self.engine = engine
|
||||
else:
|
||||
raise Exception('unknown engine format')
|
||||
|
||||
def create_table(self, table_name):
|
||||
return Table(self.engine, create_table(self.engine, table_name))
|
||||
|
||||
def load_table(self, table_name):
|
||||
return Table(self.engine, load_table(self.engine, table_name))
|
||||
|
||||
def get_table(self, table_name):
|
||||
return Table(self.engine, get_table(self.engine, table_name))
|
||||
|
||||
def drop_table(self, table_name):
|
||||
drop_table(self.engine, table_name)
|
||||
|
||||
def query(self, sql):
|
||||
return query(self.engine, sql)
|
||||
|
||||
|
||||
class Table(object):
|
||||
def __init__(self, engine, table):
|
||||
self.engine = engine
|
||||
self.table = table
|
||||
|
||||
# sqlaload.write
|
||||
|
||||
def add_row(self, *args, **kwargs):
|
||||
return add_row(self.engine, self.table, *args, **kwargs)
|
||||
|
||||
def update_row(self, *args, **kwargs):
|
||||
return update_row(self.engine, self.table, *args, **kwargs)
|
||||
|
||||
def upsert(self, *args, **kwargs):
|
||||
return upsert(self.engine, self.table, *args, **kwargs)
|
||||
|
||||
def update(self, *args, **kwargs):
|
||||
return update(self.engine, self.table, *args, **kwargs)
|
||||
|
||||
def delete(self, *args, **kwargs):
|
||||
return delete(self.engine, self.table, *args, **kwargs)
|
||||
|
||||
# sqlaload.query
|
||||
|
||||
def find_one(self, *args, **kwargs):
|
||||
return find_one(self.engine, self.table, *args, **kwargs)
|
||||
|
||||
def find(self, *args, **kwargs):
|
||||
return find(self.engine, self.table, *args, **kwargs)
|
||||
|
||||
def all(self):
|
||||
return all(self.engine, self.table)
|
||||
|
||||
def distinct(self, *args, **kwargs):
|
||||
return distinct(self.engine, self.table, *args, **kwargs)
|
||||
|
||||
# sqlaload.schema
|
||||
|
||||
def create_column(self, name, type):
|
||||
return create_column(self.engine, self.table, name, type)
|
||||
|
||||
def create_index(self, columns, name=None):
|
||||
return create_index(self.engine, self.table, columns, name=name)
|
||||
|
||||
|
||||
def create(engine):
|
||||
""" returns a DB object """
|
||||
return DB(engine)
|
||||
32
test.py
32
test.py
@ -34,7 +34,8 @@ TEST_DATA = [
|
||||
'temperature': 5,
|
||||
'place': 'Berkeley'
|
||||
}
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
class TableTestCase(unittest.TestCase):
|
||||
|
||||
@ -44,7 +45,7 @@ class TableTestCase(unittest.TestCase):
|
||||
def test_create_table(self):
|
||||
table = create_table(self.engine, 'foo')
|
||||
assert table.exists()
|
||||
assert len(table.columns)==1, table.columns
|
||||
assert len(table.columns) == 1, table.columns
|
||||
|
||||
|
||||
class QueryTestCase(unittest.TestCase):
|
||||
@ -54,22 +55,39 @@ class QueryTestCase(unittest.TestCase):
|
||||
self.table = create_table(self.engine, 'weather')
|
||||
for entry in TEST_DATA:
|
||||
upsert(self.engine, self.table, entry, ['place', 'date'],
|
||||
ensure=True)
|
||||
ensure=True)
|
||||
|
||||
def test_all(self):
|
||||
x = list(all(self.engine, self.table))
|
||||
assert len(x)==len(TEST_DATA), len(x)
|
||||
assert len(x) == len(TEST_DATA), len(x)
|
||||
|
||||
def test_distinct(self):
|
||||
x = list(distinct(self.engine, self.table, 'place'))
|
||||
assert len(x)==2, x
|
||||
assert len(x) == 2, x
|
||||
p = [i['place'] for i in x]
|
||||
assert 'Berkeley' in p, p
|
||||
assert 'Galway' in p, p
|
||||
|
||||
|
||||
class ObjectAPITestCase(unittest.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
db = create('sqlite:///:memory:')
|
||||
self.table = db.get_table('weather2')
|
||||
for entry in TEST_DATA:
|
||||
self.table.upsert(entry, ['place', 'date'], ensure=True)
|
||||
|
||||
def test_all(self):
|
||||
x = list(all(self.engine, self.table))
|
||||
assert len(x)==len(TEST_DATA), len(x)
|
||||
x = list(self.table.all())
|
||||
assert len(x) == len(TEST_DATA), len(x)
|
||||
|
||||
def test_distinct(self):
|
||||
x = list(self.table.distinct('place'))
|
||||
assert len(x) == 2, x
|
||||
p = [i['place'] for i in x]
|
||||
assert 'Berkeley' in p, p
|
||||
assert 'Galway' in p, p
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user