diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index 8f930f2..4d729bf 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -56,7 +56,7 @@ class Table(object): self.database.engine.execute(q) def _ensure_columns(self, row, types={}): - for column in set(row.keys()) - set(self.columns.keys()): + for column in set(row.keys()) - set(self.table.columns.keys()): if column in types: _type = types[column] else: diff --git a/test/test_persistence.py b/test/test_persistence.py index 68c4914..e10884f 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -2,16 +2,30 @@ import unittest from datetime import datetime from dataset import connect +from sample_data import TEST_DATA class DatabaseTestCase(unittest.TestCase): def setUp(self): self.db = connect('sqlite:///:memory:') + self.tbl = self.db['weather'] + for row in TEST_DATA: + self.tbl.insert(row) def test_create_table(self): table = self.db['foo'] assert table.table.exists() assert len(table.table.columns) == 1, table.table.columns + assert 'id' in table.table.c, table.table.c + + def test_load_table(self): + tbl = self.db.load_table('weather') + assert tbl.table==self.tbl.table + + def test_query(self): + r = self.db.query('SELECT COUNT(*) AS num FROM weather').next() + assert r['num']==len(TEST_DATA), r + if __name__ == '__main__':