Begin implementing a types handler instead of using plain text types.

this is potentially BREAKING scripts which use the string syntax.
This commit is contained in:
Friedrich Lindenberg 2017-09-02 16:47:04 +02:00
parent 2f0fa7cdd2
commit a4c73a8fb8
5 changed files with 95 additions and 75 deletions

View File

@ -1,12 +1,10 @@
import logging import logging
import threading import threading
import re
import six import six
from six.moves.urllib.parse import parse_qs, urlparse from six.moves.urllib.parse import parse_qs, urlparse
from sqlalchemy import create_engine from sqlalchemy import create_engine
from sqlalchemy import Integer, String
from sqlalchemy.sql import text from sqlalchemy.sql import text
from sqlalchemy.schema import MetaData, Column from sqlalchemy.schema import MetaData, Column
from sqlalchemy.schema import Table as SQLATable from sqlalchemy.schema import Table as SQLATable
@ -18,7 +16,7 @@ from alembic.operations import Operations
from dataset.persistence.table import Table from dataset.persistence.table import Table
from dataset.persistence.util import ResultIter, row_type, safe_url from dataset.persistence.util import ResultIter, row_type, safe_url
from dataset.util import DatasetException from dataset.persistence.types import Types
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -49,6 +47,7 @@ class Database(object):
if len(schema_qs): if len(schema_qs):
schema = schema_qs.pop() schema = schema_qs.pop()
self.types = Types()
self.schema = schema self.schema = schema
self.engine = create_engine(url, **engine_kwargs) self.engine = create_engine(url, **engine_kwargs)
self.url = url self.url = url
@ -145,9 +144,8 @@ class Database(object):
raise ValueError("Invalid table name: %r" % table_name) raise ValueError("Invalid table name: %r" % table_name)
return table_name.strip() return table_name.strip()
def create_table(self, table_name, primary_id='id', primary_type='Integer'): def create_table(self, table_name, primary_id=None, primary_type=None):
""" """Create a new table.
Create a new table.
The new table will automatically have an `id` column unless specified via The new table will automatically have an `id` column unless specified via
optional parameter primary_id, which will be used as the primary key of the optional parameter primary_id, which will be used as the primary key of the
@ -166,34 +164,31 @@ class Database(object):
# custom id and type # custom id and type
table2 = db.create_table('population2', 'age') table2 = db.create_table('population2', 'age')
table3 = db.create_table('population3', primary_id='race', primary_type='String') table3 = db.create_table('population3',
primary_id='city',
primary_type=db.types.string)
# custom length of String # custom length of String
table4 = db.create_table('population4', primary_id='race', primary_type='String(50)') table4 = db.create_table('population4',
primary_id='city',
primary_type=db.types.string(25))
# no primary key
table5 = db.create_table('population5',
primary_id=False)
""" """
table_name = self._valid_table_name(table_name) table_name = self._valid_table_name(table_name)
self._acquire() self._acquire()
try: try:
log.debug("Creating table: %s on %r" % (table_name, self.engine)) log.debug("Creating table: %s" % (table_name))
match = re.match(r'^(Integer)$|^(String)(\(\d+\))?$', primary_type)
if match:
if match.group(1) == 'Integer':
auto_flag = False
if primary_id == 'id':
auto_flag = True
col = Column(primary_id, Integer, primary_key=True, autoincrement=auto_flag)
elif not match.group(3):
col = Column(primary_id, String(255), primary_key=True)
else:
len_string = int(match.group(3)[1:-1])
len_string = min(len_string, 255)
col = Column(primary_id, String(len_string), primary_key=True)
else:
raise DatasetException(
"The primary_type has to be either 'Integer' or 'String'.")
table = SQLATable(table_name, self.metadata, schema=self.schema) table = SQLATable(table_name, self.metadata, schema=self.schema)
if primary_id is not False:
primary_id = primary_id or Table.PRIMARY_DEFAULT
primary_type = primary_type or self.types.integer
autoincrement = primary_id == Table.PRIMARY_DEFAULT
col = Column(primary_id, primary_type,
primary_key=True,
autoincrement=autoincrement)
table.append_column(col) table.append_column(col)
table.create(self.engine) table.create(self.executable)
self._tables[table_name] = table self._tables[table_name] = table
return Table(self, table) return Table(self, table)
finally: finally:
@ -227,13 +222,13 @@ class Database(object):
"""Reload a table schema from the database.""" """Reload a table schema from the database."""
table_name = self._valid_table_name(table_name) table_name = self._valid_table_name(table_name)
self.metadata = MetaData(schema=self.schema) self.metadata = MetaData(schema=self.schema)
self.metadata.bind = self.engine self.metadata.bind = self.executable
self.metadata.reflect(self.engine) self.metadata.reflect(self.executable)
self._tables[table_name] = SQLATable(table_name, self.metadata, self._tables[table_name] = SQLATable(table_name, self.metadata,
schema=self.schema) schema=self.schema)
return self._tables[table_name] return self._tables[table_name]
def get_table(self, table_name, primary_id='id', primary_type='Integer'): def get_table(self, table_name, primary_id=None, primary_type=None):
""" """
Smart wrapper around *load_table* and *create_table*. Smart wrapper around *load_table* and *create_table*.

View File

@ -5,7 +5,8 @@ from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.schema import Column, Index from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false from sqlalchemy import func, select, false
from dataset.persistence.util import guess_type, normalize_column_name
from dataset.persistence.util import normalize_column_name
from dataset.persistence.util import ResultIter from dataset.persistence.util import ResultIter
from dataset.util import DatasetException from dataset.util import DatasetException
@ -15,6 +16,7 @@ log = logging.getLogger(__name__)
class Table(object): class Table(object):
"""Represents a table in a database and exposes common operations.""" """Represents a table in a database and exposes common operations."""
PRIMARY_DEFAULT = 'id'
def __init__(self, database, table): def __init__(self, database, table):
"""Initialise the table from database schema.""" """Initialise the table from database schema."""
@ -59,7 +61,7 @@ class Table(object):
# filter out keys not in column set # filter out keys not in column set
return {k: row[k] for k in row if k in self._normalized_columns} return {k: row[k] for k in row if k in self._normalized_columns}
def insert(self, row, ensure=None, types={}): def insert(self, row, ensure=None, types=None):
""" """
Add a row (type: dict) by inserting it into the table. Add a row (type: dict) by inserting it into the table.
@ -88,7 +90,7 @@ class Table(object):
if len(res.inserted_primary_key) > 0: if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0] return res.inserted_primary_key[0]
def insert_ignore(self, row, keys, ensure=None, types={}): def insert_ignore(self, row, keys, ensure=None, types=None):
""" """
Add a row (type: dict) into the table if the row does not exist. Add a row (type: dict) into the table if the row does not exist.
@ -113,7 +115,7 @@ class Table(object):
else: else:
return False return False
def insert_many(self, rows, chunk_size=1000, ensure=None, types={}): def insert_many(self, rows, chunk_size=1000, ensure=None, types=None):
""" """
Add many rows at a time. Add many rows at a time.
@ -149,7 +151,7 @@ class Table(object):
if chunk: if chunk:
_process_chunk(chunk) _process_chunk(chunk)
def update(self, row, keys, ensure=None, types={}): def update(self, row, keys, ensure=None, types=None):
""" """
Update a row in the table. Update a row in the table.
@ -211,7 +213,7 @@ class Table(object):
filters[key] = row.get(key) filters[key] = row.get(key)
return row, self.find_one(**filters) return row, self.find_one(**filters)
def upsert(self, row, keys, ensure=None, types={}): def upsert(self, row, keys, ensure=None, types=None):
""" """
An UPSERT is a smart combination of insert and update. An UPSERT is a smart combination of insert and update.
@ -260,17 +262,18 @@ class Table(object):
def _has_column(self, column): def _has_column(self, column):
return normalize_column_name(column) in self._normalized_columns return normalize_column_name(column) in self._normalized_columns
def _ensure_columns(self, row, types={}): def _ensure_columns(self, row, types=None):
# Keep order of inserted columns # Keep order of inserted columns
for column in row.keys(): for column in row.keys():
if self._has_column(column): if self._has_column(column):
continue continue
if column in types: if types is not None and column in types:
_type = types[column] _type = types[column]
else: else:
_type = guess_type(row[column]) _type = self.database.types.guess(row[column])
log.debug("Creating column: %s (%s) on %r" % (column, log.debug("Creating column: %s (%s) on %r" % (column,
_type, self.table.name)) _type,
self.table.name))
self.create_column(column, _type) self.create_column(column, _type)
def _args_to_clause(self, args, ensure=None, clauses=()): def _args_to_clause(self, args, ensure=None, clauses=()):
@ -294,13 +297,13 @@ class Table(object):
``type`` must be a `SQLAlchemy column type <http://docs.sqlalchemy.org/en/rel_0_8/core/types.html>`_. ``type`` must be a `SQLAlchemy column type <http://docs.sqlalchemy.org/en/rel_0_8/core/types.html>`_.
:: ::
table.create_column('created_at', sqlalchemy.DateTime) table.create_column('created_at', db.types.datetime)
""" """
self._check_dropped() self._check_dropped()
self.database._acquire() self.database._acquire()
try: try:
if normalize_column_name(name) in self._normalized_columns: if normalize_column_name(name) in self._normalized_columns:
log.warn("Column with similar name exists: %s" % name) log.debug("Column exists: %s" % name)
return return
self.database.op.add_column( self.database.op.add_column(
@ -321,7 +324,8 @@ class Table(object):
table.create_column_by_example('length', 4.2) table.create_column_by_example('length', 4.2)
""" """
self._ensure_columns({name: value}, {}) type_ = self.database.types.guess(value)
self.create_column(name, type_)
def drop_column(self, name): def drop_column(self, name):
""" """

View File

@ -0,0 +1,37 @@
from datetime import datetime, date
from sqlalchemy import Integer, UnicodeText, Float, BigInteger
from sqlalchemy import Boolean, Date, DateTime, Unicode
from sqlalchemy.types import TypeEngine
class Types(object):
"""A holder class for easy access to SQLAlchemy type names."""
integer = Integer
string = Unicode
text = UnicodeText
float = Float
bigint = BigInteger
boolean = Boolean
date = Date
datetime = DateTime
def guess(cls, sample):
"""Given a single sample, guess the column type for the field.
If the sample is an instance of an SQLAlchemy type, the type will be
used instead.
"""
if isinstance(sample, TypeEngine):
return sample
if isinstance(sample, bool):
return cls.boolean
elif isinstance(sample, int):
return cls.integer
elif isinstance(sample, float):
return cls.float
elif isinstance(sample, datetime):
return cls.datetime
elif isinstance(sample, date):
return cls.date
return cls.text

View File

@ -1,4 +1,3 @@
from datetime import datetime, date
try: try:
from urlparse import urlparse from urlparse import urlparse
except ImportError: except ImportError:
@ -9,26 +8,11 @@ try:
except ImportError: # pragma: no cover except ImportError: # pragma: no cover
from ordereddict import OrderedDict from ordereddict import OrderedDict
from sqlalchemy import Integer, UnicodeText, Float, Date, DateTime, Boolean
from six import string_types from six import string_types
row_type = OrderedDict row_type = OrderedDict
def guess_type(sample):
if isinstance(sample, bool):
return Boolean
elif isinstance(sample, int):
return Integer
elif isinstance(sample, float):
return Float
elif isinstance(sample, datetime):
return DateTime
elif isinstance(sample, date):
return Date
return UnicodeText
def convert_row(row_type, row): def convert_row(row_type, row):
if row is None: if row is None:
return None return None

View File

@ -50,25 +50,23 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_custom_id1(self): def test_create_table_custom_id1(self):
pid = "string_id" pid = "string_id"
table = self.db.create_table("foo2", pid, 'String') table = self.db.create_table("foo2", pid, self.db.types.string)
assert table.table.exists() assert table.table.exists()
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({ table.insert({pid: 'foobar'})
'string_id': 'foobar'}) assert table.find_one(string_id='foobar')[pid] == 'foobar'
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
def test_create_table_custom_id2(self): def test_create_table_custom_id2(self):
pid = "string_id" pid = "string_id"
table = self.db.create_table("foo3", pid, 'String(50)') table = self.db.create_table("foo3", pid, self.db.types.string(50))
assert table.table.exists() assert table.table.exists()
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({ table.insert({pid: 'foobar'})
'string_id': 'foobar'}) assert table.find_one(string_id='foobar')[pid] == 'foobar'
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
def test_create_table_custom_id3(self): def test_create_table_custom_id3(self):
pid = "int_id" pid = "int_id"
@ -77,11 +75,11 @@ class DatabaseTestCase(unittest.TestCase):
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({'int_id': 123}) table.insert({pid: 123})
table.insert({'int_id': 124}) table.insert({pid: 124})
assert table.find_one(int_id=123)['int_id'] == 123 assert table.find_one(int_id=123)[pid] == 123
assert table.find_one(int_id=124)['int_id'] == 124 assert table.find_one(int_id=124)[pid] == 124
self.assertRaises(IntegrityError, lambda: table.insert({'int_id': 123})) self.assertRaises(IntegrityError, lambda: table.insert({pid: 123}))
def test_create_table_shorthand1(self): def test_create_table_shorthand1(self):
pid = "int_id" pid = "int_id"
@ -98,7 +96,8 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_shorthand2(self): def test_create_table_shorthand2(self):
pid = "string_id" pid = "string_id"
table = self.db.get_table('foo6', primary_id=pid, primary_type='String') table = self.db.get_table('foo6', primary_id=pid,
primary_type=self.db.types.string)
assert table.table.exists assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
@ -109,7 +108,8 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_shorthand3(self): def test_create_table_shorthand3(self):
pid = "string_id" pid = "string_id"
table = self.db.get_table('foo7', primary_id=pid, primary_type='String(20)') table = self.db.get_table('foo7', primary_id=pid,
primary_type=self.db.types.string(20))
assert table.table.exists assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c