diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 685a55c..c74a056 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -29,7 +29,7 @@ class Database(object): self.metadata.bind = self.engine self.tables = {} - def create_table(table_name): + def create_table(self, table_name): with self.lock: log.debug("Creating table: %s on %r" % (table_name, self.engine)) table = SQLATable(table_name, self.metadata) @@ -51,15 +51,15 @@ class Database(object): if table_name in self.tables: return Table(self, self.tables[table_name]) if self.engine.has_table(table_name): - return load_table(engine, table_name) + return self.load_table(table_name) else: - return create_table(engine, table_name) + return self.create_table(table_name) def __getitem__(self, table_name): return self.get_table(table_name) def query(self, query): - return resultiter(self.engine.execute(query)): + return resultiter(self.engine.execute(query)) def __repr__(self): return '' % self.url diff --git a/test/__init__.py b/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/test/sample_data.py b/test/sample_data.py new file mode 100644 index 0000000..d69bf2f --- /dev/null +++ b/test/sample_data.py @@ -0,0 +1,34 @@ +from datetime import datetime + +TEST_DATA = [ + { + 'date': datetime(2011, 01, 01), + 'temperature': 1, + 'place': 'Galway' + }, + { + 'date': datetime(2011, 01, 02), + 'temperature': -1, + 'place': 'Galway' + }, + { + 'date': datetime(2011, 01, 03), + 'temperature': 0, + 'place': 'Galway' + }, + { + 'date': datetime(2011, 01, 01), + 'temperature': 6, + 'place': 'Berkeley' + }, + { + 'date': datetime(2011, 01, 02), + 'temperature': 8, + 'place': 'Berkeley' + }, + { + 'date': datetime(2011, 01, 03), + 'temperature': 5, + 'place': 'Berkeley' + } +] diff --git a/test/test_persistence.py b/test/test_persistence.py new file mode 100644 index 0000000..68c4914 --- /dev/null +++ b/test/test_persistence.py @@ -0,0 +1,19 @@ +import unittest +from datetime import datetime + +from dataset import connect + +class DatabaseTestCase(unittest.TestCase): + + def setUp(self): + self.db = connect('sqlite:///:memory:') + + def test_create_table(self): + table = self.db['foo'] + assert table.table.exists() + assert len(table.table.columns) == 1, table.table.columns + + +if __name__ == '__main__': + unittest.main() +