diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 5f7cdc5..a0a9f56 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -25,7 +25,7 @@ log = logging.getLogger(__name__) class Database(object): def __init__(self, url, schema=None, reflect_metadata=True, engine_kwargs=None, reflect_views=True, - row_type=row_type): + ensure_schema=True, row_type=row_type): if engine_kwargs is None: engine_kwargs = {} @@ -45,6 +45,7 @@ class Database(object): self.engine = create_engine(url, **engine_kwargs) self.url = url self.row_type = row_type + self.ensure_schema = ensure_schema self._tables = {} self.metadata = MetaData(schema=schema) self.metadata.bind = self.engine diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index 4c64fe3..0a9fb32 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -11,7 +11,6 @@ from dataset.util import DatasetException log = logging.getLogger(__name__) -ensure_session = True class Table(object): @@ -53,7 +52,7 @@ class Table(object): if self._is_dropped: raise DatasetException('the table has been dropped. this object should not be used again.') - def insert(self, row, ensure=ensure_session, types={}): + def insert(self, row, ensure=None, types={}): """ Add a row (type: dict) by inserting it into the table. If ``ensure`` is set, any of the keys of the row are not @@ -72,13 +71,14 @@ class Table(object): Returns the inserted row's primary key. """ self._check_dropped() + ensure = self.database.ensure_schema if ensure is None else ensure if ensure: self._ensure_columns(row, types=types) res = self.database.executable.execute(self.table.insert(row)) if len(res.inserted_primary_key) > 0: return res.inserted_primary_key[0] - def insert_many(self, rows, chunk_size=1000, ensure=ensure_session, types={}): + def insert_many(self, rows, chunk_size=1000, ensure=None, types={}): """ Add many rows at a time, which is significantly faster than adding them one by one. Per default the rows are processed in chunks of @@ -91,6 +91,8 @@ class Table(object): rows = [dict(name='Dolly')] * 10000 table.insert_many(rows) """ + ensure = self.database.ensure_schema if ensure is None else ensure + def _process_chunk(chunk): if ensure: for row in chunk: @@ -108,7 +110,7 @@ class Table(object): if chunk: _process_chunk(chunk) - def update(self, row, keys, ensure=ensure_session, types={}): + def update(self, row, keys, ensure=None, types={}): """ Update a row in the table. The update is managed via the set of column names stated in ``keys``: they will be @@ -132,6 +134,7 @@ class Table(object): return False clause = [(u, row.get(u)) for u in keys] + ensure = self.database.ensure_schema if ensure is None else ensure if ensure: self._ensure_columns(row, types=types) @@ -149,7 +152,7 @@ class Table(object): except KeyError: return 0 - def upsert(self, row, keys, ensure=ensure_session, types={}): + def upsert(self, row, keys, ensure=None, types={}): """ An UPSERT is a smart combination of insert and update. If rows with matching ``keys`` exist they will be updated, otherwise a new row is inserted in the table. @@ -162,6 +165,8 @@ class Table(object): if not isinstance(keys, (list, tuple)): keys = [keys] self._check_dropped() + + ensure = self.database.ensure_schema if ensure is None else ensure if ensure: self.create_index(keys) @@ -220,7 +225,8 @@ class Table(object): _type, self.table.name)) self.create_column(column, _type) - def _args_to_clause(self, args, ensure=ensure_session, clauses=()): + def _args_to_clause(self, args, ensure=None, clauses=()): + ensure = self.database.ensure_schema if ensure is None else ensure if ensure: self._ensure_columns(args) clauses = list(clauses)