Handle column name validation more coherently, fixes #45.

This commit is contained in:
Friedrich Lindenberg 2015-05-23 15:30:19 +02:00
parent fd3d9f326c
commit 7f7cb41858
4 changed files with 28 additions and 8 deletions

View File

@ -8,7 +8,6 @@ from six.moves.urllib.parse import urlencode, parse_qs
from sqlalchemy import create_engine
from sqlalchemy import Integer, String
from sqlalchemy.sql import text
from sqlalchemy.pool import NullPool
from sqlalchemy.schema import MetaData, Column
from sqlalchemy.schema import Table as SQLATable
from sqlalchemy.util import safe_reraise

View File

@ -4,7 +4,7 @@ from itertools import count
from sqlalchemy.sql import and_, expression
from sqlalchemy.schema import Column, Index
from sqlalchemy import alias
from dataset.persistence.util import guess_type
from dataset.persistence.util import guess_type, normalize_column_name
from dataset.persistence.util import ResultIter
from dataset.util import DatasetException
@ -27,6 +27,10 @@ class Table(object):
"""
return list(self.table.columns.keys())
@property
def _normalized_columns(self):
return map(normalize_column_name, self.columns)
def drop(self):
"""
Drop the table from the database, deleting both the schema
@ -199,7 +203,7 @@ class Table(object):
def _ensure_columns(self, row, types={}):
# Keep order of inserted columns
for column in row.keys():
if column in self.table.columns.keys():
if normalize_column_name(column) in self._normalized_columns:
continue
if column in types:
_type = types[column]
@ -230,12 +234,8 @@ class Table(object):
self._check_dropped()
self.database._acquire()
# check that column name is OK:
if '.' in name:
raise ValueError("Invalid column name: %r" % name)
try:
if name not in self.table.columns.keys():
if normalize_column_name(name) not in self._normalized_columns:
self.database.op.add_column(
self.table.name,
Column(name, type)

View File

@ -7,6 +7,7 @@ except ImportError: # pragma: no cover
from ordereddict import OrderedDict
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean
from six import string_types
def guess_type(sample):
@ -27,6 +28,15 @@ def convert_row(row):
return OrderedDict(row.items())
def normalize_column_name(name):
if not isinstance(name, string_types):
raise ValueError('%r is not a valid column name.' % name)
name = name.lower().strip()
if not len(name) or '.' in name or '-' in name:
raise ValueError('%r is not a valid column name.' % name)
return name
class ResultIter(object):
""" SQLAlchemy ResultProxies are not iterable to get a
list of dictionaries. This is to wrap them. """

View File

@ -209,6 +209,17 @@ class TableTestCase(unittest.TestCase):
'qux.bar': 'Huhu'
})
def test_invalid_column_names(self):
tbl = self.db['weather']
with self.assertRaises(ValueError):
tbl.insert({None: 'banana'})
with self.assertRaises(ValueError):
tbl.insert({'': 'banana'})
with self.assertRaises(ValueError):
tbl.insert({'-': 'banana'})
def test_delete(self):
self.tbl.insert({
'date': datetime(2011, 1, 2),