Make column lookups properly case-insensitive, refs #310.

This commit is contained in:
Friedrich Lindenberg 2020-02-23 19:13:34 +01:00
parent 5ddff753b7
commit 72ecf561fe
3 changed files with 69 additions and 21 deletions

View File

@ -10,9 +10,10 @@ from sqlalchemy.schema import Table as SQLATable
from sqlalchemy.exc import NoSuchTableError
from dataset.types import Types
from dataset.util import normalize_column_name, index_name, ensure_tuple
from dataset.util import index_name, ensure_tuple
from dataset.util import DatasetException, ResultIter, QUERY_STEP
from dataset.util import normalize_table_name, pad_chunk_columns
from dataset.util import normalize_column_name, normalize_column_key
log = logging.getLogger(__name__)
@ -28,6 +29,7 @@ class Table(object):
self.db = database
self.name = normalize_table_name(table_name)
self._table = None
self._columns = None
self._indexes = []
self._primary_id = primary_id if primary_id is not None \
else self.PRIMARY_DEFAULT
@ -49,17 +51,36 @@ class Table(object):
self._sync_table(())
return self._table
@property
def _column_keys(self):
"""Get a dictionary of all columns and their case mapping."""
if not self.exists:
return {}
if self._columns is None:
self._columns = {}
for column in self.table.columns:
name = normalize_column_name(column.name)
key = normalize_column_key(name)
if key in self._columns:
log.warning("Duplicate column: %s", name)
self._columns[key] = name
return self._columns
@property
def columns(self):
"""Get a listing of all columns that exist in the table."""
if not self.exists:
return []
return self.table.columns.keys()
return list(self._column_keys.values())
def has_column(self, column):
"""Check if a column with the given name exists on this table."""
columns = [c.lower() for c in self.columns]
return normalize_column_name(column).lower() in columns
key = normalize_column_key(normalize_column_name(column))
return key in self._column_keys
def _get_column_name(self, name):
"""Find the best column name with case-insensitive matching."""
name = normalize_column_name(name)
key = normalize_column_key(name)
return self._column_keys.get(key, name)
def insert(self, row, ensure=None, types=None):
"""Add a ``row`` dict by inserting it into the table.
@ -282,6 +303,7 @@ class Table(object):
def _reflect_table(self):
"""Load the tables definition from the database."""
with self.db.lock:
self._columns = None
try:
self._table = SQLATable(self.name,
self.db.metadata,
@ -299,6 +321,7 @@ class Table(object):
def _sync_table(self, columns):
"""Lazy load, create or adapt the table structure in the database."""
self._columns = None
if self._table is None:
# Load an existing table from the database.
self._reflect_table()
@ -344,23 +367,22 @@ class Table(object):
this will remove any keys from the ``row`` for which there is no
matching column.
"""
columns = [c.lower() for c in self.columns]
ensure = self._check_ensure(ensure)
types = types or {}
types = {normalize_column_name(k): v for (k, v) in types.items()}
types = {self._get_column_name(k): v for (k, v) in types.items()}
out = {}
sync_columns = []
sync_columns = {}
for name, value in row.items():
name = normalize_column_name(name)
if ensure and name.lower() not in columns:
name = self._get_column_name(name)
if self.has_column(name):
out[name] = value
elif ensure:
_type = types.get(name)
if _type is None:
_type = self.db.types.guess(value)
sync_columns.append(Column(name, _type))
columns.append(name)
if name in columns:
sync_columns[name] = Column(name, _type)
out[name] = value
self._sync_table(sync_columns)
self._sync_table(sync_columns.values())
return out
def _check_ensure(self, ensure):
@ -395,6 +417,7 @@ class Table(object):
def _args_to_clause(self, args, clauses=()):
clauses = list(clauses)
for column, value in args.items():
column = self._get_column_name(column)
if not self.has_column(column):
clauses.append(false())
elif isinstance(value, (list, tuple, set)):
@ -412,7 +435,8 @@ class Table(object):
if ordering is None:
continue
column = ordering.lstrip('-')
if column not in self.table.columns:
column = self._get_column_name(column)
if not self.has_column(column):
continue
if ordering.startswith('-'):
orderings.append(self.table.c[column].desc())
@ -422,7 +446,7 @@ class Table(object):
def _keys_to_args(self, row, keys):
keys = ensure_tuple(keys)
keys = [normalize_column_name(k) for k in keys]
keys = [self._get_column_name(k) for k in keys]
row = row.copy()
args = {k: row.pop(k, None) for k in keys}
return args, row
@ -442,7 +466,7 @@ class Table(object):
table.create_column('key', unique=True, nullable=False)
table.create_column('food', default='banana')
"""
name = normalize_column_name(name)
name = self._get_column_name(name)
if self.has_column(name):
log.debug("Column exists: %s" % name)
return
@ -470,7 +494,7 @@ class Table(object):
"""
if self.db.engine.dialect.name == 'sqlite':
raise RuntimeError("SQLite does not support dropping columns.")
name = normalize_column_name(name)
name = self._get_column_name(name)
with self.db.lock:
if not self.exists or not self.has_column(name):
log.debug("Column does not exist: %s", name)
@ -494,12 +518,13 @@ class Table(object):
self._threading_warn()
self.table.drop(self.db.executable, checkfirst=True)
self._table = None
self._columns = None
def has_index(self, columns):
"""Check if an index exists to cover the given ``columns``."""
if not self.exists:
return False
columns = set([normalize_column_name(c) for c in columns])
columns = set([self._get_column_name(c) for c in columns])
if columns in self._indexes:
return True
for column in columns:
@ -520,7 +545,7 @@ class Table(object):
table.create_index(['name', 'country'])
"""
columns = [normalize_column_name(c) for c in ensure_tuple(columns)]
columns = [self._get_column_name(c) for c in ensure_tuple(columns)]
with self.db.lock:
if not self.exists:
raise DatasetException("Table has not been created yet.")

View File

@ -73,6 +73,13 @@ def normalize_column_name(name):
return name
def normalize_column_key(name):
"""Return a comparable column name."""
if name is None or not isinstance(name, str):
return None
return name.upper().strip().replace(' ', '')
def normalize_table_name(name):
"""Check if the table name is obviously invalid."""
if not isinstance(name, str):

View File

@ -253,6 +253,22 @@ class TableTestCase(unittest.TestCase):
'qux.bar': 'Huhu'
})
def test_cased_column_names(self):
tbl = self.db['cased_column_names']
tbl.insert({
'place': 'Berlin',
})
tbl.insert({
'Place': 'Berlin',
})
tbl.insert({
'PLACE ': 'Berlin',
})
assert len(tbl.columns) == 2, tbl.columns
assert len(list(tbl.find(Place='Berlin'))) == 3
assert len(list(tbl.find(place='Berlin'))) == 3
assert len(list(tbl.find(PLACE='Berlin'))) == 3
def test_invalid_column_names(self):
tbl = self.db['weather']
with self.assertRaises(ValueError):