diff --git a/dataset/table.py b/dataset/table.py index 88c5c8a..e21d251 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -3,7 +3,7 @@ import warnings import threading from sqlalchemy.sql import and_, expression -from sqlalchemy.sql.expression import ClauseElement +from sqlalchemy.sql.expression import bindparam, ClauseElement from sqlalchemy.schema import Column, Index from sqlalchemy import func, select, false from sqlalchemy.schema import Table as SQLATable @@ -128,19 +128,28 @@ class Table(object): rows = [dict(name='Dolly')] * 10000 table.insert_many(rows) """ - chunk = [] + # Sync table before inputting rows. + sync_row = {} for row in rows: - row = self._sync_columns(row, ensure, types=types) + # Only get non-existing columns. + for key in set(row.keys()).difference(set(sync_row.keys())): + # Get a sample of the new column(s) from the row. + sync_row[key] = row[key] + self._sync_columns(sync_row, ensure, types=types) + + # Get columns name list to be used for padding later. + columns = sync_row.keys() + + chunk = [] + for index, row in enumerate(rows): chunk.append(row) - if len(chunk) == chunk_size: - chunk = pad_chunk_columns(chunk) + + # Insert when chunk_size is fulfilled or this is the last row + if len(chunk) == chunk_size or index == len(rows) - 1: + chunk = pad_chunk_columns(chunk, columns) self.table.insert().execute(chunk) chunk = [] - if len(chunk): - chunk = pad_chunk_columns(chunk) - self.table.insert().execute(chunk) - def update(self, row, keys, ensure=None, types=None, return_count=False): """Update a row in the table. @@ -169,6 +178,42 @@ class Table(object): if return_count: return self.count(clause) + def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): + """Update many rows in the table at a time. + + This is significantly faster than updating them one by one. Per default + the rows are processed in chunks of 1000 per commit, unless you specify + a different ``chunk_size``. + + See :py:meth:`update() ` for details on + the other parameters. + """ + # Convert keys to a list if not a list or tuple. + keys = keys if type(keys) in (list, tuple) else [keys] + + chunk = [] + columns = set() + for index, row in enumerate(rows): + chunk.append(row) + columns = columns.union(set(row.keys())) + + # bindparam requires names to not conflict (cannot be "id" for id) + for key in keys: + row['_%s' % key] = row[key] + + # Update when chunk_size is fulfilled or this is the last row + if len(chunk) == chunk_size or index == len(rows) - 1: + stmt = self.table.update( + whereclause=and_( + *[self.table.c[k] == bindparam('_%s' % k) for k in keys] + ), + values={ + col: bindparam(col, required=False) for col in columns + } + ) + self.db.executable.execute(stmt, chunk) + chunk = [] + def upsert(self, row, keys, ensure=None, types=None): """An UPSERT is a smart combination of insert and update. @@ -187,6 +232,33 @@ class Table(object): return self.insert(row, ensure=False) return True + def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): + """ + Sorts multiple input rows into upserts and inserts. Inserts are passed + to insert_many and upserts are updated. + + See :py:meth:`upsert() ` and + :py:meth:`insert_many() `. + """ + # Convert keys to a list if not a list or tuple. + keys = keys if type(keys) in (list, tuple) else [keys] + + to_insert = [] + to_update = [] + for row in rows: + if self.find_one(**{key: row.get(key) for key in keys}): + # Row exists - update it. + to_update.append(row) + else: + # Row doesn't exist - insert it. + to_insert.append(row) + + # Insert non-existing rows. + self.insert_many(to_insert, chunk_size, ensure, types) + + # Update existing rows. + self.update_many(to_update, keys, chunk_size, ensure, types) + def delete(self, *clauses, **filters): """Delete rows from the table. diff --git a/dataset/util.py b/dataset/util.py index 60e129c..b79a607 100644 --- a/dataset/util.py +++ b/dataset/util.py @@ -108,12 +108,9 @@ def ensure_tuple(obj): return obj, -def pad_chunk_columns(chunk): +def pad_chunk_columns(chunk, columns): """Given a set of items to be inserted, make sure they all have the same columns by padding columns with None if they are missing.""" - columns = set() - for record in chunk: - columns.update(record.keys()) for record in chunk: for column in columns: record.setdefault(column, None) diff --git a/test/test_dataset.py b/test/test_dataset.py index 2f5fda3..a48f163 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -374,6 +374,29 @@ class TableTestCase(unittest.TestCase): self.tbl.insert_many(data, chunk_size=13) assert len(self.tbl) == len(data) + 6 + def test_update_many(self): + tbl = self.db['update_many_test'] + tbl.insert_many([ + dict(temp=10), dict(temp=20), dict(temp=30) + ]) + tbl.update_many( + [dict(id=1, temp=50), dict(id=3, temp=50)], 'id' + ) + + # Ensure data has been updated. + assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp'] + + def test_upsert_many(self): + # Also tests updating on records with different attributes + tbl = self.db['upsert_many_test'] + + W = 100 + tbl.upsert_many([dict(age=10), dict(weight=W)], 'id') + assert tbl.find_one(id=1)['age'] == 10 + + tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], 'id') + assert tbl.find_one(id=2)['weight'] == W / 2 + def test_drop_operations(self): assert self.tbl._table is not None, \ 'table shouldn\'t be dropped yet'