Compare commits

..

No commits in common. "781b8436a39e6de91f9528d78c9c3d66bd8db0e3" and "cea87af0b2c65b78d2973429b3a8a2ca08f587b5" have entirely different histories.

9 changed files with 34 additions and 60 deletions

View File

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 1.6.2 current_version = 1.5.2
tag_name = {new_version} tag_name = {new_version}
commit = True commit = True
tag = True tag = True

View File

@ -1,13 +1,5 @@
# dataset ChangeLog # dataset ChangeLog
## Version 1.7.0 - 2025-11-09
Improved database interactions for faster and more reliable data operations. Updated dependencies and tests to support these changes.
**Changes:** 5 files, 68 lines
**Languages:** Python (68 lines)
*The changelog has only been started with version 0.3.12, previous *The changelog has only been started with version 0.3.12, previous
changes must be reconstructed from revision history.* changes must be reconstructed from revision history.*

View File

@ -11,7 +11,7 @@ warnings.filterwarnings(
) )
__all__ = ["Database", "Table", "connect"] __all__ = ["Database", "Table", "connect"]
__version__ = "1.7.0" __version__ = "1.5.2"
def connect( def connect(

View File

@ -106,12 +106,14 @@ class Database(object):
@property @property
def metadata(self): def metadata(self):
"""Return a SQLAlchemy schema cache object.""" """Return a SQLAlchemy schema cache object."""
return MetaData(schema=self.schema) return MetaData(schema=self.schema, bind=self.executable)
@property @property
def in_transaction(self): def in_transaction(self):
"""Check if this database is in a transactional context.""" """Check if this database is in a transactional context."""
return self.executable.in_transaction() if not hasattr(self.local, "tx"):
return False
return len(self.local.tx) > 0
def _flush_tables(self): def _flush_tables(self):
"""Clear the table metadata after transaction rollbacks.""" """Clear the table metadata after transaction rollbacks."""
@ -125,9 +127,6 @@ class Database(object):
""" """
if not hasattr(self.local, "tx"): if not hasattr(self.local, "tx"):
self.local.tx = [] self.local.tx = []
if self.in_transaction:
self.local.tx.append(self.executable.begin_nested())
else:
self.local.tx.append(self.executable.begin()) self.local.tx.append(self.executable.begin())
def commit(self): def commit(self):

View File

@ -116,7 +116,7 @@ class Table(object):
Returns the inserted row's primary key. Returns the inserted row's primary key.
""" """
row = self._sync_columns(row, ensure, types=types) row = self._sync_columns(row, ensure, types=types)
res = self.db.executable.execute(self.table.insert().values(row)) res = self.db.executable.execute(self.table.insert(row))
if len(res.inserted_primary_key) > 0: if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0] return res.inserted_primary_key[0]
return True return True
@ -181,7 +181,7 @@ class Table(object):
# Insert when chunk_size is fulfilled or this is the last row # Insert when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1: if len(chunk) == chunk_size or index == len(rows) - 1:
chunk = pad_chunk_columns(chunk, columns) chunk = pad_chunk_columns(chunk, columns)
self.db.executable.execute(self.table.insert(), chunk) self.table.insert().execute(chunk)
chunk = [] chunk = []
def update(self, row, keys, ensure=None, types=None, return_count=False): def update(self, row, keys, ensure=None, types=None, return_count=False):
@ -206,7 +206,7 @@ class Table(object):
clause = self._args_to_clause(args) clause = self._args_to_clause(args)
if not len(row): if not len(row):
return self.count(clause) return self.count(clause)
stmt = self.table.update().where(clause).values(row) stmt = self.table.update(whereclause=clause, values=row)
rp = self.db.executable.execute(stmt) rp = self.db.executable.execute(stmt)
if rp.supports_sane_rowcount(): if rp.supports_sane_rowcount():
return rp.rowcount return rp.rowcount
@ -241,10 +241,9 @@ class Table(object):
# Update when chunk_size is fulfilled or this is the last row # Update when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1: if len(chunk) == chunk_size or index == len(rows) - 1:
cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys] cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys]
stmt = ( stmt = self.table.update(
self.table.update() whereclause=and_(True, *cl),
.where(and_(True, *cl)) values={col: bindparam(col, required=False) for col in columns},
.values({col: bindparam(col, required=False) for col in columns})
) )
self.db.executable.execute(stmt, chunk) self.db.executable.execute(stmt, chunk)
chunk = [] chunk = []
@ -294,7 +293,7 @@ class Table(object):
if not self.exists: if not self.exists:
return False return False
clause = self._args_to_clause(filters, clauses=clauses) clause = self._args_to_clause(filters, clauses=clauses)
stmt = self.table.delete().where(clause) stmt = self.table.delete(whereclause=clause)
rp = self.db.executable.execute(stmt) rp = self.db.executable.execute(stmt)
return rp.rowcount > 0 return rp.rowcount > 0
@ -304,10 +303,7 @@ class Table(object):
self._columns = None self._columns = None
try: try:
self._table = SQLATable( self._table = SQLATable(
self.name, self.name, self.db.metadata, schema=self.db.schema, autoload=True
self.db.metadata,
schema=self.db.schema,
autoload_with=self.db.executable,
) )
except NoSuchTableError: except NoSuchTableError:
self._table = None self._table = None
@ -357,7 +353,7 @@ class Table(object):
self._threading_warn() self._threading_warn()
for column in columns: for column in columns:
if not self.has_column(column.name): if not self.has_column(column.name):
self.db.op.add_column(self.name, column, schema=self.db.schema) self.db.op.add_column(self.name, column, self.db.schema)
self._reflect_table() self._reflect_table()
def _sync_columns(self, row, ensure, types=None): def _sync_columns(self, row, ensure, types=None):
@ -513,7 +509,7 @@ class Table(object):
return return
self._threading_warn() self._threading_warn()
self.db.op.drop_column(self.table.name, name, schema=self.table.schema) self.db.op.drop_column(self.table.name, name, self.table.schema)
self._reflect_table() self._reflect_table()
def drop(self): def drop(self):
@ -629,7 +625,7 @@ class Table(object):
order_by = self._args_to_order_by(order_by) order_by = self._args_to_order_by(order_by)
args = self._args_to_clause(kwargs, clauses=_clauses) args = self._args_to_clause(kwargs, clauses=_clauses)
query = self.table.select().where(args).limit(_limit).offset(_offset) query = self.table.select(whereclause=args, limit=_limit, offset=_offset)
if len(order_by): if len(order_by):
query = query.order_by(*order_by) query = query.order_by(*order_by)
@ -670,7 +666,7 @@ class Table(object):
return 0 return 0
args = self._args_to_clause(kwargs, clauses=_clauses) args = self._args_to_clause(kwargs, clauses=_clauses)
query = select(func.count()).where(args) query = select([func.count()], whereclause=args)
query = query.select_from(self.table) query = query.select_from(self.table)
rp = self.db.executable.execute(query) rp = self.db.executable.execute(query)
return rp.fetchone()[0] return rp.fetchone()[0]
@ -679,7 +675,7 @@ class Table(object):
"""Return the number of rows in the table.""" """Return the number of rows in the table."""
return self.count() return self.count()
def distinct(self, *args, **kwargs): def distinct(self, *args, **_filter):
"""Return all the unique (distinct) values for the given ``columns``. """Return all the unique (distinct) values for the given ``columns``.
:: ::
@ -703,19 +699,15 @@ class Table(object):
raise DatasetException("No such column: %s" % column) raise DatasetException("No such column: %s" % column)
columns.append(self.table.c[column]) columns.append(self.table.c[column])
_limit = kwargs.pop("_limit", None) clause = self._args_to_clause(_filter, clauses=clauses)
_offset = kwargs.pop("_offset", 0)
clause = self._args_to_clause(kwargs, clauses=clauses)
if not len(columns): if not len(columns):
return iter([]) return iter([])
q = ( q = expression.select(
expression.select(*columns) columns,
.distinct() distinct=True,
.where(clause) whereclause=clause,
.limit(_limit) order_by=[c.asc() for c in columns],
.offset(_offset)
.order_by(*[c.asc() for c in columns])
) )
return self.db.query(q) return self.db.query(q)

View File

@ -15,6 +15,7 @@ try:
return None return None
return row_type(row._mapping.items()) return row_type(row._mapping.items())
except ImportError: except ImportError:
# SQLAlchemy < 1.4.0, no _mapping. # SQLAlchemy < 1.4.0, no _mapping.

View File

@ -45,9 +45,9 @@ copyright = u"2013-2021, Friedrich Lindenberg, Gregor Aisch, Stefan Wehrmeyer"
# built documents. # built documents.
# #
# The short X.Y version. # The short X.Y version.
version = "1.6.2" version = "1.5.2"
# The full version, including alpha/beta/rc tags. # The full version, including alpha/beta/rc tags.
release = "1.6.2" release = "1.5.2"
# There are two options for replacing |today|: either, you set today to some # There are two options for replacing |today|: either, you set today to some
# non-false value, then it is used: # non-false value, then it is used:

View File

@ -5,7 +5,7 @@ with open("README.md") as f:
setup( setup(
name="dataset", name="dataset",
version="1.7.0", version="1.5.2",
description="Toolkit for Python-based database access.", description="Toolkit for Python-based database access.",
long_description=long_description, long_description=long_description,
long_description_content_type="text/markdown", long_description_content_type="text/markdown",
@ -30,7 +30,7 @@ setup(
include_package_data=False, include_package_data=False,
zip_safe=False, zip_safe=False,
install_requires=[ install_requires=[
"sqlalchemy >= 2.0.0", "sqlalchemy >= 1.3.2, < 2.0.0",
"alembic >= 0.6.2", "alembic >= 0.6.2",
"banal >= 1.0.1", "banal >= 1.0.1",
], ],

View File

@ -45,7 +45,7 @@ class DatabaseTestCase(unittest.TestCase):
if "sqlite" in self.db.engine.dialect.dbapi.__name__: if "sqlite" in self.db.engine.dialect.dbapi.__name__:
return return
table = self.db.create_table("foo_no_id", primary_id=False) table = self.db.create_table("foo_no_id", primary_id=False)
assert self.db.has_table(table.table.name) assert table.table.exists()
assert len(table.table.columns) == 0, table.table.columns assert len(table.table.columns) == 0, table.table.columns
def test_create_table_custom_id1(self): def test_create_table_custom_id1(self):
@ -83,7 +83,7 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_shorthand1(self): def test_create_table_shorthand1(self):
pid = "int_id" pid = "int_id"
table = self.db.get_table("foo5", pid) table = self.db.get_table("foo5", pid)
assert self.db.has_table(table.table.name) assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
@ -98,7 +98,7 @@ class DatabaseTestCase(unittest.TestCase):
table = self.db.get_table( table = self.db.get_table(
"foo6", primary_id=pid, primary_type=self.db.types.string(255) "foo6", primary_id=pid, primary_type=self.db.types.string(255)
) )
assert self.db.has_table(table.table.name) assert table.table.exists
assert len(table.table.columns) == 1, table.table.columns assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c assert pid in table.table.c, table.table.c
@ -384,16 +384,6 @@ class TableTestCase(unittest.TestCase):
assert len(x) == 3, x assert len(x) == 3, x
x = list(self.tbl.distinct("temperature", place=["B€rkeley", "G€lway"])) x = list(self.tbl.distinct("temperature", place=["B€rkeley", "G€lway"]))
assert len(x) == 6, x assert len(x) == 6, x
x = list(
self.tbl.distinct("temperature", _limit=3, place=["B€rkeley", "G€lway"])
)
assert len(x) == 3, x
x = list(
self.tbl.distinct(
"temperature", _limit=6, _offset=1, place=["B€rkeley", "G€lway"]
)
)
assert len(x) == 5, x
def test_insert_many(self): def test_insert_many(self):
data = TEST_DATA * 100 data = TEST_DATA * 100