Add support for with statement

This commit is contained in:
Victor Kashirin 2014-06-09 01:36:42 +04:00
parent 14facc15d0
commit 38b1114bde
2 changed files with 42 additions and 9 deletions

View File

@ -1,6 +1,7 @@
import logging
import threading
import re
from sqlalchemy.util import safe_reraise
try:
from urllib.parse import urlencode
@ -26,7 +27,6 @@ log = logging.getLogger(__name__)
class Database(object):
def __init__(self, url, schema=None, reflectMetadata=True,
engine_kwargs=None):
if engine_kwargs is None:
@ -84,6 +84,14 @@ class Database(object):
self.lock.release()
self.local.must_release = False
def _dispose_transaction(self):
self.local.tx.remove(self.local.tx[-1])
if not self.local.tx:
del self.local.tx
self.local.connection.close()
del self.local.connection
self._release_internal()
def begin(self):
""" Enter a transaction explicitly. No data will be written
until the transaction has been committed.
@ -93,21 +101,36 @@ class Database(object):
if not hasattr(self.local, 'connection'):
self.local.connection = self.engine.connect()
if not hasattr(self.local, 'tx'):
self.local.tx = self.local.connection.begin()
self.local.tx = []
self.local.tx.append(self.local.connection.begin())
def commit(self):
""" Commit the current transaction, making all statements executed
since the transaction was begun permanent. """
self.local.tx.commit()
del self.local.tx
self._release_internal()
if hasattr(self.local, 'tx') and self.local.tx:
self.local.tx[-1].commit()
self._dispose_transaction()
def rollback(self):
""" Roll back the current transaction, discarding all statements
executed since the transaction was begun. """
self.local.tx.rollback()
del self.local.tx
self._release_internal()
if hasattr(self.local, 'tx') and self.local.tx:
self.local.tx[-1].rollback()
self._dispose_transaction()
def __enter__(self):
self.begin()
return self
def __exit__(self, error_type, error_value, traceback):
if error_type is None:
try:
self.commit()
except:
with safe_reraise():
self.rollback()
else:
self.rollback()
@property
def tables(self):

View File

@ -7,7 +7,7 @@ try:
except ImportError:
from ordereddict import OrderedDict # Python < 2.7 drop-in
from sqlalchemy.exc import IntegrityError
from sqlalchemy.exc import IntegrityError, SQLAlchemyError
from dataset import connect
from dataset.util import DatasetException
@ -114,6 +114,16 @@ class DatabaseTestCase(unittest.TestCase):
'string_id': 'foobar'})
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
def test_with(self):
init_length = len(self.db['weather'])
try:
with self.db:
self.db['weather'].insert({'date': datetime(2011, 1, 1), 'temperature': 1, 'place': u'tmp_place'})
self.db['weather'].insert({'date': True, 'temperature': 'wrong_value', 'place': u'tmp_place'})
except SQLAlchemyError:
pass
assert len(self.db['weather']) == init_length
def test_load_table(self):
tbl = self.db.load_table('weather')
assert tbl.table == self.tbl.table