From 70874a2501c91304731b6393b6f7318008954079 Mon Sep 17 00:00:00 2001 From: Abdurrahmaan Iqbal Date: Mon, 8 Jul 2019 17:48:05 +0100 Subject: [PATCH] Implement update_many and upsert_many --- dataset/table.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/dataset/table.py b/dataset/table.py index bf829df..668e9be 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -3,7 +3,7 @@ import warnings import threading from sqlalchemy.sql import and_, expression -from sqlalchemy.sql.expression import ClauseElement +from sqlalchemy.sql.expression import bindparam, ClauseElement from sqlalchemy.schema import Column, Index from sqlalchemy import func, select, false from sqlalchemy.schema import Table as SQLATable @@ -163,6 +163,51 @@ class Table(object): if return_count: return self.count(clause) + def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): + """Update many rows in the table at a time. + + This is significantly faster than updating them one by one. Per default + the rows are processed in chunks of 1000 per commit, unless you specify + a different ``chunk_size``. + + See :py:meth:`update() ` for details on + the other parameters. + """ + chunk = [] + columns = set() + for row in rows: + chunk.append(row) + columns = columns.union(set(row.keys())) + + # bindparam requires names to not conflict (cannot be "id" for id) + for key in keys: + row[f'_{key}'] = row[f'{key}'] + + if len(chunk) == chunk_size: + stmt = self.table.update( + whereclause=and_( + *[self.table.c[key] == bindparam(f'_{key}') for key in keys] + ), + values={ + column: bindparam(column, required=False) for column in columns + } + ) + self.db.executable.execute(stmt, chunk) + chunk = [] + columns = set() + + if len(chunk): + stmt = self.table.update( + whereclause=and_( + *[self.table.c[key] == bindparam(f'_{key}') for key in keys] + ), + values={ + column: bindparam(column, required=False) for column in columns + } + ) + self.db.executable.execute(stmt, chunk) + + def upsert(self, row, keys, ensure=None, types=None): """An UPSERT is a smart combination of insert and update. @@ -181,6 +226,34 @@ class Table(object): return self.insert(row, ensure=False) return True + def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): + """ + Sorts multiple input rows into upserts and inserts. Inserts are passed + to insert_many and upserts are updated. + + See :py:meth:`upsert() ` and + :py:meth:`insert_many() `. + """ + # Convert keys to a list if not a list or tuple. + keys = keys if type(keys) in (list, tuple) else [keys] + + to_insert = [] + to_update = [] + for row in rows: + if self.find_one(**{key: row.get(key) for key in keys}): + # Row exists - update it. + to_update.append(row) + else: + # Row doesn't exist - insert it. + to_insert.append(row) + + # Insert non-existing rows. + self.insert_many(to_insert, chunk_size, ensure, types) + + # Update existing rows. + self.update_many(to_update, keys, chunk_size, ensure, types) + + def delete(self, *clauses, **filters): """Delete rows from the table.