From 38b1114bde5c4da79d02828f0a4b5bd53dbfc034 Mon Sep 17 00:00:00 2001 From: Victor Kashirin Date: Mon, 9 Jun 2014 01:36:42 +0400 Subject: [PATCH 1/3] Add support for `with` statement --- dataset/persistence/database.py | 39 ++++++++++++++++++++++++++------- test/test_persistence.py | 12 +++++++++- 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 36b6f6c..d87963a 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -1,6 +1,7 @@ import logging import threading import re +from sqlalchemy.util import safe_reraise try: from urllib.parse import urlencode @@ -26,7 +27,6 @@ log = logging.getLogger(__name__) class Database(object): - def __init__(self, url, schema=None, reflectMetadata=True, engine_kwargs=None): if engine_kwargs is None: @@ -84,6 +84,14 @@ class Database(object): self.lock.release() self.local.must_release = False + def _dispose_transaction(self): + self.local.tx.remove(self.local.tx[-1]) + if not self.local.tx: + del self.local.tx + self.local.connection.close() + del self.local.connection + self._release_internal() + def begin(self): """ Enter a transaction explicitly. No data will be written until the transaction has been committed. @@ -93,21 +101,36 @@ class Database(object): if not hasattr(self.local, 'connection'): self.local.connection = self.engine.connect() if not hasattr(self.local, 'tx'): - self.local.tx = self.local.connection.begin() + self.local.tx = [] + self.local.tx.append(self.local.connection.begin()) def commit(self): """ Commit the current transaction, making all statements executed since the transaction was begun permanent. """ - self.local.tx.commit() - del self.local.tx - self._release_internal() + if hasattr(self.local, 'tx') and self.local.tx: + self.local.tx[-1].commit() + self._dispose_transaction() def rollback(self): """ Roll back the current transaction, discarding all statements executed since the transaction was begun. """ - self.local.tx.rollback() - del self.local.tx - self._release_internal() + if hasattr(self.local, 'tx') and self.local.tx: + self.local.tx[-1].rollback() + self._dispose_transaction() + + def __enter__(self): + self.begin() + return self + + def __exit__(self, error_type, error_value, traceback): + if error_type is None: + try: + self.commit() + except: + with safe_reraise(): + self.rollback() + else: + self.rollback() @property def tables(self): diff --git a/test/test_persistence.py b/test/test_persistence.py index efd689f..4e9c8d6 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -7,7 +7,7 @@ try: except ImportError: from ordereddict import OrderedDict # Python < 2.7 drop-in -from sqlalchemy.exc import IntegrityError +from sqlalchemy.exc import IntegrityError, SQLAlchemyError from dataset import connect from dataset.util import DatasetException @@ -114,6 +114,16 @@ class DatabaseTestCase(unittest.TestCase): 'string_id': 'foobar'}) assert table.find_one(string_id='foobar')['string_id'] == 'foobar' + def test_with(self): + init_length = len(self.db['weather']) + try: + with self.db: + self.db['weather'].insert({'date': datetime(2011, 1, 1), 'temperature': 1, 'place': u'tmp_place'}) + self.db['weather'].insert({'date': True, 'temperature': 'wrong_value', 'place': u'tmp_place'}) + except SQLAlchemyError: + pass + assert len(self.db['weather']) == init_length + def test_load_table(self): tbl = self.db.load_table('weather') assert tbl.table == self.tbl.table From c78d5e00b2c454023278775d5064986c42b6fd2e Mon Sep 17 00:00:00 2001 From: Victor Kashirin Date: Wed, 11 Jun 2014 15:12:44 +0400 Subject: [PATCH 2/3] Add documentation on transactions --- docs/quickstart.rst | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/docs/quickstart.rst b/docs/quickstart.rst index f991449..8be5013 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -64,6 +64,36 @@ The list of filter columns given as the second argument filter using the values in the first column. If you don't want to update over a particular value, just use the auto-generated ``id`` column. +Using Transactions +------------------ + +You can group a set of database updates within a transaction, thus all updates are +committed at once or, in case of exception, all of them are reverted. Transactions are +supported by ``dataset`` context manager, and initiated by ``with`` statement:: + + with dataset.connect() as tx: + tx['user'].insert(dict(name='John Doe', age=46, country='China')) + +You can get same functionality with datase methods :py:meth:`all() `, +:py:meth:`all() ` and :py:meth:`rollback() `:: + + db = dataset.connect() + db.begin() + try: + db['user'].insert(dict(name='John Doe', age=46, country='China')) + db.commit() + except: + db.rollback() + +Nested transactions are supported too:: + db = dataset.connect() + with db as tx1: + tx1['user'].insert(dict(name='John Doe', age=46, country='China')) + with db sa tx2: + tx2['user'].insert(dict(name='Jane Doe', age=37, country='France', gender='female')) + + + Inspecting databases and tables ------------------------------- From 39759c92ab3ecae08b621a92b03b9c6f11275d00 Mon Sep 17 00:00:00 2001 From: Victor Kashirin Date: Wed, 11 Jun 2014 15:14:42 +0400 Subject: [PATCH 3/3] Make more explicit syntaxis for `test_with` transaction test --- test/test_persistence.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_persistence.py b/test/test_persistence.py index 4e9c8d6..6040999 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -117,9 +117,9 @@ class DatabaseTestCase(unittest.TestCase): def test_with(self): init_length = len(self.db['weather']) try: - with self.db: - self.db['weather'].insert({'date': datetime(2011, 1, 1), 'temperature': 1, 'place': u'tmp_place'}) - self.db['weather'].insert({'date': True, 'temperature': 'wrong_value', 'place': u'tmp_place'}) + with self.db as tx: + tx['weather'].insert({'date': datetime(2011, 1, 1), 'temperature': 1, 'place': u'tmp_place'}) + tx['weather'].insert({'date': True, 'temperature': 'wrong_value', 'place': u'tmp_place'}) except SQLAlchemyError: pass assert len(self.db['weather']) == init_length