Begin implementing a types handler instead of using plain text types.

this is potentially BREAKING scripts which use the string syntax.
This commit is contained in:
Friedrich Lindenberg 2017-09-02 16:47:04 +02:00
parent 2f0fa7cdd2
commit a4c73a8fb8
5 changed files with 95 additions and 75 deletions

View File

@ -1,12 +1,10 @@
import logging
import threading
import re
import six
from six.moves.urllib.parse import parse_qs, urlparse
from sqlalchemy import create_engine
from sqlalchemy import Integer, String
from sqlalchemy.sql import text
from sqlalchemy.schema import MetaData, Column
from sqlalchemy.schema import Table as SQLATable
@ -18,7 +16,7 @@ from alembic.operations import Operations
from dataset.persistence.table import Table
from dataset.persistence.util import ResultIter, row_type, safe_url
from dataset.util import DatasetException
from dataset.persistence.types import Types
log = logging.getLogger(__name__)
@ -49,6 +47,7 @@ class Database(object):
if len(schema_qs):
schema = schema_qs.pop()
self.types = Types()
self.schema = schema
self.engine = create_engine(url, **engine_kwargs)
self.url = url
@ -145,9 +144,8 @@ class Database(object):
raise ValueError("Invalid table name: %r" % table_name)
return table_name.strip()
def create_table(self, table_name, primary_id='id', primary_type='Integer'):
"""
Create a new table.
def create_table(self, table_name, primary_id=None, primary_type=None):
"""Create a new table.
The new table will automatically have an `id` column unless specified via
optional parameter primary_id, which will be used as the primary key of the
@ -166,34 +164,31 @@ class Database(object):
# custom id and type
table2 = db.create_table('population2', 'age')
table3 = db.create_table('population3', primary_id='race', primary_type='String')
table3 = db.create_table('population3',
primary_id='city',
primary_type=db.types.string)
# custom length of String
table4 = db.create_table('population4', primary_id='race', primary_type='String(50)')
table4 = db.create_table('population4',
primary_id='city',
primary_type=db.types.string(25))
# no primary key
table5 = db.create_table('population5',
primary_id=False)
"""
table_name = self._valid_table_name(table_name)
self._acquire()
try:
log.debug("Creating table: %s on %r" % (table_name, self.engine))
match = re.match(r'^(Integer)$|^(String)(\(\d+\))?$', primary_type)
if match:
if match.group(1) == 'Integer':
auto_flag = False
if primary_id == 'id':
auto_flag = True
col = Column(primary_id, Integer, primary_key=True, autoincrement=auto_flag)
elif not match.group(3):
col = Column(primary_id, String(255), primary_key=True)
else:
len_string = int(match.group(3)[1:-1])
len_string = min(len_string, 255)
col = Column(primary_id, String(len_string), primary_key=True)
else:
raise DatasetException(
"The primary_type has to be either 'Integer' or 'String'.")
log.debug("Creating table: %s" % (table_name))
table = SQLATable(table_name, self.metadata, schema=self.schema)
if primary_id is not False:
primary_id = primary_id or Table.PRIMARY_DEFAULT
primary_type = primary_type or self.types.integer
autoincrement = primary_id == Table.PRIMARY_DEFAULT
col = Column(primary_id, primary_type,
primary_key=True,
autoincrement=autoincrement)
table.append_column(col)
table.create(self.engine)
table.create(self.executable)
self._tables[table_name] = table
return Table(self, table)
finally:
@ -227,13 +222,13 @@ class Database(object):
"""Reload a table schema from the database."""
table_name = self._valid_table_name(table_name)
self.metadata = MetaData(schema=self.schema)
self.metadata.bind = self.engine
self.metadata.reflect(self.engine)
self.metadata.bind = self.executable
self.metadata.reflect(self.executable)
self._tables[table_name] = SQLATable(table_name, self.metadata,
schema=self.schema)
return self._tables[table_name]
def get_table(self, table_name, primary_id='id', primary_type='Integer'):
def get_table(self, table_name, primary_id=None, primary_type=None):
"""
Smart wrapper around *load_table* and *create_table*.

View File

@ -5,7 +5,8 @@ from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import ClauseElement
from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false
from dataset.persistence.util import guess_type, normalize_column_name
from dataset.persistence.util import normalize_column_name
from dataset.persistence.util import ResultIter
from dataset.util import DatasetException
@ -15,6 +16,7 @@ log = logging.getLogger(__name__)
class Table(object):
"""Represents a table in a database and exposes common operations."""
PRIMARY_DEFAULT = 'id'
def __init__(self, database, table):
"""Initialise the table from database schema."""
@ -59,7 +61,7 @@ class Table(object):
# filter out keys not in column set
return {k: row[k] for k in row if k in self._normalized_columns}
def insert(self, row, ensure=None, types={}):
def insert(self, row, ensure=None, types=None):
"""
Add a row (type: dict) by inserting it into the table.
@ -88,7 +90,7 @@ class Table(object):
if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0]
def insert_ignore(self, row, keys, ensure=None, types={}):
def insert_ignore(self, row, keys, ensure=None, types=None):
"""
Add a row (type: dict) into the table if the row does not exist.
@ -113,7 +115,7 @@ class Table(object):
else:
return False
def insert_many(self, rows, chunk_size=1000, ensure=None, types={}):
def insert_many(self, rows, chunk_size=1000, ensure=None, types=None):
"""
Add many rows at a time.
@ -149,7 +151,7 @@ class Table(object):
if chunk:
_process_chunk(chunk)
def update(self, row, keys, ensure=None, types={}):
def update(self, row, keys, ensure=None, types=None):
"""
Update a row in the table.
@ -211,7 +213,7 @@ class Table(object):
filters[key] = row.get(key)
return row, self.find_one(**filters)
def upsert(self, row, keys, ensure=None, types={}):
def upsert(self, row, keys, ensure=None, types=None):
"""
An UPSERT is a smart combination of insert and update.
@ -260,17 +262,18 @@ class Table(object):
def _has_column(self, column):
return normalize_column_name(column) in self._normalized_columns
def _ensure_columns(self, row, types={}):
def _ensure_columns(self, row, types=None):
# Keep order of inserted columns
for column in row.keys():
if self._has_column(column):
continue
if column in types:
if types is not None and column in types:
_type = types[column]
else:
_type = guess_type(row[column])
_type = self.database.types.guess(row[column])
log.debug("Creating column: %s (%s) on %r" % (column,
_type, self.table.name))
_type,
self.table.name))
self.create_column(column, _type)
def _args_to_clause(self, args, ensure=None, clauses=()):
@ -294,13 +297,13 @@ class Table(object):
``type`` must be a `SQLAlchemy column type <http://docs.sqlalchemy.org/en/rel_0_8/core/types.html>`_.
::
table.create_column('created_at', sqlalchemy.DateTime)
table.create_column('created_at', db.types.datetime)
"""
self._check_dropped()
self.database._acquire()
try:
if normalize_column_name(name) in self._normalized_columns:
log.warn("Column with similar name exists: %s" % name)
log.debug("Column exists: %s" % name)
return
self.database.op.add_column(
@ -321,7 +324,8 @@ class Table(object):
table.create_column_by_example('length', 4.2)
"""
self._ensure_columns({name: value}, {})
type_ = self.database.types.guess(value)
self.create_column(name, type_)
def drop_column(self, name):
"""

View File

@ -0,0 +1,37 @@
from datetime import datetime, date
from sqlalchemy import Integer, UnicodeText, Float, BigInteger
from sqlalchemy import Boolean, Date, DateTime, Unicode
from sqlalchemy.types import TypeEngine
class Types(object):
"""A holder class for easy access to SQLAlchemy type names."""
integer = Integer
string = Unicode
text = UnicodeText
float = Float
bigint = BigInteger
boolean = Boolean
date = Date
datetime = DateTime
def guess(cls, sample):
"""Given a single sample, guess the column type for the field.
If the sample is an instance of an SQLAlchemy type, the type will be
used instead.
"""
if isinstance(sample, TypeEngine):
return sample
if isinstance(sample, bool):
return cls.boolean
elif isinstance(sample, int):
return cls.integer
elif isinstance(sample, float):
return cls.float
elif isinstance(sample, datetime):
return cls.datetime
elif isinstance(sample, date):
return cls.date
return cls.text

View File

@ -1,4 +1,3 @@
from datetime import datetime, date
try:
from urlparse import urlparse
except ImportError:
@ -9,26 +8,11 @@ try:
except ImportError: # pragma: no cover
from ordereddict import OrderedDict
from sqlalchemy import Integer, UnicodeText, Float, Date, DateTime, Boolean
from six import string_types
row_type = OrderedDict
def guess_type(sample):
if isinstance(sample, bool):
return Boolean
elif isinstance(sample, int):
return Integer
elif isinstance(sample, float):
return Float
elif isinstance(sample, datetime):
return DateTime
elif isinstance(sample, date):
return Date
return UnicodeText
def convert_row(row_type, row):
if row is None:
return None

View File

@ -50,25 +50,23 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_custom_id1(self):
pid = "string_id"
table = self.db.create_table("foo2", pid, 'String')
table = self.db.create_table("foo2", pid, self.db.types.string)
assert table.table.exists()
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c
table.insert({
'string_id': 'foobar'})
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
table.insert({pid: 'foobar'})
assert table.find_one(string_id='foobar')[pid] == 'foobar'
def test_create_table_custom_id2(self):
pid = "string_id"
table = self.db.create_table("foo3", pid, 'String(50)')
table = self.db.create_table("foo3", pid, self.db.types.string(50))
assert table.table.exists()
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c
table.insert({
'string_id': 'foobar'})
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
table.insert({pid: 'foobar'})
assert table.find_one(string_id='foobar')[pid] == 'foobar'
def test_create_table_custom_id3(self):
pid = "int_id"
@ -77,11 +75,11 @@ class DatabaseTestCase(unittest.TestCase):
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c
table.insert({'int_id': 123})
table.insert({'int_id': 124})
assert table.find_one(int_id=123)['int_id'] == 123
assert table.find_one(int_id=124)['int_id'] == 124
self.assertRaises(IntegrityError, lambda: table.insert({'int_id': 123}))
table.insert({pid: 123})
table.insert({pid: 124})
assert table.find_one(int_id=123)[pid] == 123
assert table.find_one(int_id=124)[pid] == 124
self.assertRaises(IntegrityError, lambda: table.insert({pid: 123}))
def test_create_table_shorthand1(self):
pid = "int_id"
@ -98,7 +96,8 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_shorthand2(self):
pid = "string_id"
table = self.db.get_table('foo6', primary_id=pid, primary_type='String')
table = self.db.get_table('foo6', primary_id=pid,
primary_type=self.db.types.string)
assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c
@ -109,7 +108,8 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_shorthand3(self):
pid = "string_id"
table = self.db.get_table('foo7', primary_id=pid, primary_type='String(20)')
table = self.db.get_table('foo7', primary_id=pid,
primary_type=self.db.types.string(20))
assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c