Merge pull request #315 from mynameisfiber/feat/chunked-update

Added ChunkedUpdate to leverage `update_many`
This commit is contained in:
Friedrich Lindenberg 2020-03-15 12:48:02 +01:00 committed by GitHub
commit 0600f3eb43
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 77 additions and 22 deletions

View File

@ -1,23 +1,13 @@
import itertools
class InvalidCallback(ValueError): class InvalidCallback(ValueError):
pass pass
class ChunkedInsert(object): class _Chunker(object):
"""Batch up insert operations def __init__(self, table, chunksize, callback):
with ChunkedStorer(my_table) as storer:
table.insert(row)
Rows will be inserted in groups of `chunksize` (defaulting to 1000). An
optional callback can be provided that will be called before the insert.
This callback takes one parameter which is the queue which is about to be
inserted into the database
"""
def __init__(self, table, chunksize=1000, callback=None):
self.queue = [] self.queue = []
self.fields = set()
self.table = table self.table = table
self.chunksize = chunksize self.chunksize = chunksize
if callback and not callable(callback): if callback and not callable(callback):
@ -25,16 +15,9 @@ class ChunkedInsert(object):
self.callback = callback self.callback = callback
def flush(self): def flush(self):
for item in self.queue: self.queue.clear()
for field in self.fields:
item[field] = item.get(field)
if self.callback is not None:
self.callback(self.queue)
self.table.insert_many(self.queue)
self.queue = []
def insert(self, item): def _queue_add(self, item):
self.fields.update(item.keys())
self.queue.append(item) self.queue.append(item)
if len(self.queue) >= self.chunksize: if len(self.queue) >= self.chunksize:
self.flush() self.flush()
@ -44,3 +27,59 @@ class ChunkedInsert(object):
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
self.flush() self.flush()
class ChunkedInsert(_Chunker):
"""Batch up insert operations
with ChunkedInsert(my_table) as inserter:
inserter(row)
Rows will be inserted in groups of `chunksize` (defaulting to 1000). An
optional callback can be provided that will be called before the insert.
This callback takes one parameter which is the queue which is about to be
inserted into the database
"""
def __init__(self, table, chunksize=1000, callback=None):
self.fields = set()
super().__init__(table, chunksize, callback)
def insert(self, item):
self.fields.update(item.keys())
super()._queue_add(item)
def flush(self):
for item in self.queue:
for field in self.fields:
item[field] = item.get(field)
if self.callback is not None:
self.callback(self.queue)
self.table.insert_many(self.queue)
super().flush()
class ChunkedUpdate(_Chunker):
"""Batch up update operations
with ChunkedUpdate(my_table) as updater:
updater(row)
Rows will be updated in groups of `chunksize` (defaulting to 1000). An
optional callback can be provided that will be called before the update.
This callback takes one parameter which is the queue which is about to be
updated into the database
"""
def __init__(self, table, keys, chunksize=1000, callback=None):
self.keys = keys
super().__init__(table, chunksize, callback)
def update(self, item):
super()._queue_add(item)
def flush(self):
if self.callback is not None:
self.callback(self.queue)
self.queue.sort(key=dict.keys)
for fields, items in itertools.groupby(self.queue, key=dict.keys):
self.table.update_many(list(items), self.keys)
super().flush()

View File

@ -426,6 +426,22 @@ class TableTestCase(unittest.TestCase):
# Ensure data has been updated. # Ensure data has been updated.
assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp'] assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp']
def test_chunked_update(self):
tbl = self.db['update_many_test']
tbl.insert_many([
dict(temp=10, location='asdf'), dict(temp=20, location='qwer'), dict(temp=30, location='asdf')
])
chunked_tbl = chunked.ChunkedUpdate(tbl, 'id')
chunked_tbl.update(dict(id=1, temp=50))
chunked_tbl.update(dict(id=2, location='asdf'))
chunked_tbl.update(dict(id=3, temp=50))
chunked_tbl.flush()
# Ensure data has been updated.
assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp'] == 50
assert tbl.find_one(id=2)['location'] == tbl.find_one(id=3)['location'] == 'asdf'
def test_upsert_many(self): def test_upsert_many(self):
# Also tests updating on records with different attributes # Also tests updating on records with different attributes
tbl = self.db['upsert_many_test'] tbl = self.db['upsert_many_test']