From cc7787036b405310ad1a48ec5c3de390355624f5 Mon Sep 17 00:00:00 2001 From: Friedrich Lindenberg Date: Sat, 2 Sep 2017 17:17:24 +0200 Subject: [PATCH] still more cases of using the engine and not the executable in transaction --- dataset/persistence/table.py | 40 +++++++++++++++++++++--------------- dataset/persistence/util.py | 4 ++++ test/test_persistence.py | 5 ++--- 3 files changed, 29 insertions(+), 20 deletions(-) diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index a4187e9..cd115a3 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -44,11 +44,13 @@ class Table(object): sure to get a fresh instance from the :py:class:`Database `. """ self.database._acquire() - self._is_dropped = True - self.database._tables.pop(self.table.name, None) - self.table.drop(self.database.engine) - self.database._release() - return True + try: + self._is_dropped = True + self.database._tables.pop(self.table.name, None) + self.table.drop(self.database.executable) + return True + finally: + self.database._release() def _check_dropped(self): if self._is_dropped: @@ -338,15 +340,17 @@ class Table(object): if self.database.engine.dialect.name == 'sqlite': raise NotImplementedError("SQLite does not support dropping columns.") self._check_dropped() + if name not in self.table.columns.keys(): + log.debug("Column does not exist: %s", name) + return self.database._acquire() try: - if name in self.table.columns.keys(): - self.database.op.drop_column( - self.table.name, - name, - self.table.schema - ) - self.table = self.database.update_table(self.table.name) + self.database.op.drop_column( + self.table.name, + name, + self.table.schema + ) + self.table = self.database.update_table(self.table.name) finally: self.database._release() @@ -379,7 +383,7 @@ class Table(object): self.database._acquire() columns = [self.table.c[c] for c in columns] idx = Index(name, *columns, **kw) - idx.create(self.database.engine) + idx.create(self.database.executable) except: idx = None finally: @@ -398,11 +402,13 @@ class Table(object): row = table.find_one(country='United States') """ kwargs['_limit'] = 1 - iterator = self.find(*args, **kwargs) + kwargs['_step'] = None + resiter = self.find(*args, **kwargs) try: - return next(iterator) - except StopIteration: - return None + for row in resiter: + return row + finally: + resiter.close() def _args_to_order_by(self, order_by): if not isinstance(order_by, (list, tuple)): diff --git a/dataset/persistence/util.py b/dataset/persistence/util.py index c354339..7fcd3bb 100644 --- a/dataset/persistence/util.py +++ b/dataset/persistence/util.py @@ -47,6 +47,7 @@ class ResultIter(object): def __init__(self, result_proxy, row_type=row_type, step=None): self.row_type = row_type + self.result_proxy = result_proxy self.keys = list(result_proxy.keys()) self._iter = iter_result_proxy(result_proxy, step=step) @@ -58,6 +59,9 @@ class ResultIter(object): def __iter__(self): return self + def close(self): + self.result_proxy.close() + def safe_url(url): """Remove password from printed connection URLs.""" diff --git a/test/test_persistence.py b/test/test_persistence.py index d868e33..5ffe80b 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -22,8 +22,7 @@ class DatabaseTestCase(unittest.TestCase): os.environ.setdefault('DATABASE_URL', 'sqlite:///:memory:') self.db = connect(os.environ['DATABASE_URL']) self.tbl = self.db['weather'] - for row in TEST_DATA: - self.tbl.insert(row) + self.tbl.insert_many(TEST_DATA) def tearDown(self): for table in self.db.tables: @@ -54,9 +53,9 @@ class DatabaseTestCase(unittest.TestCase): assert table.table.exists() assert len(table.table.columns) == 1, table.table.columns assert pid in table.table.c, table.table.c - table.insert({pid: 'foobar'}) assert table.find_one(string_id='foobar')[pid] == 'foobar' + table.drop() def test_create_table_custom_id2(self): pid = "string_id"