diff --git a/dataset/chunked.py b/dataset/chunked.py index b6820ce..312b045 100644 --- a/dataset/chunked.py +++ b/dataset/chunked.py @@ -1,23 +1,35 @@ +class InvalidCallback(ValueError): + pass + + class ChunkedInsert(object): """Batch up insert operations with ChunkedStorer(my_table) as storer: table.insert(row) - Rows will be inserted in groups of 1000 + Rows will be inserted in groups of `chunksize` (defaulting to 1000). An + optional callback can be provided that will be called before the insert. + This callback takes one parameter which is the queue which is about to be + inserted into the database """ - def __init__(self, table, chunksize=1000): + def __init__(self, table, chunksize=1000, callback=None): self.queue = [] self.fields = set() self.table = table self.chunksize = chunksize + if callback and not callable(callback): + raise InvalidCallback + self.callback = callback def flush(self): for item in self.queue: for field in self.fields: item[field] = item.get(field) + if self.callback is not None: + self.callback(self.queue) self.table.insert_many(self.queue) self.queue = [] diff --git a/test/test_dataset.py b/test/test_dataset.py index ec38e2e..9d5b46e 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -9,7 +9,7 @@ from datetime import datetime from sqlalchemy import FLOAT, TEXT, BIGINT from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError -from dataset import connect +from dataset import connect, chunked from .sample_data import TEST_DATA, TEST_CITY_1 @@ -394,6 +394,25 @@ class TableTestCase(unittest.TestCase): self.tbl.insert_many(data, chunk_size=13) assert len(self.tbl) == len(data) + 6 + def test_chunked_insert(self): + data = TEST_DATA * 100 + with chunked.ChunkedInsert(self.tbl) as chunk_tbl: + for item in data: + chunk_tbl.insert(item) + assert len(self.tbl) == len(data) + 6 + + def test_chunked_insert_callback(self): + data = TEST_DATA * 100 + N = 0 + def callback(queue): + nonlocal N + N += len(queue) + with chunked.ChunkedInsert(self.tbl, callback=callback) as chunk_tbl: + for item in data: + chunk_tbl.insert(item) + assert len(data) == N + assert len(self.tbl) == len(data) + 6 + def test_update_many(self): tbl = self.db['update_many_test'] tbl.insert_many([