diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index 068a704..cb4cd85 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -79,6 +79,31 @@ class Table(object): if len(res.inserted_primary_key) > 0: return res.inserted_primary_key[0] + def insert_ignore(self, row, keys, ensure=None, types={}): + """ + Add a row (type: dict) into the table if the row does not exist. + + If rows with matching ``keys`` exist they will be added to the table. + + Setting ``ensure`` results in automatically creating missing columns, + i.e., keys of the row are not table columns. + + During column creation, ``types`` will be checked for a key + matching the name of a column to be created, and the given + SQLAlchemy column type will be used. Otherwise, the type is + guessed from the row value, defaulting to a simple unicode + field. + :: + + data = dict(id=10, title='I am a banana!') + table.insert_ignore(data, ['id']) + """ + res = self._upsert_pre_check(row, keys, ensure) + if res is None: + return self.insert(row, ensure=ensure, types=types) + else: + return False + def insert_many(self, rows, chunk_size=1000, ensure=None, types={}): """ Add many rows at a time. @@ -156,6 +181,27 @@ class Table(object): except KeyError: return 0 + def _upsert_pre_check(self, row, keys, ensure): + # check whether keys arg is a string and format as a list + try: + if not isinstance(keys, (list, tuple)): + keys = [keys] + self._check_dropped() + + ensure = self.database.ensure_schema if ensure is None else ensure + if ensure: + self.create_index(keys) + + filters = {} + for key in keys: + filters[key] = row.get(key) + + res = self.find_one(**filters) + except: + res = None + + return res + def upsert(self, row, keys, ensure=None, types={}): """ An UPSERT is a smart combination of insert and update. @@ -167,33 +213,17 @@ class Table(object): data = dict(id=10, title='I am a banana!') table.upsert(data, ['id']) """ - # check whether keys arg is a string and format as a list - if not isinstance(keys, (list, tuple)): - keys = [keys] - self._check_dropped() - - ensure = self.database.ensure_schema if ensure is None else ensure - if ensure: - self.create_index(keys) - - filters = {} - for key in keys: - filters[key] = row.get(key) - - res = self.find_one(**filters) - if res is not None: - row_count = self.update(row, keys, ensure=ensure, types=types) - if row_count == 0: - return False - elif row_count == 1: - try: - return res['id'] - except KeyError: - return True - else: - return True - else: + res = self._upsert_pre_check(row, keys, ensure) + if res is None: return self.insert(row, ensure=ensure, types=types) + else: + row_count = self.update(row, keys, ensure=ensure, types=types) + try: + result = (row_count > 0, res['id'])[row_count == 1] + except KeyError: + result = row_count > 0 + + return result def delete(self, *_clauses, **_filter): """ diff --git a/docs/api.rst b/docs/api.rst index 89d4416..5770e59 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -16,7 +16,7 @@ Table ----- .. autoclass:: dataset.Table - :members: columns, find, find_one, all, count, distinct, insert, insert_many, update, upsert, delete, create_column, drop_column, create_index, drop + :members: columns, find, find_one, all, count, distinct, insert, insert_ignore, insert_many, update, upsert, delete, create_column, drop_column, create_index, drop :special-members: __len__, __iter__ diff --git a/test/test_persistence.py b/test/test_persistence.py index db2836c..9dac6be 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -170,6 +170,31 @@ class TableTestCase(unittest.TestCase): assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert self.tbl.find_one(id=last_id)['place'] == 'Berlin' + def test_insert_ignore(self): + self.tbl.insert_ignore({ + 'date': datetime(2011, 1, 2), + 'temperature': -10, + 'place': 'Berlin'}, + ['place'] + ) + assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) + self.tbl.insert_ignore({ + 'date': datetime(2011, 1, 2), + 'temperature': -10, + 'place': 'Berlin'}, + ['place'] + ) + assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) + + def test_insert_ignore_all_key(self): + for i in range(0, 2): + self.tbl.insert_ignore({ + 'date': datetime(2011, 1, 2), + 'temperature': -10, + 'place': 'Berlin'}, + ['date', 'temperature', 'place'] + ) + def test_upsert(self): self.tbl.upsert({ 'date': datetime(2011, 1, 2),