Add optional callback to ChunkedInsert

This commit adds an optional `callback` argument to the `ChunkedInsert`
object. This callback is a callable object which gets called before the
chunked insert happens. This is useful for clearing any local caches
that may be in place to deal with the eventual consistency resulting
from the delayed nature of the chunked inserts.

For example,

```
cache = set()
chunked_table = ChunkedInsert(table, callback=lambda queue: cache.clear())
while True:
    data = get_data_id()
    key = data['key']
    if key in cache or table.find_one(key=key)
        continue
    cache.add(key)
    chunked_table.insert(data)
```
This commit is contained in:
Micha Gorelick 2020-03-13 14:29:03 +01:00
parent 73ff374513
commit a9a15966e2
2 changed files with 34 additions and 3 deletions

View File

@ -1,23 +1,35 @@
class InvalidCallback(ValueError):
pass
class ChunkedInsert(object): class ChunkedInsert(object):
"""Batch up insert operations """Batch up insert operations
with ChunkedStorer(my_table) as storer: with ChunkedStorer(my_table) as storer:
table.insert(row) table.insert(row)
Rows will be inserted in groups of 1000 Rows will be inserted in groups of `chunksize` (defaulting to 1000). An
optional callback can be provided that will be called before the insert.
This callback takes one parameter which is the queue which is about to be
inserted into the database
""" """
def __init__(self, table, chunksize=1000): def __init__(self, table, chunksize=1000, callback=None):
self.queue = [] self.queue = []
self.fields = set() self.fields = set()
self.table = table self.table = table
self.chunksize = chunksize self.chunksize = chunksize
if callback and not callable(callback):
raise InvalidCallback
self.callback = callback
def flush(self): def flush(self):
for item in self.queue: for item in self.queue:
for field in self.fields: for field in self.fields:
item[field] = item.get(field) item[field] = item.get(field)
if self.callback is not None:
self.callback(self.queue)
self.table.insert_many(self.queue) self.table.insert_many(self.queue)
self.queue = [] self.queue = []

View File

@ -9,7 +9,7 @@ from datetime import datetime
from sqlalchemy import FLOAT, TEXT, BIGINT from sqlalchemy import FLOAT, TEXT, BIGINT
from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
from dataset import connect from dataset import connect, chunked
from .sample_data import TEST_DATA, TEST_CITY_1 from .sample_data import TEST_DATA, TEST_CITY_1
@ -394,6 +394,25 @@ class TableTestCase(unittest.TestCase):
self.tbl.insert_many(data, chunk_size=13) self.tbl.insert_many(data, chunk_size=13)
assert len(self.tbl) == len(data) + 6 assert len(self.tbl) == len(data) + 6
def test_chunked_insert(self):
data = TEST_DATA * 100
with chunked.ChunkedInsert(self.tbl) as chunk_tbl:
for item in data:
chunk_tbl.insert(item)
assert len(self.tbl) == len(data) + 6
def test_chunked_insert_callback(self):
data = TEST_DATA * 100
N = 0
def callback(queue):
nonlocal N
N += len(queue)
with chunked.ChunkedInsert(self.tbl, callback=callback) as chunk_tbl:
for item in data:
chunk_tbl.insert(item)
assert len(data) == N
assert len(self.tbl) == len(data) + 6
def test_update_many(self): def test_update_many(self):
tbl = self.db['update_many_test'] tbl = self.db['update_many_test']
tbl.insert_many([ tbl.insert_many([