Make column lookups properly case-insensitive, refs #310.
This commit is contained in:
parent
5ddff753b7
commit
72ecf561fe
@ -10,9 +10,10 @@ from sqlalchemy.schema import Table as SQLATable
|
|||||||
from sqlalchemy.exc import NoSuchTableError
|
from sqlalchemy.exc import NoSuchTableError
|
||||||
|
|
||||||
from dataset.types import Types
|
from dataset.types import Types
|
||||||
from dataset.util import normalize_column_name, index_name, ensure_tuple
|
from dataset.util import index_name, ensure_tuple
|
||||||
from dataset.util import DatasetException, ResultIter, QUERY_STEP
|
from dataset.util import DatasetException, ResultIter, QUERY_STEP
|
||||||
from dataset.util import normalize_table_name, pad_chunk_columns
|
from dataset.util import normalize_table_name, pad_chunk_columns
|
||||||
|
from dataset.util import normalize_column_name, normalize_column_key
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
@ -28,6 +29,7 @@ class Table(object):
|
|||||||
self.db = database
|
self.db = database
|
||||||
self.name = normalize_table_name(table_name)
|
self.name = normalize_table_name(table_name)
|
||||||
self._table = None
|
self._table = None
|
||||||
|
self._columns = None
|
||||||
self._indexes = []
|
self._indexes = []
|
||||||
self._primary_id = primary_id if primary_id is not None \
|
self._primary_id = primary_id if primary_id is not None \
|
||||||
else self.PRIMARY_DEFAULT
|
else self.PRIMARY_DEFAULT
|
||||||
@ -49,17 +51,36 @@ class Table(object):
|
|||||||
self._sync_table(())
|
self._sync_table(())
|
||||||
return self._table
|
return self._table
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _column_keys(self):
|
||||||
|
"""Get a dictionary of all columns and their case mapping."""
|
||||||
|
if not self.exists:
|
||||||
|
return {}
|
||||||
|
if self._columns is None:
|
||||||
|
self._columns = {}
|
||||||
|
for column in self.table.columns:
|
||||||
|
name = normalize_column_name(column.name)
|
||||||
|
key = normalize_column_key(name)
|
||||||
|
if key in self._columns:
|
||||||
|
log.warning("Duplicate column: %s", name)
|
||||||
|
self._columns[key] = name
|
||||||
|
return self._columns
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def columns(self):
|
def columns(self):
|
||||||
"""Get a listing of all columns that exist in the table."""
|
"""Get a listing of all columns that exist in the table."""
|
||||||
if not self.exists:
|
return list(self._column_keys.values())
|
||||||
return []
|
|
||||||
return self.table.columns.keys()
|
|
||||||
|
|
||||||
def has_column(self, column):
|
def has_column(self, column):
|
||||||
"""Check if a column with the given name exists on this table."""
|
"""Check if a column with the given name exists on this table."""
|
||||||
columns = [c.lower() for c in self.columns]
|
key = normalize_column_key(normalize_column_name(column))
|
||||||
return normalize_column_name(column).lower() in columns
|
return key in self._column_keys
|
||||||
|
|
||||||
|
def _get_column_name(self, name):
|
||||||
|
"""Find the best column name with case-insensitive matching."""
|
||||||
|
name = normalize_column_name(name)
|
||||||
|
key = normalize_column_key(name)
|
||||||
|
return self._column_keys.get(key, name)
|
||||||
|
|
||||||
def insert(self, row, ensure=None, types=None):
|
def insert(self, row, ensure=None, types=None):
|
||||||
"""Add a ``row`` dict by inserting it into the table.
|
"""Add a ``row`` dict by inserting it into the table.
|
||||||
@ -282,6 +303,7 @@ class Table(object):
|
|||||||
def _reflect_table(self):
|
def _reflect_table(self):
|
||||||
"""Load the tables definition from the database."""
|
"""Load the tables definition from the database."""
|
||||||
with self.db.lock:
|
with self.db.lock:
|
||||||
|
self._columns = None
|
||||||
try:
|
try:
|
||||||
self._table = SQLATable(self.name,
|
self._table = SQLATable(self.name,
|
||||||
self.db.metadata,
|
self.db.metadata,
|
||||||
@ -299,6 +321,7 @@ class Table(object):
|
|||||||
|
|
||||||
def _sync_table(self, columns):
|
def _sync_table(self, columns):
|
||||||
"""Lazy load, create or adapt the table structure in the database."""
|
"""Lazy load, create or adapt the table structure in the database."""
|
||||||
|
self._columns = None
|
||||||
if self._table is None:
|
if self._table is None:
|
||||||
# Load an existing table from the database.
|
# Load an existing table from the database.
|
||||||
self._reflect_table()
|
self._reflect_table()
|
||||||
@ -344,23 +367,22 @@ class Table(object):
|
|||||||
this will remove any keys from the ``row`` for which there is no
|
this will remove any keys from the ``row`` for which there is no
|
||||||
matching column.
|
matching column.
|
||||||
"""
|
"""
|
||||||
columns = [c.lower() for c in self.columns]
|
|
||||||
ensure = self._check_ensure(ensure)
|
ensure = self._check_ensure(ensure)
|
||||||
types = types or {}
|
types = types or {}
|
||||||
types = {normalize_column_name(k): v for (k, v) in types.items()}
|
types = {self._get_column_name(k): v for (k, v) in types.items()}
|
||||||
out = {}
|
out = {}
|
||||||
sync_columns = []
|
sync_columns = {}
|
||||||
for name, value in row.items():
|
for name, value in row.items():
|
||||||
name = normalize_column_name(name)
|
name = self._get_column_name(name)
|
||||||
if ensure and name.lower() not in columns:
|
if self.has_column(name):
|
||||||
|
out[name] = value
|
||||||
|
elif ensure:
|
||||||
_type = types.get(name)
|
_type = types.get(name)
|
||||||
if _type is None:
|
if _type is None:
|
||||||
_type = self.db.types.guess(value)
|
_type = self.db.types.guess(value)
|
||||||
sync_columns.append(Column(name, _type))
|
sync_columns[name] = Column(name, _type)
|
||||||
columns.append(name)
|
|
||||||
if name in columns:
|
|
||||||
out[name] = value
|
out[name] = value
|
||||||
self._sync_table(sync_columns)
|
self._sync_table(sync_columns.values())
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def _check_ensure(self, ensure):
|
def _check_ensure(self, ensure):
|
||||||
@ -395,6 +417,7 @@ class Table(object):
|
|||||||
def _args_to_clause(self, args, clauses=()):
|
def _args_to_clause(self, args, clauses=()):
|
||||||
clauses = list(clauses)
|
clauses = list(clauses)
|
||||||
for column, value in args.items():
|
for column, value in args.items():
|
||||||
|
column = self._get_column_name(column)
|
||||||
if not self.has_column(column):
|
if not self.has_column(column):
|
||||||
clauses.append(false())
|
clauses.append(false())
|
||||||
elif isinstance(value, (list, tuple, set)):
|
elif isinstance(value, (list, tuple, set)):
|
||||||
@ -412,7 +435,8 @@ class Table(object):
|
|||||||
if ordering is None:
|
if ordering is None:
|
||||||
continue
|
continue
|
||||||
column = ordering.lstrip('-')
|
column = ordering.lstrip('-')
|
||||||
if column not in self.table.columns:
|
column = self._get_column_name(column)
|
||||||
|
if not self.has_column(column):
|
||||||
continue
|
continue
|
||||||
if ordering.startswith('-'):
|
if ordering.startswith('-'):
|
||||||
orderings.append(self.table.c[column].desc())
|
orderings.append(self.table.c[column].desc())
|
||||||
@ -422,7 +446,7 @@ class Table(object):
|
|||||||
|
|
||||||
def _keys_to_args(self, row, keys):
|
def _keys_to_args(self, row, keys):
|
||||||
keys = ensure_tuple(keys)
|
keys = ensure_tuple(keys)
|
||||||
keys = [normalize_column_name(k) for k in keys]
|
keys = [self._get_column_name(k) for k in keys]
|
||||||
row = row.copy()
|
row = row.copy()
|
||||||
args = {k: row.pop(k, None) for k in keys}
|
args = {k: row.pop(k, None) for k in keys}
|
||||||
return args, row
|
return args, row
|
||||||
@ -442,7 +466,7 @@ class Table(object):
|
|||||||
table.create_column('key', unique=True, nullable=False)
|
table.create_column('key', unique=True, nullable=False)
|
||||||
table.create_column('food', default='banana')
|
table.create_column('food', default='banana')
|
||||||
"""
|
"""
|
||||||
name = normalize_column_name(name)
|
name = self._get_column_name(name)
|
||||||
if self.has_column(name):
|
if self.has_column(name):
|
||||||
log.debug("Column exists: %s" % name)
|
log.debug("Column exists: %s" % name)
|
||||||
return
|
return
|
||||||
@ -470,7 +494,7 @@ class Table(object):
|
|||||||
"""
|
"""
|
||||||
if self.db.engine.dialect.name == 'sqlite':
|
if self.db.engine.dialect.name == 'sqlite':
|
||||||
raise RuntimeError("SQLite does not support dropping columns.")
|
raise RuntimeError("SQLite does not support dropping columns.")
|
||||||
name = normalize_column_name(name)
|
name = self._get_column_name(name)
|
||||||
with self.db.lock:
|
with self.db.lock:
|
||||||
if not self.exists or not self.has_column(name):
|
if not self.exists or not self.has_column(name):
|
||||||
log.debug("Column does not exist: %s", name)
|
log.debug("Column does not exist: %s", name)
|
||||||
@ -494,12 +518,13 @@ class Table(object):
|
|||||||
self._threading_warn()
|
self._threading_warn()
|
||||||
self.table.drop(self.db.executable, checkfirst=True)
|
self.table.drop(self.db.executable, checkfirst=True)
|
||||||
self._table = None
|
self._table = None
|
||||||
|
self._columns = None
|
||||||
|
|
||||||
def has_index(self, columns):
|
def has_index(self, columns):
|
||||||
"""Check if an index exists to cover the given ``columns``."""
|
"""Check if an index exists to cover the given ``columns``."""
|
||||||
if not self.exists:
|
if not self.exists:
|
||||||
return False
|
return False
|
||||||
columns = set([normalize_column_name(c) for c in columns])
|
columns = set([self._get_column_name(c) for c in columns])
|
||||||
if columns in self._indexes:
|
if columns in self._indexes:
|
||||||
return True
|
return True
|
||||||
for column in columns:
|
for column in columns:
|
||||||
@ -520,7 +545,7 @@ class Table(object):
|
|||||||
|
|
||||||
table.create_index(['name', 'country'])
|
table.create_index(['name', 'country'])
|
||||||
"""
|
"""
|
||||||
columns = [normalize_column_name(c) for c in ensure_tuple(columns)]
|
columns = [self._get_column_name(c) for c in ensure_tuple(columns)]
|
||||||
with self.db.lock:
|
with self.db.lock:
|
||||||
if not self.exists:
|
if not self.exists:
|
||||||
raise DatasetException("Table has not been created yet.")
|
raise DatasetException("Table has not been created yet.")
|
||||||
|
|||||||
@ -73,6 +73,13 @@ def normalize_column_name(name):
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_column_key(name):
|
||||||
|
"""Return a comparable column name."""
|
||||||
|
if name is None or not isinstance(name, str):
|
||||||
|
return None
|
||||||
|
return name.upper().strip().replace(' ', '')
|
||||||
|
|
||||||
|
|
||||||
def normalize_table_name(name):
|
def normalize_table_name(name):
|
||||||
"""Check if the table name is obviously invalid."""
|
"""Check if the table name is obviously invalid."""
|
||||||
if not isinstance(name, str):
|
if not isinstance(name, str):
|
||||||
|
|||||||
@ -253,6 +253,22 @@ class TableTestCase(unittest.TestCase):
|
|||||||
'qux.bar': 'Huhu'
|
'qux.bar': 'Huhu'
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def test_cased_column_names(self):
|
||||||
|
tbl = self.db['cased_column_names']
|
||||||
|
tbl.insert({
|
||||||
|
'place': 'Berlin',
|
||||||
|
})
|
||||||
|
tbl.insert({
|
||||||
|
'Place': 'Berlin',
|
||||||
|
})
|
||||||
|
tbl.insert({
|
||||||
|
'PLACE ': 'Berlin',
|
||||||
|
})
|
||||||
|
assert len(tbl.columns) == 2, tbl.columns
|
||||||
|
assert len(list(tbl.find(Place='Berlin'))) == 3
|
||||||
|
assert len(list(tbl.find(place='Berlin'))) == 3
|
||||||
|
assert len(list(tbl.find(PLACE='Berlin'))) == 3
|
||||||
|
|
||||||
def test_invalid_column_names(self):
|
def test_invalid_column_names(self):
|
||||||
tbl = self.db['weather']
|
tbl = self.db['weather']
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user