diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index b03239f..a46e0a4 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -79,6 +79,11 @@ class Database(object): """Return a SQLAlchemy schema cache object.""" return MetaData(schema=self.schema, bind=self.executable) + @property + def in_transaction(self): + """Check if this database is in a transactional context.""" + return len(self.local.tx) > 0 + def _flush_tables(self): """Clear the table metadata after transaction rollbacks.""" for table in self._tables.values(): diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index 2d3291e..98cf4f7 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -1,4 +1,6 @@ import logging +import warnings +import threading from sqlalchemy.sql import and_, expression from sqlalchemy.sql.expression import ClauseElement @@ -207,6 +209,13 @@ class Table(object): except NoSuchTableError: pass + def _threading_warn(self): + if self.db.in_transaction and threading.active_count() > 1: + warnings.warn("Changing the database schema inside a transaction " + "in a multi-threaded environment is likely to lead " + "to race conditions and synchronization issues.", + RuntimeWarning) + def _sync_table(self, columns): """Lazy load, create or adapt the table structure in the database.""" if self._table is None: @@ -218,6 +227,7 @@ class Table(object): raise DatasetException("Table does not exist: %s" % self.name) # Keep the lock scope small because this is run very often. with self.db.lock: + self._threading_warn() self._table = SQLATable(self.name, self.db.metadata, schema=self.db.schema) @@ -236,6 +246,7 @@ class Table(object): self._table.create(self.db.executable, checkfirst=True) elif len(columns): with self.db.lock: + self._threading_warn() for column in columns: self.db.op.add_column(self.name, column, self.db.schema) self._reflect_table() @@ -317,7 +328,7 @@ class Table(object): log.debug("Column exists: %s" % name) return - log.debug("Create column: %s on %s", name, self.name) + self._threading_warn() self.db.op.add_column( self.name, Column(name, type), @@ -353,6 +364,7 @@ class Table(object): log.debug("Column does not exist: %s", name) return + self._threading_warn() self.db.op.drop_column( self.table.name, name, @@ -367,6 +379,7 @@ class Table(object): """ with self.db.lock: if self.exists: + self._threading_warn() self.table.drop(self.db.executable, checkfirst=True) self._table = None @@ -401,6 +414,7 @@ class Table(object): raise DatasetException("Table has not been created yet.") if not self.has_index(columns): + self._threading_warn() name = name or index_name(self.name, columns) columns = [self.table.c[c] for c in columns] idx = Index(name, *columns, **kw)