Merge pull request #314 from mynameisfiber/feat/chunked-callback

Add optional callback to ChunkedInsert
This commit is contained in:
Friedrich Lindenberg 2020-03-15 12:46:38 +01:00 committed by GitHub
commit c84ea316db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 35 additions and 3 deletions

View File

@ -1,23 +1,35 @@
class InvalidCallback(ValueError):
pass
class ChunkedInsert(object): class ChunkedInsert(object):
"""Batch up insert operations """Batch up insert operations
with ChunkedStorer(my_table) as storer: with ChunkedStorer(my_table) as storer:
table.insert(row) table.insert(row)
Rows will be inserted in groups of 1000 Rows will be inserted in groups of `chunksize` (defaulting to 1000). An
optional callback can be provided that will be called before the insert.
This callback takes one parameter which is the queue which is about to be
inserted into the database
""" """
def __init__(self, table, chunksize=1000): def __init__(self, table, chunksize=1000, callback=None):
self.queue = [] self.queue = []
self.fields = set() self.fields = set()
self.table = table self.table = table
self.chunksize = chunksize self.chunksize = chunksize
if callback and not callable(callback):
raise InvalidCallback
self.callback = callback
def flush(self): def flush(self):
for item in self.queue: for item in self.queue:
for field in self.fields: for field in self.fields:
item[field] = item.get(field) item[field] = item.get(field)
if self.callback is not None:
self.callback(self.queue)
self.table.insert_many(self.queue) self.table.insert_many(self.queue)
self.queue = [] self.queue = []

View File

@ -9,7 +9,7 @@ from datetime import datetime
from sqlalchemy import FLOAT, TEXT, BIGINT from sqlalchemy import FLOAT, TEXT, BIGINT
from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
from dataset import connect from dataset import connect, chunked
from .sample_data import TEST_DATA, TEST_CITY_1 from .sample_data import TEST_DATA, TEST_CITY_1
@ -394,6 +394,26 @@ class TableTestCase(unittest.TestCase):
self.tbl.insert_many(data, chunk_size=13) self.tbl.insert_many(data, chunk_size=13)
assert len(self.tbl) == len(data) + 6 assert len(self.tbl) == len(data) + 6
def test_chunked_insert(self):
data = TEST_DATA * 100
with chunked.ChunkedInsert(self.tbl) as chunk_tbl:
for item in data:
chunk_tbl.insert(item)
assert len(self.tbl) == len(data) + 6
def test_chunked_insert_callback(self):
data = TEST_DATA * 100
N = 0
def callback(queue):
nonlocal N
N += len(queue)
with chunked.ChunkedInsert(self.tbl, callback=callback) as chunk_tbl:
for item in data:
chunk_tbl.insert(item)
assert len(data) == N
assert len(self.tbl) == len(data) + 6
def test_update_many(self): def test_update_many(self):
tbl = self.db['update_many_test'] tbl = self.db['update_many_test']
tbl.insert_many([ tbl.insert_many([