From 70874a2501c91304731b6393b6f7318008954079 Mon Sep 17 00:00:00 2001 From: Abdurrahmaan Iqbal Date: Mon, 8 Jul 2019 17:48:05 +0100 Subject: [PATCH 1/8] Implement update_many and upsert_many --- dataset/table.py | 75 +++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 74 insertions(+), 1 deletion(-) diff --git a/dataset/table.py b/dataset/table.py index bf829df..668e9be 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -3,7 +3,7 @@ import warnings import threading from sqlalchemy.sql import and_, expression -from sqlalchemy.sql.expression import ClauseElement +from sqlalchemy.sql.expression import bindparam, ClauseElement from sqlalchemy.schema import Column, Index from sqlalchemy import func, select, false from sqlalchemy.schema import Table as SQLATable @@ -163,6 +163,51 @@ class Table(object): if return_count: return self.count(clause) + def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): + """Update many rows in the table at a time. + + This is significantly faster than updating them one by one. Per default + the rows are processed in chunks of 1000 per commit, unless you specify + a different ``chunk_size``. + + See :py:meth:`update() ` for details on + the other parameters. + """ + chunk = [] + columns = set() + for row in rows: + chunk.append(row) + columns = columns.union(set(row.keys())) + + # bindparam requires names to not conflict (cannot be "id" for id) + for key in keys: + row[f'_{key}'] = row[f'{key}'] + + if len(chunk) == chunk_size: + stmt = self.table.update( + whereclause=and_( + *[self.table.c[key] == bindparam(f'_{key}') for key in keys] + ), + values={ + column: bindparam(column, required=False) for column in columns + } + ) + self.db.executable.execute(stmt, chunk) + chunk = [] + columns = set() + + if len(chunk): + stmt = self.table.update( + whereclause=and_( + *[self.table.c[key] == bindparam(f'_{key}') for key in keys] + ), + values={ + column: bindparam(column, required=False) for column in columns + } + ) + self.db.executable.execute(stmt, chunk) + + def upsert(self, row, keys, ensure=None, types=None): """An UPSERT is a smart combination of insert and update. @@ -181,6 +226,34 @@ class Table(object): return self.insert(row, ensure=False) return True + def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, types=None): + """ + Sorts multiple input rows into upserts and inserts. Inserts are passed + to insert_many and upserts are updated. + + See :py:meth:`upsert() ` and + :py:meth:`insert_many() `. + """ + # Convert keys to a list if not a list or tuple. + keys = keys if type(keys) in (list, tuple) else [keys] + + to_insert = [] + to_update = [] + for row in rows: + if self.find_one(**{key: row.get(key) for key in keys}): + # Row exists - update it. + to_update.append(row) + else: + # Row doesn't exist - insert it. + to_insert.append(row) + + # Insert non-existing rows. + self.insert_many(to_insert, chunk_size, ensure, types) + + # Update existing rows. + self.update_many(to_update, keys, chunk_size, ensure, types) + + def delete(self, *clauses, **filters): """Delete rows from the table. From 85d974b0c3752b87468039969a22d8e11051d2f9 Mon Sep 17 00:00:00 2001 From: Abdurrahmaan Iqbal Date: Mon, 8 Jul 2019 17:52:06 +0100 Subject: [PATCH 2/8] Refactor to remove duplicate code --- dataset/table.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/dataset/table.py b/dataset/table.py index 668e9be..f85f9fa 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -173,6 +173,9 @@ class Table(object): See :py:meth:`update() ` for details on the other parameters. """ + # Convert keys to a list if not a list or tuple. + keys = keys if type(keys) in (list, tuple) else [keys] + chunk = [] columns = set() for row in rows: @@ -183,7 +186,8 @@ class Table(object): for key in keys: row[f'_{key}'] = row[f'{key}'] - if len(chunk) == chunk_size: + # Update when chunksize is fulfilled or this is the last row + if len(chunk) == chunk_size or rows.index(row) == len(rows)-1: stmt = self.table.update( whereclause=and_( *[self.table.c[key] == bindparam(f'_{key}') for key in keys] @@ -196,18 +200,6 @@ class Table(object): chunk = [] columns = set() - if len(chunk): - stmt = self.table.update( - whereclause=and_( - *[self.table.c[key] == bindparam(f'_{key}') for key in keys] - ), - values={ - column: bindparam(column, required=False) for column in columns - } - ) - self.db.executable.execute(stmt, chunk) - - def upsert(self, row, keys, ensure=None, types=None): """An UPSERT is a smart combination of insert and update. @@ -253,7 +245,6 @@ class Table(object): # Update existing rows. self.update_many(to_update, keys, chunk_size, ensure, types) - def delete(self, *clauses, **filters): """Delete rows from the table. From a9f3eb86b26d47393db02ab9ba24166473144172 Mon Sep 17 00:00:00 2001 From: Abdurrahmaan Iqbal Date: Mon, 8 Jul 2019 18:10:00 +0100 Subject: [PATCH 3/8] Add tests for new functions --- test/test_dataset.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/test_dataset.py b/test/test_dataset.py index 99521ce..a40cef2 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -368,6 +368,29 @@ class TableTestCase(unittest.TestCase): self.tbl.insert_many(data, chunk_size=13) assert len(self.tbl) == len(data) + 6 + def test_update_many(self): + tbl = self.db['update_many_test'] + tbl.insert_many([ + dict(temp=10), dict(temp=20), dict(temp=30) + ]) + tbl.update_many( + [dict(id=1, temp=50), dict(id=3, temp=50)], 'id' + ) + + # Ensure data has been updated. + assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp'] + + def test_upsert_many(self): + # Also tests updating on records with different attributes + tbl = self.db['upsert_many_test'] + + W = 100 + tbl.upsert_many([dict(age=10), dict(weight=W)], 'id') + assert tbl.find_one(id=1)['age'] == 10 + + tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W/2)], 'id') + assert tbl.find_one(id=2)['weight'] == W/2 + def test_drop_operations(self): assert self.tbl._table is not None, \ 'table shouldn\'t be dropped yet' From 76b61651816c1f2ae9c3b2abe617caf391509e68 Mon Sep 17 00:00:00 2001 From: Abdurrahmaan Iqbal Date: Mon, 8 Jul 2019 23:08:30 +0100 Subject: [PATCH 4/8] Speed up insert_many by sync columns before input, not on the go --- dataset/table.py | 18 ++++++++++++++---- dataset/util.py | 5 +---- 2 files changed, 15 insertions(+), 8 deletions(-) diff --git a/dataset/table.py b/dataset/table.py index f85f9fa..3ef7d35 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -122,17 +122,28 @@ class Table(object): rows = [dict(name='Dolly')] * 10000 table.insert_many(rows) """ + # Sync table before inputting rows. + sync_row = {} + for row in rows: + # Only get non-existing columns. + for key in set(row.keys()).difference(set(sync_row.keys())): + # Get a sample of the new column(s) from the row. + sync_row[key] = row[key] + self._sync_columns(sync_row, ensure, types=types) + + # Get columns name list to be used for padding later. + columns = sync_row.keys() + chunk = [] for row in rows: - row = self._sync_columns(row, ensure, types=types) chunk.append(row) if len(chunk) == chunk_size: - chunk = pad_chunk_columns(chunk) + chunk = pad_chunk_columns(chunk, columns) self.table.insert().execute(chunk) chunk = [] if len(chunk): - chunk = pad_chunk_columns(chunk) + chunk = pad_chunk_columns(chunk, columns) self.table.insert().execute(chunk) def update(self, row, keys, ensure=None, types=None, return_count=False): @@ -198,7 +209,6 @@ class Table(object): ) self.db.executable.execute(stmt, chunk) chunk = [] - columns = set() def upsert(self, row, keys, ensure=None, types=None): """An UPSERT is a smart combination of insert and update. diff --git a/dataset/util.py b/dataset/util.py index 60e129c..b79a607 100644 --- a/dataset/util.py +++ b/dataset/util.py @@ -108,12 +108,9 @@ def ensure_tuple(obj): return obj, -def pad_chunk_columns(chunk): +def pad_chunk_columns(chunk, columns): """Given a set of items to be inserted, make sure they all have the same columns by padding columns with None if they are missing.""" - columns = set() - for record in chunk: - columns.update(record.keys()) for record in chunk: for column in columns: record.setdefault(column, None) From 7fd9241f258d8266eff0e359dcf09ad534de14d4 Mon Sep 17 00:00:00 2001 From: Abdurrahmaan Iqbal Date: Tue, 9 Jul 2019 09:37:48 +0100 Subject: [PATCH 5/8] Refactor input_many to remove duplicate code, fix some pep8 problems --- dataset/table.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/dataset/table.py b/dataset/table.py index 3ef7d35..031ede2 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -135,17 +135,15 @@ class Table(object): columns = sync_row.keys() chunk = [] - for row in rows: + for index, row in enumerate(rows): chunk.append(row) - if len(chunk) == chunk_size: + + # Insert when chunk_size is fulfilled or this is the last row + if len(chunk) == chunk_size or index == len(row) - 1: chunk = pad_chunk_columns(chunk, columns) self.table.insert().execute(chunk) chunk = [] - if len(chunk): - chunk = pad_chunk_columns(chunk, columns) - self.table.insert().execute(chunk) - def update(self, row, keys, ensure=None, types=None, return_count=False): """Update a row in the table. @@ -189,7 +187,7 @@ class Table(object): chunk = [] columns = set() - for row in rows: + for index, row in enumerate(rows): chunk.append(row) columns = columns.union(set(row.keys())) @@ -197,14 +195,14 @@ class Table(object): for key in keys: row[f'_{key}'] = row[f'{key}'] - # Update when chunksize is fulfilled or this is the last row - if len(chunk) == chunk_size or rows.index(row) == len(rows)-1: + # Update when chunk_size is fulfilled or this is the last row + if len(chunk) == chunk_size or index == len(rows) - 1: stmt = self.table.update( whereclause=and_( - *[self.table.c[key] == bindparam(f'_{key}') for key in keys] + *[self.table.c[k] == bindparam(f'_{k}') for k in keys] ), values={ - column: bindparam(column, required=False) for column in columns + col: bindparam(col, required=False) for col in columns } ) self.db.executable.execute(stmt, chunk) From 82c6cdc990fadf0b169adc6815e808b0a5c13012 Mon Sep 17 00:00:00 2001 From: Abdurrahmaan Iqbal Date: Tue, 9 Jul 2019 09:41:21 +0100 Subject: [PATCH 6/8] Add whitespace around arithmetic operator so flake8 test passes --- test/test_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/test_dataset.py b/test/test_dataset.py index a40cef2..2c4cacb 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -388,8 +388,8 @@ class TableTestCase(unittest.TestCase): tbl.upsert_many([dict(age=10), dict(weight=W)], 'id') assert tbl.find_one(id=1)['age'] == 10 - tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W/2)], 'id') - assert tbl.find_one(id=2)['weight'] == W/2 + tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], 'id') + assert tbl.find_one(id=2)['weight'] == W / 2 def test_drop_operations(self): assert self.tbl._table is not None, \ From 68748895916bf79e5bbfbfd32cc728cc160c9b37 Mon Sep 17 00:00:00 2001 From: Abdurrahmaan Iqbal Date: Tue, 9 Jul 2019 09:45:18 +0100 Subject: [PATCH 7/8] Fix logic error in insert_many --- dataset/table.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataset/table.py b/dataset/table.py index 031ede2..e8bc9c8 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -139,7 +139,7 @@ class Table(object): chunk.append(row) # Insert when chunk_size is fulfilled or this is the last row - if len(chunk) == chunk_size or index == len(row) - 1: + if len(chunk) == chunk_size or index == len(rows) - 1: chunk = pad_chunk_columns(chunk, columns) self.table.insert().execute(chunk) chunk = [] From 19a73759cafc359a5337fac9d61233fa59c4c637 Mon Sep 17 00:00:00 2001 From: Abdurrahmaan Iqbal Date: Tue, 9 Jul 2019 09:50:36 +0100 Subject: [PATCH 8/8] Remove f-string for wider compatiblity --- dataset/table.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dataset/table.py b/dataset/table.py index e8bc9c8..ba7018f 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -193,13 +193,13 @@ class Table(object): # bindparam requires names to not conflict (cannot be "id" for id) for key in keys: - row[f'_{key}'] = row[f'{key}'] + row['_%s' % key] = row[key] # Update when chunk_size is fulfilled or this is the last row if len(chunk) == chunk_size or index == len(rows) - 1: stmt = self.table.update( whereclause=and_( - *[self.table.c[k] == bindparam(f'_{k}') for k in keys] + *[self.table.c[k] == bindparam('_%s' % k) for k in keys] ), values={ col: bindparam(col, required=False) for col in columns