Begin implementing a types handler instead of using plain text types.
this is potentially BREAKING scripts which use the string syntax.
This commit is contained in:
parent
2f0fa7cdd2
commit
a4c73a8fb8
@ -1,12 +1,10 @@
|
|||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
import re
|
|
||||||
|
|
||||||
import six
|
import six
|
||||||
from six.moves.urllib.parse import parse_qs, urlparse
|
from six.moves.urllib.parse import parse_qs, urlparse
|
||||||
|
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
from sqlalchemy import Integer, String
|
|
||||||
from sqlalchemy.sql import text
|
from sqlalchemy.sql import text
|
||||||
from sqlalchemy.schema import MetaData, Column
|
from sqlalchemy.schema import MetaData, Column
|
||||||
from sqlalchemy.schema import Table as SQLATable
|
from sqlalchemy.schema import Table as SQLATable
|
||||||
@ -18,7 +16,7 @@ from alembic.operations import Operations
|
|||||||
|
|
||||||
from dataset.persistence.table import Table
|
from dataset.persistence.table import Table
|
||||||
from dataset.persistence.util import ResultIter, row_type, safe_url
|
from dataset.persistence.util import ResultIter, row_type, safe_url
|
||||||
from dataset.util import DatasetException
|
from dataset.persistence.types import Types
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -49,6 +47,7 @@ class Database(object):
|
|||||||
if len(schema_qs):
|
if len(schema_qs):
|
||||||
schema = schema_qs.pop()
|
schema = schema_qs.pop()
|
||||||
|
|
||||||
|
self.types = Types()
|
||||||
self.schema = schema
|
self.schema = schema
|
||||||
self.engine = create_engine(url, **engine_kwargs)
|
self.engine = create_engine(url, **engine_kwargs)
|
||||||
self.url = url
|
self.url = url
|
||||||
@ -145,9 +144,8 @@ class Database(object):
|
|||||||
raise ValueError("Invalid table name: %r" % table_name)
|
raise ValueError("Invalid table name: %r" % table_name)
|
||||||
return table_name.strip()
|
return table_name.strip()
|
||||||
|
|
||||||
def create_table(self, table_name, primary_id='id', primary_type='Integer'):
|
def create_table(self, table_name, primary_id=None, primary_type=None):
|
||||||
"""
|
"""Create a new table.
|
||||||
Create a new table.
|
|
||||||
|
|
||||||
The new table will automatically have an `id` column unless specified via
|
The new table will automatically have an `id` column unless specified via
|
||||||
optional parameter primary_id, which will be used as the primary key of the
|
optional parameter primary_id, which will be used as the primary key of the
|
||||||
@ -166,34 +164,31 @@ class Database(object):
|
|||||||
|
|
||||||
# custom id and type
|
# custom id and type
|
||||||
table2 = db.create_table('population2', 'age')
|
table2 = db.create_table('population2', 'age')
|
||||||
table3 = db.create_table('population3', primary_id='race', primary_type='String')
|
table3 = db.create_table('population3',
|
||||||
|
primary_id='city',
|
||||||
|
primary_type=db.types.string)
|
||||||
# custom length of String
|
# custom length of String
|
||||||
table4 = db.create_table('population4', primary_id='race', primary_type='String(50)')
|
table4 = db.create_table('population4',
|
||||||
|
primary_id='city',
|
||||||
|
primary_type=db.types.string(25))
|
||||||
|
# no primary key
|
||||||
|
table5 = db.create_table('population5',
|
||||||
|
primary_id=False)
|
||||||
"""
|
"""
|
||||||
table_name = self._valid_table_name(table_name)
|
table_name = self._valid_table_name(table_name)
|
||||||
self._acquire()
|
self._acquire()
|
||||||
try:
|
try:
|
||||||
log.debug("Creating table: %s on %r" % (table_name, self.engine))
|
log.debug("Creating table: %s" % (table_name))
|
||||||
match = re.match(r'^(Integer)$|^(String)(\(\d+\))?$', primary_type)
|
|
||||||
if match:
|
|
||||||
if match.group(1) == 'Integer':
|
|
||||||
auto_flag = False
|
|
||||||
if primary_id == 'id':
|
|
||||||
auto_flag = True
|
|
||||||
col = Column(primary_id, Integer, primary_key=True, autoincrement=auto_flag)
|
|
||||||
elif not match.group(3):
|
|
||||||
col = Column(primary_id, String(255), primary_key=True)
|
|
||||||
else:
|
|
||||||
len_string = int(match.group(3)[1:-1])
|
|
||||||
len_string = min(len_string, 255)
|
|
||||||
col = Column(primary_id, String(len_string), primary_key=True)
|
|
||||||
else:
|
|
||||||
raise DatasetException(
|
|
||||||
"The primary_type has to be either 'Integer' or 'String'.")
|
|
||||||
|
|
||||||
table = SQLATable(table_name, self.metadata, schema=self.schema)
|
table = SQLATable(table_name, self.metadata, schema=self.schema)
|
||||||
|
if primary_id is not False:
|
||||||
|
primary_id = primary_id or Table.PRIMARY_DEFAULT
|
||||||
|
primary_type = primary_type or self.types.integer
|
||||||
|
autoincrement = primary_id == Table.PRIMARY_DEFAULT
|
||||||
|
col = Column(primary_id, primary_type,
|
||||||
|
primary_key=True,
|
||||||
|
autoincrement=autoincrement)
|
||||||
table.append_column(col)
|
table.append_column(col)
|
||||||
table.create(self.engine)
|
table.create(self.executable)
|
||||||
self._tables[table_name] = table
|
self._tables[table_name] = table
|
||||||
return Table(self, table)
|
return Table(self, table)
|
||||||
finally:
|
finally:
|
||||||
@ -227,13 +222,13 @@ class Database(object):
|
|||||||
"""Reload a table schema from the database."""
|
"""Reload a table schema from the database."""
|
||||||
table_name = self._valid_table_name(table_name)
|
table_name = self._valid_table_name(table_name)
|
||||||
self.metadata = MetaData(schema=self.schema)
|
self.metadata = MetaData(schema=self.schema)
|
||||||
self.metadata.bind = self.engine
|
self.metadata.bind = self.executable
|
||||||
self.metadata.reflect(self.engine)
|
self.metadata.reflect(self.executable)
|
||||||
self._tables[table_name] = SQLATable(table_name, self.metadata,
|
self._tables[table_name] = SQLATable(table_name, self.metadata,
|
||||||
schema=self.schema)
|
schema=self.schema)
|
||||||
return self._tables[table_name]
|
return self._tables[table_name]
|
||||||
|
|
||||||
def get_table(self, table_name, primary_id='id', primary_type='Integer'):
|
def get_table(self, table_name, primary_id=None, primary_type=None):
|
||||||
"""
|
"""
|
||||||
Smart wrapper around *load_table* and *create_table*.
|
Smart wrapper around *load_table* and *create_table*.
|
||||||
|
|
||||||
|
|||||||
@ -5,7 +5,8 @@ from sqlalchemy.sql import and_, expression
|
|||||||
from sqlalchemy.sql.expression import ClauseElement
|
from sqlalchemy.sql.expression import ClauseElement
|
||||||
from sqlalchemy.schema import Column, Index
|
from sqlalchemy.schema import Column, Index
|
||||||
from sqlalchemy import func, select, false
|
from sqlalchemy import func, select, false
|
||||||
from dataset.persistence.util import guess_type, normalize_column_name
|
|
||||||
|
from dataset.persistence.util import normalize_column_name
|
||||||
from dataset.persistence.util import ResultIter
|
from dataset.persistence.util import ResultIter
|
||||||
from dataset.util import DatasetException
|
from dataset.util import DatasetException
|
||||||
|
|
||||||
@ -15,6 +16,7 @@ log = logging.getLogger(__name__)
|
|||||||
|
|
||||||
class Table(object):
|
class Table(object):
|
||||||
"""Represents a table in a database and exposes common operations."""
|
"""Represents a table in a database and exposes common operations."""
|
||||||
|
PRIMARY_DEFAULT = 'id'
|
||||||
|
|
||||||
def __init__(self, database, table):
|
def __init__(self, database, table):
|
||||||
"""Initialise the table from database schema."""
|
"""Initialise the table from database schema."""
|
||||||
@ -59,7 +61,7 @@ class Table(object):
|
|||||||
# filter out keys not in column set
|
# filter out keys not in column set
|
||||||
return {k: row[k] for k in row if k in self._normalized_columns}
|
return {k: row[k] for k in row if k in self._normalized_columns}
|
||||||
|
|
||||||
def insert(self, row, ensure=None, types={}):
|
def insert(self, row, ensure=None, types=None):
|
||||||
"""
|
"""
|
||||||
Add a row (type: dict) by inserting it into the table.
|
Add a row (type: dict) by inserting it into the table.
|
||||||
|
|
||||||
@ -88,7 +90,7 @@ class Table(object):
|
|||||||
if len(res.inserted_primary_key) > 0:
|
if len(res.inserted_primary_key) > 0:
|
||||||
return res.inserted_primary_key[0]
|
return res.inserted_primary_key[0]
|
||||||
|
|
||||||
def insert_ignore(self, row, keys, ensure=None, types={}):
|
def insert_ignore(self, row, keys, ensure=None, types=None):
|
||||||
"""
|
"""
|
||||||
Add a row (type: dict) into the table if the row does not exist.
|
Add a row (type: dict) into the table if the row does not exist.
|
||||||
|
|
||||||
@ -113,7 +115,7 @@ class Table(object):
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def insert_many(self, rows, chunk_size=1000, ensure=None, types={}):
|
def insert_many(self, rows, chunk_size=1000, ensure=None, types=None):
|
||||||
"""
|
"""
|
||||||
Add many rows at a time.
|
Add many rows at a time.
|
||||||
|
|
||||||
@ -149,7 +151,7 @@ class Table(object):
|
|||||||
if chunk:
|
if chunk:
|
||||||
_process_chunk(chunk)
|
_process_chunk(chunk)
|
||||||
|
|
||||||
def update(self, row, keys, ensure=None, types={}):
|
def update(self, row, keys, ensure=None, types=None):
|
||||||
"""
|
"""
|
||||||
Update a row in the table.
|
Update a row in the table.
|
||||||
|
|
||||||
@ -211,7 +213,7 @@ class Table(object):
|
|||||||
filters[key] = row.get(key)
|
filters[key] = row.get(key)
|
||||||
return row, self.find_one(**filters)
|
return row, self.find_one(**filters)
|
||||||
|
|
||||||
def upsert(self, row, keys, ensure=None, types={}):
|
def upsert(self, row, keys, ensure=None, types=None):
|
||||||
"""
|
"""
|
||||||
An UPSERT is a smart combination of insert and update.
|
An UPSERT is a smart combination of insert and update.
|
||||||
|
|
||||||
@ -260,17 +262,18 @@ class Table(object):
|
|||||||
def _has_column(self, column):
|
def _has_column(self, column):
|
||||||
return normalize_column_name(column) in self._normalized_columns
|
return normalize_column_name(column) in self._normalized_columns
|
||||||
|
|
||||||
def _ensure_columns(self, row, types={}):
|
def _ensure_columns(self, row, types=None):
|
||||||
# Keep order of inserted columns
|
# Keep order of inserted columns
|
||||||
for column in row.keys():
|
for column in row.keys():
|
||||||
if self._has_column(column):
|
if self._has_column(column):
|
||||||
continue
|
continue
|
||||||
if column in types:
|
if types is not None and column in types:
|
||||||
_type = types[column]
|
_type = types[column]
|
||||||
else:
|
else:
|
||||||
_type = guess_type(row[column])
|
_type = self.database.types.guess(row[column])
|
||||||
log.debug("Creating column: %s (%s) on %r" % (column,
|
log.debug("Creating column: %s (%s) on %r" % (column,
|
||||||
_type, self.table.name))
|
_type,
|
||||||
|
self.table.name))
|
||||||
self.create_column(column, _type)
|
self.create_column(column, _type)
|
||||||
|
|
||||||
def _args_to_clause(self, args, ensure=None, clauses=()):
|
def _args_to_clause(self, args, ensure=None, clauses=()):
|
||||||
@ -294,13 +297,13 @@ class Table(object):
|
|||||||
``type`` must be a `SQLAlchemy column type <http://docs.sqlalchemy.org/en/rel_0_8/core/types.html>`_.
|
``type`` must be a `SQLAlchemy column type <http://docs.sqlalchemy.org/en/rel_0_8/core/types.html>`_.
|
||||||
::
|
::
|
||||||
|
|
||||||
table.create_column('created_at', sqlalchemy.DateTime)
|
table.create_column('created_at', db.types.datetime)
|
||||||
"""
|
"""
|
||||||
self._check_dropped()
|
self._check_dropped()
|
||||||
self.database._acquire()
|
self.database._acquire()
|
||||||
try:
|
try:
|
||||||
if normalize_column_name(name) in self._normalized_columns:
|
if normalize_column_name(name) in self._normalized_columns:
|
||||||
log.warn("Column with similar name exists: %s" % name)
|
log.debug("Column exists: %s" % name)
|
||||||
return
|
return
|
||||||
|
|
||||||
self.database.op.add_column(
|
self.database.op.add_column(
|
||||||
@ -321,7 +324,8 @@ class Table(object):
|
|||||||
|
|
||||||
table.create_column_by_example('length', 4.2)
|
table.create_column_by_example('length', 4.2)
|
||||||
"""
|
"""
|
||||||
self._ensure_columns({name: value}, {})
|
type_ = self.database.types.guess(value)
|
||||||
|
self.create_column(name, type_)
|
||||||
|
|
||||||
def drop_column(self, name):
|
def drop_column(self, name):
|
||||||
"""
|
"""
|
||||||
|
|||||||
37
dataset/persistence/types.py
Normal file
37
dataset/persistence/types.py
Normal file
@ -0,0 +1,37 @@
|
|||||||
|
from datetime import datetime, date
|
||||||
|
|
||||||
|
from sqlalchemy import Integer, UnicodeText, Float, BigInteger
|
||||||
|
from sqlalchemy import Boolean, Date, DateTime, Unicode
|
||||||
|
from sqlalchemy.types import TypeEngine
|
||||||
|
|
||||||
|
|
||||||
|
class Types(object):
|
||||||
|
"""A holder class for easy access to SQLAlchemy type names."""
|
||||||
|
integer = Integer
|
||||||
|
string = Unicode
|
||||||
|
text = UnicodeText
|
||||||
|
float = Float
|
||||||
|
bigint = BigInteger
|
||||||
|
boolean = Boolean
|
||||||
|
date = Date
|
||||||
|
datetime = DateTime
|
||||||
|
|
||||||
|
def guess(cls, sample):
|
||||||
|
"""Given a single sample, guess the column type for the field.
|
||||||
|
|
||||||
|
If the sample is an instance of an SQLAlchemy type, the type will be
|
||||||
|
used instead.
|
||||||
|
"""
|
||||||
|
if isinstance(sample, TypeEngine):
|
||||||
|
return sample
|
||||||
|
if isinstance(sample, bool):
|
||||||
|
return cls.boolean
|
||||||
|
elif isinstance(sample, int):
|
||||||
|
return cls.integer
|
||||||
|
elif isinstance(sample, float):
|
||||||
|
return cls.float
|
||||||
|
elif isinstance(sample, datetime):
|
||||||
|
return cls.datetime
|
||||||
|
elif isinstance(sample, date):
|
||||||
|
return cls.date
|
||||||
|
return cls.text
|
||||||
@ -1,4 +1,3 @@
|
|||||||
from datetime import datetime, date
|
|
||||||
try:
|
try:
|
||||||
from urlparse import urlparse
|
from urlparse import urlparse
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -9,26 +8,11 @@ try:
|
|||||||
except ImportError: # pragma: no cover
|
except ImportError: # pragma: no cover
|
||||||
from ordereddict import OrderedDict
|
from ordereddict import OrderedDict
|
||||||
|
|
||||||
from sqlalchemy import Integer, UnicodeText, Float, Date, DateTime, Boolean
|
|
||||||
from six import string_types
|
from six import string_types
|
||||||
|
|
||||||
row_type = OrderedDict
|
row_type = OrderedDict
|
||||||
|
|
||||||
|
|
||||||
def guess_type(sample):
|
|
||||||
if isinstance(sample, bool):
|
|
||||||
return Boolean
|
|
||||||
elif isinstance(sample, int):
|
|
||||||
return Integer
|
|
||||||
elif isinstance(sample, float):
|
|
||||||
return Float
|
|
||||||
elif isinstance(sample, datetime):
|
|
||||||
return DateTime
|
|
||||||
elif isinstance(sample, date):
|
|
||||||
return Date
|
|
||||||
return UnicodeText
|
|
||||||
|
|
||||||
|
|
||||||
def convert_row(row_type, row):
|
def convert_row(row_type, row):
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@ -50,25 +50,23 @@ class DatabaseTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_create_table_custom_id1(self):
|
def test_create_table_custom_id1(self):
|
||||||
pid = "string_id"
|
pid = "string_id"
|
||||||
table = self.db.create_table("foo2", pid, 'String')
|
table = self.db.create_table("foo2", pid, self.db.types.string)
|
||||||
assert table.table.exists()
|
assert table.table.exists()
|
||||||
assert len(table.table.columns) == 1, table.table.columns
|
assert len(table.table.columns) == 1, table.table.columns
|
||||||
assert pid in table.table.c, table.table.c
|
assert pid in table.table.c, table.table.c
|
||||||
|
|
||||||
table.insert({
|
table.insert({pid: 'foobar'})
|
||||||
'string_id': 'foobar'})
|
assert table.find_one(string_id='foobar')[pid] == 'foobar'
|
||||||
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
|
|
||||||
|
|
||||||
def test_create_table_custom_id2(self):
|
def test_create_table_custom_id2(self):
|
||||||
pid = "string_id"
|
pid = "string_id"
|
||||||
table = self.db.create_table("foo3", pid, 'String(50)')
|
table = self.db.create_table("foo3", pid, self.db.types.string(50))
|
||||||
assert table.table.exists()
|
assert table.table.exists()
|
||||||
assert len(table.table.columns) == 1, table.table.columns
|
assert len(table.table.columns) == 1, table.table.columns
|
||||||
assert pid in table.table.c, table.table.c
|
assert pid in table.table.c, table.table.c
|
||||||
|
|
||||||
table.insert({
|
table.insert({pid: 'foobar'})
|
||||||
'string_id': 'foobar'})
|
assert table.find_one(string_id='foobar')[pid] == 'foobar'
|
||||||
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
|
|
||||||
|
|
||||||
def test_create_table_custom_id3(self):
|
def test_create_table_custom_id3(self):
|
||||||
pid = "int_id"
|
pid = "int_id"
|
||||||
@ -77,11 +75,11 @@ class DatabaseTestCase(unittest.TestCase):
|
|||||||
assert len(table.table.columns) == 1, table.table.columns
|
assert len(table.table.columns) == 1, table.table.columns
|
||||||
assert pid in table.table.c, table.table.c
|
assert pid in table.table.c, table.table.c
|
||||||
|
|
||||||
table.insert({'int_id': 123})
|
table.insert({pid: 123})
|
||||||
table.insert({'int_id': 124})
|
table.insert({pid: 124})
|
||||||
assert table.find_one(int_id=123)['int_id'] == 123
|
assert table.find_one(int_id=123)[pid] == 123
|
||||||
assert table.find_one(int_id=124)['int_id'] == 124
|
assert table.find_one(int_id=124)[pid] == 124
|
||||||
self.assertRaises(IntegrityError, lambda: table.insert({'int_id': 123}))
|
self.assertRaises(IntegrityError, lambda: table.insert({pid: 123}))
|
||||||
|
|
||||||
def test_create_table_shorthand1(self):
|
def test_create_table_shorthand1(self):
|
||||||
pid = "int_id"
|
pid = "int_id"
|
||||||
@ -98,7 +96,8 @@ class DatabaseTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_create_table_shorthand2(self):
|
def test_create_table_shorthand2(self):
|
||||||
pid = "string_id"
|
pid = "string_id"
|
||||||
table = self.db.get_table('foo6', primary_id=pid, primary_type='String')
|
table = self.db.get_table('foo6', primary_id=pid,
|
||||||
|
primary_type=self.db.types.string)
|
||||||
assert table.table.exists
|
assert table.table.exists
|
||||||
assert len(table.table.columns) == 1, table.table.columns
|
assert len(table.table.columns) == 1, table.table.columns
|
||||||
assert pid in table.table.c, table.table.c
|
assert pid in table.table.c, table.table.c
|
||||||
@ -109,7 +108,8 @@ class DatabaseTestCase(unittest.TestCase):
|
|||||||
|
|
||||||
def test_create_table_shorthand3(self):
|
def test_create_table_shorthand3(self):
|
||||||
pid = "string_id"
|
pid = "string_id"
|
||||||
table = self.db.get_table('foo7', primary_id=pid, primary_type='String(20)')
|
table = self.db.get_table('foo7', primary_id=pid,
|
||||||
|
primary_type=self.db.types.string(20))
|
||||||
assert table.table.exists
|
assert table.table.exists
|
||||||
assert len(table.table.columns) == 1, table.table.columns
|
assert len(table.table.columns) == 1, table.table.columns
|
||||||
assert pid in table.table.c, table.table.c
|
assert pid in table.table.c, table.table.c
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user