diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index cfefb99..417517e 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -190,7 +190,8 @@ class Database(object): self.metadata = MetaData(schema=self.schema) self.metadata.bind = self.engine self.metadata.reflect(self.engine) - return SQLATable(table_name, self.metadata) + self._tables[table_name] = SQLATable(table_name, self.metadata) + return self._tables[table_name] def get_table(self, table_name, primary_id='id', primary_type='Integer'): """ diff --git a/test/test_persistence.py b/test/test_persistence.py index 696f746..ca52c03 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -121,6 +121,13 @@ class DatabaseTestCase(unittest.TestCase): r = self.db.query('SELECT COUNT(*) AS num FROM weather').next() assert r['num'] == len(TEST_DATA), r + def test_table_cache_updates(self): + tbl1 = self.db.get_table('people') + tbl1.insert(dict(first_name='John', last_name='Smith')) + tbl2 = self.db.get_table('people') + + assert list(tbl2.all()) == [(1, 'John', 'Smith')] + class TableTestCase(unittest.TestCase):