Handle column name validation more coherently, fixes #45.
This commit is contained in:
parent
fd3d9f326c
commit
7f7cb41858
@ -8,7 +8,6 @@ from six.moves.urllib.parse import urlencode, parse_qs
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import Integer, String
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.schema import MetaData, Column
|
||||
from sqlalchemy.schema import Table as SQLATable
|
||||
from sqlalchemy.util import safe_reraise
|
||||
|
||||
@ -4,7 +4,7 @@ from itertools import count
|
||||
from sqlalchemy.sql import and_, expression
|
||||
from sqlalchemy.schema import Column, Index
|
||||
from sqlalchemy import alias
|
||||
from dataset.persistence.util import guess_type
|
||||
from dataset.persistence.util import guess_type, normalize_column_name
|
||||
from dataset.persistence.util import ResultIter
|
||||
from dataset.util import DatasetException
|
||||
|
||||
@ -27,6 +27,10 @@ class Table(object):
|
||||
"""
|
||||
return list(self.table.columns.keys())
|
||||
|
||||
@property
|
||||
def _normalized_columns(self):
|
||||
return map(normalize_column_name, self.columns)
|
||||
|
||||
def drop(self):
|
||||
"""
|
||||
Drop the table from the database, deleting both the schema
|
||||
@ -199,7 +203,7 @@ class Table(object):
|
||||
def _ensure_columns(self, row, types={}):
|
||||
# Keep order of inserted columns
|
||||
for column in row.keys():
|
||||
if column in self.table.columns.keys():
|
||||
if normalize_column_name(column) in self._normalized_columns:
|
||||
continue
|
||||
if column in types:
|
||||
_type = types[column]
|
||||
@ -230,12 +234,8 @@ class Table(object):
|
||||
self._check_dropped()
|
||||
self.database._acquire()
|
||||
|
||||
# check that column name is OK:
|
||||
if '.' in name:
|
||||
raise ValueError("Invalid column name: %r" % name)
|
||||
|
||||
try:
|
||||
if name not in self.table.columns.keys():
|
||||
if normalize_column_name(name) not in self._normalized_columns:
|
||||
self.database.op.add_column(
|
||||
self.table.name,
|
||||
Column(name, type)
|
||||
|
||||
@ -7,6 +7,7 @@ except ImportError: # pragma: no cover
|
||||
from ordereddict import OrderedDict
|
||||
|
||||
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean
|
||||
from six import string_types
|
||||
|
||||
|
||||
def guess_type(sample):
|
||||
@ -27,6 +28,15 @@ def convert_row(row):
|
||||
return OrderedDict(row.items())
|
||||
|
||||
|
||||
def normalize_column_name(name):
|
||||
if not isinstance(name, string_types):
|
||||
raise ValueError('%r is not a valid column name.' % name)
|
||||
name = name.lower().strip()
|
||||
if not len(name) or '.' in name or '-' in name:
|
||||
raise ValueError('%r is not a valid column name.' % name)
|
||||
return name
|
||||
|
||||
|
||||
class ResultIter(object):
|
||||
""" SQLAlchemy ResultProxies are not iterable to get a
|
||||
list of dictionaries. This is to wrap them. """
|
||||
|
||||
@ -209,6 +209,17 @@ class TableTestCase(unittest.TestCase):
|
||||
'qux.bar': 'Huhu'
|
||||
})
|
||||
|
||||
def test_invalid_column_names(self):
|
||||
tbl = self.db['weather']
|
||||
with self.assertRaises(ValueError):
|
||||
tbl.insert({None: 'banana'})
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tbl.insert({'': 'banana'})
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
tbl.insert({'-': 'banana'})
|
||||
|
||||
def test_delete(self):
|
||||
self.tbl.insert({
|
||||
'date': datetime(2011, 1, 2),
|
||||
|
||||
Loading…
Reference in New Issue
Block a user