Merge pull request #98 from victorkashirin/support_with_statement

Fix #90: added `with` statement support to the Database class.
This commit is contained in:
Friedrich Lindenberg 2014-06-12 08:48:42 +03:00
commit 75be7d3b62
3 changed files with 72 additions and 9 deletions

View File

@ -1,6 +1,7 @@
import logging import logging
import threading import threading
import re import re
from sqlalchemy.util import safe_reraise
try: try:
from urllib.parse import urlencode from urllib.parse import urlencode
@ -26,7 +27,6 @@ log = logging.getLogger(__name__)
class Database(object): class Database(object):
def __init__(self, url, schema=None, reflectMetadata=True, def __init__(self, url, schema=None, reflectMetadata=True,
engine_kwargs=None): engine_kwargs=None):
if engine_kwargs is None: if engine_kwargs is None:
@ -84,6 +84,14 @@ class Database(object):
self.lock.release() self.lock.release()
self.local.must_release = False self.local.must_release = False
def _dispose_transaction(self):
self.local.tx.remove(self.local.tx[-1])
if not self.local.tx:
del self.local.tx
self.local.connection.close()
del self.local.connection
self._release_internal()
def begin(self): def begin(self):
""" Enter a transaction explicitly. No data will be written """ Enter a transaction explicitly. No data will be written
until the transaction has been committed. until the transaction has been committed.
@ -93,21 +101,36 @@ class Database(object):
if not hasattr(self.local, 'connection'): if not hasattr(self.local, 'connection'):
self.local.connection = self.engine.connect() self.local.connection = self.engine.connect()
if not hasattr(self.local, 'tx'): if not hasattr(self.local, 'tx'):
self.local.tx = self.local.connection.begin() self.local.tx = []
self.local.tx.append(self.local.connection.begin())
def commit(self): def commit(self):
""" Commit the current transaction, making all statements executed """ Commit the current transaction, making all statements executed
since the transaction was begun permanent. """ since the transaction was begun permanent. """
self.local.tx.commit() if hasattr(self.local, 'tx') and self.local.tx:
del self.local.tx self.local.tx[-1].commit()
self._release_internal() self._dispose_transaction()
def rollback(self): def rollback(self):
""" Roll back the current transaction, discarding all statements """ Roll back the current transaction, discarding all statements
executed since the transaction was begun. """ executed since the transaction was begun. """
self.local.tx.rollback() if hasattr(self.local, 'tx') and self.local.tx:
del self.local.tx self.local.tx[-1].rollback()
self._release_internal() self._dispose_transaction()
def __enter__(self):
self.begin()
return self
def __exit__(self, error_type, error_value, traceback):
if error_type is None:
try:
self.commit()
except:
with safe_reraise():
self.rollback()
else:
self.rollback()
@property @property
def tables(self): def tables(self):

View File

@ -64,6 +64,36 @@ The list of filter columns given as the second argument filter using the
values in the first column. If you don't want to update over a values in the first column. If you don't want to update over a
particular value, just use the auto-generated ``id`` column. particular value, just use the auto-generated ``id`` column.
Using Transactions
------------------
You can group a set of database updates within a transaction, thus all updates are
committed at once or, in case of exception, all of them are reverted. Transactions are
supported by ``dataset`` context manager, and initiated by ``with`` statement::
with dataset.connect() as tx:
tx['user'].insert(dict(name='John Doe', age=46, country='China'))
You can get same functionality with datase methods :py:meth:`all() <dataset.Table.begin>`,
:py:meth:`all() <dataset.Table.commit>` and :py:meth:`rollback() <dataset.Table.rollback>`::
db = dataset.connect()
db.begin()
try:
db['user'].insert(dict(name='John Doe', age=46, country='China'))
db.commit()
except:
db.rollback()
Nested transactions are supported too::
db = dataset.connect()
with db as tx1:
tx1['user'].insert(dict(name='John Doe', age=46, country='China'))
with db sa tx2:
tx2['user'].insert(dict(name='Jane Doe', age=37, country='France', gender='female'))
Inspecting databases and tables Inspecting databases and tables
------------------------------- -------------------------------

View File

@ -7,7 +7,7 @@ try:
except ImportError: except ImportError:
from ordereddict import OrderedDict # Python < 2.7 drop-in from ordereddict import OrderedDict # Python < 2.7 drop-in
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from dataset import connect from dataset import connect
from dataset.util import DatasetException from dataset.util import DatasetException
@ -114,6 +114,16 @@ class DatabaseTestCase(unittest.TestCase):
'string_id': 'foobar'}) 'string_id': 'foobar'})
assert table.find_one(string_id='foobar')['string_id'] == 'foobar' assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
def test_with(self):
init_length = len(self.db['weather'])
try:
with self.db as tx:
tx['weather'].insert({'date': datetime(2011, 1, 1), 'temperature': 1, 'place': u'tmp_place'})
tx['weather'].insert({'date': True, 'temperature': 'wrong_value', 'place': u'tmp_place'})
except SQLAlchemyError:
pass
assert len(self.db['weather']) == init_length
def test_load_table(self): def test_load_table(self):
tbl = self.db.load_table('weather') tbl = self.db.load_table('weather')
assert tbl.table == self.tbl.table assert tbl.table == self.tbl.table