set default of the ensure_schema flag per session

This commit is contained in:
Friedrich Lindenberg 2016-04-08 09:23:24 +03:00
parent 6909ba5c2f
commit 4fb4845efc
2 changed files with 14 additions and 7 deletions

View File

@ -25,7 +25,7 @@ log = logging.getLogger(__name__)
class Database(object): class Database(object):
def __init__(self, url, schema=None, reflect_metadata=True, def __init__(self, url, schema=None, reflect_metadata=True,
engine_kwargs=None, reflect_views=True, engine_kwargs=None, reflect_views=True,
row_type=row_type): ensure_schema=True, row_type=row_type):
if engine_kwargs is None: if engine_kwargs is None:
engine_kwargs = {} engine_kwargs = {}
@ -45,6 +45,7 @@ class Database(object):
self.engine = create_engine(url, **engine_kwargs) self.engine = create_engine(url, **engine_kwargs)
self.url = url self.url = url
self.row_type = row_type self.row_type = row_type
self.ensure_schema = ensure_schema
self._tables = {} self._tables = {}
self.metadata = MetaData(schema=schema) self.metadata = MetaData(schema=schema)
self.metadata.bind = self.engine self.metadata.bind = self.engine

View File

@ -11,7 +11,6 @@ from dataset.util import DatasetException
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ensure_session = True
class Table(object): class Table(object):
@ -53,7 +52,7 @@ class Table(object):
if self._is_dropped: if self._is_dropped:
raise DatasetException('the table has been dropped. this object should not be used again.') raise DatasetException('the table has been dropped. this object should not be used again.')
def insert(self, row, ensure=ensure_session, types={}): def insert(self, row, ensure=None, types={}):
""" """
Add a row (type: dict) by inserting it into the table. Add a row (type: dict) by inserting it into the table.
If ``ensure`` is set, any of the keys of the row are not If ``ensure`` is set, any of the keys of the row are not
@ -72,13 +71,14 @@ class Table(object):
Returns the inserted row's primary key. Returns the inserted row's primary key.
""" """
self._check_dropped() self._check_dropped()
ensure = self.database.ensure_schema if ensure is None else ensure
if ensure: if ensure:
self._ensure_columns(row, types=types) self._ensure_columns(row, types=types)
res = self.database.executable.execute(self.table.insert(row)) res = self.database.executable.execute(self.table.insert(row))
if len(res.inserted_primary_key) > 0: if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0] return res.inserted_primary_key[0]
def insert_many(self, rows, chunk_size=1000, ensure=ensure_session, types={}): def insert_many(self, rows, chunk_size=1000, ensure=None, types={}):
""" """
Add many rows at a time, which is significantly faster than adding Add many rows at a time, which is significantly faster than adding
them one by one. Per default the rows are processed in chunks of them one by one. Per default the rows are processed in chunks of
@ -91,6 +91,8 @@ class Table(object):
rows = [dict(name='Dolly')] * 10000 rows = [dict(name='Dolly')] * 10000
table.insert_many(rows) table.insert_many(rows)
""" """
ensure = self.database.ensure_schema if ensure is None else ensure
def _process_chunk(chunk): def _process_chunk(chunk):
if ensure: if ensure:
for row in chunk: for row in chunk:
@ -108,7 +110,7 @@ class Table(object):
if chunk: if chunk:
_process_chunk(chunk) _process_chunk(chunk)
def update(self, row, keys, ensure=ensure_session, types={}): def update(self, row, keys, ensure=None, types={}):
""" """
Update a row in the table. The update is managed via Update a row in the table. The update is managed via
the set of column names stated in ``keys``: they will be the set of column names stated in ``keys``: they will be
@ -132,6 +134,7 @@ class Table(object):
return False return False
clause = [(u, row.get(u)) for u in keys] clause = [(u, row.get(u)) for u in keys]
ensure = self.database.ensure_schema if ensure is None else ensure
if ensure: if ensure:
self._ensure_columns(row, types=types) self._ensure_columns(row, types=types)
@ -149,7 +152,7 @@ class Table(object):
except KeyError: except KeyError:
return 0 return 0
def upsert(self, row, keys, ensure=ensure_session, types={}): def upsert(self, row, keys, ensure=None, types={}):
""" """
An UPSERT is a smart combination of insert and update. If rows with matching ``keys`` exist An UPSERT is a smart combination of insert and update. If rows with matching ``keys`` exist
they will be updated, otherwise a new row is inserted in the table. they will be updated, otherwise a new row is inserted in the table.
@ -162,6 +165,8 @@ class Table(object):
if not isinstance(keys, (list, tuple)): if not isinstance(keys, (list, tuple)):
keys = [keys] keys = [keys]
self._check_dropped() self._check_dropped()
ensure = self.database.ensure_schema if ensure is None else ensure
if ensure: if ensure:
self.create_index(keys) self.create_index(keys)
@ -220,7 +225,8 @@ class Table(object):
_type, self.table.name)) _type, self.table.name))
self.create_column(column, _type) self.create_column(column, _type)
def _args_to_clause(self, args, ensure=ensure_session, clauses=()): def _args_to_clause(self, args, ensure=None, clauses=()):
ensure = self.database.ensure_schema if ensure is None else ensure
if ensure: if ensure:
self._ensure_columns(args) self._ensure_columns(args)
clauses = list(clauses) clauses = list(clauses)