From b5181a41056ccd606bc78d3d0a9c65d23e70c89d Mon Sep 17 00:00:00 2001 From: Friedrich Lindenberg Date: Fri, 10 May 2013 22:58:23 +0200 Subject: [PATCH] Testing with transactions support. --- dataset/persistence/database.py | 37 +++++++++++++++++++++++++++++---- dataset/persistence/table.py | 34 +++++++++++++++++------------- 2 files changed, 53 insertions(+), 18 deletions(-) diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 3536068..4bca5da 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -1,5 +1,5 @@ import logging -from threading import RLock +import threading from sqlalchemy import create_engine from migrate.versioning.util import construct_engine @@ -22,7 +22,8 @@ class Database(object): if url.startswith('postgres'): kw['poolclass'] = NullPool engine = create_engine(url, **kw) - self.lock = RLock() + self.lock = threading.RLock() + self.local = threading.local() self.url = url self.engine = construct_engine(engine) self.metadata = MetaData() @@ -30,6 +31,34 @@ class Database(object): self.metadata.reflect(self.engine) self._tables = {} + @property + def executable(self): + """ The current connection or engine against which statements + will be executed. """ + if hasattr(self.local, 'connection'): + return self.local.connection + return self.engine + + def begin(self): + """ Enter a transaction explicitly. No data will be written + until the transaction has been committed. """ + if not hasattr(self.local, 'connection'): + self.local.connection = self.engine.connect() + if not hasattr(self.local, 'tx'): + self.local.tx = self.local.connection.begin() + + def commit(self): + """ Commit the current transaction, making all statements executed + since the transaction was begun permanent. """ + self.local.tx.commit() + del self.local.tx + + def rollback(self): + """ Roll back the current transaction, discarding all statements + executed since the transaction was begun. """ + self.local.tx.rollback() + del self.local.tx + @property def tables(self): """ Get a listing of all tables that exist in the database. @@ -55,7 +84,7 @@ class Database(object): table = SQLATable(table_name, self.metadata) col = Column('id', Integer, primary_key=True) table.append_column(col) - table.create(self.engine) + table.create(self.executable) self._tables[table_name] = table return Table(self, table) @@ -112,7 +141,7 @@ class Database(object): for row in res: print row['user'], row['c'] """ - return ResultIter(self.engine.execute(query)) + return ResultIter(self.executable.execute(query)) def __repr__(self): return '' % self.url diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index fac32a6..ac05362 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -15,7 +15,7 @@ log = logging.getLogger(__name__) class Table(object): def __init__(self, database, table): - self.indexes = {} + self.indexes = dict([(i.name, i) for i in table.indexes]) self.database = database self.table = table self._is_dropped = False @@ -42,7 +42,7 @@ class Table(object): self._is_dropped = True with self.database.lock: self.database._tables.pop(self.table.name, None) - self.table.drop(self.database.engine) + self.table.drop(self.database.executable) def _check_dropped(self): if self._is_dropped: @@ -67,7 +67,7 @@ class Table(object): self._check_dropped() if ensure: self._ensure_columns(row, types=types) - res = self.database.engine.execute(self.table.insert(row)) + res = self.database.executable.execute(self.table.insert(row)) return res.lastrowid def insert_many(self, rows, chunk_size=1000, ensure=True, types={}): @@ -126,7 +126,7 @@ class Table(object): try: filters = self._args_to_clause(dict(clause)) stmt = self.table.update(filters, row) - rp = self.database.engine.execute(stmt) + rp = self.database.executable.execute(stmt) return rp.rowcount > 0 except KeyError: return False @@ -144,7 +144,13 @@ class Table(object): if ensure: self.create_index(keys) - if not self.update(row, keys, ensure=ensure, types=types): + filters = {} + for key in keys: + filters[key] = row.get(key) + + if self.find_one(**filters) is not None: + self.update(row, keys, ensure=ensure, types=types) + else: self.insert(row, ensure=ensure, types=types) def delete(self, **_filter): @@ -164,7 +170,7 @@ class Table(object): stmt = self.table.delete(q) else: stmt = self.table.delete() - self.database.engine.execute(stmt) + self.database.executable.execute(stmt) def _ensure_columns(self, row, types={}): for column in set(row.keys()) - set(self.table.columns.keys()): @@ -196,7 +202,7 @@ class Table(object): if name not in self.table.columns.keys(): col = Column(name, type) col.create(self.table, - connection=self.database.engine) + connection=self.database.executable) def create_index(self, columns, name=None): """ @@ -215,7 +221,7 @@ class Table(object): try: columns = [self.table.c[c] for c in columns] idx = Index(name, *columns) - idx.create(self.database.engine) + idx.create(self.database.executable) except: idx = None self.indexes[name] = idx @@ -229,10 +235,10 @@ class Table(object): row = table.find_one(country='United States') """ self._check_dropped() - res = list(self.find(_limit=1, **_filter)) - if not len(res): - return None - return res[0] + args = self._args_to_clause(_filter) + query = self.table.select(whereclause=args, limit=1) + rp = self.database.executable.execute(query) + return rp.fetchone() def _args_to_order_by(self, order_by): if order_by[0] == '-': @@ -278,7 +284,7 @@ class Table(object): # query total number of rows first count_query = self.table.count(whereclause=args, limit=_limit, offset=_offset) - rp = self.database.engine.execute(count_query) + rp = self.database.executable.execute(count_query) total_row_count = rp.fetchone()[0] if _step is None or _step is False or _step == 0: @@ -301,7 +307,7 @@ class Table(object): break queries.append(self.table.select(whereclause=args, limit=qlimit, offset=qoffset, order_by=order_by)) - return ResultIter((self.database.engine.execute(q) for q in queries)) + return ResultIter((self.database.executable.execute(q) for q in queries)) def __len__(self): """