From 7f7cb41858e05ad43b2a577083f92e5dcfb5d8f0 Mon Sep 17 00:00:00 2001 From: Friedrich Lindenberg Date: Sat, 23 May 2015 15:30:19 +0200 Subject: [PATCH] Handle column name validation more coherently, fixes #45. --- dataset/persistence/database.py | 1 - dataset/persistence/table.py | 14 +++++++------- dataset/persistence/util.py | 10 ++++++++++ test/test_persistence.py | 11 +++++++++++ 4 files changed, 28 insertions(+), 8 deletions(-) diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index 2caddc8..f8b38de 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -8,7 +8,6 @@ from six.moves.urllib.parse import urlencode, parse_qs from sqlalchemy import create_engine from sqlalchemy import Integer, String from sqlalchemy.sql import text -from sqlalchemy.pool import NullPool from sqlalchemy.schema import MetaData, Column from sqlalchemy.schema import Table as SQLATable from sqlalchemy.util import safe_reraise diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index 160ebec..a00c887 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -4,7 +4,7 @@ from itertools import count from sqlalchemy.sql import and_, expression from sqlalchemy.schema import Column, Index from sqlalchemy import alias -from dataset.persistence.util import guess_type +from dataset.persistence.util import guess_type, normalize_column_name from dataset.persistence.util import ResultIter from dataset.util import DatasetException @@ -27,6 +27,10 @@ class Table(object): """ return list(self.table.columns.keys()) + @property + def _normalized_columns(self): + return map(normalize_column_name, self.columns) + def drop(self): """ Drop the table from the database, deleting both the schema @@ -199,7 +203,7 @@ class Table(object): def _ensure_columns(self, row, types={}): # Keep order of inserted columns for column in row.keys(): - if column in self.table.columns.keys(): + if normalize_column_name(column) in self._normalized_columns: continue if column in types: _type = types[column] @@ -230,12 +234,8 @@ class Table(object): self._check_dropped() self.database._acquire() - # check that column name is OK: - if '.' in name: - raise ValueError("Invalid column name: %r" % name) - try: - if name not in self.table.columns.keys(): + if normalize_column_name(name) not in self._normalized_columns: self.database.op.add_column( self.table.name, Column(name, type) diff --git a/dataset/persistence/util.py b/dataset/persistence/util.py index da126c7..4da08a9 100644 --- a/dataset/persistence/util.py +++ b/dataset/persistence/util.py @@ -7,6 +7,7 @@ except ImportError: # pragma: no cover from ordereddict import OrderedDict from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean +from six import string_types def guess_type(sample): @@ -27,6 +28,15 @@ def convert_row(row): return OrderedDict(row.items()) +def normalize_column_name(name): + if not isinstance(name, string_types): + raise ValueError('%r is not a valid column name.' % name) + name = name.lower().strip() + if not len(name) or '.' in name or '-' in name: + raise ValueError('%r is not a valid column name.' % name) + return name + + class ResultIter(object): """ SQLAlchemy ResultProxies are not iterable to get a list of dictionaries. This is to wrap them. """ diff --git a/test/test_persistence.py b/test/test_persistence.py index ce56bc2..8ed09d0 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -209,6 +209,17 @@ class TableTestCase(unittest.TestCase): 'qux.bar': 'Huhu' }) + def test_invalid_column_names(self): + tbl = self.db['weather'] + with self.assertRaises(ValueError): + tbl.insert({None: 'banana'}) + + with self.assertRaises(ValueError): + tbl.insert({'': 'banana'}) + + with self.assertRaises(ValueError): + tbl.insert({'-': 'banana'}) + def test_delete(self): self.tbl.insert({ 'date': datetime(2011, 1, 2),