Merge pull request #336 from pudo/fix-ci

Fix CI use of Postgres and MySQL
This commit is contained in:
Friedrich Lindenberg 2020-08-23 13:54:17 +02:00 committed by GitHub
commit 14db5e00e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 659 additions and 675 deletions

View File

@ -16,13 +16,13 @@ jobs:
ports: ports:
- 5432/tcp - 5432/tcp
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
mariadb: mysql:
image: mariadb image: mysql
env: env:
MYSQL_USER: mariadb MYSQL_USER: mysql
MYSQL_PASSWORD: mariadb MYSQL_PASSWORD: mysql
MYSQL_DATABASE: dataset MYSQL_DATABASE: dataset
MYSQL_ROOT_PASSWORD: mariadb MYSQL_ROOT_PASSWORD: mysql
ports: ports:
- 3306/tcp - 3306/tcp
options: --health-cmd="mysqladmin ping" --health-interval=5s --health-timeout=2s --health-retries=3 options: --health-cmd="mysqladmin ping" --health-interval=5s --health-timeout=2s --health-retries=3
@ -35,7 +35,7 @@ jobs:
- name: Set up Python - name: Set up Python
uses: actions/setup-python@v1 uses: actions/setup-python@v1
with: with:
python-version: '3.x' python-version: "3.x"
- name: Install dependencies - name: Install dependencies
env: env:
DEBIAN_FRONTEND: noninteractive DEBIAN_FRONTEND: noninteractive
@ -44,17 +44,17 @@ jobs:
pip install -e ".[dev]" pip install -e ".[dev]"
- name: Run SQLite tests - name: Run SQLite tests
env: env:
DATABASE_URI: 'sqlite:///:memory:' DATABASE_URL: "sqlite:///:memory:"
run: | run: |
make test make test
- name: Run PostgreSQL tests - name: Run PostgreSQL tests
env: env:
DATABASE_URI: 'postgresql+psycopg2://postgres:postgres@postgres:${{ job.services.postgres.ports[5432] }}/dataset' DATABASE_URL: "postgresql://postgres:postgres@127.0.0.1:${{ job.services.postgres.ports[5432] }}/dataset"
run: | run: |
make test make test
- name: Run MariaDB tests - name: Run mysql tests
env: env:
DATABASE_URI: 'mysql+pymysql://mariadb:mariadb@mariadb:${{ job.services.mariadb.ports[3306] }}/dataset?charset=utf8' DATABASE_URL: "mysql+pymysql://mysql:mysql@127.0.0.1:${{ job.services.mysql.ports[3306] }}/dataset?charset=utf8"
run: | run: |
make test make test
- name: Run flake8 to lint - name: Run flake8 to lint

View File

@ -5,17 +5,24 @@ from dataset.table import Table
from dataset.util import row_type from dataset.util import row_type
# shut up useless SA warning: # shut up useless SA warning:
warnings.filterwarnings("ignore", "Unicode type received non-unicode bind param value.")
warnings.filterwarnings( warnings.filterwarnings(
'ignore', 'Unicode type received non-unicode bind param value.') "ignore", "Skipping unsupported ALTER for creation of implicit constraint"
warnings.filterwarnings( )
'ignore', 'Skipping unsupported ALTER for creation of implicit constraint')
__all__ = ['Database', 'Table', 'freeze', 'connect'] __all__ = ["Database", "Table", "freeze", "connect"]
__version__ = '1.3.2' __version__ = "1.3.2"
def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None, def connect(
reflect_views=True, ensure_schema=True, row_type=row_type): url=None,
schema=None,
reflect_metadata=True,
engine_kwargs=None,
reflect_views=True,
ensure_schema=True,
row_type=row_type,
):
""" Opens a new connection to a database. """ Opens a new connection to a database.
*url* can be any valid `SQLAlchemy engine URL`_. If *url* is not defined *url* can be any valid `SQLAlchemy engine URL`_. If *url* is not defined
@ -40,8 +47,14 @@ def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None,
.. _DB connection timeout: http://docs.sqlalchemy.org/en/latest/core/pooling.html#setting-pool-recycle .. _DB connection timeout: http://docs.sqlalchemy.org/en/latest/core/pooling.html#setting-pool-recycle
""" """
if url is None: if url is None:
url = os.environ.get('DATABASE_URL', 'sqlite://') url = os.environ.get("DATABASE_URL", "sqlite://")
return Database(url, schema=schema, reflect_metadata=reflect_metadata, return Database(
engine_kwargs=engine_kwargs, reflect_views=reflect_views, url,
ensure_schema=ensure_schema, row_type=row_type) schema=schema,
reflect_metadata=reflect_metadata,
engine_kwargs=engine_kwargs,
reflect_views=reflect_views,
ensure_schema=ensure_schema,
row_type=row_type,
)

View File

@ -22,9 +22,16 @@ log = logging.getLogger(__name__)
class Database(object): class Database(object):
"""A database object represents a SQL database with multiple tables.""" """A database object represents a SQL database with multiple tables."""
def __init__(self, url, schema=None, reflect_metadata=True, def __init__(
engine_kwargs=None, reflect_views=True, self,
ensure_schema=True, row_type=row_type): url,
schema=None,
reflect_metadata=True,
engine_kwargs=None,
reflect_views=True,
ensure_schema=True,
row_type=row_type,
):
"""Configure and connect to the database.""" """Configure and connect to the database."""
if engine_kwargs is None: if engine_kwargs is None:
engine_kwargs = {} engine_kwargs = {}
@ -41,13 +48,14 @@ class Database(object):
if len(parsed_url.query): if len(parsed_url.query):
query = parse_qs(parsed_url.query) query = parse_qs(parsed_url.query)
if schema is None: if schema is None:
schema_qs = query.get('schema', query.get('searchpath', [])) schema_qs = query.get("schema", query.get("searchpath", []))
if len(schema_qs): if len(schema_qs):
schema = schema_qs.pop() schema = schema_qs.pop()
self.schema = schema self.schema = schema
self.engine = create_engine(url, **engine_kwargs) self.engine = create_engine(url, **engine_kwargs)
self.types = Types(self.engine.dialect.name) self.is_postgres = self.engine.dialect.name == "postgresql"
self.types = Types(is_postgres=self.is_postgres)
self.url = url self.url = url
self.row_type = row_type self.row_type = row_type
self.ensure_schema = ensure_schema self.ensure_schema = ensure_schema
@ -56,7 +64,7 @@ class Database(object):
@property @property
def executable(self): def executable(self):
"""Connection against which statements will be executed.""" """Connection against which statements will be executed."""
if not hasattr(self.local, 'conn'): if not hasattr(self.local, "conn"):
self.local.conn = self.engine.connect() self.local.conn = self.engine.connect()
return self.local.conn return self.local.conn
@ -79,7 +87,7 @@ class Database(object):
@property @property
def in_transaction(self): def in_transaction(self):
"""Check if this database is in a transactional context.""" """Check if this database is in a transactional context."""
if not hasattr(self.local, 'tx'): if not hasattr(self.local, "tx"):
return False return False
return len(self.local.tx) > 0 return len(self.local.tx) > 0
@ -93,7 +101,7 @@ class Database(object):
No data will be written until the transaction has been committed. No data will be written until the transaction has been committed.
""" """
if not hasattr(self.local, 'tx'): if not hasattr(self.local, "tx"):
self.local.tx = [] self.local.tx = []
self.local.tx.append(self.executable.begin()) self.local.tx.append(self.executable.begin())
@ -102,7 +110,7 @@ class Database(object):
Make all statements executed since the transaction was begun permanent. Make all statements executed since the transaction was begun permanent.
""" """
if hasattr(self.local, 'tx') and self.local.tx: if hasattr(self.local, "tx") and self.local.tx:
tx = self.local.tx.pop() tx = self.local.tx.pop()
tx.commit() tx.commit()
self._flush_tables() self._flush_tables()
@ -112,7 +120,7 @@ class Database(object):
Discard all statements executed since the transaction was begun. Discard all statements executed since the transaction was begun.
""" """
if hasattr(self.local, 'tx') and self.local.tx: if hasattr(self.local, "tx") and self.local.tx:
tx = self.local.tx.pop() tx = self.local.tx.pop()
tx.rollback() tx.rollback()
self._flush_tables() self._flush_tables()
@ -189,15 +197,19 @@ class Database(object):
table5 = db.create_table('population5', table5 = db.create_table('population5',
primary_id=False) primary_id=False)
""" """
assert not isinstance(primary_type, str), \ assert not isinstance(
'Text-based primary_type support is dropped, use db.types.' primary_type, str
), "Text-based primary_type support is dropped, use db.types."
table_name = normalize_table_name(table_name) table_name = normalize_table_name(table_name)
with self.lock: with self.lock:
if table_name not in self._tables: if table_name not in self._tables:
self._tables[table_name] = Table(self, table_name, self._tables[table_name] = Table(
self,
table_name,
primary_id=primary_id, primary_id=primary_id,
primary_type=primary_type, primary_type=primary_type,
auto_create=True) auto_create=True,
)
return self._tables.get(table_name) return self._tables.get(table_name)
def load_table(self, table_name): def load_table(self, table_name):
@ -264,7 +276,7 @@ class Database(object):
""" """
if isinstance(query, str): if isinstance(query, str):
query = text(query) query = text(query)
_step = kwargs.pop('_step', QUERY_STEP) _step = kwargs.pop("_step", QUERY_STEP)
if _step is False or _step == 0: if _step is False or _step == 0:
_step = None _step = None
rp = self.executable.execute(query, *args, **kwargs) rp = self.executable.execute(query, *args, **kwargs)
@ -272,4 +284,4 @@ class Database(object):
def __repr__(self): def __repr__(self):
"""Text representation contains the URL.""" """Text representation contains the URL."""
return '<Database(%s)>' % safe_url(self.url) return "<Database(%s)>" % safe_url(self.url)

View File

@ -1,16 +1,17 @@
import logging import logging
import warnings import warnings
import threading import threading
from banal import ensure_list
from sqlalchemy import func, select, false
from sqlalchemy.sql import and_, expression from sqlalchemy.sql import and_, expression
from sqlalchemy.sql.expression import bindparam, ClauseElement from sqlalchemy.sql.expression import bindparam, ClauseElement
from sqlalchemy.schema import Column, Index from sqlalchemy.schema import Column, Index
from sqlalchemy import func, select, false
from sqlalchemy.schema import Table as SQLATable from sqlalchemy.schema import Table as SQLATable
from sqlalchemy.exc import NoSuchTableError from sqlalchemy.exc import NoSuchTableError
from dataset.types import Types from dataset.types import Types, MYSQL_LENGTH_TYPES
from dataset.util import index_name, ensure_tuple from dataset.util import index_name
from dataset.util import DatasetException, ResultIter, QUERY_STEP from dataset.util import DatasetException, ResultIter, QUERY_STEP
from dataset.util import normalize_table_name, pad_chunk_columns from dataset.util import normalize_table_name, pad_chunk_columns
from dataset.util import normalize_column_name, normalize_column_key from dataset.util import normalize_column_name, normalize_column_key
@ -21,20 +22,27 @@ log = logging.getLogger(__name__)
class Table(object): class Table(object):
"""Represents a table in a database and exposes common operations.""" """Represents a table in a database and exposes common operations."""
PRIMARY_DEFAULT = 'id'
def __init__(self, database, table_name, primary_id=None, PRIMARY_DEFAULT = "id"
primary_type=None, auto_create=False):
def __init__(
self,
database,
table_name,
primary_id=None,
primary_type=None,
auto_create=False,
):
"""Initialise the table from database schema.""" """Initialise the table from database schema."""
self.db = database self.db = database
self.name = normalize_table_name(table_name) self.name = normalize_table_name(table_name)
self._table = None self._table = None
self._columns = None self._columns = None
self._indexes = [] self._indexes = []
self._primary_id = primary_id if primary_id is not None \ self._primary_id = (
else self.PRIMARY_DEFAULT primary_id if primary_id is not None else self.PRIMARY_DEFAULT
self._primary_type = primary_type if primary_type is not None \ )
else Types.integer self._primary_type = primary_type if primary_type is not None else Types.integer
self._auto_create = auto_create self._auto_create = auto_create
@property @property
@ -69,10 +77,6 @@ class Table(object):
self._columns[key] = name self._columns[key] = name
return self._columns return self._columns
def _flush_metadata(self):
with self.db.lock:
self._columns = None
@property @property
def columns(self): def columns(self):
"""Get a listing of all columns that exist in the table.""" """Get a listing of all columns that exist in the table."""
@ -205,8 +209,7 @@ class Table(object):
if return_count: if return_count:
return self.count(clause) return self.count(clause)
def update_many(self, rows, keys, chunk_size=1000, ensure=None, def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
types=None):
"""Update many rows in the table at a time. """Update many rows in the table at a time.
This is significantly faster than updating them one by one. Per default This is significantly faster than updating them one by one. Per default
@ -216,8 +219,7 @@ class Table(object):
See :py:meth:`update() <dataset.Table.update>` for details on See :py:meth:`update() <dataset.Table.update>` for details on
the other parameters. the other parameters.
""" """
# Convert keys to a list if not a list or tuple. keys = ensure_list(keys)
keys = keys if type(keys) in (list, tuple) else [keys]
chunk = [] chunk = []
columns = [] columns = []
@ -229,16 +231,14 @@ class Table(object):
# bindparam requires names to not conflict (cannot be "id" for id) # bindparam requires names to not conflict (cannot be "id" for id)
for key in keys: for key in keys:
row['_%s' % key] = row[key] row["_%s" % key] = row[key]
# Update when chunk_size is fulfilled or this is the last row # Update when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1: if len(chunk) == chunk_size or index == len(rows) - 1:
cl = [self.table.c[k] == bindparam('_%s' % k) for k in keys] cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys]
stmt = self.table.update( stmt = self.table.update(
whereclause=and_(*cl), whereclause=and_(*cl),
values={ values={col: bindparam(col, required=False) for col in columns},
col: bindparam(col, required=False) for col in columns
}
) )
self.db.executable.execute(stmt, chunk) self.db.executable.execute(stmt, chunk)
chunk = [] chunk = []
@ -261,8 +261,7 @@ class Table(object):
return self.insert(row, ensure=False) return self.insert(row, ensure=False)
return True return True
def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
types=None):
""" """
Sorts multiple input rows into upserts and inserts. Inserts are passed Sorts multiple input rows into upserts and inserts. Inserts are passed
to insert_many and upserts are updated. to insert_many and upserts are updated.
@ -270,8 +269,7 @@ class Table(object):
See :py:meth:`upsert() <dataset.Table.upsert>` and See :py:meth:`upsert() <dataset.Table.upsert>` and
:py:meth:`insert_many() <dataset.Table.insert_many>`. :py:meth:`insert_many() <dataset.Table.insert_many>`.
""" """
# Convert keys to a list if not a list or tuple. keys = ensure_list(keys)
keys = keys if type(keys) in (list, tuple) else [keys]
to_insert = [] to_insert = []
to_update = [] to_update = []
@ -310,25 +308,25 @@ class Table(object):
def _reflect_table(self): def _reflect_table(self):
"""Load the tables definition from the database.""" """Load the tables definition from the database."""
with self.db.lock: with self.db.lock:
self._flush_metadata() self._columns = None
try: try:
self._table = SQLATable(self.name, self._table = SQLATable(
self.db.metadata, self.name, self.db.metadata, schema=self.db.schema, autoload=True
schema=self.db.schema, )
autoload=True)
except NoSuchTableError: except NoSuchTableError:
self._table = None self._table = None
def _threading_warn(self): def _threading_warn(self):
if self.db.in_transaction and threading.active_count() > 1: if self.db.in_transaction and threading.active_count() > 1:
warnings.warn("Changing the database schema inside a transaction " warnings.warn(
"Changing the database schema inside a transaction "
"in a multi-threaded environment is likely to lead " "in a multi-threaded environment is likely to lead "
"to race conditions and synchronization issues.", "to race conditions and synchronization issues.",
RuntimeWarning) RuntimeWarning,
)
def _sync_table(self, columns): def _sync_table(self, columns):
"""Lazy load, create or adapt the table structure in the database.""" """Lazy load, create or adapt the table structure in the database."""
self._flush_metadata()
if self._table is None: if self._table is None:
# Load an existing table from the database. # Load an existing table from the database.
self._reflect_table() self._reflect_table()
@ -339,32 +337,34 @@ class Table(object):
# Keep the lock scope small because this is run very often. # Keep the lock scope small because this is run very often.
with self.db.lock: with self.db.lock:
self._threading_warn() self._threading_warn()
self._table = SQLATable(self.name, self._table = SQLATable(
self.db.metadata, self.name, self.db.metadata, schema=self.db.schema
schema=self.db.schema) )
if self._primary_id is not False: if self._primary_id is not False:
# This can go wrong on DBMS like MySQL and SQLite where # This can go wrong on DBMS like MySQL and SQLite where
# tables cannot have no columns. # tables cannot have no columns.
primary_id = self._primary_id primary_id = self._primary_id
primary_type = self._primary_type primary_type = self._primary_type
increment = primary_type in [Types.integer, Types.bigint] increment = primary_type in [Types.integer, Types.bigint]
column = Column(primary_id, primary_type, column = Column(
primary_id,
primary_type,
primary_key=True, primary_key=True,
autoincrement=increment) autoincrement=increment,
)
self._table.append_column(column) self._table.append_column(column)
for column in columns: for column in columns:
if not column.name == self._primary_id: if not column.name == self._primary_id:
self._table.append_column(column) self._table.append_column(column)
self._table.create(self.db.executable, checkfirst=True) self._table.create(self.db.executable, checkfirst=True)
self._columns = None
elif len(columns): elif len(columns):
with self.db.lock: with self.db.lock:
self._reflect_table() self._reflect_table()
self._threading_warn() self._threading_warn()
for column in columns: for column in columns:
if not self.has_column(column.name): if not self.has_column(column.name):
self.db.op.add_column(self.name, self.db.op.add_column(self.name, column, self.db.schema)
column,
self.db.schema)
self._reflect_table() self._reflect_table()
def _sync_columns(self, row, ensure, types=None): def _sync_columns(self, row, ensure, types=None):
@ -398,31 +398,31 @@ class Table(object):
return ensure return ensure
def _generate_clause(self, column, op, value): def _generate_clause(self, column, op, value):
if op in ('like',): if op in ("like",):
return self.table.c[column].like(value) return self.table.c[column].like(value)
if op in ('ilike',): if op in ("ilike",):
return self.table.c[column].ilike(value) return self.table.c[column].ilike(value)
if op in ('>', 'gt'): if op in (">", "gt"):
return self.table.c[column] > value return self.table.c[column] > value
if op in ('<', 'lt'): if op in ("<", "lt"):
return self.table.c[column] < value return self.table.c[column] < value
if op in ('>=', 'gte'): if op in (">=", "gte"):
return self.table.c[column] >= value return self.table.c[column] >= value
if op in ('<=', 'lte'): if op in ("<=", "lte"):
return self.table.c[column] <= value return self.table.c[column] <= value
if op in ('=', '==', 'is'): if op in ("=", "==", "is"):
return self.table.c[column] == value return self.table.c[column] == value
if op in ('!=', '<>', 'not'): if op in ("!=", "<>", "not"):
return self.table.c[column] != value return self.table.c[column] != value
if op in ('in'): if op in ("in"):
return self.table.c[column].in_(value) return self.table.c[column].in_(value)
if op in ('between', '..'): if op in ("between", ".."):
start, end = value start, end = value
return self.table.c[column].between(start, end) return self.table.c[column].between(start, end)
if op in ('startswith',): if op in ("startswith",):
return self.table.c[column].like('%' + value) return self.table.c[column].like("%" + value)
if op in ('endswith',): if op in ("endswith",):
return self.table.c[column].like(value + '%') return self.table.c[column].like(value + "%")
return false() return false()
def _args_to_clause(self, args, clauses=()): def _args_to_clause(self, args, clauses=()):
@ -432,32 +432,31 @@ class Table(object):
if not self.has_column(column): if not self.has_column(column):
clauses.append(false()) clauses.append(false())
elif isinstance(value, (list, tuple, set)): elif isinstance(value, (list, tuple, set)):
clauses.append(self._generate_clause(column, 'in', value)) clauses.append(self._generate_clause(column, "in", value))
elif isinstance(value, dict): elif isinstance(value, dict):
for op, op_value in value.items(): for op, op_value in value.items():
clauses.append(self._generate_clause(column, op, op_value)) clauses.append(self._generate_clause(column, op, op_value))
else: else:
clauses.append(self._generate_clause(column, '=', value)) clauses.append(self._generate_clause(column, "=", value))
return and_(*clauses) return and_(*clauses)
def _args_to_order_by(self, order_by): def _args_to_order_by(self, order_by):
orderings = [] orderings = []
for ordering in ensure_tuple(order_by): for ordering in ensure_list(order_by):
if ordering is None: if ordering is None:
continue continue
column = ordering.lstrip('-') column = ordering.lstrip("-")
column = self._get_column_name(column) column = self._get_column_name(column)
if not self.has_column(column): if not self.has_column(column):
continue continue
if ordering.startswith('-'): if ordering.startswith("-"):
orderings.append(self.table.c[column].desc()) orderings.append(self.table.c[column].desc())
else: else:
orderings.append(self.table.c[column].asc()) orderings.append(self.table.c[column].asc())
return orderings return orderings
def _keys_to_args(self, row, keys): def _keys_to_args(self, row, keys):
keys = ensure_tuple(keys) keys = [self._get_column_name(k) for k in ensure_list(keys)]
keys = [self._get_column_name(k) for k in keys]
row = row.copy() row = row.copy()
args = {k: row.pop(k, None) for k in keys} args = {k: row.pop(k, None) for k in keys}
return args, row return args, row
@ -503,7 +502,7 @@ class Table(object):
:: ::
table.drop_column('created_at') table.drop_column('created_at')
""" """
if self.db.engine.dialect.name == 'sqlite': if self.db.engine.dialect.name == "sqlite":
raise RuntimeError("SQLite does not support dropping columns.") raise RuntimeError("SQLite does not support dropping columns.")
name = self._get_column_name(name) name = self._get_column_name(name)
with self.db.lock: with self.db.lock:
@ -512,11 +511,7 @@ class Table(object):
return return
self._threading_warn() self._threading_warn()
self.db.op.drop_column( self.db.op.drop_column(self.table.name, name, self.table.schema)
self.table.name,
name,
self.table.schema
)
self._reflect_table() self._reflect_table()
def drop(self): def drop(self):
@ -529,8 +524,8 @@ class Table(object):
self._threading_warn() self._threading_warn()
self.table.drop(self.db.executable, checkfirst=True) self.table.drop(self.db.executable, checkfirst=True)
self._table = None self._table = None
self._columns = None
self.db._tables.pop(self.name, None) self.db._tables.pop(self.name, None)
self._flush_metadata()
def has_index(self, columns): def has_index(self, columns):
"""Check if an index exists to cover the given ``columns``.""" """Check if an index exists to cover the given ``columns``."""
@ -544,7 +539,7 @@ class Table(object):
return False return False
indexes = self.db.inspect.get_indexes(self.name, schema=self.db.schema) indexes = self.db.inspect.get_indexes(self.name, schema=self.db.schema)
for index in indexes: for index in indexes:
if columns == set(index.get('column_names', [])): if columns == set(index.get("column_names", [])):
self._indexes.append(columns) self._indexes.append(columns)
return True return True
return False return False
@ -557,7 +552,7 @@ class Table(object):
table.create_index(['name', 'country']) table.create_index(['name', 'country'])
""" """
columns = [self._get_column_name(c) for c in ensure_tuple(columns)] columns = [self._get_column_name(c) for c in ensure_list(columns)]
with self.db.lock: with self.db.lock:
if not self.exists: if not self.exists:
raise DatasetException("Table has not been created yet.") raise DatasetException("Table has not been created yet.")
@ -570,6 +565,17 @@ class Table(object):
self._threading_warn() self._threading_warn()
name = name or index_name(self.name, columns) name = name or index_name(self.name, columns)
columns = [self.table.c[c] for c in columns] columns = [self.table.c[c] for c in columns]
# MySQL crashes out if you try to index very long text fields,
# apparently. This defines (a somewhat random) prefix that
# will be captured by the index, after which I assume the engine
# conducts a more linear scan:
mysql_length = {}
for col in columns:
if isinstance(col.type, MYSQL_LENGTH_TYPES):
mysql_length[col.name] = 10
kw["mysql_length"] = mysql_length
idx = Index(name, *columns, **kw) idx = Index(name, *columns, **kw)
idx.create(self.db.executable) idx.create(self.db.executable)
@ -602,19 +608,17 @@ class Table(object):
if not self.exists: if not self.exists:
return iter([]) return iter([])
_limit = kwargs.pop('_limit', None) _limit = kwargs.pop("_limit", None)
_offset = kwargs.pop('_offset', 0) _offset = kwargs.pop("_offset", 0)
order_by = kwargs.pop('order_by', None) order_by = kwargs.pop("order_by", None)
_streamed = kwargs.pop('_streamed', False) _streamed = kwargs.pop("_streamed", False)
_step = kwargs.pop('_step', QUERY_STEP) _step = kwargs.pop("_step", QUERY_STEP)
if _step is False or _step == 0: if _step is False or _step == 0:
_step = None _step = None
order_by = self._args_to_order_by(order_by) order_by = self._args_to_order_by(order_by)
args = self._args_to_clause(kwargs, clauses=_clauses) args = self._args_to_clause(kwargs, clauses=_clauses)
query = self.table.select(whereclause=args, query = self.table.select(whereclause=args, limit=_limit, offset=_offset)
limit=_limit,
offset=_offset)
if len(order_by): if len(order_by):
query = query.order_by(*order_by) query = query.order_by(*order_by)
@ -623,9 +627,7 @@ class Table(object):
conn = self.db.engine.connect() conn = self.db.engine.connect()
conn = conn.execution_options(stream_results=True) conn = conn.execution_options(stream_results=True)
return ResultIter(conn.execute(query), return ResultIter(conn.execute(query), row_type=self.db.row_type, step=_step)
row_type=self.db.row_type,
step=_step)
def find_one(self, *args, **kwargs): def find_one(self, *args, **kwargs):
"""Get a single result from the table. """Get a single result from the table.
@ -639,8 +641,8 @@ class Table(object):
if not self.exists: if not self.exists:
return None return None
kwargs['_limit'] = 1 kwargs["_limit"] = 1
kwargs['_step'] = None kwargs["_step"] = None
resiter = self.find(*args, **kwargs) resiter = self.find(*args, **kwargs)
try: try:
for row in resiter: for row in resiter:
@ -694,10 +696,12 @@ class Table(object):
if not len(columns): if not len(columns):
return iter([]) return iter([])
q = expression.select(columns, q = expression.select(
columns,
distinct=True, distinct=True,
whereclause=clause, whereclause=clause,
order_by=[c.asc() for c in columns]) order_by=[c.asc() for c in columns],
)
return self.db.query(q) return self.db.query(q)
# Legacy methods for running find queries. # Legacy methods for running find queries.
@ -717,4 +721,4 @@ class Table(object):
def __repr__(self): def __repr__(self):
"""Get table representation.""" """Get table representation."""
return '<Table(%s)>' % self.table.name return "<Table(%s)>" % self.table.name

View File

@ -1,13 +1,16 @@
from datetime import datetime, date from datetime import datetime, date
from sqlalchemy import Integer, UnicodeText, Float, BigInteger from sqlalchemy import Integer, UnicodeText, Float, BigInteger
from sqlalchemy import Boolean, Date, DateTime, Unicode, JSON from sqlalchemy import String, Boolean, Date, DateTime, Unicode, JSON
from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.types import TypeEngine from sqlalchemy.types import TypeEngine, _Binary
MYSQL_LENGTH_TYPES = (String, _Binary)
class Types(object): class Types(object):
"""A holder class for easy access to SQLAlchemy type names.""" """A holder class for easy access to SQLAlchemy type names."""
integer = Integer integer = Integer
string = Unicode string = Unicode
text = UnicodeText text = UnicodeText
@ -17,14 +20,8 @@ class Types(object):
date = Date date = Date
datetime = DateTime datetime = DateTime
def __init__(self, dialect=None): def __init__(self, is_postgres=None):
self._dialect = dialect self.json = JSONB if is_postgres else JSON
@property
def json(self):
if self._dialect is not None and self._dialect == 'postgresql':
return JSONB
return JSON
def guess(self, sample): def guess(self, sample):
"""Given a single sample, guess the column type for the field. """Given a single sample, guess the column type for the field.

View File

@ -1,7 +1,6 @@
from hashlib import sha1 from hashlib import sha1
from urllib.parse import urlparse from urllib.parse import urlparse
from collections import OrderedDict from collections import OrderedDict
from collections.abc import Iterable
from sqlalchemy.exc import ResourceClosedError from sqlalchemy.exc import ResourceClosedError
QUERY_STEP = 1000 QUERY_STEP = 1000
@ -63,17 +62,17 @@ class ResultIter(object):
def normalize_column_name(name): def normalize_column_name(name):
"""Check if a string is a reasonable thing to use as a column name.""" """Check if a string is a reasonable thing to use as a column name."""
if not isinstance(name, str): if not isinstance(name, str):
raise ValueError('%r is not a valid column name.' % name) raise ValueError("%r is not a valid column name." % name)
# limit to 63 characters # limit to 63 characters
name = name.strip()[:63] name = name.strip()[:63]
# column names can be 63 *bytes* max in postgresql # column names can be 63 *bytes* max in postgresql
if isinstance(name, str): if isinstance(name, str):
while len(name.encode('utf-8')) >= 64: while len(name.encode("utf-8")) >= 64:
name = name[:len(name) - 1] name = name[: len(name) - 1]
if not len(name) or '.' in name or '-' in name: if not len(name) or "." in name or "-" in name:
raise ValueError('%r is not a valid column name.' % name) raise ValueError("%r is not a valid column name." % name)
return name return name
@ -81,7 +80,7 @@ def normalize_column_key(name):
"""Return a comparable column name.""" """Return a comparable column name."""
if name is None or not isinstance(name, str): if name is None or not isinstance(name, str):
return None return None
return name.upper().strip().replace(' ', '') return name.upper().strip().replace(" ", "")
def normalize_table_name(name): def normalize_table_name(name):
@ -98,25 +97,16 @@ def safe_url(url):
"""Remove password from printed connection URLs.""" """Remove password from printed connection URLs."""
parsed = urlparse(url) parsed = urlparse(url)
if parsed.password is not None: if parsed.password is not None:
pwd = ':%s@' % parsed.password pwd = ":%s@" % parsed.password
url = url.replace(pwd, ':*****@') url = url.replace(pwd, ":*****@")
return url return url
def index_name(table, columns): def index_name(table, columns):
"""Generate an artificial index name.""" """Generate an artificial index name."""
sig = '||'.join(columns) sig = "||".join(columns)
key = sha1(sig.encode('utf-8')).hexdigest()[:16] key = sha1(sig.encode("utf-8")).hexdigest()[:16]
return 'ix_%s_%s' % (table, key) return "ix_%s_%s" % (table, key)
def ensure_tuple(obj):
"""Try and make the given argument into a tuple."""
if obj is None:
return tuple()
if isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)):
return tuple(obj)
return obj,
def pad_chunk_columns(chunk, columns): def pad_chunk_columns(chunk, columns):

View File

@ -1,7 +1,19 @@
# flasky extensions. flasky pygments style based on tango style # flasky extensions. flasky pygments style based on tango style
from pygments.style import Style from pygments.style import Style
from pygments.token import Keyword, Name, Comment, String, Error, \ from pygments.token import (
Number, Operator, Generic, Whitespace, Punctuation, Other, Literal Keyword,
Name,
Comment,
String,
Error,
Number,
Operator,
Generic,
Whitespace,
Punctuation,
Other,
Literal,
)
class FlaskyStyle(Style): class FlaskyStyle(Style):
@ -10,14 +22,12 @@ class FlaskyStyle(Style):
styles = { styles = {
# No corresponding class for the following: # No corresponding class for the following:
#Text: "", # class: '' # Text: "", # class: ''
Whitespace: "underline #f8f8f8", # class: 'w' Whitespace: "underline #f8f8f8", # class: 'w'
Error: "#a40000 border:#ef2929", # class: 'err' Error: "#a40000 border:#ef2929", # class: 'err'
Other: "#000000", # class 'x' Other: "#000000", # class 'x'
Comment: "italic #8f5902", # class: 'c' Comment: "italic #8f5902", # class: 'c'
Comment.Preproc: "noitalic", # class: 'cp' Comment.Preproc: "noitalic", # class: 'cp'
Keyword: "bold #004461", # class: 'k' Keyword: "bold #004461", # class: 'k'
Keyword.Constant: "bold #004461", # class: 'kc' Keyword.Constant: "bold #004461", # class: 'kc'
Keyword.Declaration: "bold #004461", # class: 'kd' Keyword.Declaration: "bold #004461", # class: 'kd'
@ -25,12 +35,9 @@ class FlaskyStyle(Style):
Keyword.Pseudo: "bold #004461", # class: 'kp' Keyword.Pseudo: "bold #004461", # class: 'kp'
Keyword.Reserved: "bold #004461", # class: 'kr' Keyword.Reserved: "bold #004461", # class: 'kr'
Keyword.Type: "bold #004461", # class: 'kt' Keyword.Type: "bold #004461", # class: 'kt'
Operator: "#582800", # class: 'o' Operator: "#582800", # class: 'o'
Operator.Word: "bold #004461", # class: 'ow' - like keywords Operator.Word: "bold #004461", # class: 'ow' - like keywords
Punctuation: "bold #000000", # class: 'p' Punctuation: "bold #000000", # class: 'p'
# because special names such as Name.Class, Name.Function, etc. # because special names such as Name.Class, Name.Function, etc.
# are not recognized as such later in the parsing, we choose them # are not recognized as such later in the parsing, we choose them
# to look the same as ordinary variables. # to look the same as ordinary variables.
@ -53,12 +60,9 @@ class FlaskyStyle(Style):
Name.Variable.Class: "#000000", # class: 'vc' - to be revised Name.Variable.Class: "#000000", # class: 'vc' - to be revised
Name.Variable.Global: "#000000", # class: 'vg' - to be revised Name.Variable.Global: "#000000", # class: 'vg' - to be revised
Name.Variable.Instance: "#000000", # class: 'vi' - to be revised Name.Variable.Instance: "#000000", # class: 'vi' - to be revised
Number: "#990000", # class: 'm' Number: "#990000", # class: 'm'
Literal: "#000000", # class: 'l' Literal: "#000000", # class: 'l'
Literal.Date: "#000000", # class: 'ld' Literal.Date: "#000000", # class: 'ld'
String: "#4e9a06", # class: 's' String: "#4e9a06", # class: 's'
String.Backtick: "#4e9a06", # class: 'sb' String.Backtick: "#4e9a06", # class: 'sb'
String.Char: "#4e9a06", # class: 'sc' String.Char: "#4e9a06", # class: 'sc'
@ -71,7 +75,6 @@ class FlaskyStyle(Style):
String.Regex: "#4e9a06", # class: 'sr' String.Regex: "#4e9a06", # class: 'sr'
String.Single: "#4e9a06", # class: 's1' String.Single: "#4e9a06", # class: 's1'
String.Symbol: "#4e9a06", # class: 'ss' String.Symbol: "#4e9a06", # class: 'ss'
Generic: "#000000", # class: 'g' Generic: "#000000", # class: 'g'
Generic.Deleted: "#a40000", # class: 'gd' Generic.Deleted: "#a40000", # class: 'gd'
Generic.Emph: "italic #000000", # class: 'ge' Generic.Emph: "italic #000000", # class: 'ge'

View File

@ -16,85 +16,85 @@ import sys, os
# If extensions (or modules to document with autodoc) are in another directory, # If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the # add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath('../')) sys.path.insert(0, os.path.abspath("../"))
# -- General configuration ----------------------------------------------------- # -- General configuration -----------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here. # If your documentation needs a minimal Sphinx version, state it here.
#needs_sphinx = '1.0' # needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be extensions # Add any Sphinx extension module names here, as strings. They can be extensions
# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. # coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode"]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix of source filenames. # The suffix of source filenames.
source_suffix = '.rst' source_suffix = ".rst"
# The encoding of source files. # The encoding of source files.
#source_encoding = 'utf-8-sig' # source_encoding = 'utf-8-sig'
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# General information about the project. # General information about the project.
project = u'dataset' project = u"dataset"
copyright = u'2013-2018, Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer' copyright = u"2013-2018, Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer"
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = '1.0' version = "1.0"
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = '1.0.8' release = "1.0.8"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
#language = None # language = None
# There are two options for replacing |today|: either, you set today to some # There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used: # non-false value, then it is used:
#today = '' # today = ''
# Else, today_fmt is used as the format for a strftime call. # Else, today_fmt is used as the format for a strftime call.
#today_fmt = '%B %d, %Y' # today_fmt = '%B %d, %Y'
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
exclude_patterns = ['_build'] exclude_patterns = ["_build"]
# The reST default role (used for this markup: `text`) to use for all documents. # The reST default role (used for this markup: `text`) to use for all documents.
#default_role = None # default_role = None
# If true, '()' will be appended to :func: etc. cross-reference text. # If true, '()' will be appended to :func: etc. cross-reference text.
#add_function_parentheses = True # add_function_parentheses = True
# If true, the current module name will be prepended to all description # If true, the current module name will be prepended to all description
# unit titles (such as .. function::). # unit titles (such as .. function::).
#add_module_names = True # add_module_names = True
# If true, sectionauthor and moduleauthor directives will be shown in the # If true, sectionauthor and moduleauthor directives will be shown in the
# output. They are ignored by default. # output. They are ignored by default.
#show_authors = False # show_authors = False
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = "sphinx"
# A list of ignored prefixes for module index sorting. # A list of ignored prefixes for module index sorting.
#modindex_common_prefix = [] # modindex_common_prefix = []
# If true, keep warnings as "system message" paragraphs in the built documents. # If true, keep warnings as "system message" paragraphs in the built documents.
#keep_warnings = False # keep_warnings = False
# -- Options for HTML output --------------------------------------------------- # -- Options for HTML output ---------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
html_theme = 'kr' html_theme = "kr"
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
@ -104,117 +104,120 @@ html_theme = 'kr'
# } # }
# Add any paths that contain custom themes here, relative to this directory. # Add any paths that contain custom themes here, relative to this directory.
html_theme_path = ['_themes'] html_theme_path = ["_themes"]
# The name for this set of Sphinx documents. If None, it defaults to # The name for this set of Sphinx documents. If None, it defaults to
# "<project> v<release> documentation". # "<project> v<release> documentation".
#html_title = None # html_title = None
# A shorter title for the navigation bar. Default is the same as html_title. # A shorter title for the navigation bar. Default is the same as html_title.
#html_short_title = None # html_short_title = None
# The name of an image file (relative to this directory) to place at the top # The name of an image file (relative to this directory) to place at the top
# of the sidebar. # of the sidebar.
#html_logo = None # html_logo = None
# The name of an image file (within the static path) to use as favicon of the # The name of an image file (within the static path) to use as favicon of the
# docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32 # docs. This file should be a Windows icon file (.ico) being 16x16 or 32x32
# pixels large. # pixels large.
#html_favicon = None # html_favicon = None
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format. # using the given strftime format.
#html_last_updated_fmt = '%b %d, %Y' # html_last_updated_fmt = '%b %d, %Y'
# If true, SmartyPants will be used to convert quotes and dashes to # If true, SmartyPants will be used to convert quotes and dashes to
# typographically correct entities. # typographically correct entities.
#html_use_smartypants = True # html_use_smartypants = True
# Custom sidebar templates, maps document names to template names. # Custom sidebar templates, maps document names to template names.
html_sidebars = { html_sidebars = {
'index': ['sidebarlogo.html', 'sourcelink.html', 'searchbox.html'], "index": ["sidebarlogo.html", "sourcelink.html", "searchbox.html"],
'api': ['sidebarlogo.html', 'autotoc.html', 'sourcelink.html', 'searchbox.html'], "api": ["sidebarlogo.html", "autotoc.html", "sourcelink.html", "searchbox.html"],
'**': ['sidebarlogo.html', 'localtoc.html', 'sourcelink.html', 'searchbox.html'] "**": ["sidebarlogo.html", "localtoc.html", "sourcelink.html", "searchbox.html"],
} }
# Additional templates that should be rendered to pages, maps page names to # Additional templates that should be rendered to pages, maps page names to
# template names. # template names.
#html_additional_pages = {} # html_additional_pages = {}
# If false, no module index is generated. # If false, no module index is generated.
#html_domain_indices = True # html_domain_indices = True
# If false, no index is generated. # If false, no index is generated.
#html_use_index = True # html_use_index = True
# If true, the index is split into individual pages for each letter. # If true, the index is split into individual pages for each letter.
#html_split_index = False # html_split_index = False
# If true, links to the reST sources are added to the pages. # If true, links to the reST sources are added to the pages.
html_show_sourcelink = False html_show_sourcelink = False
# If true, "Created using Sphinx" is shown in the HTML footer. Default is True. # If true, "Created using Sphinx" is shown in the HTML footer. Default is True.
#html_show_sphinx = True # html_show_sphinx = True
# If true, "(C) Copyright ..." is shown in the HTML footer. Default is True. # If true, "(C) Copyright ..." is shown in the HTML footer. Default is True.
#html_show_copyright = True # html_show_copyright = True
# If true, an OpenSearch description file will be output, and all pages will # If true, an OpenSearch description file will be output, and all pages will
# contain a <link> tag referring to it. The value of this option must be the # contain a <link> tag referring to it. The value of this option must be the
# base URL from which the finished HTML is served. # base URL from which the finished HTML is served.
#html_use_opensearch = '' # html_use_opensearch = ''
# This is the file name suffix for HTML files (e.g. ".xhtml"). # This is the file name suffix for HTML files (e.g. ".xhtml").
#html_file_suffix = None # html_file_suffix = None
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'datasetdoc' htmlhelp_basename = "datasetdoc"
# -- Options for LaTeX output -------------------------------------------------- # -- Options for LaTeX output --------------------------------------------------
latex_elements = { latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper', #'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
# The font size ('10pt', '11pt' or '12pt'). #'pointsize': '10pt',
#'pointsize': '10pt', # Additional stuff for the LaTeX preamble.
#'preamble': '',
# Additional stuff for the LaTeX preamble.
#'preamble': '',
} }
# Grouping the document tree into LaTeX files. List of tuples # Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]). # (source start file, target name, title, author, documentclass [howto/manual]).
latex_documents = [ latex_documents = [
('index', 'dataset.tex', u'dataset Documentation', (
u'Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer', 'manual'), "index",
"dataset.tex",
u"dataset Documentation",
u"Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer",
"manual",
),
] ]
# The name of an image file (relative to this directory) to place at the top of # The name of an image file (relative to this directory) to place at the top of
# the title page. # the title page.
#latex_logo = None # latex_logo = None
# For "manual" documents, if this is true, then toplevel headings are parts, # For "manual" documents, if this is true, then toplevel headings are parts,
# not chapters. # not chapters.
#latex_use_parts = False # latex_use_parts = False
# If true, show page references after internal links. # If true, show page references after internal links.
#latex_show_pagerefs = False # latex_show_pagerefs = False
# If true, show URL addresses after external links. # If true, show URL addresses after external links.
#latex_show_urls = False # latex_show_urls = False
# Documents to append as an appendix to all manuals. # Documents to append as an appendix to all manuals.
#latex_appendices = [] # latex_appendices = []
# If false, no module index is generated. # If false, no module index is generated.
#latex_domain_indices = True # latex_domain_indices = True
# -- Options for manual page output -------------------------------------------- # -- Options for manual page output --------------------------------------------
@ -222,12 +225,17 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [
('index', 'dataset', u'dataset Documentation', (
[u'Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer'], 1) "index",
"dataset",
u"dataset Documentation",
[u"Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer"],
1,
)
] ]
# If true, show URL addresses after external links. # If true, show URL addresses after external links.
#man_show_urls = False # man_show_urls = False
# -- Options for Texinfo output ------------------------------------------------ # -- Options for Texinfo output ------------------------------------------------
@ -236,19 +244,25 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
('index', 'dataset', u'dataset Documentation', (
u'Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer', 'dataset', 'Databases for lazy people.', "index",
'Miscellaneous'), "dataset",
u"dataset Documentation",
u"Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer",
"dataset",
"Databases for lazy people.",
"Miscellaneous",
),
] ]
# Documents to append as an appendix to all manuals. # Documents to append as an appendix to all manuals.
#texinfo_appendices = [] # texinfo_appendices = []
# If false, no module index is generated. # If false, no module index is generated.
#texinfo_domain_indices = True # texinfo_domain_indices = True
# How to display URL addresses: 'footnote', 'no', or 'inline'. # How to display URL addresses: 'footnote', 'no', or 'inline'.
#texinfo_show_urls = 'footnote' # texinfo_show_urls = 'footnote'
# If true, do not generate a @detailmenu in the "Top" node's menu. # If true, do not generate a @detailmenu in the "Top" node's menu.
#texinfo_no_detailmenu = False # texinfo_no_detailmenu = False

View File

@ -1,51 +1,47 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
with open('README.md') as f: with open("README.md") as f:
long_description = f.read() long_description = f.read()
setup( setup(
name='dataset', name="dataset",
version='1.3.2', version="1.3.2",
description="Toolkit for Python-based database access.", description="Toolkit for Python-based database access.",
long_description=long_description, long_description=long_description,
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
'Programming Language :: Python :: 3.5', "Programming Language :: Python :: 3.5",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
'Programming Language :: Python :: 3.7', "Programming Language :: Python :: 3.7",
'Programming Language :: Python :: 3.8', "Programming Language :: Python :: 3.8",
], ],
keywords='sql sqlalchemy etl loading utility', keywords="sql sqlalchemy etl loading utility",
author='Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer', author="Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer",
author_email='friedrich.lindenberg@gmail.com', author_email="friedrich.lindenberg@gmail.com",
url='http://github.com/pudo/dataset', url="http://github.com/pudo/dataset",
license='MIT', license="MIT",
packages=find_packages(exclude=['ez_setup', 'examples', 'test']), packages=find_packages(exclude=["ez_setup", "examples", "test"]),
namespace_packages=[], namespace_packages=[],
include_package_data=False, include_package_data=False,
zip_safe=False, zip_safe=False,
install_requires=[ install_requires=["sqlalchemy >= 1.3.2", "alembic >= 0.6.2", "banal >= 1.0.1"],
'sqlalchemy >= 1.3.2',
'alembic >= 0.6.2'
],
extras_require={ extras_require={
'dev': [ "dev": [
'pip', "pip",
'nose', "nose",
'wheel', "wheel",
'flake8', "flake8",
'coverage', "coverage",
'psycopg2-binary', "psycopg2-binary",
'PyMySQL', "PyMySQL",
"cryptography",
] ]
}, },
tests_require=[ tests_require=["nose"],
'nose' test_suite="test",
], entry_points={},
test_suite='test',
entry_points={}
) )

View File

@ -4,38 +4,14 @@ from __future__ import unicode_literals
from datetime import datetime from datetime import datetime
TEST_CITY_1 = 'B€rkeley' TEST_CITY_1 = "B€rkeley"
TEST_CITY_2 = 'G€lway' TEST_CITY_2 = "G€lway"
TEST_DATA = [ TEST_DATA = [
{ {"date": datetime(2011, 1, 1), "temperature": 1, "place": TEST_CITY_2},
'date': datetime(2011, 1, 1), {"date": datetime(2011, 1, 2), "temperature": -1, "place": TEST_CITY_2},
'temperature': 1, {"date": datetime(2011, 1, 3), "temperature": 0, "place": TEST_CITY_2},
'place': TEST_CITY_2 {"date": datetime(2011, 1, 1), "temperature": 6, "place": TEST_CITY_1},
}, {"date": datetime(2011, 1, 2), "temperature": 8, "place": TEST_CITY_1},
{ {"date": datetime(2011, 1, 3), "temperature": 5, "place": TEST_CITY_1},
'date': datetime(2011, 1, 2),
'temperature': -1,
'place': TEST_CITY_2
},
{
'date': datetime(2011, 1, 3),
'temperature': 0,
'place': TEST_CITY_2
},
{
'date': datetime(2011, 1, 1),
'temperature': 6,
'place': TEST_CITY_1
},
{
'date': datetime(2011, 1, 2),
'temperature': 8,
'place': TEST_CITY_1
},
{
'date': datetime(2011, 1, 3),
'temperature': 5,
'place': TEST_CITY_1
}
] ]

View File

@ -1,12 +1,8 @@
# coding: utf-8
from __future__ import unicode_literals
from collections import OrderedDict
import os import os
import unittest import unittest
from datetime import datetime from datetime import datetime
from collections import OrderedDict
from sqlalchemy import FLOAT, TEXT, BIGINT from sqlalchemy import TEXT, BIGINT
from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError from sqlalchemy.exc import IntegrityError, SQLAlchemyError, ArgumentError
from dataset import connect, chunked from dataset import connect, chunked
@ -15,11 +11,9 @@ from .sample_data import TEST_DATA, TEST_CITY_1
class DatabaseTestCase(unittest.TestCase): class DatabaseTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
os.environ.setdefault('DATABASE_URL', 'sqlite:///:memory:') self.db = connect()
self.db = connect(os.environ['DATABASE_URL']) self.tbl = self.db["weather"]
self.tbl = self.db['weather']
self.tbl.insert_many(TEST_DATA) self.tbl.insert_many(TEST_DATA)
def tearDown(self): def tearDown(self):
@ -27,28 +21,28 @@ class DatabaseTestCase(unittest.TestCase):
self.db[table].drop() self.db[table].drop()
def test_valid_database_url(self): def test_valid_database_url(self):
assert self.db.url, os.environ['DATABASE_URL'] assert self.db.url, os.environ["DATABASE_URL"]
def test_database_url_query_string(self): def test_database_url_query_string(self):
db = connect('sqlite:///:memory:/?cached_statements=1') db = connect("sqlite:///:memory:/?cached_statements=1")
assert 'cached_statements' in db.url, db.url assert "cached_statements" in db.url, db.url
def test_tables(self): def test_tables(self):
assert self.db.tables == ['weather'], self.db.tables assert self.db.tables == ["weather"], self.db.tables
def test_contains(self): def test_contains(self):
assert 'weather' in self.db, self.db.tables assert "weather" in self.db, self.db.tables
def test_create_table(self): def test_create_table(self):
table = self.db['foo'] table = self.db["foo"]
assert table.table.exists() assert table.table.exists()
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert 'id' in table.table.c, table.table.c assert "id" in table.table.c, table.table.c
def test_create_table_no_ids(self): def test_create_table_no_ids(self):
if 'mysql' in self.db.engine.dialect.dbapi.__name__: if "mysql" in self.db.engine.dialect.dbapi.__name__:
return return
if 'sqlite' in self.db.engine.dialect.dbapi.__name__: if "sqlite" in self.db.engine.dialect.dbapi.__name__:
return return
table = self.db.create_table("foo_no_id", primary_id=False) table = self.db.create_table("foo_no_id", primary_id=False)
assert table.table.exists() assert table.table.exists()
@ -60,8 +54,8 @@ class DatabaseTestCase(unittest.TestCase):
assert table.table.exists() assert table.table.exists()
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({pid: 'foobar'}) table.insert({pid: "foobar"})
assert table.find_one(string_id='foobar')[pid] == 'foobar' assert table.find_one(string_id="foobar")[pid] == "foobar"
def test_create_table_custom_id2(self): def test_create_table_custom_id2(self):
pid = "string_id" pid = "string_id"
@ -70,8 +64,8 @@ class DatabaseTestCase(unittest.TestCase):
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({pid: 'foobar'}) table.insert({pid: "foobar"})
assert table.find_one(string_id='foobar')[pid] == 'foobar' assert table.find_one(string_id="foobar")[pid] == "foobar"
def test_create_table_custom_id3(self): def test_create_table_custom_id3(self):
pid = "int_id" pid = "int_id"
@ -88,247 +82,232 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_shorthand1(self): def test_create_table_shorthand1(self):
pid = "int_id" pid = "int_id"
table = self.db.get_table('foo5', pid) table = self.db.get_table("foo5", pid)
assert table.table.exists assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({'int_id': 123}) table.insert({"int_id": 123})
table.insert({'int_id': 124}) table.insert({"int_id": 124})
assert table.find_one(int_id=123)['int_id'] == 123 assert table.find_one(int_id=123)["int_id"] == 123
assert table.find_one(int_id=124)['int_id'] == 124 assert table.find_one(int_id=124)["int_id"] == 124
self.assertRaises(IntegrityError, lambda: table.insert({'int_id': 123})) self.assertRaises(IntegrityError, lambda: table.insert({"int_id": 123}))
def test_create_table_shorthand2(self): def test_create_table_shorthand2(self):
pid = "string_id" pid = "string_id"
table = self.db.get_table('foo6', primary_id=pid, table = self.db.get_table(
primary_type=self.db.types.string(255)) "foo6", primary_id=pid, primary_type=self.db.types.string(255)
)
assert table.table.exists assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({ table.insert({"string_id": "foobar"})
'string_id': 'foobar'}) assert table.find_one(string_id="foobar")["string_id"] == "foobar"
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
def test_with(self): def test_with(self):
init_length = len(self.db['weather']) init_length = len(self.db["weather"])
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
with self.db as tx: with self.db as tx:
tx['weather'].insert({'date': datetime(2011, 1, 1), tx["weather"].insert(
'temperature': 1, {
'place': 'tmp_place'}) "date": datetime(2011, 1, 1),
"temperature": 1,
"place": "tmp_place",
}
)
raise ValueError() raise ValueError()
assert len(self.db['weather']) == init_length assert len(self.db["weather"]) == init_length
def test_invalid_values(self): def test_invalid_values(self):
if 'mysql' in self.db.engine.dialect.dbapi.__name__: if "mysql" in self.db.engine.dialect.dbapi.__name__:
# WARNING: mysql seems to be doing some weird type casting upon insert. # WARNING: mysql seems to be doing some weird type casting
# The mysql-python driver is not affected but it isn't compatible with Python 3 # upon insert. The mysql-python driver is not affected but
# it isn't compatible with Python 3
# Conclusion: use postgresql. # Conclusion: use postgresql.
return return
with self.assertRaises(SQLAlchemyError): with self.assertRaises(SQLAlchemyError):
tbl = self.db['weather'] tbl = self.db["weather"]
tbl.insert({'date': True, 'temperature': 'wrong_value', 'place': 'tmp_place'}) tbl.insert(
{"date": True, "temperature": "wrong_value", "place": "tmp_place"}
)
def test_load_table(self): def test_load_table(self):
tbl = self.db.load_table('weather') tbl = self.db.load_table("weather")
assert tbl.table.name == self.tbl.table.name assert tbl.table.name == self.tbl.table.name
def test_query(self): def test_query(self):
r = self.db.query('SELECT COUNT(*) AS num FROM weather').next() r = self.db.query("SELECT COUNT(*) AS num FROM weather").next()
assert r['num'] == len(TEST_DATA), r assert r["num"] == len(TEST_DATA), r
def test_table_cache_updates(self): def test_table_cache_updates(self):
tbl1 = self.db.get_table('people') tbl1 = self.db.get_table("people")
data = OrderedDict([('first_name', 'John'), ('last_name', 'Smith')]) data = OrderedDict([("first_name", "John"), ("last_name", "Smith")])
tbl1.insert(data) tbl1.insert(data)
data['id'] = 1 data["id"] = 1
tbl2 = self.db.get_table('people') tbl2 = self.db.get_table("people")
assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data) assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data)
class TableTestCase(unittest.TestCase): class TableTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.db = connect('sqlite:///:memory:') self.db = connect()
self.tbl = self.db['weather'] self.tbl = self.db["weather"]
for row in TEST_DATA: for row in TEST_DATA:
self.tbl.insert(row) self.tbl.insert(row)
def tearDown(self):
self.tbl.drop()
def test_insert(self): def test_insert(self):
assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
last_id = self.tbl.insert({ last_id = self.tbl.insert(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
'temperature': -10,
'place': 'Berlin'}
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
assert self.tbl.find_one(id=last_id)['place'] == 'Berlin' assert self.tbl.find_one(id=last_id)["place"] == "Berlin"
def test_insert_ignore(self): def test_insert_ignore(self):
self.tbl.insert_ignore({ self.tbl.insert_ignore(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
'temperature': -10, ["place"],
'place': 'Berlin'},
['place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
self.tbl.insert_ignore({ self.tbl.insert_ignore(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
'temperature': -10, ["place"],
'place': 'Berlin'},
['place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_insert_ignore_all_key(self): def test_insert_ignore_all_key(self):
for i in range(0, 4): for i in range(0, 4):
self.tbl.insert_ignore({ self.tbl.insert_ignore(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
'temperature': -10, ["date", "temperature", "place"],
'place': 'Berlin'},
['date', 'temperature', 'place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_insert_json(self): def test_insert_json(self):
last_id = self.tbl.insert({ last_id = self.tbl.insert(
'date': datetime(2011, 1, 2), {
'temperature': -10, "date": datetime(2011, 1, 2),
'place': 'Berlin', "temperature": -10,
'info': { "place": "Berlin",
'currency': 'EUR', "info": {
'language': 'German', "currency": "EUR",
'population': 3292365 "language": "German",
"population": 3292365,
},
} }
})
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
assert self.tbl.find_one(id=last_id)['place'] == 'Berlin'
def test_upsert(self):
self.tbl.upsert({
'date': datetime(2011, 1, 2),
'temperature': -10,
'place': 'Berlin'},
['place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
self.tbl.upsert({ assert self.tbl.find_one(id=last_id)["place"] == "Berlin"
'date': datetime(2011, 1, 2),
'temperature': -10, def test_upsert(self):
'place': 'Berlin'}, self.tbl.upsert(
['place'] {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
["place"],
)
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
self.tbl.upsert(
{"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
["place"],
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_upsert_single_column(self): def test_upsert_single_column(self):
table = self.db['banana_single_col'] table = self.db["banana_single_col"]
table.upsert({ table.upsert({"color": "Yellow"}, ["color"])
'color': 'Yellow'},
['color']
)
assert len(table) == 1, len(table) assert len(table) == 1, len(table)
table.upsert({ table.upsert({"color": "Yellow"}, ["color"])
'color': 'Yellow'},
['color']
)
assert len(table) == 1, len(table) assert len(table) == 1, len(table)
def test_upsert_all_key(self): def test_upsert_all_key(self):
assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
for i in range(0, 2): for i in range(0, 2):
self.tbl.upsert({ self.tbl.upsert(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
'temperature': -10, ["date", "temperature", "place"],
'place': 'Berlin'},
['date', 'temperature', 'place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_upsert_id(self): def test_upsert_id(self):
table = self.db['banana_with_id'] table = self.db["banana_with_id"]
data = dict(id=10, title='I am a banana!') data = dict(id=10, title="I am a banana!")
table.upsert(data, ['id']) table.upsert(data, ["id"])
assert len(table) == 1, len(table) assert len(table) == 1, len(table)
def test_update_while_iter(self): def test_update_while_iter(self):
for row in self.tbl: for row in self.tbl:
row['foo'] = 'bar' row["foo"] = "bar"
self.tbl.update(row, ['place', 'date']) self.tbl.update(row, ["place", "date"])
assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
def test_weird_column_names(self): def test_weird_column_names(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.tbl.insert({ self.tbl.insert(
'date': datetime(2011, 1, 2), {
'temperature': -10, "date": datetime(2011, 1, 2),
'foo.bar': 'Berlin', "temperature": -10,
'qux.bar': 'Huhu' "foo.bar": "Berlin",
}) "qux.bar": "Huhu",
}
)
def test_cased_column_names(self): def test_cased_column_names(self):
tbl = self.db['cased_column_names'] tbl = self.db["cased_column_names"]
tbl.insert({ tbl.insert({"place": "Berlin"})
'place': 'Berlin', tbl.insert({"Place": "Berlin"})
}) tbl.insert({"PLACE ": "Berlin"})
tbl.insert({
'Place': 'Berlin',
})
tbl.insert({
'PLACE ': 'Berlin',
})
assert len(tbl.columns) == 2, tbl.columns assert len(tbl.columns) == 2, tbl.columns
assert len(list(tbl.find(Place='Berlin'))) == 3 assert len(list(tbl.find(Place="Berlin"))) == 3
assert len(list(tbl.find(place='Berlin'))) == 3 assert len(list(tbl.find(place="Berlin"))) == 3
assert len(list(tbl.find(PLACE='Berlin'))) == 3 assert len(list(tbl.find(PLACE="Berlin"))) == 3
def test_invalid_column_names(self): def test_invalid_column_names(self):
tbl = self.db['weather'] tbl = self.db["weather"]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tbl.insert({None: 'banana'}) tbl.insert({None: "banana"})
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tbl.insert({'': 'banana'}) tbl.insert({"": "banana"})
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tbl.insert({'-': 'banana'}) tbl.insert({"-": "banana"})
def test_delete(self): def test_delete(self):
self.tbl.insert({ self.tbl.insert(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
'temperature': -10,
'place': 'Berlin'}
) )
original_count = len(self.tbl) original_count = len(self.tbl)
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
# Test bad use of API # Test bad use of API
with self.assertRaises(ArgumentError): with self.assertRaises(ArgumentError):
self.tbl.delete({'place': 'Berlin'}) self.tbl.delete({"place": "Berlin"})
assert len(self.tbl) == original_count, len(self.tbl) assert len(self.tbl) == original_count, len(self.tbl)
assert self.tbl.delete(place='Berlin') is True, 'should return 1' assert self.tbl.delete(place="Berlin") is True, "should return 1"
assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
assert self.tbl.delete() is True, 'should return non zero' assert self.tbl.delete() is True, "should return non zero"
assert len(self.tbl) == 0, len(self.tbl) assert len(self.tbl) == 0, len(self.tbl)
def test_repr(self): def test_repr(self):
assert repr(self.tbl) == '<Table(weather)>', \ assert (
'the representation should be <Table(weather)>' repr(self.tbl) == "<Table(weather)>"
), "the representation should be <Table(weather)>"
def test_delete_nonexist_entry(self): def test_delete_nonexist_entry(self):
assert self.tbl.delete(place='Berlin') is False, \ assert (
'entry not exist, should fail to delete' self.tbl.delete(place="Berlin") is False
), "entry not exist, should fail to delete"
def test_find_one(self): def test_find_one(self):
self.tbl.insert({ self.tbl.insert(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
'temperature': -10, )
'place': 'Berlin' d = self.tbl.find_one(place="Berlin")
}) assert d["temperature"] == -10, d
d = self.tbl.find_one(place='Berlin') d = self.tbl.find_one(place="Atlantis")
assert d['temperature'] == -10, d
d = self.tbl.find_one(place='Atlantis')
assert d is None, d assert d is None, d
def test_count(self): def test_count(self):
@ -347,31 +326,31 @@ class TableTestCase(unittest.TestCase):
assert len(ds) == 1, ds assert len(ds) == 1, ds
ds = list(self.tbl.find(_step=2)) ds = list(self.tbl.find(_step=2))
assert len(ds) == len(TEST_DATA), ds assert len(ds) == len(TEST_DATA), ds
ds = list(self.tbl.find(order_by=['temperature'])) ds = list(self.tbl.find(order_by=["temperature"]))
assert ds[0]['temperature'] == -1, ds assert ds[0]["temperature"] == -1, ds
ds = list(self.tbl.find(order_by=['-temperature'])) ds = list(self.tbl.find(order_by=["-temperature"]))
assert ds[0]['temperature'] == 8, ds assert ds[0]["temperature"] == 8, ds
ds = list(self.tbl.find(self.tbl.table.columns.temperature > 4)) ds = list(self.tbl.find(self.tbl.table.columns.temperature > 4))
assert len(ds) == 3, ds assert len(ds) == 3, ds
def test_find_dsl(self): def test_find_dsl(self):
ds = list(self.tbl.find(place={'like': '%lw%'})) ds = list(self.tbl.find(place={"like": "%lw%"}))
assert len(ds) == 3, ds assert len(ds) == 3, ds
ds = list(self.tbl.find(temperature={'>': 5})) ds = list(self.tbl.find(temperature={">": 5}))
assert len(ds) == 2, ds assert len(ds) == 2, ds
ds = list(self.tbl.find(temperature={'>=': 5})) ds = list(self.tbl.find(temperature={">=": 5}))
assert len(ds) == 3, ds assert len(ds) == 3, ds
ds = list(self.tbl.find(temperature={'<': 0})) ds = list(self.tbl.find(temperature={"<": 0}))
assert len(ds) == 1, ds assert len(ds) == 1, ds
ds = list(self.tbl.find(temperature={'<=': 0})) ds = list(self.tbl.find(temperature={"<=": 0}))
assert len(ds) == 2, ds assert len(ds) == 2, ds
ds = list(self.tbl.find(temperature={'!=': -1})) ds = list(self.tbl.find(temperature={"!=": -1}))
assert len(ds) == 5, ds assert len(ds) == 5, ds
ds = list(self.tbl.find(temperature={'between': [5, 8]})) ds = list(self.tbl.find(temperature={"between": [5, 8]}))
assert len(ds) == 3, ds assert len(ds) == 3, ds
ds = list(self.tbl.find(place={'=': 'G€lway'})) ds = list(self.tbl.find(place={"=": "G€lway"}))
assert len(ds) == 3, ds assert len(ds) == 3, ds
ds = list(self.tbl.find(place={'ilike': '%LwAy'})) ds = list(self.tbl.find(place={"ilike": "%LwAy"}))
assert len(ds) == 3, ds assert len(ds) == 3, ds
def test_offset(self): def test_offset(self):
@ -384,36 +363,39 @@ class TableTestCase(unittest.TestCase):
ds = list(self.tbl.find(place=TEST_CITY_1, _streamed=True, _step=1)) ds = list(self.tbl.find(place=TEST_CITY_1, _streamed=True, _step=1))
assert len(ds) == 3, len(ds) assert len(ds) == 3, len(ds)
for row in self.tbl.find(place=TEST_CITY_1, _streamed=True, _step=1): for row in self.tbl.find(place=TEST_CITY_1, _streamed=True, _step=1):
row['temperature'] = -1 row["temperature"] = -1
self.tbl.update(row, ['id']) self.tbl.update(row, ["id"])
def test_distinct(self): def test_distinct(self):
x = list(self.tbl.distinct('place')) x = list(self.tbl.distinct("place"))
assert len(x) == 2, x assert len(x) == 2, x
x = list(self.tbl.distinct('place', 'date')) x = list(self.tbl.distinct("place", "date"))
assert len(x) == 6, x assert len(x) == 6, x
x = list(self.tbl.distinct( x = list(
'place', 'date', self.tbl.distinct(
self.tbl.table.columns.date >= datetime(2011, 1, 2, 0, 0))) "place",
"date",
self.tbl.table.columns.date >= datetime(2011, 1, 2, 0, 0),
)
)
assert len(x) == 4, x assert len(x) == 4, x
x = list(self.tbl.distinct('temperature', place='B€rkeley')) x = list(self.tbl.distinct("temperature", place="B€rkeley"))
assert len(x) == 3, x assert len(x) == 3, x
x = list(self.tbl.distinct('temperature', x = list(self.tbl.distinct("temperature", place=["B€rkeley", "G€lway"]))
place=['B€rkeley', 'G€lway']))
assert len(x) == 6, x assert len(x) == 6, x
def test_insert_many(self): def test_insert_many(self):
data = TEST_DATA * 100 data = TEST_DATA * 100
self.tbl.insert_many(data, chunk_size=13) self.tbl.insert_many(data, chunk_size=13)
assert len(self.tbl) == len(data) + 6 assert len(self.tbl) == len(data) + 6, (len(self.tbl), len(data))
def test_chunked_insert(self): def test_chunked_insert(self):
data = TEST_DATA * 100 data = TEST_DATA * 100
with chunked.ChunkedInsert(self.tbl) as chunk_tbl: with chunked.ChunkedInsert(self.tbl) as chunk_tbl:
for item in data: for item in data:
chunk_tbl.insert(item) chunk_tbl.insert(item)
assert len(self.tbl) == len(data) + 6 assert len(self.tbl) == len(data) + 6, (len(self.tbl), len(data))
def test_chunked_insert_callback(self): def test_chunked_insert_callback(self):
data = TEST_DATA * 100 data = TEST_DATA * 100
@ -422,6 +404,7 @@ class TableTestCase(unittest.TestCase):
def callback(queue): def callback(queue):
nonlocal N nonlocal N
N += len(queue) N += len(queue)
with chunked.ChunkedInsert(self.tbl, callback=callback) as chunk_tbl: with chunked.ChunkedInsert(self.tbl, callback=callback) as chunk_tbl:
for item in data: for item in data:
chunk_tbl.insert(item) chunk_tbl.insert(item)
@ -429,73 +412,73 @@ class TableTestCase(unittest.TestCase):
assert len(self.tbl) == len(data) + 6 assert len(self.tbl) == len(data) + 6
def test_update_many(self): def test_update_many(self):
tbl = self.db['update_many_test'] tbl = self.db["update_many_test"]
tbl.insert_many([ tbl.insert_many([dict(temp=10), dict(temp=20), dict(temp=30)])
dict(temp=10), dict(temp=20), dict(temp=30) tbl.update_many([dict(id=1, temp=50), dict(id=3, temp=50)], "id")
])
tbl.update_many(
[dict(id=1, temp=50), dict(id=3, temp=50)], 'id'
)
# Ensure data has been updated. # Ensure data has been updated.
assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp'] assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"]
def test_chunked_update(self): def test_chunked_update(self):
tbl = self.db['update_many_test'] tbl = self.db["update_many_test"]
tbl.insert_many([ tbl.insert_many(
dict(temp=10, location='asdf'), dict(temp=20, location='qwer'), dict(temp=30, location='asdf') [
]) dict(temp=10, location="asdf"),
dict(temp=20, location="qwer"),
dict(temp=30, location="asdf"),
]
)
chunked_tbl = chunked.ChunkedUpdate(tbl, 'id') chunked_tbl = chunked.ChunkedUpdate(tbl, "id")
chunked_tbl.update(dict(id=1, temp=50)) chunked_tbl.update(dict(id=1, temp=50))
chunked_tbl.update(dict(id=2, location='asdf')) chunked_tbl.update(dict(id=2, location="asdf"))
chunked_tbl.update(dict(id=3, temp=50)) chunked_tbl.update(dict(id=3, temp=50))
chunked_tbl.flush() chunked_tbl.flush()
# Ensure data has been updated. # Ensure data has been updated.
assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp'] == 50 assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] == 50
assert tbl.find_one(id=2)['location'] == tbl.find_one(id=3)['location'] == 'asdf' assert (
tbl.find_one(id=2)["location"] == tbl.find_one(id=3)["location"] == "asdf"
) # noqa
def test_upsert_many(self): def test_upsert_many(self):
# Also tests updating on records with different attributes # Also tests updating on records with different attributes
tbl = self.db['upsert_many_test'] tbl = self.db["upsert_many_test"]
W = 100 W = 100
tbl.upsert_many([dict(age=10), dict(weight=W)], 'id') tbl.upsert_many([dict(age=10), dict(weight=W)], "id")
assert tbl.find_one(id=1)['age'] == 10 assert tbl.find_one(id=1)["age"] == 10
tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], 'id') tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], "id")
assert tbl.find_one(id=2)['weight'] == W / 2 assert tbl.find_one(id=2)["weight"] == W / 2
def test_drop_operations(self): def test_drop_operations(self):
assert self.tbl._table is not None, \ assert self.tbl._table is not None, "table shouldn't be dropped yet"
'table shouldn\'t be dropped yet'
self.tbl.drop() self.tbl.drop()
assert self.tbl._table is None, \ assert self.tbl._table is None, "table should be dropped now"
'table should be dropped now'
assert list(self.tbl.all()) == [], self.tbl.all() assert list(self.tbl.all()) == [], self.tbl.all()
assert self.tbl.count() == 0, self.tbl.count() assert self.tbl.count() == 0, self.tbl.count()
def test_table_drop(self): def test_table_drop(self):
assert 'weather' in self.db assert "weather" in self.db
self.db['weather'].drop() self.db["weather"].drop()
assert 'weather' not in self.db assert "weather" not in self.db
def test_table_drop_then_create(self): def test_table_drop_then_create(self):
assert 'weather' in self.db assert "weather" in self.db
self.db['weather'].drop() self.db["weather"].drop()
assert 'weather' not in self.db assert "weather" not in self.db
self.db['weather'].insert({'foo': 'bar'}) self.db["weather"].insert({"foo": "bar"})
def test_columns(self): def test_columns(self):
cols = self.tbl.columns cols = self.tbl.columns
assert len(list(cols)) == 4, 'column count mismatch' assert len(list(cols)) == 4, "column count mismatch"
assert 'date' in cols and 'temperature' in cols and 'place' in cols assert "date" in cols and "temperature" in cols and "place" in cols
def test_drop_column(self): def test_drop_column(self):
try: try:
self.tbl.drop_column('date') self.tbl.drop_column("date")
assert 'date' not in self.tbl.columns assert "date" not in self.tbl.columns
except RuntimeError: except RuntimeError:
pass pass
@ -507,71 +490,69 @@ class TableTestCase(unittest.TestCase):
def test_update(self): def test_update(self):
date = datetime(2011, 1, 2) date = datetime(2011, 1, 2)
res = self.tbl.update({ res = self.tbl.update(
'date': date, {"date": date, "temperature": -10, "place": TEST_CITY_1}, ["place", "date"]
'temperature': -10,
'place': TEST_CITY_1},
['place', 'date']
) )
assert res, 'update should return True' assert res, "update should return True"
m = self.tbl.find_one(place=TEST_CITY_1, date=date) m = self.tbl.find_one(place=TEST_CITY_1, date=date)
assert m['temperature'] == -10, \ assert m["temperature"] == -10, (
'new temp. should be -10 but is %d' % m['temperature'] "new temp. should be -10 but is %d" % m["temperature"]
)
def test_create_column(self): def test_create_column(self):
tbl = self.tbl tbl = self.tbl
tbl.create_column('foo', FLOAT) flt = self.db.types.float
assert 'foo' in tbl.table.c, tbl.table.c tbl.create_column("foo", flt)
assert isinstance(tbl.table.c['foo'].type, FLOAT), \ assert "foo" in tbl.table.c, tbl.table.c
tbl.table.c['foo'].type assert isinstance(tbl.table.c["foo"].type, flt), tbl.table.c["foo"].type
assert 'foo' in tbl.columns, tbl.columns assert "foo" in tbl.columns, tbl.columns
def test_ensure_column(self): def test_ensure_column(self):
tbl = self.tbl tbl = self.tbl
tbl.create_column_by_example('foo', 0.1) flt = self.db.types.float
assert 'foo' in tbl.table.c, tbl.table.c tbl.create_column_by_example("foo", 0.1)
assert isinstance(tbl.table.c['foo'].type, FLOAT), \ assert "foo" in tbl.table.c, tbl.table.c
tbl.table.c['bar'].type assert isinstance(tbl.table.c["foo"].type, flt), tbl.table.c["bar"].type
tbl.create_column_by_example('bar', 1) tbl.create_column_by_example("bar", 1)
assert 'bar' in tbl.table.c, tbl.table.c assert "bar" in tbl.table.c, tbl.table.c
assert isinstance(tbl.table.c['bar'].type, BIGINT), \ assert isinstance(tbl.table.c["bar"].type, BIGINT), tbl.table.c["bar"].type
tbl.table.c['bar'].type tbl.create_column_by_example("pippo", "test")
tbl.create_column_by_example('pippo', 'test') assert "pippo" in tbl.table.c, tbl.table.c
assert 'pippo' in tbl.table.c, tbl.table.c assert isinstance(tbl.table.c["pippo"].type, TEXT), tbl.table.c["pippo"].type
assert isinstance(tbl.table.c['pippo'].type, TEXT), \ tbl.create_column_by_example("bigbar", 11111111111)
tbl.table.c['pippo'].type assert "bigbar" in tbl.table.c, tbl.table.c
tbl.create_column_by_example('bigbar', 11111111111) assert isinstance(tbl.table.c["bigbar"].type, BIGINT), tbl.table.c[
assert 'bigbar' in tbl.table.c, tbl.table.c "bigbar"
assert isinstance(tbl.table.c['bigbar'].type, BIGINT), \ ].type
tbl.table.c['bigbar'].type tbl.create_column_by_example("littlebar", -11111111111)
tbl.create_column_by_example('littlebar', -11111111111) assert "littlebar" in tbl.table.c, tbl.table.c
assert 'littlebar' in tbl.table.c, tbl.table.c assert isinstance(tbl.table.c["littlebar"].type, BIGINT), tbl.table.c[
assert isinstance(tbl.table.c['littlebar'].type, BIGINT), \ "littlebar"
tbl.table.c['littlebar'].type ].type
def test_key_order(self): def test_key_order(self):
res = self.db.query('SELECT temperature, place FROM weather LIMIT 1') res = self.db.query("SELECT temperature, place FROM weather LIMIT 1")
keys = list(res.next().keys()) keys = list(res.next().keys())
assert keys[0] == 'temperature' assert keys[0] == "temperature"
assert keys[1] == 'place' assert keys[1] == "place"
def test_empty_query(self): def test_empty_query(self):
empty = list(self.tbl.find(place='not in data')) empty = list(self.tbl.find(place="not in data"))
assert len(empty) == 0, empty assert len(empty) == 0, empty
class Constructor(dict): class Constructor(dict):
""" Very simple low-functionality extension to ``dict`` to """ Very simple low-functionality extension to ``dict`` to
provide attribute access to dictionary contents""" provide attribute access to dictionary contents"""
def __getattr__(self, name): def __getattr__(self, name):
return self[name] return self[name]
class RowTypeTestCase(unittest.TestCase): class RowTypeTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.db = connect('sqlite:///:memory:', row_type=Constructor) self.db = connect(row_type=Constructor)
self.tbl = self.db['weather'] self.tbl = self.db["weather"]
for row in TEST_DATA: for row in TEST_DATA:
self.tbl.insert(row) self.tbl.insert(row)
@ -580,15 +561,13 @@ class RowTypeTestCase(unittest.TestCase):
self.db[table].drop() self.db[table].drop()
def test_find_one(self): def test_find_one(self):
self.tbl.insert({ self.tbl.insert(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
'temperature': -10,
'place': 'Berlin'}
) )
d = self.tbl.find_one(place='Berlin') d = self.tbl.find_one(place="Berlin")
assert d['temperature'] == -10, d assert d["temperature"] == -10, d
assert d.temperature == -10, d assert d.temperature == -10, d
d = self.tbl.find_one(place='Atlantis') d = self.tbl.find_one(place="Atlantis")
assert d is None, d assert d is None, d
def test_find(self): def test_find(self):
@ -602,11 +581,11 @@ class RowTypeTestCase(unittest.TestCase):
assert isinstance(item, Constructor), item assert isinstance(item, Constructor), item
def test_distinct(self): def test_distinct(self):
x = list(self.tbl.distinct('place')) x = list(self.tbl.distinct("place"))
assert len(x) == 2, x assert len(x) == 2, x
for item in x: for item in x:
assert isinstance(item, Constructor), item assert isinstance(item, Constructor), item
x = list(self.tbl.distinct('place', 'date')) x = list(self.tbl.distinct("place", "date"))
assert len(x) == 6, x assert len(x) == 6, x
for item in x: for item in x:
assert isinstance(item, Constructor), item assert isinstance(item, Constructor), item
@ -619,5 +598,5 @@ class RowTypeTestCase(unittest.TestCase):
assert c == len(self.tbl) assert c == len(self.tbl)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()