added object-oriented API
This commit is contained in:
parent
ec7741223b
commit
b7bef6ddc4
@ -4,7 +4,9 @@ from sqlaload.schema import create_column
|
|||||||
from sqlaload.write import add_row, update_row
|
from sqlaload.write import add_row, update_row
|
||||||
from sqlaload.write import upsert, update, delete
|
from sqlaload.write import upsert, update, delete
|
||||||
from sqlaload.query import distinct, resultiter, all, find_one, find, query
|
from sqlaload.query import distinct, resultiter, all, find_one, find, query
|
||||||
|
from sqlaload.db import create
|
||||||
|
|
||||||
# shut up useless SA warning:
|
# shut up useless SA warning:
|
||||||
import warnings
|
import warnings
|
||||||
warnings.filterwarnings('ignore', 'Unicode type received non-unicode bind param value.')
|
warnings.filterwarnings('ignore', 'Unicode type received non-unicode bind param value.')
|
||||||
|
|
||||||
|
|||||||
82
sqlaload/db.py
Normal file
82
sqlaload/db.py
Normal file
@ -0,0 +1,82 @@
|
|||||||
|
|
||||||
|
|
||||||
|
from sqlaload.schema import connect, get_table, create_table, load_table, drop_table, create_column, create_index
|
||||||
|
from sqlalchemy.engine import Engine
|
||||||
|
from sqlaload.write import add_row, update_row
|
||||||
|
from sqlaload.write import upsert, update, delete
|
||||||
|
from sqlaload.query import distinct, all, find_one, find, query
|
||||||
|
|
||||||
|
|
||||||
|
class DB(object):
|
||||||
|
def __init__(self, engine):
|
||||||
|
if isinstance(engine, (str, unicode)):
|
||||||
|
self.engine = connect(engine)
|
||||||
|
elif isinstance(engine, Engine):
|
||||||
|
self.engine = engine
|
||||||
|
else:
|
||||||
|
raise Exception('unknown engine format')
|
||||||
|
|
||||||
|
def create_table(self, table_name):
|
||||||
|
return Table(self.engine, create_table(self.engine, table_name))
|
||||||
|
|
||||||
|
def load_table(self, table_name):
|
||||||
|
return Table(self.engine, load_table(self.engine, table_name))
|
||||||
|
|
||||||
|
def get_table(self, table_name):
|
||||||
|
return Table(self.engine, get_table(self.engine, table_name))
|
||||||
|
|
||||||
|
def drop_table(self, table_name):
|
||||||
|
drop_table(self.engine, table_name)
|
||||||
|
|
||||||
|
def query(self, sql):
|
||||||
|
return query(self.engine, sql)
|
||||||
|
|
||||||
|
|
||||||
|
class Table(object):
|
||||||
|
def __init__(self, engine, table):
|
||||||
|
self.engine = engine
|
||||||
|
self.table = table
|
||||||
|
|
||||||
|
# sqlaload.write
|
||||||
|
|
||||||
|
def add_row(self, *args, **kwargs):
|
||||||
|
return add_row(self.engine, self.table, *args, **kwargs)
|
||||||
|
|
||||||
|
def update_row(self, *args, **kwargs):
|
||||||
|
return update_row(self.engine, self.table, *args, **kwargs)
|
||||||
|
|
||||||
|
def upsert(self, *args, **kwargs):
|
||||||
|
return upsert(self.engine, self.table, *args, **kwargs)
|
||||||
|
|
||||||
|
def update(self, *args, **kwargs):
|
||||||
|
return update(self.engine, self.table, *args, **kwargs)
|
||||||
|
|
||||||
|
def delete(self, *args, **kwargs):
|
||||||
|
return delete(self.engine, self.table, *args, **kwargs)
|
||||||
|
|
||||||
|
# sqlaload.query
|
||||||
|
|
||||||
|
def find_one(self, *args, **kwargs):
|
||||||
|
return find_one(self.engine, self.table, *args, **kwargs)
|
||||||
|
|
||||||
|
def find(self, *args, **kwargs):
|
||||||
|
return find(self.engine, self.table, *args, **kwargs)
|
||||||
|
|
||||||
|
def all(self):
|
||||||
|
return all(self.engine, self.table)
|
||||||
|
|
||||||
|
def distinct(self, *args, **kwargs):
|
||||||
|
return distinct(self.engine, self.table, *args, **kwargs)
|
||||||
|
|
||||||
|
# sqlaload.schema
|
||||||
|
|
||||||
|
def create_column(self, name, type):
|
||||||
|
return create_column(self.engine, self.table, name, type)
|
||||||
|
|
||||||
|
def create_index(self, columns, name=None):
|
||||||
|
return create_index(self.engine, self.table, columns, name=name)
|
||||||
|
|
||||||
|
|
||||||
|
def create(engine):
|
||||||
|
""" returns a DB object """
|
||||||
|
return DB(engine)
|
||||||
32
test.py
32
test.py
@ -34,7 +34,8 @@ TEST_DATA = [
|
|||||||
'temperature': 5,
|
'temperature': 5,
|
||||||
'place': 'Berkeley'
|
'place': 'Berkeley'
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
class TableTestCase(unittest.TestCase):
|
class TableTestCase(unittest.TestCase):
|
||||||
|
|
||||||
@ -44,7 +45,7 @@ class TableTestCase(unittest.TestCase):
|
|||||||
def test_create_table(self):
|
def test_create_table(self):
|
||||||
table = create_table(self.engine, 'foo')
|
table = create_table(self.engine, 'foo')
|
||||||
assert table.exists()
|
assert table.exists()
|
||||||
assert len(table.columns)==1, table.columns
|
assert len(table.columns) == 1, table.columns
|
||||||
|
|
||||||
|
|
||||||
class QueryTestCase(unittest.TestCase):
|
class QueryTestCase(unittest.TestCase):
|
||||||
@ -54,22 +55,39 @@ class QueryTestCase(unittest.TestCase):
|
|||||||
self.table = create_table(self.engine, 'weather')
|
self.table = create_table(self.engine, 'weather')
|
||||||
for entry in TEST_DATA:
|
for entry in TEST_DATA:
|
||||||
upsert(self.engine, self.table, entry, ['place', 'date'],
|
upsert(self.engine, self.table, entry, ['place', 'date'],
|
||||||
ensure=True)
|
ensure=True)
|
||||||
|
|
||||||
def test_all(self):
|
def test_all(self):
|
||||||
x = list(all(self.engine, self.table))
|
x = list(all(self.engine, self.table))
|
||||||
assert len(x)==len(TEST_DATA), len(x)
|
assert len(x) == len(TEST_DATA), len(x)
|
||||||
|
|
||||||
def test_distinct(self):
|
def test_distinct(self):
|
||||||
x = list(distinct(self.engine, self.table, 'place'))
|
x = list(distinct(self.engine, self.table, 'place'))
|
||||||
assert len(x)==2, x
|
assert len(x) == 2, x
|
||||||
p = [i['place'] for i in x]
|
p = [i['place'] for i in x]
|
||||||
assert 'Berkeley' in p, p
|
assert 'Berkeley' in p, p
|
||||||
assert 'Galway' in p, p
|
assert 'Galway' in p, p
|
||||||
|
|
||||||
|
|
||||||
|
class ObjectAPITestCase(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
db = create('sqlite:///:memory:')
|
||||||
|
self.table = db.get_table('weather2')
|
||||||
|
for entry in TEST_DATA:
|
||||||
|
self.table.upsert(entry, ['place', 'date'], ensure=True)
|
||||||
|
|
||||||
def test_all(self):
|
def test_all(self):
|
||||||
x = list(all(self.engine, self.table))
|
x = list(self.table.all())
|
||||||
assert len(x)==len(TEST_DATA), len(x)
|
assert len(x) == len(TEST_DATA), len(x)
|
||||||
|
|
||||||
|
def test_distinct(self):
|
||||||
|
x = list(self.table.distinct('place'))
|
||||||
|
assert len(x) == 2, x
|
||||||
|
p = [i['place'] for i in x]
|
||||||
|
assert 'Berkeley' in p, p
|
||||||
|
assert 'Galway' in p, p
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user