From a9a15966e2bc3b1399241b06b18bc06769d579fa Mon Sep 17 00:00:00 2001 From: Micha Gorelick Date: Fri, 13 Mar 2020 14:29:03 +0100 Subject: [PATCH] Add optional callback to ChunkedInsert This commit adds an optional `callback` argument to the `ChunkedInsert` object. This callback is a callable object which gets called before the chunked insert happens. This is useful for clearing any local caches that may be in place to deal with the eventual consistency resulting from the delayed nature of the chunked inserts. For example, ``` cache = set() chunked_table = ChunkedInsert(table, callback=lambda queue: cache.clear()) while True: data = get_data_id() key = data['key'] if key in cache or table.find_one(key=key) continue cache.add(key) chunked_table.insert(data) ``` --- dataset/chunked.py | 16 ++++++++++++++-- test/test_dataset.py | 21 ++++++++++++++++++++- 2 files changed, 34 insertions(+), 3 deletions(-) diff --git a/dataset/chunked.py b/dataset/chunked.py index b6820ce..312b045 100644 --- a/dataset/chunked.py +++ b/dataset/chunked.py @@ -1,23 +1,35 @@ +class InvalidCallback(ValueError): + pass + + class ChunkedInsert(object): """Batch up insert operations with ChunkedStorer(my_table) as storer: table.insert(row) - Rows will be inserted in groups of 1000 + Rows will be inserted in groups of `chunksize` (defaulting to 1000). An + optional callback can be provided that will be called before the insert. + This callback takes one parameter which is the queue which is about to be + inserted into the database """ - def __init__(self, table, chunksize=1000): + def __init__(self, table, chunksize=1000, callback=None): self.queue = [] self.fields = set() self.table = table self.chunksize = chunksize + if callback and not callable(callback): + raise InvalidCallback + self.callback = callback def flush(self): for item in self.queue: for field in self.fields: item[field] = item.get(field) + if self.callback is not None: + self.callback(self.queue) self.table.insert_many(self.queue) self.queue = [] diff --git a/test/test_dataset.py b/test/test_dataset.py index ec38e2e..9d5b46e 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -9,7 +9,7 @@ from datetime import datetime from sqlalchemy import FLOAT, TEXT, BIGINT from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError -from dataset import connect +from dataset import connect, chunked from .sample_data import TEST_DATA, TEST_CITY_1 @@ -394,6 +394,25 @@ class TableTestCase(unittest.TestCase): self.tbl.insert_many(data, chunk_size=13) assert len(self.tbl) == len(data) + 6 + def test_chunked_insert(self): + data = TEST_DATA * 100 + with chunked.ChunkedInsert(self.tbl) as chunk_tbl: + for item in data: + chunk_tbl.insert(item) + assert len(self.tbl) == len(data) + 6 + + def test_chunked_insert_callback(self): + data = TEST_DATA * 100 + N = 0 + def callback(queue): + nonlocal N + N += len(queue) + with chunked.ChunkedInsert(self.tbl, callback=callback) as chunk_tbl: + for item in data: + chunk_tbl.insert(item) + assert len(data) == N + assert len(self.tbl) == len(data) + 6 + def test_update_many(self): tbl = self.db['update_many_test'] tbl.insert_many([