Merge pull request #314 from mynameisfiber/feat/chunked-callback
Add optional callback to ChunkedInsert
This commit is contained in:
commit
c84ea316db
@ -1,23 +1,35 @@
|
||||
|
||||
|
||||
class InvalidCallback(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class ChunkedInsert(object):
|
||||
"""Batch up insert operations
|
||||
with ChunkedStorer(my_table) as storer:
|
||||
table.insert(row)
|
||||
|
||||
Rows will be inserted in groups of 1000
|
||||
Rows will be inserted in groups of `chunksize` (defaulting to 1000). An
|
||||
optional callback can be provided that will be called before the insert.
|
||||
This callback takes one parameter which is the queue which is about to be
|
||||
inserted into the database
|
||||
"""
|
||||
|
||||
def __init__(self, table, chunksize=1000):
|
||||
def __init__(self, table, chunksize=1000, callback=None):
|
||||
self.queue = []
|
||||
self.fields = set()
|
||||
self.table = table
|
||||
self.chunksize = chunksize
|
||||
if callback and not callable(callback):
|
||||
raise InvalidCallback
|
||||
self.callback = callback
|
||||
|
||||
def flush(self):
|
||||
for item in self.queue:
|
||||
for field in self.fields:
|
||||
item[field] = item.get(field)
|
||||
if self.callback is not None:
|
||||
self.callback(self.queue)
|
||||
self.table.insert_many(self.queue)
|
||||
self.queue = []
|
||||
|
||||
|
||||
@ -9,7 +9,7 @@ from datetime import datetime
|
||||
from sqlalchemy import FLOAT, TEXT, BIGINT
|
||||
from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
|
||||
|
||||
from dataset import connect
|
||||
from dataset import connect, chunked
|
||||
|
||||
from .sample_data import TEST_DATA, TEST_CITY_1
|
||||
|
||||
@ -394,6 +394,26 @@ class TableTestCase(unittest.TestCase):
|
||||
self.tbl.insert_many(data, chunk_size=13)
|
||||
assert len(self.tbl) == len(data) + 6
|
||||
|
||||
def test_chunked_insert(self):
|
||||
data = TEST_DATA * 100
|
||||
with chunked.ChunkedInsert(self.tbl) as chunk_tbl:
|
||||
for item in data:
|
||||
chunk_tbl.insert(item)
|
||||
assert len(self.tbl) == len(data) + 6
|
||||
|
||||
def test_chunked_insert_callback(self):
|
||||
data = TEST_DATA * 100
|
||||
N = 0
|
||||
|
||||
def callback(queue):
|
||||
nonlocal N
|
||||
N += len(queue)
|
||||
with chunked.ChunkedInsert(self.tbl, callback=callback) as chunk_tbl:
|
||||
for item in data:
|
||||
chunk_tbl.insert(item)
|
||||
assert len(data) == N
|
||||
assert len(self.tbl) == len(data) + 6
|
||||
|
||||
def test_update_many(self):
|
||||
tbl = self.db['update_many_test']
|
||||
tbl.insert_many([
|
||||
|
||||
Loading…
Reference in New Issue
Block a user