From 72ecf561fe3f4b37ce2d4f95fc328d34aa67de72 Mon Sep 17 00:00:00 2001 From: Friedrich Lindenberg Date: Sun, 23 Feb 2020 19:13:34 +0100 Subject: [PATCH] Make column lookups properly case-insensitive, refs #310. --- dataset/table.py | 67 ++++++++++++++++++++++++++++++-------------- dataset/util.py | 7 +++++ test/test_dataset.py | 16 +++++++++++ 3 files changed, 69 insertions(+), 21 deletions(-) diff --git a/dataset/table.py b/dataset/table.py index 7371026..ca008e1 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -10,9 +10,10 @@ from sqlalchemy.schema import Table as SQLATable from sqlalchemy.exc import NoSuchTableError from dataset.types import Types -from dataset.util import normalize_column_name, index_name, ensure_tuple +from dataset.util import index_name, ensure_tuple from dataset.util import DatasetException, ResultIter, QUERY_STEP from dataset.util import normalize_table_name, pad_chunk_columns +from dataset.util import normalize_column_name, normalize_column_key log = logging.getLogger(__name__) @@ -28,6 +29,7 @@ class Table(object): self.db = database self.name = normalize_table_name(table_name) self._table = None + self._columns = None self._indexes = [] self._primary_id = primary_id if primary_id is not None \ else self.PRIMARY_DEFAULT @@ -49,17 +51,36 @@ class Table(object): self._sync_table(()) return self._table + @property + def _column_keys(self): + """Get a dictionary of all columns and their case mapping.""" + if not self.exists: + return {} + if self._columns is None: + self._columns = {} + for column in self.table.columns: + name = normalize_column_name(column.name) + key = normalize_column_key(name) + if key in self._columns: + log.warning("Duplicate column: %s", name) + self._columns[key] = name + return self._columns + @property def columns(self): """Get a listing of all columns that exist in the table.""" - if not self.exists: - return [] - return self.table.columns.keys() + return list(self._column_keys.values()) def has_column(self, column): """Check if a column with the given name exists on this table.""" - columns = [c.lower() for c in self.columns] - return normalize_column_name(column).lower() in columns + key = normalize_column_key(normalize_column_name(column)) + return key in self._column_keys + + def _get_column_name(self, name): + """Find the best column name with case-insensitive matching.""" + name = normalize_column_name(name) + key = normalize_column_key(name) + return self._column_keys.get(key, name) def insert(self, row, ensure=None, types=None): """Add a ``row`` dict by inserting it into the table. @@ -282,6 +303,7 @@ class Table(object): def _reflect_table(self): """Load the tables definition from the database.""" with self.db.lock: + self._columns = None try: self._table = SQLATable(self.name, self.db.metadata, @@ -299,6 +321,7 @@ class Table(object): def _sync_table(self, columns): """Lazy load, create or adapt the table structure in the database.""" + self._columns = None if self._table is None: # Load an existing table from the database. self._reflect_table() @@ -344,23 +367,22 @@ class Table(object): this will remove any keys from the ``row`` for which there is no matching column. """ - columns = [c.lower() for c in self.columns] ensure = self._check_ensure(ensure) types = types or {} - types = {normalize_column_name(k): v for (k, v) in types.items()} + types = {self._get_column_name(k): v for (k, v) in types.items()} out = {} - sync_columns = [] + sync_columns = {} for name, value in row.items(): - name = normalize_column_name(name) - if ensure and name.lower() not in columns: + name = self._get_column_name(name) + if self.has_column(name): + out[name] = value + elif ensure: _type = types.get(name) if _type is None: _type = self.db.types.guess(value) - sync_columns.append(Column(name, _type)) - columns.append(name) - if name in columns: + sync_columns[name] = Column(name, _type) out[name] = value - self._sync_table(sync_columns) + self._sync_table(sync_columns.values()) return out def _check_ensure(self, ensure): @@ -395,6 +417,7 @@ class Table(object): def _args_to_clause(self, args, clauses=()): clauses = list(clauses) for column, value in args.items(): + column = self._get_column_name(column) if not self.has_column(column): clauses.append(false()) elif isinstance(value, (list, tuple, set)): @@ -412,7 +435,8 @@ class Table(object): if ordering is None: continue column = ordering.lstrip('-') - if column not in self.table.columns: + column = self._get_column_name(column) + if not self.has_column(column): continue if ordering.startswith('-'): orderings.append(self.table.c[column].desc()) @@ -422,7 +446,7 @@ class Table(object): def _keys_to_args(self, row, keys): keys = ensure_tuple(keys) - keys = [normalize_column_name(k) for k in keys] + keys = [self._get_column_name(k) for k in keys] row = row.copy() args = {k: row.pop(k, None) for k in keys} return args, row @@ -442,7 +466,7 @@ class Table(object): table.create_column('key', unique=True, nullable=False) table.create_column('food', default='banana') """ - name = normalize_column_name(name) + name = self._get_column_name(name) if self.has_column(name): log.debug("Column exists: %s" % name) return @@ -470,7 +494,7 @@ class Table(object): """ if self.db.engine.dialect.name == 'sqlite': raise RuntimeError("SQLite does not support dropping columns.") - name = normalize_column_name(name) + name = self._get_column_name(name) with self.db.lock: if not self.exists or not self.has_column(name): log.debug("Column does not exist: %s", name) @@ -494,12 +518,13 @@ class Table(object): self._threading_warn() self.table.drop(self.db.executable, checkfirst=True) self._table = None + self._columns = None def has_index(self, columns): """Check if an index exists to cover the given ``columns``.""" if not self.exists: return False - columns = set([normalize_column_name(c) for c in columns]) + columns = set([self._get_column_name(c) for c in columns]) if columns in self._indexes: return True for column in columns: @@ -520,7 +545,7 @@ class Table(object): table.create_index(['name', 'country']) """ - columns = [normalize_column_name(c) for c in ensure_tuple(columns)] + columns = [self._get_column_name(c) for c in ensure_tuple(columns)] with self.db.lock: if not self.exists: raise DatasetException("Table has not been created yet.") diff --git a/dataset/util.py b/dataset/util.py index b951fd1..2b0f8c3 100644 --- a/dataset/util.py +++ b/dataset/util.py @@ -73,6 +73,13 @@ def normalize_column_name(name): return name +def normalize_column_key(name): + """Return a comparable column name.""" + if name is None or not isinstance(name, str): + return None + return name.upper().strip().replace(' ', '') + + def normalize_table_name(name): """Check if the table name is obviously invalid.""" if not isinstance(name, str): diff --git a/test/test_dataset.py b/test/test_dataset.py index 588c550..ec38e2e 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -253,6 +253,22 @@ class TableTestCase(unittest.TestCase): 'qux.bar': 'Huhu' }) + def test_cased_column_names(self): + tbl = self.db['cased_column_names'] + tbl.insert({ + 'place': 'Berlin', + }) + tbl.insert({ + 'Place': 'Berlin', + }) + tbl.insert({ + 'PLACE ': 'Berlin', + }) + assert len(tbl.columns) == 2, tbl.columns + assert len(list(tbl.find(Place='Berlin'))) == 3 + assert len(list(tbl.find(place='Berlin'))) == 3 + assert len(list(tbl.find(PLACE='Berlin'))) == 3 + def test_invalid_column_names(self): tbl = self.db['weather'] with self.assertRaises(ValueError):