Begin implementing a types handler instead of using plain text types.
this is potentially BREAKING scripts which use the string syntax.
This commit is contained in:
parent
2f0fa7cdd2
commit
a4c73a8fb8
@ -1,12 +1,10 @@
|
||||
import logging
|
||||
import threading
|
||||
import re
|
||||
|
||||
import six
|
||||
from six.moves.urllib.parse import parse_qs, urlparse
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import Integer, String
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.schema import MetaData, Column
|
||||
from sqlalchemy.schema import Table as SQLATable
|
||||
@ -18,7 +16,7 @@ from alembic.operations import Operations
|
||||
|
||||
from dataset.persistence.table import Table
|
||||
from dataset.persistence.util import ResultIter, row_type, safe_url
|
||||
from dataset.util import DatasetException
|
||||
from dataset.persistence.types import Types
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -49,6 +47,7 @@ class Database(object):
|
||||
if len(schema_qs):
|
||||
schema = schema_qs.pop()
|
||||
|
||||
self.types = Types()
|
||||
self.schema = schema
|
||||
self.engine = create_engine(url, **engine_kwargs)
|
||||
self.url = url
|
||||
@ -145,9 +144,8 @@ class Database(object):
|
||||
raise ValueError("Invalid table name: %r" % table_name)
|
||||
return table_name.strip()
|
||||
|
||||
def create_table(self, table_name, primary_id='id', primary_type='Integer'):
|
||||
"""
|
||||
Create a new table.
|
||||
def create_table(self, table_name, primary_id=None, primary_type=None):
|
||||
"""Create a new table.
|
||||
|
||||
The new table will automatically have an `id` column unless specified via
|
||||
optional parameter primary_id, which will be used as the primary key of the
|
||||
@ -166,34 +164,31 @@ class Database(object):
|
||||
|
||||
# custom id and type
|
||||
table2 = db.create_table('population2', 'age')
|
||||
table3 = db.create_table('population3', primary_id='race', primary_type='String')
|
||||
table3 = db.create_table('population3',
|
||||
primary_id='city',
|
||||
primary_type=db.types.string)
|
||||
# custom length of String
|
||||
table4 = db.create_table('population4', primary_id='race', primary_type='String(50)')
|
||||
table4 = db.create_table('population4',
|
||||
primary_id='city',
|
||||
primary_type=db.types.string(25))
|
||||
# no primary key
|
||||
table5 = db.create_table('population5',
|
||||
primary_id=False)
|
||||
"""
|
||||
table_name = self._valid_table_name(table_name)
|
||||
self._acquire()
|
||||
try:
|
||||
log.debug("Creating table: %s on %r" % (table_name, self.engine))
|
||||
match = re.match(r'^(Integer)$|^(String)(\(\d+\))?$', primary_type)
|
||||
if match:
|
||||
if match.group(1) == 'Integer':
|
||||
auto_flag = False
|
||||
if primary_id == 'id':
|
||||
auto_flag = True
|
||||
col = Column(primary_id, Integer, primary_key=True, autoincrement=auto_flag)
|
||||
elif not match.group(3):
|
||||
col = Column(primary_id, String(255), primary_key=True)
|
||||
else:
|
||||
len_string = int(match.group(3)[1:-1])
|
||||
len_string = min(len_string, 255)
|
||||
col = Column(primary_id, String(len_string), primary_key=True)
|
||||
else:
|
||||
raise DatasetException(
|
||||
"The primary_type has to be either 'Integer' or 'String'.")
|
||||
|
||||
log.debug("Creating table: %s" % (table_name))
|
||||
table = SQLATable(table_name, self.metadata, schema=self.schema)
|
||||
table.append_column(col)
|
||||
table.create(self.engine)
|
||||
if primary_id is not False:
|
||||
primary_id = primary_id or Table.PRIMARY_DEFAULT
|
||||
primary_type = primary_type or self.types.integer
|
||||
autoincrement = primary_id == Table.PRIMARY_DEFAULT
|
||||
col = Column(primary_id, primary_type,
|
||||
primary_key=True,
|
||||
autoincrement=autoincrement)
|
||||
table.append_column(col)
|
||||
table.create(self.executable)
|
||||
self._tables[table_name] = table
|
||||
return Table(self, table)
|
||||
finally:
|
||||
@ -227,13 +222,13 @@ class Database(object):
|
||||
"""Reload a table schema from the database."""
|
||||
table_name = self._valid_table_name(table_name)
|
||||
self.metadata = MetaData(schema=self.schema)
|
||||
self.metadata.bind = self.engine
|
||||
self.metadata.reflect(self.engine)
|
||||
self.metadata.bind = self.executable
|
||||
self.metadata.reflect(self.executable)
|
||||
self._tables[table_name] = SQLATable(table_name, self.metadata,
|
||||
schema=self.schema)
|
||||
return self._tables[table_name]
|
||||
|
||||
def get_table(self, table_name, primary_id='id', primary_type='Integer'):
|
||||
def get_table(self, table_name, primary_id=None, primary_type=None):
|
||||
"""
|
||||
Smart wrapper around *load_table* and *create_table*.
|
||||
|
||||
|
||||
@ -5,7 +5,8 @@ from sqlalchemy.sql import and_, expression
|
||||
from sqlalchemy.sql.expression import ClauseElement
|
||||
from sqlalchemy.schema import Column, Index
|
||||
from sqlalchemy import func, select, false
|
||||
from dataset.persistence.util import guess_type, normalize_column_name
|
||||
|
||||
from dataset.persistence.util import normalize_column_name
|
||||
from dataset.persistence.util import ResultIter
|
||||
from dataset.util import DatasetException
|
||||
|
||||
@ -15,6 +16,7 @@ log = logging.getLogger(__name__)
|
||||
|
||||
class Table(object):
|
||||
"""Represents a table in a database and exposes common operations."""
|
||||
PRIMARY_DEFAULT = 'id'
|
||||
|
||||
def __init__(self, database, table):
|
||||
"""Initialise the table from database schema."""
|
||||
@ -59,7 +61,7 @@ class Table(object):
|
||||
# filter out keys not in column set
|
||||
return {k: row[k] for k in row if k in self._normalized_columns}
|
||||
|
||||
def insert(self, row, ensure=None, types={}):
|
||||
def insert(self, row, ensure=None, types=None):
|
||||
"""
|
||||
Add a row (type: dict) by inserting it into the table.
|
||||
|
||||
@ -88,7 +90,7 @@ class Table(object):
|
||||
if len(res.inserted_primary_key) > 0:
|
||||
return res.inserted_primary_key[0]
|
||||
|
||||
def insert_ignore(self, row, keys, ensure=None, types={}):
|
||||
def insert_ignore(self, row, keys, ensure=None, types=None):
|
||||
"""
|
||||
Add a row (type: dict) into the table if the row does not exist.
|
||||
|
||||
@ -113,7 +115,7 @@ class Table(object):
|
||||
else:
|
||||
return False
|
||||
|
||||
def insert_many(self, rows, chunk_size=1000, ensure=None, types={}):
|
||||
def insert_many(self, rows, chunk_size=1000, ensure=None, types=None):
|
||||
"""
|
||||
Add many rows at a time.
|
||||
|
||||
@ -149,7 +151,7 @@ class Table(object):
|
||||
if chunk:
|
||||
_process_chunk(chunk)
|
||||
|
||||
def update(self, row, keys, ensure=None, types={}):
|
||||
def update(self, row, keys, ensure=None, types=None):
|
||||
"""
|
||||
Update a row in the table.
|
||||
|
||||
@ -211,7 +213,7 @@ class Table(object):
|
||||
filters[key] = row.get(key)
|
||||
return row, self.find_one(**filters)
|
||||
|
||||
def upsert(self, row, keys, ensure=None, types={}):
|
||||
def upsert(self, row, keys, ensure=None, types=None):
|
||||
"""
|
||||
An UPSERT is a smart combination of insert and update.
|
||||
|
||||
@ -260,17 +262,18 @@ class Table(object):
|
||||
def _has_column(self, column):
|
||||
return normalize_column_name(column) in self._normalized_columns
|
||||
|
||||
def _ensure_columns(self, row, types={}):
|
||||
def _ensure_columns(self, row, types=None):
|
||||
# Keep order of inserted columns
|
||||
for column in row.keys():
|
||||
if self._has_column(column):
|
||||
continue
|
||||
if column in types:
|
||||
if types is not None and column in types:
|
||||
_type = types[column]
|
||||
else:
|
||||
_type = guess_type(row[column])
|
||||
_type = self.database.types.guess(row[column])
|
||||
log.debug("Creating column: %s (%s) on %r" % (column,
|
||||
_type, self.table.name))
|
||||
_type,
|
||||
self.table.name))
|
||||
self.create_column(column, _type)
|
||||
|
||||
def _args_to_clause(self, args, ensure=None, clauses=()):
|
||||
@ -294,13 +297,13 @@ class Table(object):
|
||||
``type`` must be a `SQLAlchemy column type <http://docs.sqlalchemy.org/en/rel_0_8/core/types.html>`_.
|
||||
::
|
||||
|
||||
table.create_column('created_at', sqlalchemy.DateTime)
|
||||
table.create_column('created_at', db.types.datetime)
|
||||
"""
|
||||
self._check_dropped()
|
||||
self.database._acquire()
|
||||
try:
|
||||
if normalize_column_name(name) in self._normalized_columns:
|
||||
log.warn("Column with similar name exists: %s" % name)
|
||||
log.debug("Column exists: %s" % name)
|
||||
return
|
||||
|
||||
self.database.op.add_column(
|
||||
@ -321,7 +324,8 @@ class Table(object):
|
||||
|
||||
table.create_column_by_example('length', 4.2)
|
||||
"""
|
||||
self._ensure_columns({name: value}, {})
|
||||
type_ = self.database.types.guess(value)
|
||||
self.create_column(name, type_)
|
||||
|
||||
def drop_column(self, name):
|
||||
"""
|
||||
|
||||
37
dataset/persistence/types.py
Normal file
37
dataset/persistence/types.py
Normal file
@ -0,0 +1,37 @@
|
||||
from datetime import datetime, date
|
||||
|
||||
from sqlalchemy import Integer, UnicodeText, Float, BigInteger
|
||||
from sqlalchemy import Boolean, Date, DateTime, Unicode
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
|
||||
class Types(object):
|
||||
"""A holder class for easy access to SQLAlchemy type names."""
|
||||
integer = Integer
|
||||
string = Unicode
|
||||
text = UnicodeText
|
||||
float = Float
|
||||
bigint = BigInteger
|
||||
boolean = Boolean
|
||||
date = Date
|
||||
datetime = DateTime
|
||||
|
||||
def guess(cls, sample):
|
||||
"""Given a single sample, guess the column type for the field.
|
||||
|
||||
If the sample is an instance of an SQLAlchemy type, the type will be
|
||||
used instead.
|
||||
"""
|
||||
if isinstance(sample, TypeEngine):
|
||||
return sample
|
||||
if isinstance(sample, bool):
|
||||
return cls.boolean
|
||||
elif isinstance(sample, int):
|
||||
return cls.integer
|
||||
elif isinstance(sample, float):
|
||||
return cls.float
|
||||
elif isinstance(sample, datetime):
|
||||
return cls.datetime
|
||||
elif isinstance(sample, date):
|
||||
return cls.date
|
||||
return cls.text
|
||||
@ -1,4 +1,3 @@
|
||||
from datetime import datetime, date
|
||||
try:
|
||||
from urlparse import urlparse
|
||||
except ImportError:
|
||||
@ -9,26 +8,11 @@ try:
|
||||
except ImportError: # pragma: no cover
|
||||
from ordereddict import OrderedDict
|
||||
|
||||
from sqlalchemy import Integer, UnicodeText, Float, Date, DateTime, Boolean
|
||||
from six import string_types
|
||||
|
||||
row_type = OrderedDict
|
||||
|
||||
|
||||
def guess_type(sample):
|
||||
if isinstance(sample, bool):
|
||||
return Boolean
|
||||
elif isinstance(sample, int):
|
||||
return Integer
|
||||
elif isinstance(sample, float):
|
||||
return Float
|
||||
elif isinstance(sample, datetime):
|
||||
return DateTime
|
||||
elif isinstance(sample, date):
|
||||
return Date
|
||||
return UnicodeText
|
||||
|
||||
|
||||
def convert_row(row_type, row):
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
@ -50,25 +50,23 @@ class DatabaseTestCase(unittest.TestCase):
|
||||
|
||||
def test_create_table_custom_id1(self):
|
||||
pid = "string_id"
|
||||
table = self.db.create_table("foo2", pid, 'String')
|
||||
table = self.db.create_table("foo2", pid, self.db.types.string)
|
||||
assert table.table.exists()
|
||||
assert len(table.table.columns) == 1, table.table.columns
|
||||
assert pid in table.table.c, table.table.c
|
||||
|
||||
table.insert({
|
||||
'string_id': 'foobar'})
|
||||
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
|
||||
table.insert({pid: 'foobar'})
|
||||
assert table.find_one(string_id='foobar')[pid] == 'foobar'
|
||||
|
||||
def test_create_table_custom_id2(self):
|
||||
pid = "string_id"
|
||||
table = self.db.create_table("foo3", pid, 'String(50)')
|
||||
table = self.db.create_table("foo3", pid, self.db.types.string(50))
|
||||
assert table.table.exists()
|
||||
assert len(table.table.columns) == 1, table.table.columns
|
||||
assert pid in table.table.c, table.table.c
|
||||
|
||||
table.insert({
|
||||
'string_id': 'foobar'})
|
||||
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
|
||||
table.insert({pid: 'foobar'})
|
||||
assert table.find_one(string_id='foobar')[pid] == 'foobar'
|
||||
|
||||
def test_create_table_custom_id3(self):
|
||||
pid = "int_id"
|
||||
@ -77,11 +75,11 @@ class DatabaseTestCase(unittest.TestCase):
|
||||
assert len(table.table.columns) == 1, table.table.columns
|
||||
assert pid in table.table.c, table.table.c
|
||||
|
||||
table.insert({'int_id': 123})
|
||||
table.insert({'int_id': 124})
|
||||
assert table.find_one(int_id=123)['int_id'] == 123
|
||||
assert table.find_one(int_id=124)['int_id'] == 124
|
||||
self.assertRaises(IntegrityError, lambda: table.insert({'int_id': 123}))
|
||||
table.insert({pid: 123})
|
||||
table.insert({pid: 124})
|
||||
assert table.find_one(int_id=123)[pid] == 123
|
||||
assert table.find_one(int_id=124)[pid] == 124
|
||||
self.assertRaises(IntegrityError, lambda: table.insert({pid: 123}))
|
||||
|
||||
def test_create_table_shorthand1(self):
|
||||
pid = "int_id"
|
||||
@ -98,7 +96,8 @@ class DatabaseTestCase(unittest.TestCase):
|
||||
|
||||
def test_create_table_shorthand2(self):
|
||||
pid = "string_id"
|
||||
table = self.db.get_table('foo6', primary_id=pid, primary_type='String')
|
||||
table = self.db.get_table('foo6', primary_id=pid,
|
||||
primary_type=self.db.types.string)
|
||||
assert table.table.exists
|
||||
assert len(table.table.columns) == 1, table.table.columns
|
||||
assert pid in table.table.c, table.table.c
|
||||
@ -109,7 +108,8 @@ class DatabaseTestCase(unittest.TestCase):
|
||||
|
||||
def test_create_table_shorthand3(self):
|
||||
pid = "string_id"
|
||||
table = self.db.get_table('foo7', primary_id=pid, primary_type='String(20)')
|
||||
table = self.db.get_table('foo7', primary_id=pid,
|
||||
primary_type=self.db.types.string(20))
|
||||
assert table.table.exists
|
||||
assert len(table.table.columns) == 1, table.table.columns
|
||||
assert pid in table.table.c, table.table.c
|
||||
|
||||
Loading…
Reference in New Issue
Block a user