Apply black

This commit is contained in:
Friedrich Lindenberg 2020-08-02 12:52:11 +02:00
parent eb4abcbe42
commit a17f2c8d5c
10 changed files with 571 additions and 582 deletions

View File

@ -5,17 +5,24 @@ from dataset.table import Table
from dataset.util import row_type from dataset.util import row_type
# shut up useless SA warning: # shut up useless SA warning:
warnings.filterwarnings("ignore", "Unicode type received non-unicode bind param value.")
warnings.filterwarnings( warnings.filterwarnings(
'ignore', 'Unicode type received non-unicode bind param value.') "ignore", "Skipping unsupported ALTER for creation of implicit constraint"
warnings.filterwarnings( )
'ignore', 'Skipping unsupported ALTER for creation of implicit constraint')
__all__ = ['Database', 'Table', 'freeze', 'connect'] __all__ = ["Database", "Table", "freeze", "connect"]
__version__ = '1.3.2' __version__ = "1.3.2"
def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None, def connect(
reflect_views=True, ensure_schema=True, row_type=row_type): url=None,
schema=None,
reflect_metadata=True,
engine_kwargs=None,
reflect_views=True,
ensure_schema=True,
row_type=row_type,
):
""" Opens a new connection to a database. """ Opens a new connection to a database.
*url* can be any valid `SQLAlchemy engine URL`_. If *url* is not defined *url* can be any valid `SQLAlchemy engine URL`_. If *url* is not defined
@ -40,8 +47,14 @@ def connect(url=None, schema=None, reflect_metadata=True, engine_kwargs=None,
.. _DB connection timeout: http://docs.sqlalchemy.org/en/latest/core/pooling.html#setting-pool-recycle .. _DB connection timeout: http://docs.sqlalchemy.org/en/latest/core/pooling.html#setting-pool-recycle
""" """
if url is None: if url is None:
url = os.environ.get('DATABASE_URL', 'sqlite://') url = os.environ.get("DATABASE_URL", "sqlite://")
return Database(url, schema=schema, reflect_metadata=reflect_metadata, return Database(
engine_kwargs=engine_kwargs, reflect_views=reflect_views, url,
ensure_schema=ensure_schema, row_type=row_type) schema=schema,
reflect_metadata=reflect_metadata,
engine_kwargs=engine_kwargs,
reflect_views=reflect_views,
ensure_schema=ensure_schema,
row_type=row_type,
)

View File

@ -22,9 +22,16 @@ log = logging.getLogger(__name__)
class Database(object): class Database(object):
"""A database object represents a SQL database with multiple tables.""" """A database object represents a SQL database with multiple tables."""
def __init__(self, url, schema=None, reflect_metadata=True, def __init__(
engine_kwargs=None, reflect_views=True, self,
ensure_schema=True, row_type=row_type): url,
schema=None,
reflect_metadata=True,
engine_kwargs=None,
reflect_views=True,
ensure_schema=True,
row_type=row_type,
):
"""Configure and connect to the database.""" """Configure and connect to the database."""
if engine_kwargs is None: if engine_kwargs is None:
engine_kwargs = {} engine_kwargs = {}
@ -41,13 +48,13 @@ class Database(object):
if len(parsed_url.query): if len(parsed_url.query):
query = parse_qs(parsed_url.query) query = parse_qs(parsed_url.query)
if schema is None: if schema is None:
schema_qs = query.get('schema', query.get('searchpath', [])) schema_qs = query.get("schema", query.get("searchpath", []))
if len(schema_qs): if len(schema_qs):
schema = schema_qs.pop() schema = schema_qs.pop()
self.schema = schema self.schema = schema
self.engine = create_engine(url, **engine_kwargs) self.engine = create_engine(url, **engine_kwargs)
self.is_postgres = self.engine.dialect.name == 'postgresql' self.is_postgres = self.engine.dialect.name == "postgresql"
self.types = Types(is_postgres=self.is_postgres) self.types = Types(is_postgres=self.is_postgres)
self.url = url self.url = url
self.row_type = row_type self.row_type = row_type
@ -57,7 +64,7 @@ class Database(object):
@property @property
def executable(self): def executable(self):
"""Connection against which statements will be executed.""" """Connection against which statements will be executed."""
if not hasattr(self.local, 'conn'): if not hasattr(self.local, "conn"):
self.local.conn = self.engine.connect() self.local.conn = self.engine.connect()
return self.local.conn return self.local.conn
@ -80,7 +87,7 @@ class Database(object):
@property @property
def in_transaction(self): def in_transaction(self):
"""Check if this database is in a transactional context.""" """Check if this database is in a transactional context."""
if not hasattr(self.local, 'tx'): if not hasattr(self.local, "tx"):
return False return False
return len(self.local.tx) > 0 return len(self.local.tx) > 0
@ -94,7 +101,7 @@ class Database(object):
No data will be written until the transaction has been committed. No data will be written until the transaction has been committed.
""" """
if not hasattr(self.local, 'tx'): if not hasattr(self.local, "tx"):
self.local.tx = [] self.local.tx = []
self.local.tx.append(self.executable.begin()) self.local.tx.append(self.executable.begin())
@ -103,7 +110,7 @@ class Database(object):
Make all statements executed since the transaction was begun permanent. Make all statements executed since the transaction was begun permanent.
""" """
if hasattr(self.local, 'tx') and self.local.tx: if hasattr(self.local, "tx") and self.local.tx:
tx = self.local.tx.pop() tx = self.local.tx.pop()
tx.commit() tx.commit()
self._flush_tables() self._flush_tables()
@ -113,7 +120,7 @@ class Database(object):
Discard all statements executed since the transaction was begun. Discard all statements executed since the transaction was begun.
""" """
if hasattr(self.local, 'tx') and self.local.tx: if hasattr(self.local, "tx") and self.local.tx:
tx = self.local.tx.pop() tx = self.local.tx.pop()
tx.rollback() tx.rollback()
self._flush_tables() self._flush_tables()
@ -190,15 +197,19 @@ class Database(object):
table5 = db.create_table('population5', table5 = db.create_table('population5',
primary_id=False) primary_id=False)
""" """
assert not isinstance(primary_type, str), \ assert not isinstance(
'Text-based primary_type support is dropped, use db.types.' primary_type, str
), "Text-based primary_type support is dropped, use db.types."
table_name = normalize_table_name(table_name) table_name = normalize_table_name(table_name)
with self.lock: with self.lock:
if table_name not in self._tables: if table_name not in self._tables:
self._tables[table_name] = Table(self, table_name, self._tables[table_name] = Table(
self,
table_name,
primary_id=primary_id, primary_id=primary_id,
primary_type=primary_type, primary_type=primary_type,
auto_create=True) auto_create=True,
)
return self._tables.get(table_name) return self._tables.get(table_name)
def load_table(self, table_name): def load_table(self, table_name):
@ -265,7 +276,7 @@ class Database(object):
""" """
if isinstance(query, str): if isinstance(query, str):
query = text(query) query = text(query)
_step = kwargs.pop('_step', QUERY_STEP) _step = kwargs.pop("_step", QUERY_STEP)
if _step is False or _step == 0: if _step is False or _step == 0:
_step = None _step = None
rp = self.executable.execute(query, *args, **kwargs) rp = self.executable.execute(query, *args, **kwargs)
@ -273,4 +284,4 @@ class Database(object):
def __repr__(self): def __repr__(self):
"""Text representation contains the URL.""" """Text representation contains the URL."""
return '<Database(%s)>' % safe_url(self.url) return "<Database(%s)>" % safe_url(self.url)

View File

@ -22,20 +22,27 @@ log = logging.getLogger(__name__)
class Table(object): class Table(object):
"""Represents a table in a database and exposes common operations.""" """Represents a table in a database and exposes common operations."""
PRIMARY_DEFAULT = 'id'
def __init__(self, database, table_name, primary_id=None, PRIMARY_DEFAULT = "id"
primary_type=None, auto_create=False):
def __init__(
self,
database,
table_name,
primary_id=None,
primary_type=None,
auto_create=False,
):
"""Initialise the table from database schema.""" """Initialise the table from database schema."""
self.db = database self.db = database
self.name = normalize_table_name(table_name) self.name = normalize_table_name(table_name)
self._table = None self._table = None
self._columns = None self._columns = None
self._indexes = [] self._indexes = []
self._primary_id = primary_id if primary_id is not None \ self._primary_id = (
else self.PRIMARY_DEFAULT primary_id if primary_id is not None else self.PRIMARY_DEFAULT
self._primary_type = primary_type if primary_type is not None \ )
else Types.integer self._primary_type = primary_type if primary_type is not None else Types.integer
self._auto_create = auto_create self._auto_create = auto_create
@property @property
@ -206,8 +213,7 @@ class Table(object):
if return_count: if return_count:
return self.count(clause) return self.count(clause)
def update_many(self, rows, keys, chunk_size=1000, ensure=None, def update_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
types=None):
"""Update many rows in the table at a time. """Update many rows in the table at a time.
This is significantly faster than updating them one by one. Per default This is significantly faster than updating them one by one. Per default
@ -229,16 +235,14 @@ class Table(object):
# bindparam requires names to not conflict (cannot be "id" for id) # bindparam requires names to not conflict (cannot be "id" for id)
for key in keys: for key in keys:
row['_%s' % key] = row[key] row["_%s" % key] = row[key]
# Update when chunk_size is fulfilled or this is the last row # Update when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1: if len(chunk) == chunk_size or index == len(rows) - 1:
cl = [self.table.c[k] == bindparam('_%s' % k) for k in keys] cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys]
stmt = self.table.update( stmt = self.table.update(
whereclause=and_(*cl), whereclause=and_(*cl),
values={ values={col: bindparam(col, required=False) for col in columns},
col: bindparam(col, required=False) for col in columns
}
) )
self.db.executable.execute(stmt, chunk) self.db.executable.execute(stmt, chunk)
chunk = [] chunk = []
@ -261,8 +265,7 @@ class Table(object):
return self.insert(row, ensure=False) return self.insert(row, ensure=False)
return True return True
def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, def upsert_many(self, rows, keys, chunk_size=1000, ensure=None, types=None):
types=None):
""" """
Sorts multiple input rows into upserts and inserts. Inserts are passed Sorts multiple input rows into upserts and inserts. Inserts are passed
to insert_many and upserts are updated. to insert_many and upserts are updated.
@ -311,19 +314,20 @@ class Table(object):
with self.db.lock: with self.db.lock:
self._flush_metadata() self._flush_metadata()
try: try:
self._table = SQLATable(self.name, self._table = SQLATable(
self.db.metadata, self.name, self.db.metadata, schema=self.db.schema, autoload=True
schema=self.db.schema, )
autoload=True)
except NoSuchTableError: except NoSuchTableError:
self._table = None self._table = None
def _threading_warn(self): def _threading_warn(self):
if self.db.in_transaction and threading.active_count() > 1: if self.db.in_transaction and threading.active_count() > 1:
warnings.warn("Changing the database schema inside a transaction " warnings.warn(
"Changing the database schema inside a transaction "
"in a multi-threaded environment is likely to lead " "in a multi-threaded environment is likely to lead "
"to race conditions and synchronization issues.", "to race conditions and synchronization issues.",
RuntimeWarning) RuntimeWarning,
)
def _sync_table(self, columns): def _sync_table(self, columns):
"""Lazy load, create or adapt the table structure in the database.""" """Lazy load, create or adapt the table structure in the database."""
@ -338,18 +342,21 @@ class Table(object):
# Keep the lock scope small because this is run very often. # Keep the lock scope small because this is run very often.
with self.db.lock: with self.db.lock:
self._threading_warn() self._threading_warn()
self._table = SQLATable(self.name, self._table = SQLATable(
self.db.metadata, self.name, self.db.metadata, schema=self.db.schema
schema=self.db.schema) )
if self._primary_id is not False: if self._primary_id is not False:
# This can go wrong on DBMS like MySQL and SQLite where # This can go wrong on DBMS like MySQL and SQLite where
# tables cannot have no columns. # tables cannot have no columns.
primary_id = self._primary_id primary_id = self._primary_id
primary_type = self._primary_type primary_type = self._primary_type
increment = primary_type in [Types.integer, Types.bigint] increment = primary_type in [Types.integer, Types.bigint]
column = Column(primary_id, primary_type, column = Column(
primary_id,
primary_type,
primary_key=True, primary_key=True,
autoincrement=increment) autoincrement=increment,
)
self._table.append_column(column) self._table.append_column(column)
for column in columns: for column in columns:
if not column.name == self._primary_id: if not column.name == self._primary_id:
@ -361,9 +368,7 @@ class Table(object):
self._threading_warn() self._threading_warn()
for column in columns: for column in columns:
if not self.has_column(column.name): if not self.has_column(column.name):
self.db.op.add_column(self.name, self.db.op.add_column(self.name, column, self.db.schema)
column,
self.db.schema)
self._reflect_table() self._reflect_table()
def _sync_columns(self, row, ensure, types=None): def _sync_columns(self, row, ensure, types=None):
@ -397,31 +402,31 @@ class Table(object):
return ensure return ensure
def _generate_clause(self, column, op, value): def _generate_clause(self, column, op, value):
if op in ('like',): if op in ("like",):
return self.table.c[column].like(value) return self.table.c[column].like(value)
if op in ('ilike',): if op in ("ilike",):
return self.table.c[column].ilike(value) return self.table.c[column].ilike(value)
if op in ('>', 'gt'): if op in (">", "gt"):
return self.table.c[column] > value return self.table.c[column] > value
if op in ('<', 'lt'): if op in ("<", "lt"):
return self.table.c[column] < value return self.table.c[column] < value
if op in ('>=', 'gte'): if op in (">=", "gte"):
return self.table.c[column] >= value return self.table.c[column] >= value
if op in ('<=', 'lte'): if op in ("<=", "lte"):
return self.table.c[column] <= value return self.table.c[column] <= value
if op in ('=', '==', 'is'): if op in ("=", "==", "is"):
return self.table.c[column] == value return self.table.c[column] == value
if op in ('!=', '<>', 'not'): if op in ("!=", "<>", "not"):
return self.table.c[column] != value return self.table.c[column] != value
if op in ('in'): if op in ("in"):
return self.table.c[column].in_(value) return self.table.c[column].in_(value)
if op in ('between', '..'): if op in ("between", ".."):
start, end = value start, end = value
return self.table.c[column].between(start, end) return self.table.c[column].between(start, end)
if op in ('startswith',): if op in ("startswith",):
return self.table.c[column].like('%' + value) return self.table.c[column].like("%" + value)
if op in ('endswith',): if op in ("endswith",):
return self.table.c[column].like(value + '%') return self.table.c[column].like(value + "%")
return false() return false()
def _args_to_clause(self, args, clauses=()): def _args_to_clause(self, args, clauses=()):
@ -431,12 +436,12 @@ class Table(object):
if not self.has_column(column): if not self.has_column(column):
clauses.append(false()) clauses.append(false())
elif isinstance(value, (list, tuple, set)): elif isinstance(value, (list, tuple, set)):
clauses.append(self._generate_clause(column, 'in', value)) clauses.append(self._generate_clause(column, "in", value))
elif isinstance(value, dict): elif isinstance(value, dict):
for op, op_value in value.items(): for op, op_value in value.items():
clauses.append(self._generate_clause(column, op, op_value)) clauses.append(self._generate_clause(column, op, op_value))
else: else:
clauses.append(self._generate_clause(column, '=', value)) clauses.append(self._generate_clause(column, "=", value))
return and_(*clauses) return and_(*clauses)
def _args_to_order_by(self, order_by): def _args_to_order_by(self, order_by):
@ -444,11 +449,11 @@ class Table(object):
for ordering in ensure_list(order_by): for ordering in ensure_list(order_by):
if ordering is None: if ordering is None:
continue continue
column = ordering.lstrip('-') column = ordering.lstrip("-")
column = self._get_column_name(column) column = self._get_column_name(column)
if not self.has_column(column): if not self.has_column(column):
continue continue
if ordering.startswith('-'): if ordering.startswith("-"):
orderings.append(self.table.c[column].desc()) orderings.append(self.table.c[column].desc())
else: else:
orderings.append(self.table.c[column].asc()) orderings.append(self.table.c[column].asc())
@ -501,7 +506,7 @@ class Table(object):
:: ::
table.drop_column('created_at') table.drop_column('created_at')
""" """
if self.db.engine.dialect.name == 'sqlite': if self.db.engine.dialect.name == "sqlite":
raise RuntimeError("SQLite does not support dropping columns.") raise RuntimeError("SQLite does not support dropping columns.")
name = self._get_column_name(name) name = self._get_column_name(name)
with self.db.lock: with self.db.lock:
@ -510,11 +515,7 @@ class Table(object):
return return
self._threading_warn() self._threading_warn()
self.db.op.drop_column( self.db.op.drop_column(self.table.name, name, self.table.schema)
self.table.name,
name,
self.table.schema
)
self._reflect_table() self._reflect_table()
def drop(self): def drop(self):
@ -542,7 +543,7 @@ class Table(object):
return False return False
indexes = self.db.inspect.get_indexes(self.name, schema=self.db.schema) indexes = self.db.inspect.get_indexes(self.name, schema=self.db.schema)
for index in indexes: for index in indexes:
if columns == set(index.get('column_names', [])): if columns == set(index.get("column_names", [])):
self._indexes.append(columns) self._indexes.append(columns)
return True return True
return False return False
@ -600,19 +601,17 @@ class Table(object):
if not self.exists: if not self.exists:
return iter([]) return iter([])
_limit = kwargs.pop('_limit', None) _limit = kwargs.pop("_limit", None)
_offset = kwargs.pop('_offset', 0) _offset = kwargs.pop("_offset", 0)
order_by = kwargs.pop('order_by', None) order_by = kwargs.pop("order_by", None)
_streamed = kwargs.pop('_streamed', False) _streamed = kwargs.pop("_streamed", False)
_step = kwargs.pop('_step', QUERY_STEP) _step = kwargs.pop("_step", QUERY_STEP)
if _step is False or _step == 0: if _step is False or _step == 0:
_step = None _step = None
order_by = self._args_to_order_by(order_by) order_by = self._args_to_order_by(order_by)
args = self._args_to_clause(kwargs, clauses=_clauses) args = self._args_to_clause(kwargs, clauses=_clauses)
query = self.table.select(whereclause=args, query = self.table.select(whereclause=args, limit=_limit, offset=_offset)
limit=_limit,
offset=_offset)
if len(order_by): if len(order_by):
query = query.order_by(*order_by) query = query.order_by(*order_by)
@ -621,9 +620,7 @@ class Table(object):
conn = self.db.engine.connect() conn = self.db.engine.connect()
conn = conn.execution_options(stream_results=True) conn = conn.execution_options(stream_results=True)
return ResultIter(conn.execute(query), return ResultIter(conn.execute(query), row_type=self.db.row_type, step=_step)
row_type=self.db.row_type,
step=_step)
def find_one(self, *args, **kwargs): def find_one(self, *args, **kwargs):
"""Get a single result from the table. """Get a single result from the table.
@ -637,8 +634,8 @@ class Table(object):
if not self.exists: if not self.exists:
return None return None
kwargs['_limit'] = 1 kwargs["_limit"] = 1
kwargs['_step'] = None kwargs["_step"] = None
resiter = self.find(*args, **kwargs) resiter = self.find(*args, **kwargs)
try: try:
for row in resiter: for row in resiter:
@ -692,10 +689,12 @@ class Table(object):
if not len(columns): if not len(columns):
return iter([]) return iter([])
q = expression.select(columns, q = expression.select(
columns,
distinct=True, distinct=True,
whereclause=clause, whereclause=clause,
order_by=[c.asc() for c in columns]) order_by=[c.asc() for c in columns],
)
return self.db.query(q) return self.db.query(q)
# Legacy methods for running find queries. # Legacy methods for running find queries.
@ -715,4 +714,4 @@ class Table(object):
def __repr__(self): def __repr__(self):
"""Get table representation.""" """Get table representation."""
return '<Table(%s)>' % self.table.name return "<Table(%s)>" % self.table.name

View File

@ -8,6 +8,7 @@ from sqlalchemy.types import TypeEngine
class Types(object): class Types(object):
"""A holder class for easy access to SQLAlchemy type names.""" """A holder class for easy access to SQLAlchemy type names."""
integer = Integer integer = Integer
string = Unicode string = Unicode
text = UnicodeText text = UnicodeText

View File

@ -62,17 +62,17 @@ class ResultIter(object):
def normalize_column_name(name): def normalize_column_name(name):
"""Check if a string is a reasonable thing to use as a column name.""" """Check if a string is a reasonable thing to use as a column name."""
if not isinstance(name, str): if not isinstance(name, str):
raise ValueError('%r is not a valid column name.' % name) raise ValueError("%r is not a valid column name." % name)
# limit to 63 characters # limit to 63 characters
name = name.strip()[:63] name = name.strip()[:63]
# column names can be 63 *bytes* max in postgresql # column names can be 63 *bytes* max in postgresql
if isinstance(name, str): if isinstance(name, str):
while len(name.encode('utf-8')) >= 64: while len(name.encode("utf-8")) >= 64:
name = name[: len(name) - 1] name = name[: len(name) - 1]
if not len(name) or '.' in name or '-' in name: if not len(name) or "." in name or "-" in name:
raise ValueError('%r is not a valid column name.' % name) raise ValueError("%r is not a valid column name." % name)
return name return name
@ -80,7 +80,7 @@ def normalize_column_key(name):
"""Return a comparable column name.""" """Return a comparable column name."""
if name is None or not isinstance(name, str): if name is None or not isinstance(name, str):
return None return None
return name.upper().strip().replace(' ', '') return name.upper().strip().replace(" ", "")
def normalize_table_name(name): def normalize_table_name(name):
@ -97,16 +97,16 @@ def safe_url(url):
"""Remove password from printed connection URLs.""" """Remove password from printed connection URLs."""
parsed = urlparse(url) parsed = urlparse(url)
if parsed.password is not None: if parsed.password is not None:
pwd = ':%s@' % parsed.password pwd = ":%s@" % parsed.password
url = url.replace(pwd, ':*****@') url = url.replace(pwd, ":*****@")
return url return url
def index_name(table, columns): def index_name(table, columns):
"""Generate an artificial index name.""" """Generate an artificial index name."""
sig = '||'.join(columns) sig = "||".join(columns)
key = sha1(sig.encode('utf-8')).hexdigest()[:16] key = sha1(sig.encode("utf-8")).hexdigest()[:16]
return 'ix_%s_%s' % (table, key) return "ix_%s_%s" % (table, key)
def pad_chunk_columns(chunk, columns): def pad_chunk_columns(chunk, columns):

View File

@ -1,7 +1,19 @@
# flasky extensions. flasky pygments style based on tango style # flasky extensions. flasky pygments style based on tango style
from pygments.style import Style from pygments.style import Style
from pygments.token import Keyword, Name, Comment, String, Error, \ from pygments.token import (
Number, Operator, Generic, Whitespace, Punctuation, Other, Literal Keyword,
Name,
Comment,
String,
Error,
Number,
Operator,
Generic,
Whitespace,
Punctuation,
Other,
Literal,
)
class FlaskyStyle(Style): class FlaskyStyle(Style):
@ -14,10 +26,8 @@ class FlaskyStyle(Style):
Whitespace: "underline #f8f8f8", # class: 'w' Whitespace: "underline #f8f8f8", # class: 'w'
Error: "#a40000 border:#ef2929", # class: 'err' Error: "#a40000 border:#ef2929", # class: 'err'
Other: "#000000", # class 'x' Other: "#000000", # class 'x'
Comment: "italic #8f5902", # class: 'c' Comment: "italic #8f5902", # class: 'c'
Comment.Preproc: "noitalic", # class: 'cp' Comment.Preproc: "noitalic", # class: 'cp'
Keyword: "bold #004461", # class: 'k' Keyword: "bold #004461", # class: 'k'
Keyword.Constant: "bold #004461", # class: 'kc' Keyword.Constant: "bold #004461", # class: 'kc'
Keyword.Declaration: "bold #004461", # class: 'kd' Keyword.Declaration: "bold #004461", # class: 'kd'
@ -25,12 +35,9 @@ class FlaskyStyle(Style):
Keyword.Pseudo: "bold #004461", # class: 'kp' Keyword.Pseudo: "bold #004461", # class: 'kp'
Keyword.Reserved: "bold #004461", # class: 'kr' Keyword.Reserved: "bold #004461", # class: 'kr'
Keyword.Type: "bold #004461", # class: 'kt' Keyword.Type: "bold #004461", # class: 'kt'
Operator: "#582800", # class: 'o' Operator: "#582800", # class: 'o'
Operator.Word: "bold #004461", # class: 'ow' - like keywords Operator.Word: "bold #004461", # class: 'ow' - like keywords
Punctuation: "bold #000000", # class: 'p' Punctuation: "bold #000000", # class: 'p'
# because special names such as Name.Class, Name.Function, etc. # because special names such as Name.Class, Name.Function, etc.
# are not recognized as such later in the parsing, we choose them # are not recognized as such later in the parsing, we choose them
# to look the same as ordinary variables. # to look the same as ordinary variables.
@ -53,12 +60,9 @@ class FlaskyStyle(Style):
Name.Variable.Class: "#000000", # class: 'vc' - to be revised Name.Variable.Class: "#000000", # class: 'vc' - to be revised
Name.Variable.Global: "#000000", # class: 'vg' - to be revised Name.Variable.Global: "#000000", # class: 'vg' - to be revised
Name.Variable.Instance: "#000000", # class: 'vi' - to be revised Name.Variable.Instance: "#000000", # class: 'vi' - to be revised
Number: "#990000", # class: 'm' Number: "#990000", # class: 'm'
Literal: "#000000", # class: 'l' Literal: "#000000", # class: 'l'
Literal.Date: "#000000", # class: 'ld' Literal.Date: "#000000", # class: 'ld'
String: "#4e9a06", # class: 's' String: "#4e9a06", # class: 's'
String.Backtick: "#4e9a06", # class: 'sb' String.Backtick: "#4e9a06", # class: 'sb'
String.Char: "#4e9a06", # class: 'sc' String.Char: "#4e9a06", # class: 'sc'
@ -71,7 +75,6 @@ class FlaskyStyle(Style):
String.Regex: "#4e9a06", # class: 'sr' String.Regex: "#4e9a06", # class: 'sr'
String.Single: "#4e9a06", # class: 's1' String.Single: "#4e9a06", # class: 's1'
String.Symbol: "#4e9a06", # class: 'ss' String.Symbol: "#4e9a06", # class: 'ss'
Generic: "#000000", # class: 'g' Generic: "#000000", # class: 'g'
Generic.Deleted: "#a40000", # class: 'gd' Generic.Deleted: "#a40000", # class: 'gd'
Generic.Emph: "italic #000000", # class: 'ge' Generic.Emph: "italic #000000", # class: 'ge'

View File

@ -16,7 +16,7 @@ import sys, os
# If extensions (or modules to document with autodoc) are in another directory, # If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the # add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here. # documentation root, use os.path.abspath to make it absolute, like shown here.
sys.path.insert(0, os.path.abspath('../')) sys.path.insert(0, os.path.abspath("../"))
# -- General configuration ----------------------------------------------------- # -- General configuration -----------------------------------------------------
@ -25,32 +25,32 @@ sys.path.insert(0, os.path.abspath('../'))
# Add any Sphinx extension module names here, as strings. They can be extensions # Add any Sphinx extension module names here, as strings. They can be extensions
# coming with Sphinx (named 'sphinx.ext.*') or your custom ones. # coming with Sphinx (named 'sphinx.ext.*') or your custom ones.
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] extensions = ["sphinx.ext.autodoc", "sphinx.ext.viewcode"]
# Add any paths that contain templates here, relative to this directory. # Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates'] templates_path = ["_templates"]
# The suffix of source filenames. # The suffix of source filenames.
source_suffix = '.rst' source_suffix = ".rst"
# The encoding of source files. # The encoding of source files.
# source_encoding = 'utf-8-sig' # source_encoding = 'utf-8-sig'
# The master toctree document. # The master toctree document.
master_doc = 'index' master_doc = "index"
# General information about the project. # General information about the project.
project = u'dataset' project = u"dataset"
copyright = u'2013-2018, Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer' copyright = u"2013-2018, Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer"
# The version info for the project you're documenting, acts as replacement for # The version info for the project you're documenting, acts as replacement for
# |version| and |release|, also used in various other places throughout the # |version| and |release|, also used in various other places throughout the
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = '1.0' version = "1.0"
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = '1.0.8' release = "1.0.8"
# The language for content autogenerated by Sphinx. Refer to documentation # The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages. # for a list of supported languages.
@ -64,7 +64,7 @@ release = '1.0.8'
# List of patterns, relative to source directory, that match files and # List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files. # directories to ignore when looking for source files.
exclude_patterns = ['_build'] exclude_patterns = ["_build"]
# The reST default role (used for this markup: `text`) to use for all documents. # The reST default role (used for this markup: `text`) to use for all documents.
# default_role = None # default_role = None
@ -81,7 +81,7 @@ exclude_patterns = ['_build']
# show_authors = False # show_authors = False
# The name of the Pygments (syntax highlighting) style to use. # The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx' pygments_style = "sphinx"
# A list of ignored prefixes for module index sorting. # A list of ignored prefixes for module index sorting.
# modindex_common_prefix = [] # modindex_common_prefix = []
@ -94,7 +94,7 @@ pygments_style = 'sphinx'
# The theme to use for HTML and HTML Help pages. See the documentation for # The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes. # a list of builtin themes.
html_theme = 'kr' html_theme = "kr"
# Theme options are theme-specific and customize the look and feel of a theme # Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the # further. For a list of options available for each theme, see the
@ -104,7 +104,7 @@ html_theme = 'kr'
# } # }
# Add any paths that contain custom themes here, relative to this directory. # Add any paths that contain custom themes here, relative to this directory.
html_theme_path = ['_themes'] html_theme_path = ["_themes"]
# The name for this set of Sphinx documents. If None, it defaults to # The name for this set of Sphinx documents. If None, it defaults to
# "<project> v<release> documentation". # "<project> v<release> documentation".
@ -125,7 +125,7 @@ html_theme_path = ['_themes']
# Add any paths that contain custom static files (such as style sheets) here, # Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files, # relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css". # so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static'] html_static_path = ["_static"]
# If not '', a 'Last updated on:' timestamp is inserted at every page bottom, # If not '', a 'Last updated on:' timestamp is inserted at every page bottom,
# using the given strftime format. # using the given strftime format.
@ -137,9 +137,9 @@ html_static_path = ['_static']
# Custom sidebar templates, maps document names to template names. # Custom sidebar templates, maps document names to template names.
html_sidebars = { html_sidebars = {
'index': ['sidebarlogo.html', 'sourcelink.html', 'searchbox.html'], "index": ["sidebarlogo.html", "sourcelink.html", "searchbox.html"],
'api': ['sidebarlogo.html', 'autotoc.html', 'sourcelink.html', 'searchbox.html'], "api": ["sidebarlogo.html", "autotoc.html", "sourcelink.html", "searchbox.html"],
'**': ['sidebarlogo.html', 'localtoc.html', 'sourcelink.html', 'searchbox.html'] "**": ["sidebarlogo.html", "localtoc.html", "sourcelink.html", "searchbox.html"],
} }
# Additional templates that should be rendered to pages, maps page names to # Additional templates that should be rendered to pages, maps page names to
@ -173,7 +173,7 @@ html_show_sourcelink = False
# html_file_suffix = None # html_file_suffix = None
# Output file base name for HTML help builder. # Output file base name for HTML help builder.
htmlhelp_basename = 'datasetdoc' htmlhelp_basename = "datasetdoc"
# -- Options for LaTeX output -------------------------------------------------- # -- Options for LaTeX output --------------------------------------------------
@ -181,10 +181,8 @@ htmlhelp_basename = 'datasetdoc'
latex_elements = { latex_elements = {
# The paper size ('letterpaper' or 'a4paper'). # The paper size ('letterpaper' or 'a4paper').
#'papersize': 'letterpaper', #'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt'). # The font size ('10pt', '11pt' or '12pt').
#'pointsize': '10pt', #'pointsize': '10pt',
# Additional stuff for the LaTeX preamble. # Additional stuff for the LaTeX preamble.
#'preamble': '', #'preamble': '',
} }
@ -192,8 +190,13 @@ latex_elements = {
# Grouping the document tree into LaTeX files. List of tuples # Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title, author, documentclass [howto/manual]). # (source start file, target name, title, author, documentclass [howto/manual]).
latex_documents = [ latex_documents = [
('index', 'dataset.tex', u'dataset Documentation', (
u'Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer', 'manual'), "index",
"dataset.tex",
u"dataset Documentation",
u"Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer",
"manual",
),
] ]
# The name of an image file (relative to this directory) to place at the top of # The name of an image file (relative to this directory) to place at the top of
@ -222,8 +225,13 @@ latex_documents = [
# One entry per manual page. List of tuples # One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section). # (source start file, name, description, authors, manual section).
man_pages = [ man_pages = [
('index', 'dataset', u'dataset Documentation', (
[u'Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer'], 1) "index",
"dataset",
u"dataset Documentation",
[u"Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer"],
1,
)
] ]
# If true, show URL addresses after external links. # If true, show URL addresses after external links.
@ -236,9 +244,15 @@ man_pages = [
# (source start file, target name, title, author, # (source start file, target name, title, author,
# dir menu entry, description, category) # dir menu entry, description, category)
texinfo_documents = [ texinfo_documents = [
('index', 'dataset', u'dataset Documentation', (
u'Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer', 'dataset', 'Databases for lazy people.', "index",
'Miscellaneous'), "dataset",
u"dataset Documentation",
u"Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer",
"dataset",
"Databases for lazy people.",
"Miscellaneous",
),
] ]
# Documents to append as an appendix to all manuals. # Documents to append as an appendix to all manuals.

View File

@ -1,52 +1,46 @@
from setuptools import setup, find_packages from setuptools import setup, find_packages
with open('README.md') as f: with open("README.md") as f:
long_description = f.read() long_description = f.read()
setup( setup(
name='dataset', name="dataset",
version='1.3.2', version="1.3.2",
description="Toolkit for Python-based database access.", description="Toolkit for Python-based database access.",
long_description=long_description, long_description=long_description,
long_description_content_type='text/markdown', long_description_content_type="text/markdown",
classifiers=[ classifiers=[
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Intended Audience :: Developers", "Intended Audience :: Developers",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent",
'Programming Language :: Python :: 3.5', "Programming Language :: Python :: 3.5",
'Programming Language :: Python :: 3.6', "Programming Language :: Python :: 3.6",
'Programming Language :: Python :: 3.7', "Programming Language :: Python :: 3.7",
'Programming Language :: Python :: 3.8', "Programming Language :: Python :: 3.8",
], ],
keywords='sql sqlalchemy etl loading utility', keywords="sql sqlalchemy etl loading utility",
author='Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer', author="Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer",
author_email='friedrich.lindenberg@gmail.com', author_email="friedrich.lindenberg@gmail.com",
url='http://github.com/pudo/dataset', url="http://github.com/pudo/dataset",
license='MIT', license="MIT",
packages=find_packages(exclude=['ez_setup', 'examples', 'test']), packages=find_packages(exclude=["ez_setup", "examples", "test"]),
namespace_packages=[], namespace_packages=[],
include_package_data=False, include_package_data=False,
zip_safe=False, zip_safe=False,
install_requires=[ install_requires=["sqlalchemy >= 1.3.2", "alembic >= 0.6.2", "banal >= 1.0.1",],
'sqlalchemy >= 1.3.2',
'alembic >= 0.6.2',
'banal >= 1.0.1',
],
extras_require={ extras_require={
'dev': [ "dev": [
'pip', "pip",
'nose', "nose",
'wheel', "wheel",
'flake8', "flake8",
'coverage', "coverage",
'psycopg2-binary', "psycopg2-binary",
'PyMySQL', "PyMySQL",
] ]
}, },
tests_require=[ tests_require=["nose"],
'nose' test_suite="test",
], entry_points={},
test_suite='test',
entry_points={}
) )

View File

@ -4,38 +4,14 @@ from __future__ import unicode_literals
from datetime import datetime from datetime import datetime
TEST_CITY_1 = 'B€rkeley' TEST_CITY_1 = "B€rkeley"
TEST_CITY_2 = 'G€lway' TEST_CITY_2 = "G€lway"
TEST_DATA = [ TEST_DATA = [
{ {"date": datetime(2011, 1, 1), "temperature": 1, "place": TEST_CITY_2},
'date': datetime(2011, 1, 1), {"date": datetime(2011, 1, 2), "temperature": -1, "place": TEST_CITY_2},
'temperature': 1, {"date": datetime(2011, 1, 3), "temperature": 0, "place": TEST_CITY_2},
'place': TEST_CITY_2 {"date": datetime(2011, 1, 1), "temperature": 6, "place": TEST_CITY_1},
}, {"date": datetime(2011, 1, 2), "temperature": 8, "place": TEST_CITY_1},
{ {"date": datetime(2011, 1, 3), "temperature": 5, "place": TEST_CITY_1},
'date': datetime(2011, 1, 2),
'temperature': -1,
'place': TEST_CITY_2
},
{
'date': datetime(2011, 1, 3),
'temperature': 0,
'place': TEST_CITY_2
},
{
'date': datetime(2011, 1, 1),
'temperature': 6,
'place': TEST_CITY_1
},
{
'date': datetime(2011, 1, 2),
'temperature': 8,
'place': TEST_CITY_1
},
{
'date': datetime(2011, 1, 3),
'temperature': 5,
'place': TEST_CITY_1
}
] ]

View File

@ -11,10 +11,9 @@ from .sample_data import TEST_DATA, TEST_CITY_1
class DatabaseTestCase(unittest.TestCase): class DatabaseTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.db = connect() self.db = connect()
self.tbl = self.db['weather'] self.tbl = self.db["weather"]
self.tbl.insert_many(TEST_DATA) self.tbl.insert_many(TEST_DATA)
def tearDown(self): def tearDown(self):
@ -22,28 +21,28 @@ class DatabaseTestCase(unittest.TestCase):
self.db[table].drop() self.db[table].drop()
def test_valid_database_url(self): def test_valid_database_url(self):
assert self.db.url, os.environ['DATABASE_URL'] assert self.db.url, os.environ["DATABASE_URL"]
def test_database_url_query_string(self): def test_database_url_query_string(self):
db = connect('sqlite:///:memory:/?cached_statements=1') db = connect("sqlite:///:memory:/?cached_statements=1")
assert 'cached_statements' in db.url, db.url assert "cached_statements" in db.url, db.url
def test_tables(self): def test_tables(self):
assert self.db.tables == ['weather'], self.db.tables assert self.db.tables == ["weather"], self.db.tables
def test_contains(self): def test_contains(self):
assert 'weather' in self.db, self.db.tables assert "weather" in self.db, self.db.tables
def test_create_table(self): def test_create_table(self):
table = self.db['foo'] table = self.db["foo"]
assert table.table.exists() assert table.table.exists()
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert 'id' in table.table.c, table.table.c assert "id" in table.table.c, table.table.c
def test_create_table_no_ids(self): def test_create_table_no_ids(self):
if 'mysql' in self.db.engine.dialect.dbapi.__name__: if "mysql" in self.db.engine.dialect.dbapi.__name__:
return return
if 'sqlite' in self.db.engine.dialect.dbapi.__name__: if "sqlite" in self.db.engine.dialect.dbapi.__name__:
return return
table = self.db.create_table("foo_no_id", primary_id=False) table = self.db.create_table("foo_no_id", primary_id=False)
assert table.table.exists() assert table.table.exists()
@ -55,8 +54,8 @@ class DatabaseTestCase(unittest.TestCase):
assert table.table.exists() assert table.table.exists()
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({pid: 'foobar'}) table.insert({pid: "foobar"})
assert table.find_one(string_id='foobar')[pid] == 'foobar' assert table.find_one(string_id="foobar")[pid] == "foobar"
def test_create_table_custom_id2(self): def test_create_table_custom_id2(self):
pid = "string_id" pid = "string_id"
@ -65,8 +64,8 @@ class DatabaseTestCase(unittest.TestCase):
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({pid: 'foobar'}) table.insert({pid: "foobar"})
assert table.find_one(string_id='foobar')[pid] == 'foobar' assert table.find_one(string_id="foobar")[pid] == "foobar"
def test_create_table_custom_id3(self): def test_create_table_custom_id3(self):
pid = "int_id" pid = "int_id"
@ -83,253 +82,235 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_shorthand1(self): def test_create_table_shorthand1(self):
pid = "int_id" pid = "int_id"
table = self.db.get_table('foo5', pid) table = self.db.get_table("foo5", pid)
assert table.table.exists assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({'int_id': 123}) table.insert({"int_id": 123})
table.insert({'int_id': 124}) table.insert({"int_id": 124})
assert table.find_one(int_id=123)['int_id'] == 123 assert table.find_one(int_id=123)["int_id"] == 123
assert table.find_one(int_id=124)['int_id'] == 124 assert table.find_one(int_id=124)["int_id"] == 124
self.assertRaises(IntegrityError, self.assertRaises(IntegrityError, lambda: table.insert({"int_id": 123}))
lambda: table.insert({'int_id': 123}))
def test_create_table_shorthand2(self): def test_create_table_shorthand2(self):
pid = "string_id" pid = "string_id"
table = self.db.get_table('foo6', primary_id=pid, table = self.db.get_table(
primary_type=self.db.types.string(255)) "foo6", primary_id=pid, primary_type=self.db.types.string(255)
)
assert table.table.exists assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
table.insert({ table.insert({"string_id": "foobar"})
'string_id': 'foobar'}) assert table.find_one(string_id="foobar")["string_id"] == "foobar"
assert table.find_one(string_id='foobar')['string_id'] == 'foobar'
def test_with(self): def test_with(self):
init_length = len(self.db['weather']) init_length = len(self.db["weather"])
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
with self.db as tx: with self.db as tx:
tx['weather'].insert({'date': datetime(2011, 1, 1), tx["weather"].insert(
'temperature': 1, {
'place': 'tmp_place'}) "date": datetime(2011, 1, 1),
"temperature": 1,
"place": "tmp_place",
}
)
raise ValueError() raise ValueError()
assert len(self.db['weather']) == init_length assert len(self.db["weather"]) == init_length
def test_invalid_values(self): def test_invalid_values(self):
if 'mysql' in self.db.engine.dialect.dbapi.__name__: if "mysql" in self.db.engine.dialect.dbapi.__name__:
# WARNING: mysql seems to be doing some weird type casting # WARNING: mysql seems to be doing some weird type casting
# upon insert. The mysql-python driver is not affected but # upon insert. The mysql-python driver is not affected but
# it isn't compatible with Python 3 # it isn't compatible with Python 3
# Conclusion: use postgresql. # Conclusion: use postgresql.
return return
with self.assertRaises(SQLAlchemyError): with self.assertRaises(SQLAlchemyError):
tbl = self.db['weather'] tbl = self.db["weather"]
tbl.insert({ tbl.insert(
'date': True, {"date": True, "temperature": "wrong_value", "place": "tmp_place"}
'temperature': 'wrong_value', )
'place': 'tmp_place'
})
def test_load_table(self): def test_load_table(self):
tbl = self.db.load_table('weather') tbl = self.db.load_table("weather")
assert tbl.table.name == self.tbl.table.name assert tbl.table.name == self.tbl.table.name
def test_query(self): def test_query(self):
r = self.db.query('SELECT COUNT(*) AS num FROM weather').next() r = self.db.query("SELECT COUNT(*) AS num FROM weather").next()
assert r['num'] == len(TEST_DATA), r assert r["num"] == len(TEST_DATA), r
def test_table_cache_updates(self): def test_table_cache_updates(self):
tbl1 = self.db.get_table('people') tbl1 = self.db.get_table("people")
data = OrderedDict([('first_name', 'John'), ('last_name', 'Smith')]) data = OrderedDict([("first_name", "John"), ("last_name", "Smith")])
tbl1.insert(data) tbl1.insert(data)
data['id'] = 1 data["id"] = 1
tbl2 = self.db.get_table('people') tbl2 = self.db.get_table("people")
assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data) assert dict(tbl2.all().next()) == dict(data), (tbl2.all().next(), data)
class TableTestCase(unittest.TestCase): class TableTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.db = connect() self.db = connect()
self.tbl = self.db['weather'] self.tbl = self.db["weather"]
for row in TEST_DATA: for row in TEST_DATA:
self.tbl.insert(row) self.tbl.insert(row)
def test_insert(self): def test_insert(self):
assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
last_id = self.tbl.insert({ last_id = self.tbl.insert(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
'temperature': -10,
'place': 'Berlin'}
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
assert self.tbl.find_one(id=last_id)['place'] == 'Berlin' assert self.tbl.find_one(id=last_id)["place"] == "Berlin"
def test_insert_ignore(self): def test_insert_ignore(self):
self.tbl.insert_ignore({ self.tbl.insert_ignore(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
'temperature': -10, ["place"],
'place': 'Berlin'},
['place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
self.tbl.insert_ignore({ self.tbl.insert_ignore(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
'temperature': -10, ["place"],
'place': 'Berlin'},
['place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_insert_ignore_all_key(self): def test_insert_ignore_all_key(self):
for i in range(0, 4): for i in range(0, 4):
self.tbl.insert_ignore({ self.tbl.insert_ignore(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
'temperature': -10, ["date", "temperature", "place"],
'place': 'Berlin'},
['date', 'temperature', 'place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_insert_json(self): def test_insert_json(self):
last_id = self.tbl.insert({ last_id = self.tbl.insert(
'date': datetime(2011, 1, 2), {
'temperature': -10, "date": datetime(2011, 1, 2),
'place': 'Berlin', "temperature": -10,
'info': { "place": "Berlin",
'currency': 'EUR', "info": {
'language': 'German', "currency": "EUR",
'population': 3292365 "language": "German",
"population": 3292365,
},
} }
})
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
assert self.tbl.find_one(id=last_id)['place'] == 'Berlin'
def test_upsert(self):
self.tbl.upsert({
'date': datetime(2011, 1, 2),
'temperature': -10,
'place': 'Berlin'},
['place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
self.tbl.upsert({ assert self.tbl.find_one(id=last_id)["place"] == "Berlin"
'date': datetime(2011, 1, 2),
'temperature': -10, def test_upsert(self):
'place': 'Berlin'}, self.tbl.upsert(
['place'] {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
["place"],
)
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
self.tbl.upsert(
{"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
["place"],
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_upsert_single_column(self): def test_upsert_single_column(self):
table = self.db['banana_single_col'] table = self.db["banana_single_col"]
table.upsert({ table.upsert({"color": "Yellow"}, ["color"])
'color': 'Yellow'},
['color']
)
assert len(table) == 1, len(table) assert len(table) == 1, len(table)
table.upsert({ table.upsert({"color": "Yellow"}, ["color"])
'color': 'Yellow'},
['color']
)
assert len(table) == 1, len(table) assert len(table) == 1, len(table)
def test_upsert_all_key(self): def test_upsert_all_key(self):
assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
for i in range(0, 2): for i in range(0, 2):
self.tbl.upsert({ self.tbl.upsert(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"},
'temperature': -10, ["date", "temperature", "place"],
'place': 'Berlin'},
['date', 'temperature', 'place']
) )
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
def test_upsert_id(self): def test_upsert_id(self):
table = self.db['banana_with_id'] table = self.db["banana_with_id"]
data = dict(id=10, title='I am a banana!') data = dict(id=10, title="I am a banana!")
table.upsert(data, ['id']) table.upsert(data, ["id"])
assert len(table) == 1, len(table) assert len(table) == 1, len(table)
def test_update_while_iter(self): def test_update_while_iter(self):
for row in self.tbl: for row in self.tbl:
row['foo'] = 'bar' row["foo"] = "bar"
self.tbl.update(row, ['place', 'date']) self.tbl.update(row, ["place", "date"])
assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
def test_weird_column_names(self): def test_weird_column_names(self):
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
self.tbl.insert({ self.tbl.insert(
'date': datetime(2011, 1, 2), {
'temperature': -10, "date": datetime(2011, 1, 2),
'foo.bar': 'Berlin', "temperature": -10,
'qux.bar': 'Huhu' "foo.bar": "Berlin",
}) "qux.bar": "Huhu",
}
)
def test_cased_column_names(self): def test_cased_column_names(self):
tbl = self.db['cased_column_names'] tbl = self.db["cased_column_names"]
tbl.insert({ tbl.insert(
'place': 'Berlin', {"place": "Berlin",}
}) )
tbl.insert({ tbl.insert(
'Place': 'Berlin', {"Place": "Berlin",}
}) )
tbl.insert({ tbl.insert(
'PLACE ': 'Berlin', {"PLACE ": "Berlin",}
}) )
assert len(tbl.columns) == 2, tbl.columns assert len(tbl.columns) == 2, tbl.columns
assert len(list(tbl.find(Place='Berlin'))) == 3 assert len(list(tbl.find(Place="Berlin"))) == 3
assert len(list(tbl.find(place='Berlin'))) == 3 assert len(list(tbl.find(place="Berlin"))) == 3
assert len(list(tbl.find(PLACE='Berlin'))) == 3 assert len(list(tbl.find(PLACE="Berlin"))) == 3
def test_invalid_column_names(self): def test_invalid_column_names(self):
tbl = self.db['weather'] tbl = self.db["weather"]
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tbl.insert({None: 'banana'}) tbl.insert({None: "banana"})
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tbl.insert({'': 'banana'}) tbl.insert({"": "banana"})
with self.assertRaises(ValueError): with self.assertRaises(ValueError):
tbl.insert({'-': 'banana'}) tbl.insert({"-": "banana"})
def test_delete(self): def test_delete(self):
self.tbl.insert({ self.tbl.insert(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
'temperature': -10,
'place': 'Berlin'}
) )
original_count = len(self.tbl) original_count = len(self.tbl)
assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl) assert len(self.tbl) == len(TEST_DATA) + 1, len(self.tbl)
# Test bad use of API # Test bad use of API
with self.assertRaises(ArgumentError): with self.assertRaises(ArgumentError):
self.tbl.delete({'place': 'Berlin'}) self.tbl.delete({"place": "Berlin"})
assert len(self.tbl) == original_count, len(self.tbl) assert len(self.tbl) == original_count, len(self.tbl)
assert self.tbl.delete(place='Berlin') is True, 'should return 1' assert self.tbl.delete(place="Berlin") is True, "should return 1"
assert len(self.tbl) == len(TEST_DATA), len(self.tbl) assert len(self.tbl) == len(TEST_DATA), len(self.tbl)
assert self.tbl.delete() is True, 'should return non zero' assert self.tbl.delete() is True, "should return non zero"
assert len(self.tbl) == 0, len(self.tbl) assert len(self.tbl) == 0, len(self.tbl)
def test_repr(self): def test_repr(self):
assert repr(self.tbl) == '<Table(weather)>', \ assert (
'the representation should be <Table(weather)>' repr(self.tbl) == "<Table(weather)>"
), "the representation should be <Table(weather)>"
def test_delete_nonexist_entry(self): def test_delete_nonexist_entry(self):
assert self.tbl.delete(place='Berlin') is False, \ assert (
'entry not exist, should fail to delete' self.tbl.delete(place="Berlin") is False
), "entry not exist, should fail to delete"
def test_find_one(self): def test_find_one(self):
self.tbl.insert({ self.tbl.insert(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
'temperature': -10, )
'place': 'Berlin' d = self.tbl.find_one(place="Berlin")
}) assert d["temperature"] == -10, d
d = self.tbl.find_one(place='Berlin') d = self.tbl.find_one(place="Atlantis")
assert d['temperature'] == -10, d
d = self.tbl.find_one(place='Atlantis')
assert d is None, d assert d is None, d
def test_count(self): def test_count(self):
@ -348,31 +329,31 @@ class TableTestCase(unittest.TestCase):
assert len(ds) == 1, ds assert len(ds) == 1, ds
ds = list(self.tbl.find(_step=2)) ds = list(self.tbl.find(_step=2))
assert len(ds) == len(TEST_DATA), ds assert len(ds) == len(TEST_DATA), ds
ds = list(self.tbl.find(order_by=['temperature'])) ds = list(self.tbl.find(order_by=["temperature"]))
assert ds[0]['temperature'] == -1, ds assert ds[0]["temperature"] == -1, ds
ds = list(self.tbl.find(order_by=['-temperature'])) ds = list(self.tbl.find(order_by=["-temperature"]))
assert ds[0]['temperature'] == 8, ds assert ds[0]["temperature"] == 8, ds
ds = list(self.tbl.find(self.tbl.table.columns.temperature > 4)) ds = list(self.tbl.find(self.tbl.table.columns.temperature > 4))
assert len(ds) == 3, ds assert len(ds) == 3, ds
def test_find_dsl(self): def test_find_dsl(self):
ds = list(self.tbl.find(place={'like': '%lw%'})) ds = list(self.tbl.find(place={"like": "%lw%"}))
assert len(ds) == 3, ds assert len(ds) == 3, ds
ds = list(self.tbl.find(temperature={'>': 5})) ds = list(self.tbl.find(temperature={">": 5}))
assert len(ds) == 2, ds assert len(ds) == 2, ds
ds = list(self.tbl.find(temperature={'>=': 5})) ds = list(self.tbl.find(temperature={">=": 5}))
assert len(ds) == 3, ds assert len(ds) == 3, ds
ds = list(self.tbl.find(temperature={'<': 0})) ds = list(self.tbl.find(temperature={"<": 0}))
assert len(ds) == 1, ds assert len(ds) == 1, ds
ds = list(self.tbl.find(temperature={'<=': 0})) ds = list(self.tbl.find(temperature={"<=": 0}))
assert len(ds) == 2, ds assert len(ds) == 2, ds
ds = list(self.tbl.find(temperature={'!=': -1})) ds = list(self.tbl.find(temperature={"!=": -1}))
assert len(ds) == 5, ds assert len(ds) == 5, ds
ds = list(self.tbl.find(temperature={'between': [5, 8]})) ds = list(self.tbl.find(temperature={"between": [5, 8]}))
assert len(ds) == 3, ds assert len(ds) == 3, ds
ds = list(self.tbl.find(place={'=': 'G€lway'})) ds = list(self.tbl.find(place={"=": "G€lway"}))
assert len(ds) == 3, ds assert len(ds) == 3, ds
ds = list(self.tbl.find(place={'ilike': '%LwAy'})) ds = list(self.tbl.find(place={"ilike": "%LwAy"}))
assert len(ds) == 3, ds assert len(ds) == 3, ds
def test_offset(self): def test_offset(self):
@ -385,23 +366,26 @@ class TableTestCase(unittest.TestCase):
ds = list(self.tbl.find(place=TEST_CITY_1, _streamed=True, _step=1)) ds = list(self.tbl.find(place=TEST_CITY_1, _streamed=True, _step=1))
assert len(ds) == 3, len(ds) assert len(ds) == 3, len(ds)
for row in self.tbl.find(place=TEST_CITY_1, _streamed=True, _step=1): for row in self.tbl.find(place=TEST_CITY_1, _streamed=True, _step=1):
row['temperature'] = -1 row["temperature"] = -1
self.tbl.update(row, ['id']) self.tbl.update(row, ["id"])
def test_distinct(self): def test_distinct(self):
x = list(self.tbl.distinct('place')) x = list(self.tbl.distinct("place"))
assert len(x) == 2, x assert len(x) == 2, x
x = list(self.tbl.distinct('place', 'date')) x = list(self.tbl.distinct("place", "date"))
assert len(x) == 6, x assert len(x) == 6, x
x = list(self.tbl.distinct( x = list(
'place', 'date', self.tbl.distinct(
self.tbl.table.columns.date >= datetime(2011, 1, 2, 0, 0))) "place",
"date",
self.tbl.table.columns.date >= datetime(2011, 1, 2, 0, 0),
)
)
assert len(x) == 4, x assert len(x) == 4, x
x = list(self.tbl.distinct('temperature', place='B€rkeley')) x = list(self.tbl.distinct("temperature", place="B€rkeley"))
assert len(x) == 3, x assert len(x) == 3, x
x = list(self.tbl.distinct('temperature', x = list(self.tbl.distinct("temperature", place=["B€rkeley", "G€lway"]))
place=['B€rkeley', 'G€lway']))
assert len(x) == 6, x assert len(x) == 6, x
def test_insert_many(self): def test_insert_many(self):
@ -423,6 +407,7 @@ class TableTestCase(unittest.TestCase):
def callback(queue): def callback(queue):
nonlocal N nonlocal N
N += len(queue) N += len(queue)
with chunked.ChunkedInsert(self.tbl, callback=callback) as chunk_tbl: with chunked.ChunkedInsert(self.tbl, callback=callback) as chunk_tbl:
for item in data: for item in data:
chunk_tbl.insert(item) chunk_tbl.insert(item)
@ -430,74 +415,73 @@ class TableTestCase(unittest.TestCase):
assert len(self.tbl) == len(data) + 6 assert len(self.tbl) == len(data) + 6
def test_update_many(self): def test_update_many(self):
tbl = self.db['update_many_test'] tbl = self.db["update_many_test"]
tbl.insert_many([ tbl.insert_many([dict(temp=10), dict(temp=20), dict(temp=30)])
dict(temp=10), dict(temp=20), dict(temp=30) tbl.update_many([dict(id=1, temp=50), dict(id=3, temp=50)], "id")
])
tbl.update_many(
[dict(id=1, temp=50), dict(id=3, temp=50)], 'id'
)
# Ensure data has been updated. # Ensure data has been updated.
assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp'] assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"]
def test_chunked_update(self): def test_chunked_update(self):
tbl = self.db['update_many_test'] tbl = self.db["update_many_test"]
tbl.insert_many([ tbl.insert_many(
dict(temp=10, location='asdf'), dict(temp=20, location='qwer'), [
dict(temp=30, location='asdf') dict(temp=10, location="asdf"),
]) dict(temp=20, location="qwer"),
dict(temp=30, location="asdf"),
]
)
chunked_tbl = chunked.ChunkedUpdate(tbl, 'id') chunked_tbl = chunked.ChunkedUpdate(tbl, "id")
chunked_tbl.update(dict(id=1, temp=50)) chunked_tbl.update(dict(id=1, temp=50))
chunked_tbl.update(dict(id=2, location='asdf')) chunked_tbl.update(dict(id=2, location="asdf"))
chunked_tbl.update(dict(id=3, temp=50)) chunked_tbl.update(dict(id=3, temp=50))
chunked_tbl.flush() chunked_tbl.flush()
# Ensure data has been updated. # Ensure data has been updated.
assert tbl.find_one(id=1)['temp'] == tbl.find_one(id=3)['temp'] == 50 assert tbl.find_one(id=1)["temp"] == tbl.find_one(id=3)["temp"] == 50
assert tbl.find_one(id=2)['location'] == tbl.find_one(id=3)['location'] == 'asdf' # noqa assert (
tbl.find_one(id=2)["location"] == tbl.find_one(id=3)["location"] == "asdf"
) # noqa
def test_upsert_many(self): def test_upsert_many(self):
# Also tests updating on records with different attributes # Also tests updating on records with different attributes
tbl = self.db['upsert_many_test'] tbl = self.db["upsert_many_test"]
W = 100 W = 100
tbl.upsert_many([dict(age=10), dict(weight=W)], 'id') tbl.upsert_many([dict(age=10), dict(weight=W)], "id")
assert tbl.find_one(id=1)['age'] == 10 assert tbl.find_one(id=1)["age"] == 10
tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], 'id') tbl.upsert_many([dict(id=1, age=70), dict(id=2, weight=W / 2)], "id")
assert tbl.find_one(id=2)['weight'] == W / 2 assert tbl.find_one(id=2)["weight"] == W / 2
def test_drop_operations(self): def test_drop_operations(self):
assert self.tbl._table is not None, \ assert self.tbl._table is not None, "table shouldn't be dropped yet"
'table shouldn\'t be dropped yet'
self.tbl.drop() self.tbl.drop()
assert self.tbl._table is None, \ assert self.tbl._table is None, "table should be dropped now"
'table should be dropped now'
assert list(self.tbl.all()) == [], self.tbl.all() assert list(self.tbl.all()) == [], self.tbl.all()
assert self.tbl.count() == 0, self.tbl.count() assert self.tbl.count() == 0, self.tbl.count()
def test_table_drop(self): def test_table_drop(self):
assert 'weather' in self.db assert "weather" in self.db
self.db['weather'].drop() self.db["weather"].drop()
assert 'weather' not in self.db assert "weather" not in self.db
def test_table_drop_then_create(self): def test_table_drop_then_create(self):
assert 'weather' in self.db assert "weather" in self.db
self.db['weather'].drop() self.db["weather"].drop()
assert 'weather' not in self.db assert "weather" not in self.db
self.db['weather'].insert({'foo': 'bar'}) self.db["weather"].insert({"foo": "bar"})
def test_columns(self): def test_columns(self):
cols = self.tbl.columns cols = self.tbl.columns
assert len(list(cols)) == 4, 'column count mismatch' assert len(list(cols)) == 4, "column count mismatch"
assert 'date' in cols and 'temperature' in cols and 'place' in cols assert "date" in cols and "temperature" in cols and "place" in cols
def test_drop_column(self): def test_drop_column(self):
try: try:
self.tbl.drop_column('date') self.tbl.drop_column("date")
assert 'date' not in self.tbl.columns assert "date" not in self.tbl.columns
except RuntimeError: except RuntimeError:
pass pass
@ -509,71 +493,67 @@ class TableTestCase(unittest.TestCase):
def test_update(self): def test_update(self):
date = datetime(2011, 1, 2) date = datetime(2011, 1, 2)
res = self.tbl.update({ res = self.tbl.update(
'date': date, {"date": date, "temperature": -10, "place": TEST_CITY_1}, ["place", "date"]
'temperature': -10,
'place': TEST_CITY_1},
['place', 'date']
) )
assert res, 'update should return True' assert res, "update should return True"
m = self.tbl.find_one(place=TEST_CITY_1, date=date) m = self.tbl.find_one(place=TEST_CITY_1, date=date)
assert m['temperature'] == -10, \ assert m["temperature"] == -10, (
'new temp. should be -10 but is %d' % m['temperature'] "new temp. should be -10 but is %d" % m["temperature"]
)
def test_create_column(self): def test_create_column(self):
tbl = self.tbl tbl = self.tbl
tbl.create_column('foo', FLOAT) tbl.create_column("foo", FLOAT)
assert 'foo' in tbl.table.c, tbl.table.c assert "foo" in tbl.table.c, tbl.table.c
assert isinstance(tbl.table.c['foo'].type, FLOAT), \ assert isinstance(tbl.table.c["foo"].type, FLOAT), tbl.table.c["foo"].type
tbl.table.c['foo'].type assert "foo" in tbl.columns, tbl.columns
assert 'foo' in tbl.columns, tbl.columns
def test_ensure_column(self): def test_ensure_column(self):
tbl = self.tbl tbl = self.tbl
tbl.create_column_by_example('foo', 0.1) tbl.create_column_by_example("foo", 0.1)
assert 'foo' in tbl.table.c, tbl.table.c assert "foo" in tbl.table.c, tbl.table.c
assert isinstance(tbl.table.c['foo'].type, FLOAT), \ assert isinstance(tbl.table.c["foo"].type, FLOAT), tbl.table.c["bar"].type
tbl.table.c['bar'].type tbl.create_column_by_example("bar", 1)
tbl.create_column_by_example('bar', 1) assert "bar" in tbl.table.c, tbl.table.c
assert 'bar' in tbl.table.c, tbl.table.c assert isinstance(tbl.table.c["bar"].type, BIGINT), tbl.table.c["bar"].type
assert isinstance(tbl.table.c['bar'].type, BIGINT), \ tbl.create_column_by_example("pippo", "test")
tbl.table.c['bar'].type assert "pippo" in tbl.table.c, tbl.table.c
tbl.create_column_by_example('pippo', 'test') assert isinstance(tbl.table.c["pippo"].type, TEXT), tbl.table.c["pippo"].type
assert 'pippo' in tbl.table.c, tbl.table.c tbl.create_column_by_example("bigbar", 11111111111)
assert isinstance(tbl.table.c['pippo'].type, TEXT), \ assert "bigbar" in tbl.table.c, tbl.table.c
tbl.table.c['pippo'].type assert isinstance(tbl.table.c["bigbar"].type, BIGINT), tbl.table.c[
tbl.create_column_by_example('bigbar', 11111111111) "bigbar"
assert 'bigbar' in tbl.table.c, tbl.table.c ].type
assert isinstance(tbl.table.c['bigbar'].type, BIGINT), \ tbl.create_column_by_example("littlebar", -11111111111)
tbl.table.c['bigbar'].type assert "littlebar" in tbl.table.c, tbl.table.c
tbl.create_column_by_example('littlebar', -11111111111) assert isinstance(tbl.table.c["littlebar"].type, BIGINT), tbl.table.c[
assert 'littlebar' in tbl.table.c, tbl.table.c "littlebar"
assert isinstance(tbl.table.c['littlebar'].type, BIGINT), \ ].type
tbl.table.c['littlebar'].type
def test_key_order(self): def test_key_order(self):
res = self.db.query('SELECT temperature, place FROM weather LIMIT 1') res = self.db.query("SELECT temperature, place FROM weather LIMIT 1")
keys = list(res.next().keys()) keys = list(res.next().keys())
assert keys[0] == 'temperature' assert keys[0] == "temperature"
assert keys[1] == 'place' assert keys[1] == "place"
def test_empty_query(self): def test_empty_query(self):
empty = list(self.tbl.find(place='not in data')) empty = list(self.tbl.find(place="not in data"))
assert len(empty) == 0, empty assert len(empty) == 0, empty
class Constructor(dict): class Constructor(dict):
""" Very simple low-functionality extension to ``dict`` to """ Very simple low-functionality extension to ``dict`` to
provide attribute access to dictionary contents""" provide attribute access to dictionary contents"""
def __getattr__(self, name): def __getattr__(self, name):
return self[name] return self[name]
class RowTypeTestCase(unittest.TestCase): class RowTypeTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
self.db = connect(row_type=Constructor) self.db = connect(row_type=Constructor)
self.tbl = self.db['weather'] self.tbl = self.db["weather"]
for row in TEST_DATA: for row in TEST_DATA:
self.tbl.insert(row) self.tbl.insert(row)
@ -582,15 +562,13 @@ class RowTypeTestCase(unittest.TestCase):
self.db[table].drop() self.db[table].drop()
def test_find_one(self): def test_find_one(self):
self.tbl.insert({ self.tbl.insert(
'date': datetime(2011, 1, 2), {"date": datetime(2011, 1, 2), "temperature": -10, "place": "Berlin"}
'temperature': -10,
'place': 'Berlin'}
) )
d = self.tbl.find_one(place='Berlin') d = self.tbl.find_one(place="Berlin")
assert d['temperature'] == -10, d assert d["temperature"] == -10, d
assert d.temperature == -10, d assert d.temperature == -10, d
d = self.tbl.find_one(place='Atlantis') d = self.tbl.find_one(place="Atlantis")
assert d is None, d assert d is None, d
def test_find(self): def test_find(self):
@ -604,11 +582,11 @@ class RowTypeTestCase(unittest.TestCase):
assert isinstance(item, Constructor), item assert isinstance(item, Constructor), item
def test_distinct(self): def test_distinct(self):
x = list(self.tbl.distinct('place')) x = list(self.tbl.distinct("place"))
assert len(x) == 2, x assert len(x) == 2, x
for item in x: for item in x:
assert isinstance(item, Constructor), item assert isinstance(item, Constructor), item
x = list(self.tbl.distinct('place', 'date')) x = list(self.tbl.distinct("place", "date"))
assert len(x) == 6, x assert len(x) == 6, x
for item in x: for item in x:
assert isinstance(item, Constructor), item assert isinstance(item, Constructor), item
@ -621,5 +599,5 @@ class RowTypeTestCase(unittest.TestCase):
assert c == len(self.tbl) assert c == len(self.tbl)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()