Migrate from SQLAlchemy 1.x legacy Query API to 2.x style select/delete statements
This commit is contained in:
+18
-10
@@ -21,6 +21,7 @@ import datetime as dt
|
||||
import unittest
|
||||
|
||||
import httpx
|
||||
import sqlalchemy as sa
|
||||
from flask import Flask
|
||||
|
||||
from accounting.utils.next_uri import encode_next
|
||||
@@ -275,8 +276,10 @@ class AccountTestCase(unittest.TestCase):
|
||||
response: httpx.Response
|
||||
|
||||
with self.__app.app_context():
|
||||
self.assertEqual({x.code for x in Account.query.all()},
|
||||
{CASH.code, BANK.code})
|
||||
self.assertEqual(
|
||||
{x.code
|
||||
for x in db.session.scalars(sa.select(Account)).unique()},
|
||||
{CASH.code, BANK.code})
|
||||
|
||||
# Missing CSRF token
|
||||
response = self.__client.post(store_uri,
|
||||
@@ -367,10 +370,11 @@ class AccountTestCase(unittest.TestCase):
|
||||
f"{PREFIX}/{STOCK.base_code}-003")
|
||||
|
||||
with self.__app.app_context():
|
||||
self.assertEqual({x.code for x in Account.query.all()},
|
||||
{CASH.code, BANK.code, STOCK.code,
|
||||
f"{STOCK.base_code}-002",
|
||||
f"{STOCK.base_code}-003"})
|
||||
self.assertEqual(
|
||||
{x.code
|
||||
for x in db.session.scalars(sa.select(Account)).unique()},
|
||||
{CASH.code, BANK.code, STOCK.code,
|
||||
f"{STOCK.base_code}-002", f"{STOCK.base_code}-003"})
|
||||
|
||||
account: Account | None = Account.find_by_code(STOCK.code)
|
||||
self.assertIsNotNone(account)
|
||||
@@ -621,8 +625,10 @@ class AccountTestCase(unittest.TestCase):
|
||||
"currency-1-credit-1-amount": "20"})
|
||||
|
||||
with self.__app.app_context():
|
||||
self.assertEqual({x.code for x in Account.query.all()},
|
||||
{CASH.code, PETTY.code, BANK.code})
|
||||
self.assertEqual(
|
||||
{x.code
|
||||
for x in db.session.scalars(sa.select(Account)).unique()},
|
||||
{CASH.code, PETTY.code, BANK.code})
|
||||
|
||||
# Cannot delete the cash account
|
||||
response = self.__client.post(f"{PREFIX}/{CASH.code}/delete",
|
||||
@@ -645,8 +651,10 @@ class AccountTestCase(unittest.TestCase):
|
||||
self.assertEqual(response.headers["Location"], list_uri)
|
||||
|
||||
with self.__app.app_context():
|
||||
self.assertEqual({x.code for x in Account.query.all()},
|
||||
{CASH.code, BANK.code})
|
||||
self.assertEqual(
|
||||
{x.code
|
||||
for x in db.session.scalars(sa.select(Account)).unique()},
|
||||
{CASH.code, BANK.code})
|
||||
|
||||
response = self.__client.get(detail_uri)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
+16
-10
@@ -101,7 +101,8 @@ class ConsoleCommandTestCase(unittest.TestCase):
|
||||
for x in rows}
|
||||
|
||||
with self.__app.app_context():
|
||||
accounts: list[BaseAccount] = BaseAccount.query.all()
|
||||
accounts: list[BaseAccount] = db.session.scalars(
|
||||
sa.select(BaseAccount)).unique().all()
|
||||
|
||||
self.assertEqual(len(accounts), len(data))
|
||||
for account in accounts:
|
||||
@@ -141,10 +142,14 @@ class ConsoleCommandTestCase(unittest.TestCase):
|
||||
from accounting.models import BaseAccount, Account, AccountL10n
|
||||
|
||||
with self.__app.app_context():
|
||||
bases: list[BaseAccount] = BaseAccount.query\
|
||||
.filter(sa.func.char_length(BaseAccount.code) == 4).all()
|
||||
accounts: list[Account] = Account.query.all()
|
||||
l10n: list[AccountL10n] = AccountL10n.query.all()
|
||||
bases: list[BaseAccount] = db.session.scalars(
|
||||
sa.select(BaseAccount)
|
||||
.where(sa.func.char_length(BaseAccount.code) == 4))\
|
||||
.unique().all()
|
||||
accounts: list[Account] = db.session.scalars(
|
||||
sa.select(Account)).unique().all()
|
||||
l10n: list[AccountL10n] = db.session.scalars(
|
||||
sa.select(AccountL10n)).all()
|
||||
|
||||
self.assertEqual({x.code for x in bases},
|
||||
{x.base_code for x in accounts})
|
||||
@@ -175,7 +180,8 @@ class ConsoleCommandTestCase(unittest.TestCase):
|
||||
for x in csv.DictReader(fp)}
|
||||
|
||||
with self.__app.app_context():
|
||||
currencies: list[Currency] = Currency.query.all()
|
||||
currencies: list[Currency] = db.session.scalars(
|
||||
sa.select(Currency)).unique().all()
|
||||
|
||||
self.assertEqual(len(currencies), len(data))
|
||||
for currency in currencies:
|
||||
@@ -216,9 +222,9 @@ class ConsoleCommandTestCase(unittest.TestCase):
|
||||
result.output + str(result.exception))
|
||||
|
||||
# Turns the titles into lowercase.
|
||||
for base in BaseAccount.query:
|
||||
for base in db.session.scalars(sa.select(BaseAccount)).unique():
|
||||
base.title_l10n = base.title_l10n.lower()
|
||||
for account in Account.query:
|
||||
for account in db.session.scalars(sa.select(Account)).unique():
|
||||
account.title_l10n = account.title_l10n.lower()
|
||||
account.created_at \
|
||||
= account.created_at - dt.timedelta(seconds=5)
|
||||
@@ -242,9 +248,9 @@ class ConsoleCommandTestCase(unittest.TestCase):
|
||||
args=["accounting-titleize", "-u", "editor"])
|
||||
self.assertEqual(result.exit_code, 0,
|
||||
result.output + str(result.exception))
|
||||
for base in BaseAccount.query:
|
||||
for base in db.session.scalars(sa.select(BaseAccount)).unique():
|
||||
self.__test_title_case(base.title_l10n)
|
||||
for account in Account.query:
|
||||
for account in db.session.scalars(sa.select(Account)).unique():
|
||||
if account.id != new_account.id:
|
||||
self.__test_title_case(account.title_l10n)
|
||||
self.assertNotEqual(account.created_at, account.updated_at)
|
||||
|
||||
+17
-8
@@ -21,6 +21,7 @@ import datetime as dt
|
||||
import unittest
|
||||
|
||||
import httpx
|
||||
import sqlalchemy as sa
|
||||
from flask import Flask
|
||||
|
||||
from accounting.utils.next_uri import encode_next
|
||||
@@ -221,8 +222,10 @@ class CurrencyTestCase(unittest.TestCase):
|
||||
response: httpx.Response
|
||||
|
||||
with self.__app.app_context():
|
||||
self.assertEqual({x.code for x in Currency.query.all()},
|
||||
{USD.code, EUR.code})
|
||||
self.assertEqual(
|
||||
{x.code
|
||||
for x in db.session.scalars(sa.select(Currency)).unique()},
|
||||
{USD.code, EUR.code})
|
||||
|
||||
# Missing CSRF token
|
||||
response = self.__client.post(store_uri,
|
||||
@@ -287,8 +290,10 @@ class CurrencyTestCase(unittest.TestCase):
|
||||
self.assertEqual(response.headers["Location"], create_uri)
|
||||
|
||||
with self.__app.app_context():
|
||||
self.assertEqual({x.code for x in Currency.query.all()},
|
||||
{USD.code, EUR.code, TWD.code})
|
||||
self.assertEqual(
|
||||
{x.code
|
||||
for x in db.session.scalars(sa.select(Currency)).unique()},
|
||||
{USD.code, EUR.code, TWD.code})
|
||||
|
||||
currency: Currency = db.session.get(Currency, TWD.code)
|
||||
self.assertEqual(currency.code, TWD.code)
|
||||
@@ -554,8 +559,10 @@ class CurrencyTestCase(unittest.TestCase):
|
||||
"currency-1-credit-1-amount": "20"})
|
||||
|
||||
with self.__app.app_context():
|
||||
self.assertEqual({x.code for x in Currency.query.all()},
|
||||
{USD.code, EUR.code, JPY.code})
|
||||
self.assertEqual(
|
||||
{x.code
|
||||
for x in db.session.scalars(sa.select(Currency)).unique()},
|
||||
{USD.code, EUR.code, JPY.code})
|
||||
|
||||
# Cannot delete the default currency
|
||||
response = self.__client.post(f"{PREFIX}/{USD.code}/delete",
|
||||
@@ -578,8 +585,10 @@ class CurrencyTestCase(unittest.TestCase):
|
||||
self.assertEqual(response.headers["Location"], list_uri)
|
||||
|
||||
with self.__app.app_context():
|
||||
self.assertEqual({x.code for x in Currency.query.all()},
|
||||
{USD.code, EUR.code})
|
||||
self.assertEqual(
|
||||
{x.code
|
||||
for x in db.session.scalars(sa.select(Currency)).unique()},
|
||||
{USD.code, EUR.code})
|
||||
|
||||
response = self.__client.get(detail_uri)
|
||||
self.assertEqual(response.status_code, 404)
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
import os
|
||||
from secrets import token_urlsafe
|
||||
|
||||
import sqlalchemy as sa
|
||||
from click.testing import Result
|
||||
from flask import Flask, Blueprint, render_template, redirect, Response, \
|
||||
url_for
|
||||
@@ -112,8 +113,8 @@ def create_app(is_testing: bool = False, is_skip_accounts: bool = False,
|
||||
return auth.current_user()
|
||||
|
||||
def get_by_username(self, username: str) -> auth.User | None:
|
||||
return auth.User.query\
|
||||
.filter(auth.User.username == username).first()
|
||||
return db.session.scalar(
|
||||
sa.select(auth.User).where(auth.User.username == username))
|
||||
|
||||
def get_pk(self, user: auth.User) -> int:
|
||||
return user.id
|
||||
@@ -140,7 +141,9 @@ def init_db(app: Flask, is_skip_accounts: bool,
|
||||
db.create_all()
|
||||
from .auth import User
|
||||
for username in ["viewer", "editor", "admin", "nobody"]:
|
||||
if User.query.filter(User.username == username).first() is None:
|
||||
user: User | None = db.session.scalar(
|
||||
sa.select(User).where(User.username == username))
|
||||
if user is None:
|
||||
db.session.add(User(username=username))
|
||||
db.session.commit()
|
||||
runner: FlaskCliRunner = app.test_cli_runner()
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
"""
|
||||
from collections.abc import Callable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import Blueprint, render_template, Flask, redirect, url_for, \
|
||||
session, request, g, Response, abort
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
@@ -91,8 +92,8 @@ def current_user() -> User | None:
|
||||
if "user" not in session:
|
||||
g.user = None
|
||||
else:
|
||||
g.user = User.query.filter(
|
||||
User.username == session["user"]).first()
|
||||
g.user = db.session.scalar(
|
||||
sa.select(User).where(User.username == session["user"]))
|
||||
return g.user
|
||||
|
||||
|
||||
|
||||
@@ -218,8 +218,8 @@ class BaseTestData(ABC):
|
||||
self._app: Flask = app
|
||||
"""The Flask application."""
|
||||
with self._app.app_context():
|
||||
current_user: User | None = User.query\
|
||||
.filter(User.username == username).first()
|
||||
current_user: User | None = db.session.scalar(
|
||||
sa.select(User).where(User.username == username))
|
||||
assert current_user is not None
|
||||
self.__current_user_id: int = current_user.id
|
||||
"""The current user ID."""
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
"""
|
||||
import datetime as dt
|
||||
|
||||
import sqlalchemy as sa
|
||||
from flask import Flask, Blueprint, url_for, flash, redirect, session, \
|
||||
render_template, current_app, Response
|
||||
from flask_babel import lazy_gettext
|
||||
@@ -83,14 +84,14 @@ def __reset_database() -> None:
|
||||
from accounting.account import init_accounts_command
|
||||
from accounting.currency import init_currencies_command
|
||||
|
||||
JournalEntryLineItem.query.delete()
|
||||
JournalEntry.query.delete()
|
||||
CurrencyL10n.query.delete()
|
||||
Currency.query.delete()
|
||||
AccountL10n.query.delete()
|
||||
Account.query.delete()
|
||||
BaseAccountL10n.query.delete()
|
||||
BaseAccount.query.delete()
|
||||
db.session.execute(sa.delete(JournalEntryLineItem))
|
||||
db.session.execute(sa.delete(JournalEntry))
|
||||
db.session.execute(sa.delete(CurrencyL10n))
|
||||
db.session.execute(sa.delete(Currency))
|
||||
db.session.execute(sa.delete(AccountL10n))
|
||||
db.session.execute(sa.delete(Account))
|
||||
db.session.execute(sa.delete(BaseAccountL10n))
|
||||
db.session.execute(sa.delete(BaseAccount))
|
||||
init_base_accounts_command()
|
||||
init_accounts_command(session["user"])
|
||||
init_currencies_command(session["user"])
|
||||
|
||||
Reference in New Issue
Block a user