Handle column name validation more coherently, fixes #45.
This commit is contained in:
parent
fd3d9f326c
commit
7f7cb41858
@ -8,7 +8,6 @@ from six.moves.urllib.parse import urlencode, parse_qs
|
|||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy import Integer, String
|
from sqlalchemy import Integer, String
|
||||||
from sqlalchemy.sql import text
|
from sqlalchemy.sql import text
|
||||||
from sqlalchemy.pool import NullPool
|
|
||||||
from sqlalchemy.schema import MetaData, Column
|
from sqlalchemy.schema import MetaData, Column
|
||||||
from sqlalchemy.schema import Table as SQLATable
|
from sqlalchemy.schema import Table as SQLATable
|
||||||
from sqlalchemy.util import safe_reraise
|
from sqlalchemy.util import safe_reraise
|
||||||
|
|||||||
@ -4,7 +4,7 @@ from itertools import count
|
|||||||
from sqlalchemy.sql import and_, expression
|
from sqlalchemy.sql import and_, expression
|
||||||
from sqlalchemy.schema import Column, Index
|
from sqlalchemy.schema import Column, Index
|
||||||
from sqlalchemy import alias
|
from sqlalchemy import alias
|
||||||
from dataset.persistence.util import guess_type
|
from dataset.persistence.util import guess_type, normalize_column_name
|
||||||
from dataset.persistence.util import ResultIter
|
from dataset.persistence.util import ResultIter
|
||||||
from dataset.util import DatasetException
|
from dataset.util import DatasetException
|
||||||
|
|
||||||
@ -27,6 +27,10 @@ class Table(object):
|
|||||||
"""
|
"""
|
||||||
return list(self.table.columns.keys())
|
return list(self.table.columns.keys())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _normalized_columns(self):
|
||||||
|
return map(normalize_column_name, self.columns)
|
||||||
|
|
||||||
def drop(self):
|
def drop(self):
|
||||||
"""
|
"""
|
||||||
Drop the table from the database, deleting both the schema
|
Drop the table from the database, deleting both the schema
|
||||||
@ -199,7 +203,7 @@ class Table(object):
|
|||||||
def _ensure_columns(self, row, types={}):
|
def _ensure_columns(self, row, types={}):
|
||||||
# Keep order of inserted columns
|
# Keep order of inserted columns
|
||||||
for column in row.keys():
|
for column in row.keys():
|
||||||
if column in self.table.columns.keys():
|
if normalize_column_name(column) in self._normalized_columns:
|
||||||
continue
|
continue
|
||||||
if column in types:
|
if column in types:
|
||||||
_type = types[column]
|
_type = types[column]
|
||||||
@ -230,12 +234,8 @@ class Table(object):
|
|||||||
self._check_dropped()
|
self._check_dropped()
|
||||||
self.database._acquire()
|
self.database._acquire()
|
||||||
|
|
||||||
# check that column name is OK:
|
|
||||||
if '.' in name:
|
|
||||||
raise ValueError("Invalid column name: %r" % name)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if name not in self.table.columns.keys():
|
if normalize_column_name(name) not in self._normalized_columns:
|
||||||
self.database.op.add_column(
|
self.database.op.add_column(
|
||||||
self.table.name,
|
self.table.name,
|
||||||
Column(name, type)
|
Column(name, type)
|
||||||
|
|||||||
@ -7,6 +7,7 @@ except ImportError: # pragma: no cover
|
|||||||
from ordereddict import OrderedDict
|
from ordereddict import OrderedDict
|
||||||
|
|
||||||
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean
|
from sqlalchemy import Integer, UnicodeText, Float, DateTime, Boolean
|
||||||
|
from six import string_types
|
||||||
|
|
||||||
|
|
||||||
def guess_type(sample):
|
def guess_type(sample):
|
||||||
@ -27,6 +28,15 @@ def convert_row(row):
|
|||||||
return OrderedDict(row.items())
|
return OrderedDict(row.items())
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_column_name(name):
|
||||||
|
if not isinstance(name, string_types):
|
||||||
|
raise ValueError('%r is not a valid column name.' % name)
|
||||||
|
name = name.lower().strip()
|
||||||
|
if not len(name) or '.' in name or '-' in name:
|
||||||
|
raise ValueError('%r is not a valid column name.' % name)
|
||||||
|
return name
|
||||||
|
|
||||||
|
|
||||||
class ResultIter(object):
|
class ResultIter(object):
|
||||||
""" SQLAlchemy ResultProxies are not iterable to get a
|
""" SQLAlchemy ResultProxies are not iterable to get a
|
||||||
list of dictionaries. This is to wrap them. """
|
list of dictionaries. This is to wrap them. """
|
||||||
|
|||||||
@ -209,6 +209,17 @@ class TableTestCase(unittest.TestCase):
|
|||||||
'qux.bar': 'Huhu'
|
'qux.bar': 'Huhu'
|
||||||
})
|
})
|
||||||
|
|
||||||
|
def test_invalid_column_names(self):
|
||||||
|
tbl = self.db['weather']
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tbl.insert({None: 'banana'})
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tbl.insert({'': 'banana'})
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
tbl.insert({'-': 'banana'})
|
||||||
|
|
||||||
def test_delete(self):
|
def test_delete(self):
|
||||||
self.tbl.insert({
|
self.tbl.insert({
|
||||||
'date': datetime(2011, 1, 2),
|
'date': datetime(2011, 1, 2),
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user