Add support for with statement

This commit is contained in:
Victor Kashirin 2014-06-09 01:36:42 +04:00
parent 14facc15d0
commit 38b1114bde
2 changed files with 42 additions and 9 deletions

View File

@ -1,6 +1,7 @@
import logging import logging
import threading import threading
import re import re
from sqlalchemy.util import safe_reraise
try: try:
from urllib.parse import urlencode from urllib.parse import urlencode
@ -26,7 +27,6 @@ log = logging.getLogger(__name__)
class Database(object): class Database(object):
def __init__(self, url, schema=None, reflectMetadata=True, def __init__(self, url, schema=None, reflectMetadata=True,
engine_kwargs=None): engine_kwargs=None):
if engine_kwargs is None: if engine_kwargs is None:
@ -84,6 +84,14 @@ class Database(object):
self.lock.release() self.lock.release()
self.local.must_release = False self.local.must_release = False
def _dispose_transaction(self):
self.local.tx.remove(self.local.tx[-1])
if not self.local.tx:
del self.local.tx
self.local.connection.close()
del self.local.connection
self._release_internal()
def begin(self): def begin(self):
""" Enter a transaction explicitly. No data will be written """ Enter a transaction explicitly. No data will be written
until the transaction has been committed. until the transaction has been committed.
@ -93,21 +101,36 @@ class Database(object):
if not hasattr(self.local, 'connection'): if not hasattr(self.local, 'connection'):
self.local.connection = self.engine.connect() self.local.connection = self.engine.connect()
if not hasattr(self.local, 'tx'): if not hasattr(self.local, 'tx'):
self.local.tx = self.local.connection.begin() self.local.tx = []
self.local.tx.append(self.local.connection.begin())
def commit(self): def commit(self):
""" Commit the current transaction, making all statements executed """ Commit the current transaction, making all statements executed
since the transaction was begun permanent. """ since the transaction was begun permanent. """
self.local.tx.commit() if hasattr(self.local, 'tx') and self.local.tx:
del self.local.tx self.local.tx[-1].commit()
self._release_internal() self._dispose_transaction()
def rollback(self): def rollback(self):
""" Roll back the current transaction, discarding all statements """ Roll back the current transaction, discarding all statements
executed since the transaction was begun. """ executed since the transaction was begun. """
self.local.tx.rollback() if hasattr(self.local, 'tx') and self.local.tx:
del self.local.tx self.local.tx[-1].rollback()
self._release_internal() self._dispose_transaction()
def __enter__(self):
self.begin()
return self
def __exit__(self, error_type, error_value, traceback):
if error_type is None:
try:
self.commit()
except:
with safe_reraise():
self.rollback()
else:
self.rollback()
@property @property
def tables(self): def tables(self):

View File

@ -7,7 +7,7 @@ try:
except ImportError: except ImportError:
from ordereddict import OrderedDict # Python < 2.7 drop-in from ordereddict import OrderedDict # Python < 2.7 drop-in
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from dataset import connect from dataset import connect
from dataset.util import DatasetException from dataset.util import DatasetException
@ -114,6 +114,16 @@ class DatabaseTestCase(unittest.TestCase):
'string_id': 'foobar'}) 'string_id': 'foobar'})
assert table.find_one(string_id='foobar')['string_id'] == 'foobar' assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
def test_with(self):
init_length = len(self.db['weather'])
try:
with self.db:
self.db['weather'].insert({'date': datetime(2011, 1, 1), 'temperature': 1, 'place': u'tmp_place'})
self.db['weather'].insert({'date': True, 'temperature': 'wrong_value', 'place': u'tmp_place'})
except SQLAlchemyError:
pass
assert len(self.db['weather']) == init_length
def test_load_table(self): def test_load_table(self):
tbl = self.db.load_table('weather') tbl = self.db.load_table('weather')
assert tbl.table == self.tbl.table assert tbl.table == self.tbl.table