Merge pull request #298 from abmyii/master

Implement update_many, upsert_many and refactor for a 2x speed-up of insert_many
This commit is contained in:
Friedrich Lindenberg 2019-07-13 14:47:04 +02:00 committed by GitHub
commit 081cb5ec7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 13 deletions

View File

@ -3,7 +3,7 @@ import warnings
import threading
from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.sql.expression import bindparam, ClauseElement
from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false
from sqlalchemy.schema import Table as SQLATable
@ -128,19 +128,28 @@ class Table(object):
rows = [dict(name='Dolly')] * 10000
table.insert_many(rows)
"""
chunk = []
# Sync table before inputting rows.
sync_row = {}
for row in rows:
row = self._sync_columns(row, ensure, types=types)
# Only get non-existing columns.
for key in set(row.keys()).difference(set(sync_row.keys())):
# Get a sample of the new column(s) from the row.
sync_row[key] = row[key]
self._sync_columns(sync_row, ensure, types=types)
# Get columns name list to be used for padding later.
columns = sync_row.keys()
chunk = []
for index, row in enumerate(rows):
chunk.append(row)
if len(chunk) == chunk_size:
chunk = pad_chunk_columns(chunk)
# Insert when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
chunk = pad_chunk_columns(chunk, columns)
self.table.insert().execute(chunk)
chunk = []
if len(chunk):
chunk = pad_chunk_columns(chunk)
self.table.insert().execute(chunk)
def update(self, row, keys, ensure=None, types=None, return_count=False):
"""Update a row in the table.
@ -169,6 +178,42 @@ class Table(object):
if return_count:
return self.count(clause)
def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
"""Update many rows in the table at a time.
This is significantly faster than updating them one by one. Per default
the rows are processed in chunks of 1000 per commit, unless you specify
a different ``chunk_size``.
See :py:meth:`update() <dataset.Table.update>` for details on
the other parameters.
"""
# Convert keys to a list if not a list or tuple.
keys = keys if type(keys) in (list, tuple) else [keys]
chunk = []
columns = set()
for index, row in enumerate(rows):
chunk.append(row)
columns = columns.union(set(row.keys()))
# bindparam requires names to not conflict (cannot be "id" for id)
for key in keys:
row['_%s' % key] = row[key]
# Update when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
stmt = self.table.update(
whereclause=and_(
*[self.table.c[k] == bindparam('_%s' % k) for k in keys]
),
values={
col: bindparam(col, required=False) for col in columns
}
)
self.db.executable.execute(stmt, chunk)
chunk = []
def upsert(self, row, keys, ensure=None, types=None):
"""An UPSERT is a smart combination of insert and update.
@ -187,6 +232,33 @@ class Table(object):
return self.insert(row, ensure=False)
return True
def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
"""
Sorts multiple input rows into upserts and inserts. Inserts are passed
to insert_many and upserts are updated.
See :py:meth:`upsert() <dataset.Table.upsert>` and
:py:meth:`insert_many() <dataset.Table.insert_many>`.
"""
# Convert keys to a list if not a list or tuple.
keys = keys if type(keys) in (list, tuple) else [keys]
to_insert = []
to_update = []
for row in rows:
if self.find_one(**{key: row.get(key) for key in keys}):
# Row exists - update it.
to_update.append(row)
else:
# Row doesn't exist - insert it.
to_insert.append(row)
# Insert non-existing rows.
self.insert_many(to_insert, chunk_size, ensure, types)
# Update existing rows.
self.update_many(to_update, keys, chunk_size, ensure, types)
def delete(self, *clauses, **filters):
"""Delete rows from the table.

View File

@ -108,12 +108,9 @@ def ensure_tuple(obj):
return obj,
def pad_chunk_columns(chunk):
def pad_chunk_columns(chunk, columns):
"""Given a set of items to be inserted, make sure they all have the
same columns by padding columns with None if they are missing."""
columns = set()
for record in chunk:
columns.update(record.keys())
for record in chunk:
for column in columns:
record.setdefault(column, None)

View File

@ -374,6 +374,29 @@ class TableTestCase(unittest.TestCase):
self.tbl.insert_many(data, chunk_size=13)
assert len(self.tbl) == len(data) + 6
def test_update_many(self):
tbl = self.db['update_many_test']
tbl.insert_many([
dict(temp=10), dict(temp=20), dict(temp=30)
])
tbl.update_many(
[dict(id=1, temp=50), dict(id=3, temp=50)], 'id'
)
# Ensure data has been updated.
assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp']
def test_upsert_many(self):
# Also tests updating on records with different attributes
tbl = self.db['upsert_many_test']
W = 100
tbl.upsert_many([dict(age=10), dict(weight=W)], 'id')
assert tbl.find_one(id=1)['age'] == 10
tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], 'id')
assert tbl.find_one(id=2)['weight'] == W / 2
def test_drop_operations(self):
assert self.tbl._table is not None, \
'table shouldn\'t be dropped yet'