From 73c4c0f4bc7b4e60aa1591cb660902f502356200 Mon Sep 17 00:00:00 2001 From: retoor Date: Sun, 9 Nov 2025 17:02:19 +0100 Subject: [PATCH] refactor: simplify database and table interactions feat: update sqlalchemy version to 2.0.0 refactor: improve table insert and update methods refactor: streamline table select and distinct queries refactor: simplify table delete and schema creation feat: test: update tests to reflect api changes --- dataset/database.py | 11 ++++++----- dataset/table.py | 38 +++++++++++++++++++++----------------- dataset/util.py | 1 - setup.py | 2 +- test/test_dataset.py | 16 +++++++++++----- 5 files changed, 39 insertions(+), 29 deletions(-) diff --git a/dataset/database.py b/dataset/database.py index d8a07ad..16c37f5 100644 --- a/dataset/database.py +++ b/dataset/database.py @@ -106,14 +106,12 @@ class Database(object): @property def metadata(self): """Return a SQLAlchemy schema cache object.""" - return MetaData(schema=self.schema, bind=self.executable) + return MetaData(schema=self.schema) @property def in_transaction(self): """Check if this database is in a transactional context.""" - if not hasattr(self.local, "tx"): - return False - return len(self.local.tx) > 0 + return self.executable.in_transaction() def _flush_tables(self): """Clear the table metadata after transaction rollbacks.""" @@ -127,7 +125,10 @@ class Database(object): """ if not hasattr(self.local, "tx"): self.local.tx = [] - self.local.tx.append(self.executable.begin()) + if self.in_transaction: + self.local.tx.append(self.executable.begin_nested()) + else: + self.local.tx.append(self.executable.begin()) def commit(self): """Commit the current transaction. diff --git a/dataset/table.py b/dataset/table.py index 3fd0ff1..b53ac4e 100644 --- a/dataset/table.py +++ b/dataset/table.py @@ -116,7 +116,7 @@ class Table(object): Returns the inserted row's primary key. """ row = self._sync_columns(row, ensure, types=types) - res = self.db.executable.execute(self.table.insert(row)) + res = self.db.executable.execute(self.table.insert().values(row)) if len(res.inserted_primary_key) > 0: return res.inserted_primary_key[0] return True @@ -181,7 +181,7 @@ class Table(object): # Insert when chunk_size is fulfilled or this is the last row if len(chunk) == chunk_size or index == len(rows) - 1: chunk = pad_chunk_columns(chunk, columns) - self.table.insert().execute(chunk) + self.db.executable.execute(self.table.insert(), chunk) chunk = [] def update(self, row, keys, ensure=None, types=None, return_count=False): @@ -206,7 +206,7 @@ class Table(object): clause = self._args_to_clause(args) if not len(row): return self.count(clause) - stmt = self.table.update(whereclause=clause, values=row) + stmt = self.table.update().where(clause).values(row) rp = self.db.executable.execute(stmt) if rp.supports_sane_rowcount(): return rp.rowcount @@ -241,9 +241,10 @@ class Table(object): # Update when chunk_size is fulfilled or this is the last row if len(chunk) == chunk_size or index == len(rows) - 1: cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys] - stmt = self.table.update( - whereclause=and_(True, *cl), - values={col: bindparam(col, required=False) for col in columns}, + stmt = ( + self.table.update() + .where(and_(True, *cl)) + .values({col: bindparam(col, required=False) for col in columns}) ) self.db.executable.execute(stmt, chunk) chunk = [] @@ -293,7 +294,7 @@ class Table(object): if not self.exists: return False clause = self._args_to_clause(filters, clauses=clauses) - stmt = self.table.delete(whereclause=clause) + stmt = self.table.delete().where(clause) rp = self.db.executable.execute(stmt) return rp.rowcount > 0 @@ -303,7 +304,10 @@ class Table(object): self._columns = None try: self._table = SQLATable( - self.name, self.db.metadata, schema=self.db.schema, autoload=True + self.name, + self.db.metadata, + schema=self.db.schema, + autoload_with=self.db.executable, ) except NoSuchTableError: self._table = None @@ -625,7 +629,7 @@ class Table(object): order_by = self._args_to_order_by(order_by) args = self._args_to_clause(kwargs, clauses=_clauses) - query = self.table.select(whereclause=args, limit=_limit, offset=_offset) + query = self.table.select().where(args).limit(_limit).offset(_offset) if len(order_by): query = query.order_by(*order_by) @@ -666,7 +670,7 @@ class Table(object): return 0 args = self._args_to_clause(kwargs, clauses=_clauses) - query = select([func.count()], whereclause=args) + query = select(func.count()).where(args) query = query.select_from(self.table) rp = self.db.executable.execute(query) return rp.fetchone()[0] @@ -705,13 +709,13 @@ class Table(object): if not len(columns): return iter([]) - q = expression.select( - columns, - distinct=True, - whereclause=clause, - limit=_limit, - offset=_offset, - order_by=[c.asc() for c in columns], + q = ( + expression.select(*columns) + .distinct() + .where(clause) + .limit(_limit) + .offset(_offset) + .order_by(*[c.asc() for c in columns]) ) return self.db.query(q) diff --git a/dataset/util.py b/dataset/util.py index 4fa225d..8547da2 100644 --- a/dataset/util.py +++ b/dataset/util.py @@ -15,7 +15,6 @@ try: return None return row_type(row._mapping.items()) - except ImportError: # SQLAlchemy < 1.4.0, no _mapping. diff --git a/setup.py b/setup.py index 749af05..4097af5 100644 --- a/setup.py +++ b/setup.py @@ -30,7 +30,7 @@ setup( include_package_data=False, zip_safe=False, install_requires=[ - "sqlalchemy >= 1.3.2, < 2.0.0", + "sqlalchemy >= 2.0.0", "alembic >= 0.6.2", "banal >= 1.0.1", ], diff --git a/test/test_dataset.py b/test/test_dataset.py index 41058f9..601ac4c 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -45,7 +45,7 @@ class DatabaseTestCase(unittest.TestCase): if "sqlite" in self.db.engine.dialect.dbapi.__name__: return table = self.db.create_table("foo_no_id", primary_id=False) - assert table.table.exists() + assert self.db.has_table(table.table.name) assert len(table.table.columns) == 0, table.table.columns def test_create_table_custom_id1(self): @@ -83,7 +83,7 @@ class DatabaseTestCase(unittest.TestCase): def test_create_table_shorthand1(self): pid = "int_id" table = self.db.get_table("foo5", pid) - assert table.table.exists + assert self.db.has_table(table.table.name) assert len(table.table.columns) == 1, table.table.columns assert pid in table.table.c, table.table.c @@ -98,7 +98,7 @@ class DatabaseTestCase(unittest.TestCase): table = self.db.get_table( "foo6", primary_id=pid, primary_type=self.db.types.string(255) ) - assert table.table.exists + assert self.db.has_table(table.table.name) assert len(table.table.columns) == 1, table.table.columns assert pid in table.table.c, table.table.c @@ -384,9 +384,15 @@ class TableTestCase(unittest.TestCase): assert len(x) == 3, x x = list(self.tbl.distinct("temperature", place=["B€rkeley", "G€lway"])) assert len(x) == 6, x - x = list(self.tbl.distinct("temperature", _limit=3, place=["B€rkeley", "G€lway"])) + x = list( + self.tbl.distinct("temperature", _limit=3, place=["B€rkeley", "G€lway"]) + ) assert len(x) == 3, x - x = list(self.tbl.distinct("temperature", _limit=6, _offset=1, place=["B€rkeley", "G€lway"])) + x = list( + self.tbl.distinct( + "temperature", _limit=6, _offset=1, place=["B€rkeley", "G€lway"] + ) + ) assert len(x) == 5, x def test_insert_many(self):