refactor: simplify database and table interactions

feat: update sqlalchemy version to 2.0.0
refactor: improve table insert and update methods
refactor: streamline table select and distinct queries
refactor: simplify table delete and schema creation
feat: test: update tests to reflect api changes
This commit is contained in:
retoor 2025-11-09 17:02:19 +01:00
parent b2ab09e58c
commit 73c4c0f4bc
5 changed files with 39 additions and 29 deletions

View File

@ -106,14 +106,12 @@ class Database(object):
@property
def metadata(self):
"""Return a SQLAlchemy schema cache object."""
return MetaData(schema=self.schema, bind=self.executable)
return MetaData(schema=self.schema)
@property
def in_transaction(self):
"""Check if this database is in a transactional context."""
if not hasattr(self.local, "tx"):
return False
return len(self.local.tx) > 0
return self.executable.in_transaction()
def _flush_tables(self):
"""Clear the table metadata after transaction rollbacks."""
@ -127,7 +125,10 @@ class Database(object):
"""
if not hasattr(self.local, "tx"):
self.local.tx = []
self.local.tx.append(self.executable.begin())
if self.in_transaction:
self.local.tx.append(self.executable.begin_nested())
else:
self.local.tx.append(self.executable.begin())
def commit(self):
"""Commit the current transaction.

View File

@ -116,7 +116,7 @@ class Table(object):
Returns the inserted row's primary key.
"""
row = self._sync_columns(row, ensure, types=types)
res = self.db.executable.execute(self.table.insert(row))
res = self.db.executable.execute(self.table.insert().values(row))
if len(res.inserted_primary_key) > 0:
return res.inserted_primary_key[0]
return True
@ -181,7 +181,7 @@ class Table(object):
# Insert when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
chunk = pad_chunk_columns(chunk, columns)
self.table.insert().execute(chunk)
self.db.executable.execute(self.table.insert(), chunk)
chunk = []
def update(self, row, keys, ensure=None, types=None, return_count=False):
@ -206,7 +206,7 @@ class Table(object):
clause = self._args_to_clause(args)
if not len(row):
return self.count(clause)
stmt = self.table.update(whereclause=clause, values=row)
stmt = self.table.update().where(clause).values(row)
rp = self.db.executable.execute(stmt)
if rp.supports_sane_rowcount():
return rp.rowcount
@ -241,9 +241,10 @@ class Table(object):
# Update when chunk_size is fulfilled or this is the last row
if len(chunk) == chunk_size or index == len(rows) - 1:
cl = [self.table.c[k] == bindparam("_%s" % k) for k in keys]
stmt = self.table.update(
whereclause=and_(True, *cl),
values={col: bindparam(col, required=False) for col in columns},
stmt = (
self.table.update()
.where(and_(True, *cl))
.values({col: bindparam(col, required=False) for col in columns})
)
self.db.executable.execute(stmt, chunk)
chunk = []
@ -293,7 +294,7 @@ class Table(object):
if not self.exists:
return False
clause = self._args_to_clause(filters, clauses=clauses)
stmt = self.table.delete(whereclause=clause)
stmt = self.table.delete().where(clause)
rp = self.db.executable.execute(stmt)
return rp.rowcount > 0
@ -303,7 +304,10 @@ class Table(object):
self._columns = None
try:
self._table = SQLATable(
self.name, self.db.metadata, schema=self.db.schema, autoload=True
self.name,
self.db.metadata,
schema=self.db.schema,
autoload_with=self.db.executable,
)
except NoSuchTableError:
self._table = None
@ -625,7 +629,7 @@ class Table(object):
order_by = self._args_to_order_by(order_by)
args = self._args_to_clause(kwargs, clauses=_clauses)
query = self.table.select(whereclause=args, limit=_limit, offset=_offset)
query = self.table.select().where(args).limit(_limit).offset(_offset)
if len(order_by):
query = query.order_by(*order_by)
@ -666,7 +670,7 @@ class Table(object):
return 0
args = self._args_to_clause(kwargs, clauses=_clauses)
query = select([func.count()], whereclause=args)
query = select(func.count()).where(args)
query = query.select_from(self.table)
rp = self.db.executable.execute(query)
return rp.fetchone()[0]
@ -705,13 +709,13 @@ class Table(object):
if not len(columns):
return iter([])
q = expression.select(
columns,
distinct=True,
whereclause=clause,
limit=_limit,
offset=_offset,
order_by=[c.asc() for c in columns],
q = (
expression.select(*columns)
.distinct()
.where(clause)
.limit(_limit)
.offset(_offset)
.order_by(*[c.asc() for c in columns])
)
return self.db.query(q)

View File

@ -15,7 +15,6 @@ try:
return None
return row_type(row._mapping.items())
except ImportError:
# SQLAlchemy < 1.4.0, no _mapping.

View File

@ -30,7 +30,7 @@ setup(
include_package_data=False,
zip_safe=False,
install_requires=[
"sqlalchemy >= 1.3.2, < 2.0.0",
"sqlalchemy >= 2.0.0",
"alembic >= 0.6.2",
"banal >= 1.0.1",
],

View File

@ -45,7 +45,7 @@ class DatabaseTestCase(unittest.TestCase):
if "sqlite" in self.db.engine.dialect.dbapi.__name__:
return
table = self.db.create_table("foo_no_id", primary_id=False)
assert table.table.exists()
assert self.db.has_table(table.table.name)
assert len(table.table.columns) == 0, table.table.columns
def test_create_table_custom_id1(self):
@ -83,7 +83,7 @@ class DatabaseTestCase(unittest.TestCase):
def test_create_table_shorthand1(self):
pid = "int_id"
table = self.db.get_table("foo5", pid)
assert table.table.exists
assert self.db.has_table(table.table.name)
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c
@ -98,7 +98,7 @@ class DatabaseTestCase(unittest.TestCase):
table = self.db.get_table(
"foo6", primary_id=pid, primary_type=self.db.types.string(255)
)
assert table.table.exists
assert self.db.has_table(table.table.name)
assert len(table.table.columns) == 1, table.table.columns
assert pid in table.table.c, table.table.c
@ -384,9 +384,15 @@ class TableTestCase(unittest.TestCase):
assert len(x) == 3, x
x = list(self.tbl.distinct("temperature", place=["B€rkeley", "G€lway"]))
assert len(x) == 6, x
x = list(self.tbl.distinct("temperature", _limit=3, place=["B€rkeley", "G€lway"]))
x = list(
self.tbl.distinct("temperature", _limit=3, place=["B€rkeley", "G€lway"])
)
assert len(x) == 3, x
x = list(self.tbl.distinct("temperature", _limit=6, _offset=1, place=["B€rkeley", "G€lway"]))
x = list(
self.tbl.distinct(
"temperature", _limit=6, _offset=1, place=["B€rkeley", "G€lway"]
)
)
assert len(x) == 5, x
def test_insert_many(self):