diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index e0b4f30..7072853 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -86,8 +86,6 @@ class Database(object): Enter a transaction explicitly. No data will be written until the transaction has been committed. - **NOTICE:** Schema modification operations, such as the creation - of tables or columns will not be part of the transactional context. """ if not hasattr(self.local, 'tx'): self.local.tx = [] @@ -179,6 +177,9 @@ class Database(object): assert not isinstance(primary_type, six.string_types), \ 'Text-based primary_type support is dropped, use db.types.' table_name = self._valid_table_name(table_name) + table = self._reflect_table(table_name) + if table is not None: + return Table(self, table) self._acquire() try: log.debug("Creating table: %s" % (table_name)) @@ -192,7 +193,7 @@ class Database(object): primary_key=True, autoincrement=autoincrement) table.append_column(col) - table.create(self.executable) + table.create(self.executable, checkfirst=True) return Table(self, table) finally: self._release() @@ -222,7 +223,8 @@ class Database(object): return SQLATable(table_name, self.metadata, schema=self.schema, - autoload=True) + autoload=True, + autoload_with=self.executable) except NoSuchTableError: return None @@ -242,9 +244,6 @@ class Database(object): table = db['population'] """ - table = self._reflect_table(table_name) - if table is not None: - return Table(self, table) return self.create_table(table_name, primary_id, primary_type) def __getitem__(self, table_name): diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index 8c3d9c7..0f57dc3 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -30,10 +30,6 @@ class Table(object): """Get a listing of all columns that exist in the table.""" return list(self.table.columns.keys()) - @property - def _normalized_columns(self): - return map(normalize_column_name, self.columns) - def drop(self): """ Drop the table from the database. @@ -46,13 +42,13 @@ class Table(object): self._check_dropped() self.database._acquire() try: - self._is_dropped = True - self.table.drop(self.database.executable) + self.table.drop(self.database.executable, checkfirst=True) + self.table = None finally: self.database._release() def _check_dropped(self): - if self._is_dropped: + if self.table is None: raise DatasetException('the table has been dropped. this object should not be used again.') def _prune_row(self, row): @@ -60,7 +56,7 @@ class Table(object): # normalize keys row = {normalize_column_name(k): v for k, v in row.items()} # filter out keys not in column set - return {k: row[k] for k in row if k in self._normalized_columns} + return {k: row[k] for k in row if k in self.columns} def insert(self, row, ensure=None, types=None): """ @@ -261,7 +257,7 @@ class Table(object): return rows.rowcount > 0 def _has_column(self, column): - return normalize_column_name(column) in self._normalized_columns + return normalize_column_name(column) in self.columns def _ensure_columns(self, row, types=None): # Keep order of inserted columns @@ -303,7 +299,8 @@ class Table(object): self._check_dropped() self.database._acquire() try: - if normalize_column_name(name) in self._normalized_columns: + name = normalize_column_name(name) + if name in self.columns: log.debug("Column exists: %s" % name) return diff --git a/test/test_persistence.py b/test/test_persistence.py index 64f9d85..b5789bf 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -21,6 +21,10 @@ class DatabaseTestCase(unittest.TestCase): def setUp(self): os.environ.setdefault('DATABASE_URL', 'sqlite:///:memory:') self.db = connect(os.environ['DATABASE_URL']) + for table in self.db.tables: + self.db[table].drop() + self.db.commit() + print 'XXX', self.db.tables self.tbl = self.db['weather'] self.tbl.insert_many(TEST_DATA) @@ -331,9 +335,9 @@ class TableTestCase(unittest.TestCase): assert len(self.tbl) == len(data) + 6 def test_drop_warning(self): - assert self.tbl._is_dropped is False, 'table shouldn\'t be dropped yet' + assert self.tbl.table is not None, 'table shouldn\'t be dropped yet' self.tbl.drop() - assert self.tbl._is_dropped is True, 'table should be dropped now' + assert self.tbl.table is None, 'table should be dropped now' try: list(self.tbl.all()) except DatasetException: