Handle column name validation more coherently, fixes #45.

This commit is contained in:
Friedrich Lindenberg 2015-05-23 15:30:19 +02:00
parent fd3d9f326c
commit 7f7cb41858
4 changed files with 28 additions and 8 deletions

View File

@ -8,7 +8,6 @@ from six.moves.urllib.parse import urlencode, parse_qs
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy import Integer, String from sqlalchemy import Integer, String
from sqlalchemy.sql import text from sqlalchemy.sql import text
from sqlalchemy.pool import NullPool
from sqlalchemy.schema import MetaData, Column from sqlalchemy.schema import MetaData, Column
from sqlalchemy.schema import Table as SQLATable from sqlalchemy.schema import Table as SQLATable
from sqlalchemy.util import safe_reraise from sqlalchemy.util import safe_reraise

View File

@ -4,7 +4,7 @@ from itertools import count
from sqlalchemy.sql import and_, expression from sqlalchemy.sql import and_, expression
from sqlalchemy.schema import Column, Index from sqlalchemy.schema import Column, Index
from sqlalchemy import alias from sqlalchemy import alias
from dataset.persistence.util import guess_type from dataset.persistence.util import guess_type, normalize_column_name
from dataset.persistence.util import ResultIter from dataset.persistence.util import ResultIter
from dataset.util import DatasetException from dataset.util import DatasetException
@ -27,6 +27,10 @@ class Table(object):
""" """
return list(self.table.columns.keys()) return list(self.table.columns.keys())
@property
def _normalized_columns(self):
return map(normalize_column_name, self.columns)
def drop(self): def drop(self):
""" """
Drop the table from the database, deleting both the schema Drop the table from the database, deleting both the schema
@ -199,7 +203,7 @@ class Table(object):
def _ensure_columns(self, row, types={}): def _ensure_columns(self, row, types={}):
# Keep order of inserted columns # Keep order of inserted columns
for column in row.keys(): for column in row.keys():
if column in self.table.columns.keys(): if normalize_column_name(column) in self._normalized_columns:
continue continue
if column in types: if column in types:
_type = types[column] _type = types[column]
@ -230,12 +234,8 @@ class Table(object):
self._check_dropped() self._check_dropped()
self.database._acquire() self.database._acquire()
# check that column name is OK:
if '.' in name:
raise ValueError("Invalid column name: %r" % name)
try: try:
if name not in self.table.columns.keys(): if normalize_column_name(name) not in self._normalized_columns:
self.database.op.add_column( self.database.op.add_column(
self.table.name, self.table.name,
Column(name, type) Column(name, type)

View File

@ -7,6 +7,7 @@ except ImportError: # pragma: no cover
from ordereddict import OrderedDict from ordereddict import OrderedDict
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean
from six import string_types
def guess_type(sample): def guess_type(sample):
@ -27,6 +28,15 @@ def convert_row(row):
return OrderedDict(row.items()) return OrderedDict(row.items())
def normalize_column_name(name):
if not isinstance(name, string_types):
raise ValueError('%r is not a valid column name.' % name)
name = name.lower().strip()
if not len(name) or '.' in name or '-' in name:
raise ValueError('%r is not a valid column name.' % name)
return name
class ResultIter(object): class ResultIter(object):
""" SQLAlchemy ResultProxies are not iterable to get a """ SQLAlchemy ResultProxies are not iterable to get a
list of dictionaries. This is to wrap them. """ list of dictionaries. This is to wrap them. """

View File

@ -209,6 +209,17 @@ class TableTestCase(unittest.TestCase):
'qux.bar': 'Huhu' 'qux.bar': 'Huhu'
}) })
def test_invalid_column_names(self):
tbl = self.db['weather']
with self.assertRaises(ValueError):
tbl.insert({None: 'banana'})
with self.assertRaises(ValueError):
tbl.insert({'': 'banana'})
with self.assertRaises(ValueError):
tbl.insert({'-': 'banana'})
def test_delete(self): def test_delete(self):
self.tbl.insert({ self.tbl.insert({
'date': datetime(2011, 1, 2), 'date': datetime(2011, 1, 2),