diff --git a/dataset/chunked.py b/dataset/chunked.py index 312b045..a5ca158 100644 --- a/dataset/chunked.py +++ b/dataset/chunked.py @@ -1,23 +1,13 @@ +import itertools class InvalidCallback(ValueError): pass -class ChunkedInsert(object): - """Batch up insert operations - with ChunkedStorer(my_table) as storer: - table.insert(row) - - Rows will be inserted in groups of `chunksize` (defaulting to 1000). An - optional callback can be provided that will be called before the insert. - This callback takes one parameter which is the queue which is about to be - inserted into the database - """ - - def __init__(self, table, chunksize=1000, callback=None): +class _Chunker(object): + def __init__(self, table, chunksize, callback): self.queue = [] - self.fields = set() self.table = table self.chunksize = chunksize if callback and not callable(callback): @@ -25,16 +15,9 @@ class ChunkedInsert(object): self.callback = callback def flush(self): - for item in self.queue: - for field in self.fields: - item[field] = item.get(field) - if self.callback is not None: - self.callback(self.queue) - self.table.insert_many(self.queue) - self.queue = [] + self.queue.clear() - def insert(self, item): - self.fields.update(item.keys()) + def _queue_add(self, item): self.queue.append(item) if len(self.queue) >= self.chunksize: self.flush() @@ -44,3 +27,59 @@ class ChunkedInsert(object): def __exit__(self, exc_type, exc_val, exc_tb): self.flush() + + +class ChunkedInsert(_Chunker): + """Batch up insert operations + with ChunkedInsert(my_table) as inserter: + inserter(row) + + Rows will be inserted in groups of `chunksize` (defaulting to 1000). An + optional callback can be provided that will be called before the insert. + This callback takes one parameter which is the queue which is about to be + inserted into the database + """ + + def __init__(self, table, chunksize=1000, callback=None): + self.fields = set() + super().__init__(table, chunksize, callback) + + def insert(self, item): + self.fields.update(item.keys()) + super()._queue_add(item) + + def flush(self): + for item in self.queue: + for field in self.fields: + item[field] = item.get(field) + if self.callback is not None: + self.callback(self.queue) + self.table.insert_many(self.queue) + super().flush() + + +class ChunkedUpdate(_Chunker): + """Batch up update operations + with ChunkedUpdate(my_table) as updater: + updater(row) + + Rows will be updated in groups of `chunksize` (defaulting to 1000). An + optional callback can be provided that will be called before the update. + This callback takes one parameter which is the queue which is about to be + updated into the database + """ + + def __init__(self, table, keys, chunksize=1000, callback=None): + self.keys = keys + super().__init__(table, chunksize, callback) + + def update(self, item): + super()._queue_add(item) + + def flush(self): + if self.callback is not None: + self.callback(self.queue) + self.queue.sort(key=dict.keys) + for fields, items in itertools.groupby(self.queue, key=dict.keys): + self.table.update_many(list(items), self.keys) + super().flush() diff --git a/test/test_dataset.py b/test/test_dataset.py index 8f24dd1..e5cecd6 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -426,6 +426,22 @@ class TableTestCase(unittest.TestCase): # Ensure data has been updated. assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp'] + def test_chunked_update(self): + tbl = self.db['update_many_test'] + tbl.insert_many([ + dict(temp=10, location='asdf'), dict(temp=20, location='qwer'), dict(temp=30, location='asdf') + ]) + + chunked_tbl = chunked.ChunkedUpdate(tbl, 'id') + chunked_tbl.update(dict(id=1, temp=50)) + chunked_tbl.update(dict(id=2, location='asdf')) + chunked_tbl.update(dict(id=3, temp=50)) + chunked_tbl.flush() + + # Ensure data has been updated. + assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp'] == 50 + assert tbl.find_one(id=2)['location'] == tbl.find_one(id=3)['location'] == 'asdf' + def test_upsert_many(self): # Also tests updating on records with different attributes tbl = self.db['upsert_many_test']