Migrate from SQLAlchemy 1.x legacy Query API to 2.x style select/delete statements

This commit is contained in:
2026-04-06 01:06:01 +08:00
parent 356950e2c7
commit 970c2e9946
39 changed files with 372 additions and 275 deletions
+18 -10
View File
@@ -21,6 +21,7 @@ import datetime as dt
import unittest
import httpx
import sqlalchemy as sa
from flask import Flask
from accounting.utils.next_uri import encode_next
@@ -275,8 +276,10 @@ class AccountTestCase(unittest.TestCase):
response: httpx.Response
with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()},
{CASH.code, BANK.code})
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, BANK.code})
# Missing CSRF token
response = self.__client.post(store_uri,
@@ -367,10 +370,11 @@ class AccountTestCase(unittest.TestCase):
f"{PREFIX}/{STOCK.base_code}-003")
with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()},
{CASH.code, BANK.code, STOCK.code,
f"{STOCK.base_code}-002",
f"{STOCK.base_code}-003"})
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, BANK.code, STOCK.code,
f"{STOCK.base_code}-002", f"{STOCK.base_code}-003"})
account: Account | None = Account.find_by_code(STOCK.code)
self.assertIsNotNone(account)
@@ -621,8 +625,10 @@ class AccountTestCase(unittest.TestCase):
"currency-1-credit-1-amount": "20"})
with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()},
{CASH.code, PETTY.code, BANK.code})
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, PETTY.code, BANK.code})
# Cannot delete the cash account
response = self.__client.post(f"{PREFIX}/{CASH.code}/delete",
@@ -645,8 +651,10 @@ class AccountTestCase(unittest.TestCase):
self.assertEqual(response.headers["Location"], list_uri)
with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()},
{CASH.code, BANK.code})
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, BANK.code})
response = self.__client.get(detail_uri)
self.assertEqual(response.status_code, 404)
+16 -10
View File
@@ -101,7 +101,8 @@ class ConsoleCommandTestCase(unittest.TestCase):
for x in rows}
with self.__app.app_context():
accounts: list[BaseAccount] = BaseAccount.query.all()
accounts: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)).unique().all()
self.assertEqual(len(accounts), len(data))
for account in accounts:
@@ -141,10 +142,14 @@ class ConsoleCommandTestCase(unittest.TestCase):
from accounting.models import BaseAccount, Account, AccountL10n
with self.__app.app_context():
bases: list[BaseAccount] = BaseAccount.query\
.filter(sa.func.char_length(BaseAccount.code) == 4).all()
accounts: list[Account] = Account.query.all()
l10n: list[AccountL10n] = AccountL10n.query.all()
bases: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)
.where(sa.func.char_length(BaseAccount.code) == 4))\
.unique().all()
accounts: list[Account] = db.session.scalars(
sa.select(Account)).unique().all()
l10n: list[AccountL10n] = db.session.scalars(
sa.select(AccountL10n)).all()
self.assertEqual({x.code for x in bases},
{x.base_code for x in accounts})
@@ -175,7 +180,8 @@ class ConsoleCommandTestCase(unittest.TestCase):
for x in csv.DictReader(fp)}
with self.__app.app_context():
currencies: list[Currency] = Currency.query.all()
currencies: list[Currency] = db.session.scalars(
sa.select(Currency)).unique().all()
self.assertEqual(len(currencies), len(data))
for currency in currencies:
@@ -216,9 +222,9 @@ class ConsoleCommandTestCase(unittest.TestCase):
result.output + str(result.exception))
# Turns the titles into lowercase.
for base in BaseAccount.query:
for base in db.session.scalars(sa.select(BaseAccount)).unique():
base.title_l10n = base.title_l10n.lower()
for account in Account.query:
for account in db.session.scalars(sa.select(Account)).unique():
account.title_l10n = account.title_l10n.lower()
account.created_at \
= account.created_at - dt.timedelta(seconds=5)
@@ -242,9 +248,9 @@ class ConsoleCommandTestCase(unittest.TestCase):
args=["accounting-titleize", "-u", "editor"])
self.assertEqual(result.exit_code, 0,
result.output + str(result.exception))
for base in BaseAccount.query:
for base in db.session.scalars(sa.select(BaseAccount)).unique():
self.__test_title_case(base.title_l10n)
for account in Account.query:
for account in db.session.scalars(sa.select(Account)).unique():
if account.id != new_account.id:
self.__test_title_case(account.title_l10n)
self.assertNotEqual(account.created_at, account.updated_at)
+17 -8
View File
@@ -21,6 +21,7 @@ import datetime as dt
import unittest
import httpx
import sqlalchemy as sa
from flask import Flask
from accounting.utils.next_uri import encode_next
@@ -221,8 +222,10 @@ class CurrencyTestCase(unittest.TestCase):
response: httpx.Response
with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()},
{USD.code, EUR.code})
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code})
# Missing CSRF token
response = self.__client.post(store_uri,
@@ -287,8 +290,10 @@ class CurrencyTestCase(unittest.TestCase):
self.assertEqual(response.headers["Location"], create_uri)
with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()},
{USD.code, EUR.code, TWD.code})
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code, TWD.code})
currency: Currency = db.session.get(Currency, TWD.code)
self.assertEqual(currency.code, TWD.code)
@@ -554,8 +559,10 @@ class CurrencyTestCase(unittest.TestCase):
"currency-1-credit-1-amount": "20"})
with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()},
{USD.code, EUR.code, JPY.code})
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code, JPY.code})
# Cannot delete the default currency
response = self.__client.post(f"{PREFIX}/{USD.code}/delete",
@@ -578,8 +585,10 @@ class CurrencyTestCase(unittest.TestCase):
self.assertEqual(response.headers["Location"], list_uri)
with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()},
{USD.code, EUR.code})
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code})
response = self.__client.get(detail_uri)
self.assertEqual(response.status_code, 404)
+6 -3
View File
@@ -20,6 +20,7 @@
import os
from secrets import token_urlsafe
import sqlalchemy as sa
from click.testing import Result
from flask import Flask, Blueprint, render_template, redirect, Response, \
url_for
@@ -112,8 +113,8 @@ def create_app(is_testing: bool = False, is_skip_accounts: bool = False,
return auth.current_user()
def get_by_username(self, username: str) -> auth.User | None:
return auth.User.query\
.filter(auth.User.username == username).first()
return db.session.scalar(
sa.select(auth.User).where(auth.User.username == username))
def get_pk(self, user: auth.User) -> int:
return user.id
@@ -140,7 +141,9 @@ def init_db(app: Flask, is_skip_accounts: bool,
db.create_all()
from .auth import User
for username in ["viewer", "editor", "admin", "nobody"]:
if User.query.filter(User.username == username).first() is None:
user: User | None = db.session.scalar(
sa.select(User).where(User.username == username))
if user is None:
db.session.add(User(username=username))
db.session.commit()
runner: FlaskCliRunner = app.test_cli_runner()
+3 -2
View File
@@ -19,6 +19,7 @@
"""
from collections.abc import Callable
import sqlalchemy as sa
from flask import Blueprint, render_template, Flask, redirect, url_for, \
session, request, g, Response, abort
from sqlalchemy.orm import Mapped, mapped_column
@@ -91,8 +92,8 @@ def current_user() -> User | None:
if "user" not in session:
g.user = None
else:
g.user = User.query.filter(
User.username == session["user"]).first()
g.user = db.session.scalar(
sa.select(User).where(User.username == session["user"]))
return g.user
+2 -2
View File
@@ -218,8 +218,8 @@ class BaseTestData(ABC):
self._app: Flask = app
"""The Flask application."""
with self._app.app_context():
current_user: User | None = User.query\
.filter(User.username == username).first()
current_user: User | None = db.session.scalar(
sa.select(User).where(User.username == username))
assert current_user is not None
self.__current_user_id: int = current_user.id
"""The current user ID."""
+9 -8
View File
@@ -19,6 +19,7 @@
"""
import datetime as dt
import sqlalchemy as sa
from flask import Flask, Blueprint, url_for, flash, redirect, session, \
render_template, current_app, Response
from flask_babel import lazy_gettext
@@ -83,14 +84,14 @@ def __reset_database() -> None:
from accounting.account import init_accounts_command
from accounting.currency import init_currencies_command
JournalEntryLineItem.query.delete()
JournalEntry.query.delete()
CurrencyL10n.query.delete()
Currency.query.delete()
AccountL10n.query.delete()
Account.query.delete()
BaseAccountL10n.query.delete()
BaseAccount.query.delete()
db.session.execute(sa.delete(JournalEntryLineItem))
db.session.execute(sa.delete(JournalEntry))
db.session.execute(sa.delete(CurrencyL10n))
db.session.execute(sa.delete(Currency))
db.session.execute(sa.delete(AccountL10n))
db.session.execute(sa.delete(Account))
db.session.execute(sa.delete(BaseAccountL10n))
db.session.execute(sa.delete(BaseAccount))
init_base_accounts_command()
init_accounts_command(session["user"])
init_currencies_command(session["user"])