From a4c73a8fb8b8d1b78528bdd786d7944cf95220e3 Mon Sep 17 00:00:00 2001 From: Friedrich Lindenberg Date: Sat, 2 Sep 2017 16:47:04 +0200 Subject: [PATCH] Begin implementing a types handler instead of using plain text types. this is potentially BREAKING scripts which use the string syntax. --- dataset/persistence/database.py | 57 +++++++++++++++------------------ dataset/persistence/table.py | 30 +++++++++-------- dataset/persistence/types.py | 37 +++++++++++++++++++++ dataset/persistence/util.py | 16 --------- test/test_persistence.py | 30 ++++++++--------- 5 files changed, 95 insertions(+), 75 deletions(-) create mode 100644 dataset/persistence/types.py diff --git a/dataset/persistence/database.py b/dataset/persistence/database.py index d48af22..67aa1c1 100644 --- a/dataset/persistence/database.py +++ b/dataset/persistence/database.py @@ -1,12 +1,10 @@ import logging import threading -import re import six from six.moves.urllib.parse import parse_qs, urlparse from sqlalchemy import create_engine -from sqlalchemy import Integer, String from sqlalchemy.sql import text from sqlalchemy.schema import MetaData, Column from sqlalchemy.schema import Table as SQLATable @@ -18,7 +16,7 @@ from alembic.operations import Operations from dataset.persistence.table import Table from dataset.persistence.util import ResultIter, row_type, safe_url -from dataset.util import DatasetException +from dataset.persistence.types import Types log = logging.getLogger(__name__) @@ -49,6 +47,7 @@ class Database(object): if len(schema_qs): schema = schema_qs.pop() + self.types = Types() self.schema = schema self.engine = create_engine(url, **engine_kwargs) self.url = url @@ -145,9 +144,8 @@ class Database(object): raise ValueError("Invalid table name: %r" % table_name) return table_name.strip() - def create_table(self, table_name, primary_id='id', primary_type='Integer'): - """ - Create a new table. + def create_table(self, table_name, primary_id=None, primary_type=None): + """Create a new table. The new table will automatically have an `id` column unless specified via optional parameter primary_id, which will be used as the primary key of the @@ -166,34 +164,31 @@ class Database(object): # custom id and type table2 = db.create_table('population2', 'age') - table3 = db.create_table('population3', primary_id='race', primary_type='String') + table3 = db.create_table('population3', + primary_id='city', + primary_type=db.types.string) # custom length of String - table4 = db.create_table('population4', primary_id='race', primary_type='String(50)') + table4 = db.create_table('population4', + primary_id='city', + primary_type=db.types.string(25)) + # no primary key + table5 = db.create_table('population5', + primary_id=False) """ table_name = self._valid_table_name(table_name) self._acquire() try: - log.debug("Creating table: %s on %r" % (table_name, self.engine)) - match = re.match(r'^(Integer)$|^(String)(\(\d+\))?$', primary_type) - if match: - if match.group(1) == 'Integer': - auto_flag = False - if primary_id == 'id': - auto_flag = True - col = Column(primary_id, Integer, primary_key=True, autoincrement=auto_flag) - elif not match.group(3): - col = Column(primary_id, String(255), primary_key=True) - else: - len_string = int(match.group(3)[1:-1]) - len_string = min(len_string, 255) - col = Column(primary_id, String(len_string), primary_key=True) - else: - raise DatasetException( - "The primary_type has to be either 'Integer' or 'String'.") - + log.debug("Creating table: %s" % (table_name)) table = SQLATable(table_name, self.metadata, schema=self.schema) - table.append_column(col) - table.create(self.engine) + if primary_id is not False: + primary_id = primary_id or Table.PRIMARY_DEFAULT + primary_type = primary_type or self.types.integer + autoincrement = primary_id == Table.PRIMARY_DEFAULT + col = Column(primary_id, primary_type, + primary_key=True, + autoincrement=autoincrement) + table.append_column(col) + table.create(self.executable) self._tables[table_name] = table return Table(self, table) finally: @@ -227,13 +222,13 @@ class Database(object): """Reload a table schema from the database.""" table_name = self._valid_table_name(table_name) self.metadata = MetaData(schema=self.schema) - self.metadata.bind = self.engine - self.metadata.reflect(self.engine) + self.metadata.bind = self.executable + self.metadata.reflect(self.executable) self._tables[table_name] = SQLATable(table_name, self.metadata, schema=self.schema) return self._tables[table_name] - def get_table(self, table_name, primary_id='id', primary_type='Integer'): + def get_table(self, table_name, primary_id=None, primary_type=None): """ Smart wrapper around *load_table* and *create_table*. diff --git a/dataset/persistence/table.py b/dataset/persistence/table.py index 2153b28..a4187e9 100644 --- a/dataset/persistence/table.py +++ b/dataset/persistence/table.py @@ -5,7 +5,8 @@ from sqlalchemy.sql import and_, expression from sqlalchemy.sql.expression import ClauseElement from sqlalchemy.schema import Column, Index from sqlalchemy import func, select, false -from dataset.persistence.util import guess_type, normalize_column_name + +from dataset.persistence.util import normalize_column_name from dataset.persistence.util import ResultIter from dataset.util import DatasetException @@ -15,6 +16,7 @@ log = logging.getLogger(__name__) class Table(object): """Represents a table in a database and exposes common operations.""" + PRIMARY_DEFAULT = 'id' def __init__(self, database, table): """Initialise the table from database schema.""" @@ -59,7 +61,7 @@ class Table(object): # filter out keys not in column set return {k: row[k] for k in row if k in self._normalized_columns} - def insert(self, row, ensure=None, types={}): + def insert(self, row, ensure=None, types=None): """ Add a row (type: dict) by inserting it into the table. @@ -88,7 +90,7 @@ class Table(object): if len(res.inserted_primary_key) > 0: return res.inserted_primary_key[0] - def insert_ignore(self, row, keys, ensure=None, types={}): + def insert_ignore(self, row, keys, ensure=None, types=None): """ Add a row (type: dict) into the table if the row does not exist. @@ -113,7 +115,7 @@ class Table(object): else: return False - def insert_many(self, rows, chunk_size=1000, ensure=None, types={}): + def insert_many(self, rows, chunk_size=1000, ensure=None, types=None): """ Add many rows at a time. @@ -149,7 +151,7 @@ class Table(object): if chunk: _process_chunk(chunk) - def update(self, row, keys, ensure=None, types={}): + def update(self, row, keys, ensure=None, types=None): """ Update a row in the table. @@ -211,7 +213,7 @@ class Table(object): filters[key] = row.get(key) return row, self.find_one(**filters) - def upsert(self, row, keys, ensure=None, types={}): + def upsert(self, row, keys, ensure=None, types=None): """ An UPSERT is a smart combination of insert and update. @@ -260,17 +262,18 @@ class Table(object): def _has_column(self, column): return normalize_column_name(column) in self._normalized_columns - def _ensure_columns(self, row, types={}): + def _ensure_columns(self, row, types=None): # Keep order of inserted columns for column in row.keys(): if self._has_column(column): continue - if column in types: + if types is not None and column in types: _type = types[column] else: - _type = guess_type(row[column]) + _type = self.database.types.guess(row[column]) log.debug("Creating column: %s (%s) on %r" % (column, - _type, self.table.name)) + _type, + self.table.name)) self.create_column(column, _type) def _args_to_clause(self, args, ensure=None, clauses=()): @@ -294,13 +297,13 @@ class Table(object): ``type`` must be a `SQLAlchemy column type `_. :: - table.create_column('created_at', sqlalchemy.DateTime) + table.create_column('created_at', db.types.datetime) """ self._check_dropped() self.database._acquire() try: if normalize_column_name(name) in self._normalized_columns: - log.warn("Column with similar name exists: %s" % name) + log.debug("Column exists: %s" % name) return self.database.op.add_column( @@ -321,7 +324,8 @@ class Table(object): table.create_column_by_example('length', 4.2) """ - self._ensure_columns({name: value}, {}) + type_ = self.database.types.guess(value) + self.create_column(name, type_) def drop_column(self, name): """ diff --git a/dataset/persistence/types.py b/dataset/persistence/types.py new file mode 100644 index 0000000..20606ce --- /dev/null +++ b/dataset/persistence/types.py @@ -0,0 +1,37 @@ +from datetime import datetime, date + +from sqlalchemy import Integer, UnicodeText, Float, BigInteger +from sqlalchemy import Boolean, Date, DateTime, Unicode +from sqlalchemy.types import TypeEngine + + +class Types(object): + """A holder class for easy access to SQLAlchemy type names.""" + integer = Integer + string = Unicode + text = UnicodeText + float = Float + bigint = BigInteger + boolean = Boolean + date = Date + datetime = DateTime + + def guess(cls, sample): + """Given a single sample, guess the column type for the field. + + If the sample is an instance of an SQLAlchemy type, the type will be + used instead. + """ + if isinstance(sample, TypeEngine): + return sample + if isinstance(sample, bool): + return cls.boolean + elif isinstance(sample, int): + return cls.integer + elif isinstance(sample, float): + return cls.float + elif isinstance(sample, datetime): + return cls.datetime + elif isinstance(sample, date): + return cls.date + return cls.text diff --git a/dataset/persistence/util.py b/dataset/persistence/util.py index 41fbce2..c354339 100644 --- a/dataset/persistence/util.py +++ b/dataset/persistence/util.py @@ -1,4 +1,3 @@ -from datetime import datetime, date try: from urlparse import urlparse except ImportError: @@ -9,26 +8,11 @@ try: except ImportError: # pragma: no cover from ordereddict import OrderedDict -from sqlalchemy import Integer, UnicodeText, Float, Date, DateTime, Boolean from six import string_types row_type = OrderedDict -def guess_type(sample): - if isinstance(sample, bool): - return Boolean - elif isinstance(sample, int): - return Integer - elif isinstance(sample, float): - return Float - elif isinstance(sample, datetime): - return DateTime - elif isinstance(sample, date): - return Date - return UnicodeText - - def convert_row(row_type, row): if row is None: return None diff --git a/test/test_persistence.py b/test/test_persistence.py index 1e6ab92..d868e33 100644 --- a/test/test_persistence.py +++ b/test/test_persistence.py @@ -50,25 +50,23 @@ class DatabaseTestCase(unittest.TestCase): def test_create_table_custom_id1(self): pid = "string_id" - table = self.db.create_table("foo2", pid, 'String') + table = self.db.create_table("foo2", pid, self.db.types.string) assert table.table.exists() assert len(table.table.columns) == 1, table.table.columns assert pid in table.table.c, table.table.c - table.insert({ - 'string_id': 'foobar'}) - assert table.find_one(string_id='foobar')['string_id'] == 'foobar' + table.insert({pid: 'foobar'}) + assert table.find_one(string_id='foobar')[pid] == 'foobar' def test_create_table_custom_id2(self): pid = "string_id" - table = self.db.create_table("foo3", pid, 'String(50)') + table = self.db.create_table("foo3", pid, self.db.types.string(50)) assert table.table.exists() assert len(table.table.columns) == 1, table.table.columns assert pid in table.table.c, table.table.c - table.insert({ - 'string_id': 'foobar'}) - assert table.find_one(string_id='foobar')['string_id'] == 'foobar' + table.insert({pid: 'foobar'}) + assert table.find_one(string_id='foobar')[pid] == 'foobar' def test_create_table_custom_id3(self): pid = "int_id" @@ -77,11 +75,11 @@ class DatabaseTestCase(unittest.TestCase): assert len(table.table.columns) == 1, table.table.columns assert pid in table.table.c, table.table.c - table.insert({'int_id': 123}) - table.insert({'int_id': 124}) - assert table.find_one(int_id=123)['int_id'] == 123 - assert table.find_one(int_id=124)['int_id'] == 124 - self.assertRaises(IntegrityError, lambda: table.insert({'int_id': 123})) + table.insert({pid: 123}) + table.insert({pid: 124}) + assert table.find_one(int_id=123)[pid] == 123 + assert table.find_one(int_id=124)[pid] == 124 + self.assertRaises(IntegrityError, lambda: table.insert({pid: 123})) def test_create_table_shorthand1(self): pid = "int_id" @@ -98,7 +96,8 @@ class DatabaseTestCase(unittest.TestCase): def test_create_table_shorthand2(self): pid = "string_id" - table = self.db.get_table('foo6', primary_id=pid, primary_type='String') + table = self.db.get_table('foo6', primary_id=pid, + primary_type=self.db.types.string) assert table.table.exists assert len(table.table.columns) == 1, table.table.columns assert pid in table.table.c, table.table.c @@ -109,7 +108,8 @@ class DatabaseTestCase(unittest.TestCase): def test_create_table_shorthand3(self): pid = "string_id" - table = self.db.get_table('foo7', primary_id=pid, primary_type='String(20)') + table = self.db.get_table('foo7', primary_id=pid, + primary_type=self.db.types.string(20)) assert table.table.exists assert len(table.table.columns) == 1, table.table.columns assert pid in table.table.c, table.table.c