diff --git a/dataset/database.py b/dataset/database.py index ba38db5..87bc2d5 100644 --- a/dataset/database.py +++ b/dataset/database.py @@ -47,7 +47,8 @@ class Database(object): self.schema = schema self.engine = create_engine(url, **engine_kwargs) - self.types = Types(self.engine.dialect.name) + self.is_postgres = self.engine.dialect.name == 'postgresql' + self.types = Types(is_postgres=self.is_postgres) self.url = url self.row_type = row_type self.ensure_schema = ensure_schema diff --git a/dataset/table.py b/dataset/table.py index 5c8026a..fae124b 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -1,16 +1,17 @@ import logging import warnings import threading +from banal import ensure_list +from sqlalchemy import func, select, false from sqlalchemy.sql import and_, expression from sqlalchemy.sql.expression import bindparam, ClauseElement from sqlalchemy.schema import Column, Index -from sqlalchemy import func, select, false from sqlalchemy.schema import Table as SQLATable from sqlalchemy.exc import NoSuchTableError from dataset.types import Types -from dataset.util import index_name, ensure_tuple +from dataset.util import index_name from dataset.util import DatasetException, ResultIter, QUERY_STEP from dataset.util import normalize_table_name, pad_chunk_columns from dataset.util import normalize_column_name, normalize_column_key @@ -216,8 +217,7 @@ class Table(object): See :py:meth:`update() ` for details on the other parameters. """ - # Convert keys to a list if not a list or tuple. - keys = keys if type(keys) in (list, tuple) else [keys] + keys = ensure_list(keys) chunk = [] columns = [] @@ -270,8 +270,7 @@ class Table(object): See :py:meth:`upsert() ` and :py:meth:`insert_many() `. """ - # Convert keys to a list if not a list or tuple. - keys = keys if type(keys) in (list, tuple) else [keys] + keys = ensure_list(keys) to_insert = [] to_update = [] @@ -442,7 +441,7 @@ class Table(object): def _args_to_order_by(self, order_by): orderings = [] - for ordering in ensure_tuple(order_by): + for ordering in ensure_list(order_by): if ordering is None: continue column = ordering.lstrip('-') @@ -456,8 +455,7 @@ class Table(object): return orderings def _keys_to_args(self, row, keys): - keys = ensure_tuple(keys) - keys = [self._get_column_name(k) for k in keys] + keys = [self._get_column_name(k) for k in ensure_list(keys)] row = row.copy() args = {k: row.pop(k, None) for k in keys} return args, row @@ -557,7 +555,7 @@ class Table(object): table.create_index(['name', 'country']) """ - columns = [self._get_column_name(c) for c in ensure_tuple(columns)] + columns = [self._get_column_name(c) for c in ensure_list(columns)] with self.db.lock: if not self.exists: raise DatasetException("Table has not been created yet.") diff --git a/dataset/types.py b/dataset/types.py index d30d3c1..9c73c0b 100644 --- a/dataset/types.py +++ b/dataset/types.py @@ -17,14 +17,8 @@ class Types(object): date = Date datetime = DateTime - def __init__(self, dialect=None): - self._dialect = dialect - - @property - def json(self): - if self._dialect is not None and self._dialect == 'postgresql': - return JSONB - return JSON + def __init__(self, is_postgres=None): + self.json = JSONB if is_postgres else JSON def guess(self, sample): """Given a single sample, guess the column type for the field. diff --git a/setup.py b/setup.py index ac9d336..d5a2501 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,8 @@ setup( zip_safe=False, install_requires=[ 'sqlalchemy >= 1.3.2', - 'alembic >= 0.6.2' + 'alembic >= 0.6.2', + 'banal >= 1.0.1', ], extras_require={ 'dev': [