Adopt banal for utility functions

This commit is contained in:
Friedrich Lindenberg 2020-06-28 18:58:00 +02:00
parent 8dede2b757
commit 37c0f87d10
4 changed files with 14 additions and 20 deletions

View File

@ -47,7 +47,8 @@ class Database(object):
self.schema = schema
self.engine = create_engine(url, **engine_kwargs)
self.types = Types(self.engine.dialect.name)
self.is_postgres = self.engine.dialect.name == 'postgresql'
self.types = Types(is_postgres=self.is_postgres)
self.url = url
self.row_type = row_type
self.ensure_schema = ensure_schema

View File

@ -1,16 +1,17 @@
import logging
import warnings
import threading
from banal import ensure_list
from sqlalchemy import func, select, false
from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import bindparam, ClauseElement
from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false
from sqlalchemy.schema import Table as SQLATable
from sqlalchemy.exc import NoSuchTableError
from dataset.types import Types
from dataset.util import index_name, ensure_tuple
from dataset.util import index_name
from dataset.util import DatasetException, ResultIter, QUERY_STEP
from dataset.util import normalize_table_name, pad_chunk_columns
from dataset.util import normalize_column_name, normalize_column_key
@ -216,8 +217,7 @@ class Table(object):
See :py:meth:`update() <dataset.Table.update>` for details on
the other parameters.
"""
# Convert keys to a list if not a list or tuple.
keys = keys if type(keys) in (list, tuple) else [keys]
keys = ensure_list(keys)
chunk = []
columns = []
@ -270,8 +270,7 @@ class Table(object):
See :py:meth:`upsert() <dataset.Table.upsert>` and
:py:meth:`insert_many() <dataset.Table.insert_many>`.
"""
# Convert keys to a list if not a list or tuple.
keys = keys if type(keys) in (list, tuple) else [keys]
keys = ensure_list(keys)
to_insert = []
to_update = []
@ -442,7 +441,7 @@ class Table(object):
def _args_to_order_by(self, order_by):
orderings = []
for ordering in ensure_tuple(order_by):
for ordering in ensure_list(order_by):
if ordering is None:
continue
column = ordering.lstrip('-')
@ -456,8 +455,7 @@ class Table(object):
return orderings
def _keys_to_args(self, row, keys):
keys = ensure_tuple(keys)
keys = [self._get_column_name(k) for k in keys]
keys = [self._get_column_name(k) for k in ensure_list(keys)]
row = row.copy()
args = {k: row.pop(k, None) for k in keys}
return args, row
@ -557,7 +555,7 @@ class Table(object):
table.create_index(['name', 'country'])
"""
columns = [self._get_column_name(c) for c in ensure_tuple(columns)]
columns = [self._get_column_name(c) for c in ensure_list(columns)]
with self.db.lock:
if not self.exists:
raise DatasetException("Table has not been created yet.")

View File

@ -17,14 +17,8 @@ class Types(object):
date = Date
datetime = DateTime
def __init__(self, dialect=None):
self._dialect = dialect
@property
def json(self):
if self._dialect is not None and self._dialect == 'postgresql':
return JSONB
return JSON
def __init__(self, is_postgres=None):
self.json = JSONB if is_postgres else JSON
def guess(self, sample):
"""Given a single sample, guess the column type for the field.

View File

@ -30,7 +30,8 @@ setup(
zip_safe=False,
install_requires=[
'sqlalchemy >= 1.3.2',
'alembic >= 0.6.2'
'alembic >= 0.6.2',
'banal >= 1.0.1',
],
extras_require={
'dev': [