Merge pull request #298 from abmyii/master

Implement update_many, upsert_many and refactor for a 2x speed-up of insert_many
This commit is contained in:
Friedrich Lindenberg 2019-07-13 14:47:04 +02:00 committed by GitHub
commit 081cb5ec7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 105 additions and 13 deletions

View File

@ -3,7 +3,7 @@ import warnings
import threading import threading
from sqlalchemy.sql import and_, expression from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import bindparam, ClauseElement
from sqlalchemy.schema import Column, Index from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false from sqlalchemy import func, select, false
from sqlalchemy.schema import Table as SQLATable from sqlalchemy.schema import Table as SQLATable
@ -128,19 +128,28 @@ class Table(object):
rows = [dict(name='Dolly')] * 10000 rows = [dict(name='Dolly')] * 10000
table.insert_many(rows) table.insert_many(rows)
""" """
chunk = [] # Sync table before inputting rows.
sync_row = {}
for row in rows: for row in rows:
row = self._sync_columns(row, ensure, types=types) # Only get non-existing columns.
for key in set(row.keys()).difference(set(sync_row.keys())):
# Get a sample of the new column(s) from the row.
sync_row[key] = row[key]
self._sync_columns(sync_row, ensure, types=types)
# Get columns name list to be used for padding later.
columns = sync_row.keys()
chunk = []
for index, row in enumerate(rows):
chunk.append(row) chunk.append(row)
if len(chunk) == chunk_size:
chunk = pad_chunk_columns(chunk) # Insert when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
chunk = pad_chunk_columns(chunk, columns)
self.table.insert().execute(chunk) self.table.insert().execute(chunk)
chunk = [] chunk = []
if len(chunk):
chunk = pad_chunk_columns(chunk)
self.table.insert().execute(chunk)
def update(self, row, keys, ensure=None, types=None, return_count=False): def update(self, row, keys, ensure=None, types=None, return_count=False):
"""Update a row in the table. """Update a row in the table.
@ -169,6 +178,42 @@ class Table(object):
if return_count: if return_count:
return self.count(clause) return self.count(clause)
def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
"""Update many rows in the table at a time.
This is significantly faster than updating them one by one. Per default
the rows are processed in chunks of 1000 per commit, unless you specify
a different ``chunk_size``.
See :py:meth:`update() <dataset.Table.update>` for details on
the other parameters.
"""
# Convert keys to a list if not a list or tuple.
keys = keys if type(keys) in (list, tuple) else [keys]
chunk = []
columns = set()
for index, row in enumerate(rows):
chunk.append(row)
columns = columns.union(set(row.keys()))
# bindparam requires names to not conflict (cannot be "id" for id)
for key in keys:
row['_%s' % key] = row[key]
# Update when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
stmt = self.table.update(
whereclause=and_(
*[self.table.c[k] == bindparam('_%s' % k) for k in keys]
),
values={
col: bindparam(col, required=False) for col in columns
}
)
self.db.executable.execute(stmt, chunk)
chunk = []
def upsert(self, row, keys, ensure=None, types=None): def upsert(self, row, keys, ensure=None, types=None):
"""An UPSERT is a smart combination of insert and update. """An UPSERT is a smart combination of insert and update.
@ -187,6 +232,33 @@ class Table(object):
return self.insert(row, ensure=False) return self.insert(row, ensure=False)
return True return True
def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
"""
Sorts multiple input rows into upserts and inserts. Inserts are passed
to insert_many and upserts are updated.
See :py:meth:`upsert() <dataset.Table.upsert>` and
:py:meth:`insert_many() <dataset.Table.insert_many>`.
"""
# Convert keys to a list if not a list or tuple.
keys = keys if type(keys) in (list, tuple) else [keys]
to_insert = []
to_update = []
for row in rows:
if self.find_one(**{key: row.get(key) for key in keys}):
# Row exists - update it.
to_update.append(row)
else:
# Row doesn't exist - insert it.
to_insert.append(row)
# Insert non-existing rows.
self.insert_many(to_insert, chunk_size, ensure, types)
# Update existing rows.
self.update_many(to_update, keys, chunk_size, ensure, types)
def delete(self, *clauses, **filters): def delete(self, *clauses, **filters):
"""Delete rows from the table. """Delete rows from the table.

View File

@ -108,12 +108,9 @@ def ensure_tuple(obj):
return obj, return obj,
def pad_chunk_columns(chunk): def pad_chunk_columns(chunk, columns):
"""Given a set of items to be inserted, make sure they all have the """Given a set of items to be inserted, make sure they all have the
same columns by padding columns with None if they are missing.""" same columns by padding columns with None if they are missing."""
columns = set()
for record in chunk:
columns.update(record.keys())
for record in chunk: for record in chunk:
for column in columns: for column in columns:
record.setdefault(column, None) record.setdefault(column, None)

View File

@ -374,6 +374,29 @@ class TableTestCase(unittest.TestCase):
self.tbl.insert_many(data, chunk_size=13) self.tbl.insert_many(data, chunk_size=13)
assert len(self.tbl) == len(data) + 6 assert len(self.tbl) == len(data) + 6
def test_update_many(self):
tbl = self.db['update_many_test']
tbl.insert_many([
dict(temp=10), dict(temp=20), dict(temp=30)
])
tbl.update_many(
[dict(id=1, temp=50), dict(id=3, temp=50)], 'id'
)
# Ensure data has been updated.
assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp']
def test_upsert_many(self):
# Also tests updating on records with different attributes
tbl = self.db['upsert_many_test']
W = 100
tbl.upsert_many([dict(age=10), dict(weight=W)], 'id')
assert tbl.find_one(id=1)['age'] == 10
tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], 'id')
assert tbl.find_one(id=2)['weight'] == W / 2
def test_drop_operations(self): def test_drop_operations(self):
assert self.tbl._table is not None, \ assert self.tbl._table is not None, \
'table shouldn\'t be dropped yet' 'table shouldn\'t be dropped yet'