Migrate from SQLAlchemy 1.x legacy Query API to 2.x style select/delete statements

This commit is contained in:
2026-04-06 01:06:01 +08:00
parent 356950e2c7
commit 970c2e9946
39 changed files with 372 additions and 275 deletions
+2 -1
View File
@@ -49,7 +49,8 @@ The following is an example configuration for *Mia! Accounting*.
return current_user() return current_user()
def get_by_username(self, username: str) -> User | None: def get_by_username(self, username: str) -> User | None:
return User.query.filter(User.username == username).first() return db.session.scalar(
sa.select(User).where(User.username == username))
def get_pk(self, user: User) -> int: def get_pk(self, user: User) -> int:
return user.id return user.id
+5 -4
View File
@@ -36,13 +36,14 @@ def init_accounts_command(username: str) -> None:
"""Initializes the accounts.""" """Initializes the accounts."""
creator_pk: int = get_user_pk(username) creator_pk: int = get_user_pk(username)
bases: list[BaseAccount] = BaseAccount.query\ bases: list[BaseAccount] = db.session.scalars(
.filter(db.func.length(BaseAccount.code) == 4)\ sa.select(BaseAccount).where(db.func.length(BaseAccount.code) == 4)
.order_by(BaseAccount.code).all() .order_by(BaseAccount.code)).unique().all()
if len(bases) == 0: if len(bases) == 0:
raise click.Abort raise click.Abort
existing: list[Account] = Account.query.all() existing: list[Account] = \
db.session.scalars(sa.select(Account)).unique().all()
existing_base_code: set[str] = {x.base_code for x in existing} existing_base_code: set[str] = {x.base_code for x in existing}
bases_to_add: list[BaseAccount] = [x for x in bases bases_to_add: list[BaseAccount] = [x for x in bases
+11 -9
View File
@@ -97,8 +97,9 @@ class AccountForm(FlaskForm):
if obj.base_code is not None: if obj.base_code is not None:
sort_accounts_in(obj.base_code, obj.id) sort_accounts_in(obj.base_code, obj.id)
sort_accounts_in(self.base_code.data, obj.id) sort_accounts_in(self.base_code.data, obj.id)
count: int = Account.query\ count: int = db.session.scalar(
.filter(Account.base_code == self.base_code.data).count() sa.select(sa.func.count(Account.id))
.where(Account.base_code == self.base_code.data))
obj.base_code = self.base_code.data obj.base_code = self.base_code.data
obj.no = count + 1 obj.no = count + 1
obj.title = self.title.data obj.title = self.title.data
@@ -137,9 +138,10 @@ class AccountForm(FlaskForm):
:return: The selectable base accounts. :return: The selectable base accounts.
""" """
return BaseAccount.query\ return db.session.scalars(
.filter(sa.func.char_length(BaseAccount.code) == 4)\ sa.select(BaseAccount)
.order_by(BaseAccount.code).all() .where(sa.func.char_length(BaseAccount.code) == 4)
.order_by(BaseAccount.code)).unique()
def sort_accounts_in(base_code: str, exclude: int) -> None: def sort_accounts_in(base_code: str, exclude: int) -> None:
@@ -150,10 +152,10 @@ def sort_accounts_in(base_code: str, exclude: int) -> None:
:param exclude: The account ID to exclude. :param exclude: The account ID to exclude.
:return: None. :return: None.
""" """
accounts: list[Account] = Account.query\ accounts: list[Account] = db.session.scalars(
.filter(Account.base_code == base_code, sa.select(Account)
Account.id != exclude)\ .where(Account.base_code == base_code, Account.id != exclude)
.order_by(Account.no).all() .order_by(Account.no)).unique().all()
for i in range(len(accounts)): for i in range(len(accounts)):
if accounts[i].no != i + 1: if accounts[i].no != i + 1:
accounts[i].no = i + 1 accounts[i].no = i + 1
+10 -5
View File
@@ -20,6 +20,7 @@
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from .. import db
from ..locale import gettext from ..locale import gettext
from ..models import Account, AccountL10n from ..models import Account, AccountL10n
from ..utils.query import parse_query_keywords from ..utils.query import parse_query_keywords
@@ -32,15 +33,18 @@ def get_account_query() -> list[Account]:
""" """
keywords: list[str] = parse_query_keywords(request.args.get("q")) keywords: list[str] = parse_query_keywords(request.args.get("q"))
if len(keywords) == 0: if len(keywords) == 0:
return Account.query.order_by(Account.base_code, Account.no).all() return db.session.scalars(
sa.select(Account)
.order_by(Account.base_code, Account.no)).unique().all()
code: sa.ColumnElement[str] = Account.base_code + "-" \ code: sa.ColumnElement[str] = Account.base_code + "-" \
+ sa.func.substr("000" + sa.cast(Account.no, sa.String), + sa.func.substr("000" + sa.cast(Account.no, sa.String),
sa.func.char_length(sa.cast(Account.no, sa.func.char_length(sa.cast(Account.no,
sa.String)) + 1) sa.String)) + 1)
conditions: list[sa.ColumnElement[bool]] = [] conditions: list[sa.ColumnElement[bool]] = []
for k in keywords: for k in keywords:
l10n: list[AccountL10n] = AccountL10n.query\ l10n: list[AccountL10n] = db.session.scalars(
.filter(AccountL10n.title.icontains(k)).all() sa.select(AccountL10n)
.where(AccountL10n.title.icontains(k))).all()
l10n_matches: set[int] = {x.account_id for x in l10n} l10n_matches: set[int] = {x.account_id for x in l10n}
sub_conditions: list[sa.ColumnElement[bool]] \ sub_conditions: list[sa.ColumnElement[bool]] \
= [Account.base_code.contains(k), = [Account.base_code.contains(k),
@@ -51,5 +55,6 @@ def get_account_query() -> list[Account]:
sub_conditions.append(Account.is_need_offset) sub_conditions.append(Account.is_need_offset)
conditions.append(sa.or_(*sub_conditions)) conditions.append(sa.or_(*sub_conditions))
return Account.query.filter(*conditions)\ return db.session.scalars(
.order_by(Account.base_code, Account.no).all() sa.select(Account).where(*conditions)
.order_by(Account.base_code, Account.no)).unique().all()
+1 -1
View File
@@ -28,7 +28,7 @@ from ..utils.title_case import title_case
def init_base_accounts_command() -> None: def init_base_accounts_command() -> None:
"""Initializes the base accounts.""" """Initializes the base accounts."""
if BaseAccount.query.first() is not None: if db.session.scalar(sa.select(BaseAccount)) is not None:
return return
with open(data_dir / "base_accounts.csv") as fp: with open(data_dir / "base_accounts.csv") as fp:
+9 -5
View File
@@ -20,6 +20,7 @@
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from .. import db
from ..models import BaseAccount, BaseAccountL10n from ..models import BaseAccount, BaseAccountL10n
from ..utils.query import parse_query_keywords from ..utils.query import parse_query_keywords
@@ -31,14 +32,17 @@ def get_base_account_query() -> list[BaseAccount]:
""" """
keywords: list[str] = parse_query_keywords(request.args.get("q")) keywords: list[str] = parse_query_keywords(request.args.get("q"))
if len(keywords) == 0: if len(keywords) == 0:
return BaseAccount.query.order_by(BaseAccount.code).all() return db.session.scalars(
sa.select(BaseAccount).order_by(BaseAccount.code)).unique().all()
conditions: list[sa.ColumnElement[bool]] = [] conditions: list[sa.ColumnElement[bool]] = []
for k in keywords: for k in keywords:
l10n: list[BaseAccountL10n] = BaseAccountL10n.query\ l10n: list[BaseAccountL10n] = db.session.scalars(
.filter(BaseAccountL10n.title.icontains(k)).all() sa.select(BaseAccountL10n)
.where((BaseAccountL10n.title.icontains(k)))).all()
l10n_matches: set[str] = {x.account_code for x in l10n} l10n_matches: set[str] = {x.account_code for x in l10n}
conditions.append(sa.or_(BaseAccount.code.contains(k), conditions.append(sa.or_(BaseAccount.code.contains(k),
BaseAccount.title_l10n.icontains(k), BaseAccount.title_l10n.icontains(k),
BaseAccount.code.in_(l10n_matches))) BaseAccount.code.in_(l10n_matches)))
return BaseAccount.query.filter(*conditions)\ return db.session.scalars(
.order_by(BaseAccount.code).all() sa.select(BaseAccount).where(*conditions)
.order_by(BaseAccount.code)).unique().all()
+4 -2
View File
@@ -66,8 +66,10 @@ def init_db_command(username: str, skip_accounts: bool,
init_base_accounts_command() init_base_accounts_command()
if not skip_accounts: if not skip_accounts:
init_accounts_command(username) init_accounts_command(username)
print("OK 1")
if not skip_currencies: if not skip_currencies:
init_currencies_command(username) init_currencies_command(username)
print("OK 2")
db.session.commit() db.session.commit()
click.echo("Accounting database initialized.") click.echo("Accounting database initialized.")
@@ -81,12 +83,12 @@ def titleize_command(username: str) -> None:
"""Capitalize the account titles.""" """Capitalize the account titles."""
updater_pk: int = get_user_pk(username) updater_pk: int = get_user_pk(username)
updated: int = 0 updated: int = 0
for base in BaseAccount.query: for base in db.session.scalars(sa.select(BaseAccount)).unique():
new_title: str = title_case(base.title_l10n) new_title: str = title_case(base.title_l10n)
if base.title_l10n != new_title: if base.title_l10n != new_title:
base.title_l10n = new_title base.title_l10n = new_title
updated = updated + 1 updated = updated + 1
for account in Account.query: for account in db.session.scalars(sa.select(Account)).unique():
if account.title_l10n.lower() == account.base.title_l10n.lower(): if account.title_l10n.lower() == account.base.title_l10n.lower():
new_title: str = title_case(account.title_l10n) new_title: str = title_case(account.title_l10n)
if account.title_l10n != new_title: if account.title_l10n != new_title:
+2 -1
View File
@@ -29,7 +29,8 @@ from ..utils.user import get_user_pk
def init_currencies_command(username: str) -> None: def init_currencies_command(username: str) -> None:
"""Initializes the currencies.""" """Initializes the currencies."""
existing_codes: set[str] = {x.code for x in Currency.query.all()} existing_codes: set[str] = \
{x.code for x in db.session.scalars(sa.select(Currency)).unique()}
with open(data_dir / "currencies.csv") as fp: with open(data_dir / "currencies.csv") as fp:
data: list[dict[str, str]] = [x for x in csv.DictReader(fp)] data: list[dict[str, str]] = [x for x in csv.DictReader(fp)]
+9 -5
View File
@@ -20,6 +20,7 @@
import sqlalchemy as sa import sqlalchemy as sa
from flask import request from flask import request
from .. import db
from ..models import Currency, CurrencyL10n from ..models import Currency, CurrencyL10n
from ..utils.query import parse_query_keywords from ..utils.query import parse_query_keywords
@@ -31,14 +32,17 @@ def get_currency_query() -> list[Currency]:
""" """
keywords: list[str] = parse_query_keywords(request.args.get("q")) keywords: list[str] = parse_query_keywords(request.args.get("q"))
if len(keywords) == 0: if len(keywords) == 0:
return Currency.query.order_by(Currency.code).all() return db.session.scalars(
sa.select(Currency).order_by(Currency.code)).unique().all()
conditions: list[sa.ColumnElement[bool]] = [] conditions: list[sa.ColumnElement[bool]] = []
for k in keywords: for k in keywords:
l10n: list[CurrencyL10n] = CurrencyL10n.query\ l10n: list[CurrencyL10n] = db.session.scalars(
.filter(CurrencyL10n.name.icontains(k)).all() sa.select(CurrencyL10n)
.where(CurrencyL10n.name.icontains(k))).all()
l10n_matches: set[str] = {x.account_code for x in l10n} l10n_matches: set[str] = {x.account_code for x in l10n}
conditions.append(sa.or_(Currency.code.icontains(k), conditions.append(sa.or_(Currency.code.icontains(k),
Currency.name_l10n.icontains(k), Currency.name_l10n.icontains(k),
Currency.code.in_(l10n_matches))) Currency.code.in_(l10n_matches)))
return Currency.query.filter(*conditions)\ return db.session.scalars(
.order_by(Currency.code).all() sa.select(Currency).where(*conditions)
.order_by(Currency.code)).unique().all()
+11 -11
View File
@@ -55,7 +55,7 @@ class SameCurrencyAsOriginalLineItems:
return return
original_line_item_currency_codes: set[str] = set(db.session.scalars( original_line_item_currency_codes: set[str] = set(db.session.scalars(
sa.select(JournalEntryLineItem.currency_code) sa.select(JournalEntryLineItem.currency_code)
.filter(JournalEntryLineItem.id.in_(original_line_item_id))).all()) .where(JournalEntryLineItem.id.in_(original_line_item_id))).all())
for currency_code in original_line_item_currency_codes: for currency_code in original_line_item_currency_codes:
if field.data != currency_code: if field.data != currency_code:
raise ValidationError(lazy_gettext( raise ValidationError(lazy_gettext(
@@ -72,17 +72,17 @@ class KeepCurrencyWhenHavingOffset:
if field.data is None: if field.data is None:
return return
offset: sa.Alias = offset_alias() offset: sa.Alias = offset_alias()
original_line_items: list[JournalEntryLineItem]\ original_line_items: list[JournalEntryLineItem] = db.session.scalars(
= JournalEntryLineItem.query\ sa.select(JournalEntryLineItem)
.join(offset, .join(offset,
JournalEntryLineItem.id == offset.c.original_line_item_id, JournalEntryLineItem.id == offset.c.original_line_item_id,
isouter=True)\ isouter=True)
.filter(JournalEntryLineItem.id .where(JournalEntryLineItem.id
.in_({x.id.data for x in form.line_items .in_({x.id.data for x in form.line_items
if x.id.data is not None}))\ if x.id.data is not None}))
.group_by(JournalEntryLineItem.id, .group_by(JournalEntryLineItem.id,
JournalEntryLineItem.currency_code)\ JournalEntryLineItem.currency_code)
.having(sa.func.count(offset.c.id) > 0).all() .having(sa.func.count(offset.c.id) > 0)).unique().all()
for original_line_item in original_line_items: for original_line_item in original_line_items:
if original_line_item.currency_code != field.data: if original_line_item.currency_code != field.data:
raise ValidationError(lazy_gettext( raise ValidationError(lazy_gettext(
@@ -152,8 +152,8 @@ class CurrencyForm(FlaskForm):
line_item_id: set[int] = {x.id.data for x in line_item_forms line_item_id: set[int] = {x.id.data for x in line_item_forms
if x.id.data is not None} if x.id.data is not None}
select: sa.Select = sa.select(sa.func.count(JournalEntryLineItem.id))\ select: sa.Select = sa.select(sa.func.count(JournalEntryLineItem.id))\
.filter(JournalEntryLineItem.original_line_item_id .where(JournalEntryLineItem.original_line_item_id
.in_(line_item_id)) .in_(line_item_id))
return db.session.scalar(select) > 0 return db.session.scalar(select) > 0
@@ -159,8 +159,9 @@ class JournalEntryForm(FlaskForm):
to_delete: set[int] = {x.id for x in obj.line_items to_delete: set[int] = {x.id for x in obj.line_items
if x.id not in collector.to_keep} if x.id not in collector.to_keep}
if len(to_delete) > 0: if len(to_delete) > 0:
JournalEntryLineItem.query\ db.session.execute(
.filter(JournalEntryLineItem.id.in_(to_delete)).delete() sa.delete(JournalEntryLineItem)
.where(JournalEntryLineItem.id.in_(to_delete)))
self.is_modified = True self.is_modified = True
if is_new or db.session.is_modified(obj): if is_new or db.session.is_modified(obj):
@@ -195,7 +196,7 @@ class JournalEntryForm(FlaskForm):
if self.max_date is not None and new_date == self.max_date: if self.max_date is not None and new_date == self.max_date:
db_min_no: int | None = db.session.scalar( db_min_no: int | None = db.session.scalar(
sa.select(sa.func.min(JournalEntry.no)) sa.select(sa.func.min(JournalEntry.no))
.filter(JournalEntry.date == new_date)) .where(JournalEntry.date == new_date))
if db_min_no is None: if db_min_no is None:
obj.date = new_date obj.date = new_date
obj.no = 1 obj.no = 1
@@ -205,8 +206,9 @@ class JournalEntryForm(FlaskForm):
sort_journal_entries_in(new_date) sort_journal_entries_in(new_date)
else: else:
sort_journal_entries_in(new_date, obj.id) sort_journal_entries_in(new_date, obj.id)
count: int = JournalEntry.query\ count: int = db.session.scalar(
.filter(JournalEntry.date == new_date).count() sa.select(sa.func.count(JournalEntry.id))
.where(JournalEntry.date == new_date))
obj.date = new_date obj.date = new_date
obj.no = count + 1 obj.no = count + 1
@@ -221,7 +223,7 @@ class JournalEntryForm(FlaskForm):
if not (x.code[0] == "2" and x.is_need_offset)] if not (x.code[0] == "2" and x.is_need_offset)]
in_use: set[int] = set(db.session.scalars( in_use: set[int] = set(db.session.scalars(
sa.select(JournalEntryLineItem.account_id) sa.select(JournalEntryLineItem.account_id)
.filter(JournalEntryLineItem.is_debit) .where(JournalEntryLineItem.is_debit)
.group_by(JournalEntryLineItem.account_id)).all()) .group_by(JournalEntryLineItem.account_id)).all())
for account in accounts: for account in accounts:
account.is_in_use = account.id in in_use account.is_in_use = account.id in in_use
@@ -238,7 +240,7 @@ class JournalEntryForm(FlaskForm):
if not (x.code[0] == "1" and x.is_need_offset)] if not (x.code[0] == "1" and x.is_need_offset)]
in_use: set[int] = set(db.session.scalars( in_use: set[int] = set(db.session.scalars(
sa.select(JournalEntryLineItem.account_id) sa.select(JournalEntryLineItem.account_id)
.filter(sa.not_(JournalEntryLineItem.is_debit)) .where(sa.not_(JournalEntryLineItem.is_debit))
.group_by(JournalEntryLineItem.account_id)).all()) .group_by(JournalEntryLineItem.account_id)).all())
for account in accounts: for account in accounts:
account.is_in_use = account.id in in_use account.is_in_use = account.id in in_use
@@ -288,7 +290,7 @@ class JournalEntryForm(FlaskForm):
return None return None
select: sa.Select = sa.select(sa.func.max(JournalEntry.date))\ select: sa.Select = sa.select(sa.func.max(JournalEntry.date))\
.join(JournalEntryLineItem)\ .join(JournalEntryLineItem)\
.filter(JournalEntryLineItem.id.in_(original_line_item_id)) .where(JournalEntryLineItem.id.in_(original_line_item_id))
return db.session.scalar(select) return db.session.scalar(select)
@property @property
@@ -301,8 +303,8 @@ class JournalEntryForm(FlaskForm):
if x.id.data is not None} if x.id.data is not None}
select: sa.Select = sa.select(sa.func.min(JournalEntry.date))\ select: sa.Select = sa.select(sa.func.min(JournalEntry.date))\
.join(JournalEntryLineItem)\ .join(JournalEntryLineItem)\
.filter(JournalEntryLineItem.original_line_item_id .where(JournalEntryLineItem.original_line_item_id
.in_(line_item_id)) .in_(line_item_id))
return db.session.scalar(select) return db.session.scalar(select)
@@ -202,9 +202,9 @@ class NotExceedingOriginalLineItemNetBalance:
else_=-JournalEntryLineItem.amount)) else_=-JournalEntryLineItem.amount))
offset_total_but_form: Decimal | None = db.session.scalar( offset_total_but_form: Decimal | None = db.session.scalar(
sa.select(offset_total_func) sa.select(offset_total_func)
.filter(JournalEntryLineItem.original_line_item_id .where(JournalEntryLineItem.original_line_item_id
== original_line_item.id, == original_line_item.id,
JournalEntryLineItem.id.not_in(existing_line_item_id))) JournalEntryLineItem.id.not_in(existing_line_item_id)))
if offset_total_but_form is None: if offset_total_but_form is None:
offset_total_but_form = Decimal("0") offset_total_but_form = Decimal("0")
offset_total_on_form: Decimal = sum( offset_total_on_form: Decimal = sum(
@@ -231,7 +231,7 @@ class NotLessThanOffsetTotal:
(JournalEntryLineItem.is_debit != is_debit, (JournalEntryLineItem.is_debit != is_debit,
JournalEntryLineItem.amount), JournalEntryLineItem.amount),
else_=-JournalEntryLineItem.amount)))\ else_=-JournalEntryLineItem.amount)))\
.filter(JournalEntryLineItem.original_line_item_id == form.id.data) .where(JournalEntryLineItem.original_line_item_id == form.id.data)
offset_total: Decimal | None = db.session.scalar(select_offset_total) offset_total: Decimal | None = db.session.scalar(select_offset_total)
if offset_total is not None and field.data < offset_total: if offset_total is not None and field.data < offset_total:
raise ValidationError(lazy_gettext( raise ValidationError(lazy_gettext(
@@ -353,13 +353,14 @@ class LineItemForm(FlaskForm):
def get_offsets() -> list[JournalEntryLineItem]: def get_offsets() -> list[JournalEntryLineItem]:
if not self.is_need_offset or self.id.data is None: if not self.is_need_offset or self.id.data is None:
return [] return []
return JournalEntryLineItem.query.join(JournalEntry)\ return db.session.scalars(
.filter(JournalEntryLineItem.original_line_item_id sa.select(JournalEntryLineItem).join(JournalEntry)
== self.id.data)\ .where(JournalEntryLineItem.original_line_item_id
== self.id.data)
.order_by(JournalEntry.date, JournalEntry.no, .order_by(JournalEntry.date, JournalEntry.no,
JournalEntryLineItem.no)\ JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.journal_entry), .options(selectinload(JournalEntryLineItem.journal_entry),
selectinload(JournalEntryLineItem.account)).all() selectinload(JournalEntryLineItem.account))).all()
setattr(self, "__offsets", get_offsets()) setattr(self, "__offsets", get_offsets())
return getattr(self, "__offsets") return getattr(self, "__offsets")
@@ -37,9 +37,9 @@ def sort_journal_entries_in(date: dt.date, exclude: int | None = None) -> None:
conditions: list[sa.ColumnElement[bool]] = [JournalEntry.date == date] conditions: list[sa.ColumnElement[bool]] = [JournalEntry.date == date]
if exclude is not None: if exclude is not None:
conditions.append(JournalEntry.id != exclude) conditions.append(JournalEntry.id != exclude)
journal_entries: list[JournalEntry] = JournalEntry.query\ journal_entries: list[JournalEntry] = db.session.scalars(
.filter(*conditions)\ sa.select(JournalEntry).where(*conditions)
.order_by(JournalEntry.no).all() .order_by(JournalEntry.no)).all()
for i in range(len(journal_entries)): for i in range(len(journal_entries)):
if journal_entries[i].no != i + 1: if journal_entries[i].no != i + 1:
journal_entries[i].no = i + 1 journal_entries[i].no = i + 1
@@ -63,8 +63,9 @@ class JournalEntryReorderForm:
:return: :return:
""" """
journal_entries: list[JournalEntry] = JournalEntry.query\ journal_entries: list[JournalEntry] = db.session.scalars(
.filter(JournalEntry.date == self.date).all() sa.select(JournalEntry)
.where(JournalEntry.date == self.date)).all()
# Collects the specified order. # Collects the specified order.
orders: dict[JournalEntry, int] = {} orders: dict[JournalEntry, int] = {}
@@ -272,15 +272,17 @@ class DescriptionEditor:
select: sa.Select = sa.Select(debit_credit, tag_type, tag, select: sa.Select = sa.Select(debit_credit, tag_type, tag,
JournalEntryLineItem.account_id, JournalEntryLineItem.account_id,
sa.func.count().label("freq"))\ sa.func.count().label("freq"))\
.filter(JournalEntryLineItem.description.is_not(None), .where(JournalEntryLineItem.description.is_not(None),
JournalEntryLineItem.description.like("_%—_%"), JournalEntryLineItem.description.like("_%—_%"),
JournalEntryLineItem.original_line_item_id.is_(None))\ JournalEntryLineItem.original_line_item_id.is_(None))\
.group_by(debit_credit, tag_type, tag, .group_by(debit_credit, tag_type, tag,
JournalEntryLineItem.account_id) JournalEntryLineItem.account_id)
result: list[sa.Row] = db.session.execute(select).all() result: list[sa.Row] = db.session.execute(select).all()
accounts: dict[int, Account] \ accounts: dict[int, Account] \
= {x.id: x for x in Account.query = {x.id: x for x in db.session.scalars(
.filter(Account.id.in_({x.account_id for x in result})).all()} sa.select(Account)
.where(Account.id.in_({x.account_id for x in result})))
.unique()}
debit_credit_dict: dict[Literal["debit", "credit"], debit_credit_dict: dict[Literal["debit", "credit"],
DescriptionDebitCredit] \ DescriptionDebitCredit] \
= {x.debit_credit: x for x in {self.debit, self.credit}} = {x.debit_credit: x for x in {self.debit, self.credit}}
@@ -326,7 +328,8 @@ class DescriptionEditor:
= [get_condition(x) for x in codes] = [get_condition(x) for x in codes]
accounts: dict[str, Account] \ accounts: dict[str, Account] \
= {x.code: x for x in = {x.code: x for x in
Account.query.filter(sa.or_(*conditions)).all()} db.session.scalars(
sa.select(Account).where(sa.or_(*conditions))).unique()}
for code in codes: for code in codes:
assert code in accounts, \ assert code in accounts, \
f"Unknown account \"{code}\" for regular transactions." f"Unknown account \"{code}\" for regular transactions."
@@ -61,20 +61,21 @@ def get_selectable_original_line_items(
.join(offset, .join(offset,
JournalEntryLineItem.id == offset.c.original_line_item_id, JournalEntryLineItem.id == offset.c.original_line_item_id,
isouter=True)\ isouter=True)\
.filter(*conditions)\ .where(*conditions)\
.group_by(JournalEntryLineItem.id)\ .group_by(JournalEntryLineItem.id)\
.having(sa.or_(sa.func.count(offset.c.id) == 0, net_balance != 0)) .having(sa.or_(sa.func.count(offset.c.id) == 0, net_balance != 0))
net_balances: dict[int, Decimal] \ net_balances: dict[int, Decimal] \
= {x.id: x.net_balance = {x.id: x.net_balance
for x in db.session.execute(select_net_balances).all()} for x in db.session.execute(select_net_balances)}
line_items: list[JournalEntryLineItem] = JournalEntryLineItem.query\ line_items: list[JournalEntryLineItem] = db.session.scalars(
.filter(JournalEntryLineItem.id.in_({x for x in net_balances}))\ sa.select(JournalEntryLineItem)
.join(JournalEntry)\ .where(JournalEntryLineItem.id.in_({x for x in net_balances}))
.join(JournalEntry)
.order_by(JournalEntry.date, JournalEntry.no, .order_by(JournalEntry.date, JournalEntry.no,
JournalEntryLineItem.is_debit, JournalEntryLineItem.no)\ JournalEntryLineItem.is_debit, JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.currency), .options(selectinload(JournalEntryLineItem.currency),
selectinload(JournalEntryLineItem.account), selectinload(JournalEntryLineItem.account),
selectinload(JournalEntryLineItem.journal_entry)).all() selectinload(JournalEntryLineItem.journal_entry))).all()
line_items.reverse() line_items.reverse()
for line_item in line_items: for line_item in line_items:
line_item.net_balance = line_item.amount \ line_item.net_balance = line_item.amount \
+3 -3
View File
@@ -195,9 +195,9 @@ def show_journal_entry_order(date: dt.date) -> str:
:param date: The date. :param date: The date.
:return: The order of the journal entries in the date. :return: The order of the journal entries in the date.
""" """
journal_entries: list[JournalEntry] = JournalEntry.query \ journal_entries: list[JournalEntry] = db.session.scalars(
.filter(JournalEntry.date == date) \ sa.select(JournalEntry).where(JournalEntry.date == date)
.order_by(JournalEntry.no).all() .order_by(JournalEntry.no)).all()
return render_template("accounting/journal-entry/order.html", return render_template("accounting/journal-entry/order.html",
date=date, list=journal_entries) date=date, list=journal_entries)
+48 -39
View File
@@ -268,9 +268,10 @@ class Account(db.Model):
:return: None. :return: None.
""" """
AccountL10n.query.filter(AccountL10n.account == self).delete() db.session.execute(sa.delete(AccountL10n)
.where(AccountL10n.account == self))
cls: type[Self] = self.__class__ cls: type[Self] = self.__class__
cls.query.filter(cls.id == self.id).delete() db.session.execute(sa.delete(cls).where(cls.id == self.id))
@classmethod @classmethod
def find_by_code(cls, code: str) -> Self | None: def find_by_code(cls, code: str) -> Self | None:
@@ -282,8 +283,9 @@ class Account(db.Model):
m: re.Match[str] | None = re.match(r"^([1-9]{4})-(\d{3})$", code) m: re.Match[str] | None = re.match(r"^([1-9]{4})-(\d{3})$", code)
if m is None: if m is None:
return None return None
return cls.query.filter(cls.base_code == m.group(1), return db.session.scalar(
cls.no == int(m.group(2))).first() sa.select(cls).where(cls.base_code == m.group(1),
cls.no == int(m.group(2))))
@classmethod @classmethod
def selectable_debit(cls) -> list[Self]: def selectable_debit(cls) -> list[Self]:
@@ -292,20 +294,22 @@ class Account(db.Model):
:return: The selectable debit accounts. :return: The selectable debit accounts.
""" """
return cls.query.filter(sa.or_(cls.base_code.startswith("1"), return db.session.scalars(
sa.and_(cls.base_code.startswith("2"), sa.select(cls)
sa.not_(cls.is_need_offset)), .where(sa.or_(cls.base_code.startswith("1"),
cls.base_code.startswith("3"), sa.and_(cls.base_code.startswith("2"),
cls.base_code.startswith("5"), sa.not_(cls.is_need_offset)),
cls.base_code.startswith("6"), cls.base_code.startswith("3"),
cls.base_code.startswith("75"), cls.base_code.startswith("5"),
cls.base_code.startswith("76"), cls.base_code.startswith("6"),
cls.base_code.startswith("77"), cls.base_code.startswith("75"),
cls.base_code.startswith("78"), cls.base_code.startswith("76"),
cls.base_code.startswith("8"), cls.base_code.startswith("77"),
cls.base_code.startswith("9")), cls.base_code.startswith("78"),
cls.base_code != "3353")\ cls.base_code.startswith("8"),
.order_by(cls.base_code, cls.no).all() cls.base_code.startswith("9")),
cls.base_code != "3353")
.order_by(cls.base_code, cls.no)).unique().all()
@classmethod @classmethod
def selectable_credit(cls) -> list[Self]: def selectable_credit(cls) -> list[Self]:
@@ -314,19 +318,21 @@ class Account(db.Model):
:return: The selectable debit accounts. :return: The selectable debit accounts.
""" """
return cls.query.filter(sa.or_(sa.and_(cls.base_code.startswith("1"), return db.session.scalars(
sa.not_(cls.is_need_offset)), sa.select(cls)
cls.base_code.startswith("2"), .where(sa.or_(sa.and_(cls.base_code.startswith("1"),
cls.base_code.startswith("3"), sa.not_(cls.is_need_offset)),
cls.base_code.startswith("4"), cls.base_code.startswith("2"),
cls.base_code.startswith("71"), cls.base_code.startswith("3"),
cls.base_code.startswith("72"), cls.base_code.startswith("4"),
cls.base_code.startswith("73"), cls.base_code.startswith("71"),
cls.base_code.startswith("74"), cls.base_code.startswith("72"),
cls.base_code.startswith("8"), cls.base_code.startswith("73"),
cls.base_code.startswith("9")), cls.base_code.startswith("74"),
cls.base_code != "3353")\ cls.base_code.startswith("8"),
.order_by(cls.base_code, cls.no).all() cls.base_code.startswith("9")),
cls.base_code != "3353")
.order_by(cls.base_code, cls.no)).unique().all()
@classmethod @classmethod
def cash(cls) -> Self: def cash(cls) -> Self:
@@ -472,9 +478,10 @@ class Currency(db.Model):
:return: None. :return: None.
""" """
CurrencyL10n.query.filter(CurrencyL10n.currency == self).delete() db.session.execute(
cls: type[Self] = self.__class__ sa.delete(CurrencyL10n)
cls.query.filter(cls.code == self.code).delete() .where(CurrencyL10n.currency_code == self.code))
db.session.delete(self)
class CurrencyL10n(db.Model): class CurrencyL10n(db.Model):
@@ -649,8 +656,9 @@ class JournalEntry(db.Model):
:return: None. :return: None.
""" """
JournalEntryLineItem.query\ db.session.execute(
.filter(JournalEntryLineItem.journal_entry_id == self.id).delete() sa.delete(JournalEntryLineItem)
.where(JournalEntryLineItem.journal_entry_id == self.id))
db.session.delete(self) db.session.delete(self)
@@ -816,10 +824,11 @@ class JournalEntryLineItem(db.Model):
""" """
if not hasattr(self, "__offsets"): if not hasattr(self, "__offsets"):
cls: type[Self] = self.__class__ cls: type[Self] = self.__class__
offsets: list[Self] = cls.query.join(JournalEntry)\ offsets: list[Self] = db.session.scalars(
.filter(JournalEntryLineItem.original_line_item_id == self.id)\ sa.select(cls).join(JournalEntry)
.where(cls.original_line_item_id == self.id)
.order_by(JournalEntry.date, JournalEntry.no, .order_by(JournalEntry.date, JournalEntry.no,
cls.is_debit, cls.no).all() cls.is_debit, cls.no)).unique().all()
setattr(self, "__offsets", offsets) setattr(self, "__offsets", offsets)
return getattr(self, "__offsets") return getattr(self, "__offsets")
+5 -2
View File
@@ -23,9 +23,12 @@ This file is largely taken from the NanoParma ERP project, first written in
import datetime as dt import datetime as dt
from collections.abc import Callable from collections.abc import Callable
import sqlalchemy as sa
from .period import Period from .period import Period
from .shortcuts import ThisMonth, LastMonth, SinceLastMonth, ThisYear, \ from .shortcuts import ThisMonth, LastMonth, SinceLastMonth, ThisYear, \
LastYear, Today, Yesterday, AllTime, TemplatePeriod, YearPeriod LastYear, Today, Yesterday, AllTime, TemplatePeriod, YearPeriod
from ... import db
from ...models import JournalEntry from ...models import JournalEntry
from ...utils.timezone import get_tz_today from ...utils.timezone import get_tz_today
@@ -62,8 +65,8 @@ class PeriodChooser:
self.url_template: str = get_url(TemplatePeriod()) self.url_template: str = get_url(TemplatePeriod())
"""The URL template.""" """The URL template."""
first: JournalEntry | None \ first: JournalEntry | None = db.session.scalar(
= JournalEntry.query.order_by(JournalEntry.date).first() sa.select(JournalEntry).order_by(JournalEntry.date))
start: dt.date | None = None if first is None else first.date start: dt.date | None = None if first is None else first.date
# Attributes # Attributes
+14 -11
View File
@@ -133,16 +133,17 @@ class AccountCollector:
= sa.select(Account.id, Account.base_code, Account.no, = sa.select(Account.id, Account.base_code, Account.no,
balance_func)\ balance_func)\
.join(JournalEntry).join(Account)\ .join(JournalEntry).join(Account)\
.filter(*conditions)\ .where(*conditions)\
.group_by(Account.id, Account.base_code, Account.no)\ .group_by(Account.id, Account.base_code, Account.no)\
.having(balance_func != 0)\ .having(balance_func != 0)\
.order_by(Account.base_code, Account.no) .order_by(Account.base_code, Account.no)
account_balances: list[sa.Row] \ account_balances: list[sa.Row] \
= db.session.execute(select_balance).all() = db.session.execute(select_balance).all()
self.__all_accounts: list[Account] = Account.query\ self.__all_accounts: list[Account] = db.session.scalars(
.filter(sa.or_(Account.id.in_({x.id for x in account_balances}), sa.select(Account)
Account.base_code == "3351", .where(sa.or_(Account.id.in_({x.id for x in account_balances}),
Account.base_code == "3353")).all() Account.base_code == "3351",
Account.base_code == "3353"))).unique().all()
"""The accounts.""" """The accounts."""
account_by_id: dict[int, Account] \ account_by_id: dict[int, Account] \
= {x.id: x for x in self.__all_accounts} = {x.id: x for x in self.__all_accounts}
@@ -219,7 +220,7 @@ class AccountCollector:
(JournalEntryLineItem.is_debit, JournalEntryLineItem.amount), (JournalEntryLineItem.is_debit, JournalEntryLineItem.amount),
else_=-JournalEntryLineItem.amount)) else_=-JournalEntryLineItem.amount))
select_balance: sa.Select = sa.select(balance_func)\ select_balance: sa.Select = sa.select(balance_func)\
.join(JournalEntry).join(Account).filter(*conditions) .join(JournalEntry).join(Account).where(*conditions)
return db.session.scalar(select_balance) return db.session.scalar(select_balance)
def __add_owner_s_equity(self, code: str, amount: Decimal | None, def __add_owner_s_equity(self, code: str, amount: Decimal | None,
@@ -383,11 +384,13 @@ class BalanceSheet(BaseReport):
balances: list[ReportAccount] = AccountCollector( balances: list[ReportAccount] = AccountCollector(
self.__currency, self.__period).accounts self.__currency, self.__period).accounts
titles: list[BaseAccount] = BaseAccount.query\ titles: list[BaseAccount] = db.session.scalars(
.filter(BaseAccount.code.in_({"1", "2", "3"})).all() sa.select(BaseAccount)
subtitles: list[BaseAccount] = BaseAccount.query\ .where(BaseAccount.code.in_({"1", "2", "3"}))).unique().all()
.filter(BaseAccount.code.in_({x.account.base_code[:2] subtitle_codes: set[str] = {x.account.base_code[:2] for x in balances}
for x in balances})).all() subtitles: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)
.where(BaseAccount.code.in_(subtitle_codes))).unique().all()
sections: dict[str, Section] = {x.code: Section(x) for x in titles} sections: dict[str, Section] = {x.code: Section(x) for x in titles}
subsections: dict[str, Subsection] = {x.code: Subsection(x) subsections: dict[str, Subsection] = {x.code: Subsection(x)
@@ -119,9 +119,9 @@ class LineItemCollector:
else_=-JournalEntryLineItem.amount)) else_=-JournalEntryLineItem.amount))
select: sa.Select[tuple[Decimal]] = sa.Select(balance_func)\ select: sa.Select[tuple[Decimal]] = sa.Select(balance_func)\
.join(JournalEntry).join(Account)\ .join(JournalEntry).join(Account)\
.filter(JournalEntryLineItem.currency_code == self.__currency.code, .where(JournalEntryLineItem.currency_code == self.__currency.code,
self.__account_condition, self.__account_condition,
JournalEntry.date < self.__period.start) JournalEntry.date < self.__period.start)
balance: Decimal | None = db.session.scalar(select) balance: Decimal | None = db.session.scalar(select)
if balance is None: if balance is None:
return None return None
@@ -150,22 +150,22 @@ class LineItemCollector:
if self.__period.end is not None: if self.__period.end is not None:
conditions.append(JournalEntry.date <= self.__period.end) conditions.append(JournalEntry.date <= self.__period.end)
journal_entry_with_account: sa.Select = sa.Select(JournalEntry.id).\ journal_entry_with_account: sa.Select = sa.Select(JournalEntry.id).\
join(JournalEntryLineItem).join(Account).filter(*conditions) join(JournalEntryLineItem).join(Account).where(*conditions)
return [ReportLineItem(x) return [ReportLineItem(x) for x in db.session.scalars(
for x in JournalEntryLineItem.query sa.select(JournalEntryLineItem)
.join(JournalEntry).join(Account) .join(JournalEntry).join(Account)
.filter(JournalEntryLineItem.journal_entry_id .where(JournalEntryLineItem.journal_entry_id
.in_(journal_entry_with_account), .in_(journal_entry_with_account),
JournalEntryLineItem.currency_code JournalEntryLineItem.currency_code
== self.__currency.code, == self.__currency.code,
sa.not_(self.__account_condition)) sa.not_(self.__account_condition))
.order_by(JournalEntry.date, .order_by(JournalEntry.date,
JournalEntry.no, JournalEntry.no,
JournalEntryLineItem.is_debit, JournalEntryLineItem.is_debit,
JournalEntryLineItem.no) JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.account), .options(selectinload(JournalEntryLineItem.account),
selectinload(JournalEntryLineItem.journal_entry))] selectinload(JournalEntryLineItem.journal_entry)))]
@property @property
def __account_condition(self) -> sa.ColumnElement[bool]: def __account_condition(self) -> sa.ColumnElement[bool]:
@@ -343,8 +343,8 @@ class PageParams(BasePageParams):
self.account.id == 0)] self.account.id == 0)]
in_use: sa.Select = sa.Select(JournalEntryLineItem.account_id)\ in_use: sa.Select = sa.Select(JournalEntryLineItem.account_id)\
.join(Account)\ .join(Account)\
.filter(JournalEntryLineItem.currency_code == self.currency.code, .where(JournalEntryLineItem.currency_code == self.currency.code,
CurrentAccount.sql_condition())\ CurrentAccount.sql_condition())\
.group_by(JournalEntryLineItem.account_id) .group_by(JournalEntryLineItem.account_id)
options.extend([OptionLink(str(x), options.extend([OptionLink(str(x),
income_expenses_url( income_expenses_url(
@@ -352,8 +352,10 @@ class PageParams(BasePageParams):
CurrentAccount(x), CurrentAccount(x),
self.period), self.period),
x.id == self.account.id) x.id == self.account.id)
for x in Account.query.filter(Account.id.in_(in_use)) for x in db.session.scalars(
.order_by(Account.base_code, Account.no).all()]) sa.select(Account).where(Account.id.in_(in_use))
.order_by(Account.base_code, Account.no))
.unique()])
return options return options
@@ -218,11 +218,14 @@ class IncomeStatement(BaseReport):
""" """
balances: list[ReportAccount] = self.__query_balances() balances: list[ReportAccount] = self.__query_balances()
titles: list[BaseAccount] = BaseAccount.query\ title_codes: set[str] = {"4", "5", "6", "7", "8", "9"}
.filter(BaseAccount.code.in_({"4", "5", "6", "7", "8", "9"})).all() titles: list[BaseAccount] = db.session.scalars(
subtitles: list[BaseAccount] = BaseAccount.query\ sa.select(BaseAccount)
.filter(BaseAccount.code.in_({x.account.base_code[:2] .where(BaseAccount.code.in_(title_codes))).unique().all()
for x in balances})).all() subtitle_codes: set[str] = {x.account.base_code[:2] for x in balances}
subtitles: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)
.where(BaseAccount.code.in_(subtitle_codes))).unique().all()
total_titles: dict[str, str] \ total_titles: dict[str, str] \
= {"4": gettext("Total Operating Revenue"), = {"4": gettext("Total Operating Revenue"),
@@ -267,14 +270,15 @@ class IncomeStatement(BaseReport):
else_=JournalEntryLineItem.amount)).label("balance") else_=JournalEntryLineItem.amount)).label("balance")
select_balances: sa.Select = sa.select(Account.id, balance_func)\ select_balances: sa.Select = sa.select(Account.id, balance_func)\
.join(JournalEntry).join(Account)\ .join(JournalEntry).join(Account)\
.filter(*conditions)\ .where(*conditions)\
.group_by(Account.id)\ .group_by(Account.id)\
.having(balance_func != 0)\ .having(balance_func != 0)\
.order_by(Account.base_code, Account.no) .order_by(Account.base_code, Account.no)
balances: list[sa.Row] = db.session.execute(select_balances).all() balances: list[sa.Row] = db.session.execute(select_balances).all()
accounts: dict[int, Account] \ accounts: dict[int, Account] \
= {x.id: x for x in Account.query = {x.id: x for x in db.session.scalars(
.filter(Account.id.in_([x.id for x in balances])).all()} sa.select(Account)
.where(Account.id.in_([x.id for x in balances]))).unique()}
return [ReportAccount(account=accounts[x.id], return [ReportAccount(account=accounts[x.id],
amount=x.balance, amount=x.balance,
url=ledger_url(self.__currency, url=ledger_url(self.__currency,
+6 -4
View File
@@ -31,6 +31,7 @@ from ..utils.csv_export import BaseCSVRow, csv_download, period_spec
from ..utils.report_chooser import ReportChooser from ..utils.report_chooser import ReportChooser
from ..utils.report_type import ReportType from ..utils.report_type import ReportType
from ..utils.urls import journal_url from ..utils.urls import journal_url
from ... import db
from ...locale import gettext from ...locale import gettext
from ...models import Currency, Account, JournalEntry, JournalEntryLineItem from ...models import Currency, Account, JournalEntry, JournalEntryLineItem
from ...utils.pagination import Pagination from ...utils.pagination import Pagination
@@ -188,15 +189,16 @@ class Journal(BaseReport):
conditions.append(JournalEntry.date >= self.__period.start) conditions.append(JournalEntry.date >= self.__period.start)
if self.__period.end is not None: if self.__period.end is not None:
conditions.append(JournalEntry.date <= self.__period.end) conditions.append(JournalEntry.date <= self.__period.end)
return JournalEntryLineItem.query.join(JournalEntry)\ return db.session.scalars(
.filter(*conditions)\ sa.select(JournalEntryLineItem).join(JournalEntry)
.where(*conditions)
.order_by(JournalEntry.date, .order_by(JournalEntry.date,
JournalEntry.no, JournalEntry.no,
JournalEntryLineItem.is_debit.desc(), JournalEntryLineItem.is_debit.desc(),
JournalEntryLineItem.no)\ JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.account), .options(selectinload(JournalEntryLineItem.account),
selectinload(JournalEntryLineItem.currency), selectinload(JournalEntryLineItem.currency),
selectinload(JournalEntryLineItem.journal_entry)).all() selectinload(JournalEntryLineItem.journal_entry))).all()
def csv(self) -> Response: def csv(self) -> Response:
"""Returns the report as CSV for download. """Returns the report as CSV for download.
+12 -11
View File
@@ -115,9 +115,9 @@ class LineItemCollector:
(JournalEntryLineItem.is_debit, JournalEntryLineItem.amount), (JournalEntryLineItem.is_debit, JournalEntryLineItem.amount),
else_=-JournalEntryLineItem.amount)) else_=-JournalEntryLineItem.amount))
select: sa.Select = sa.Select(balance_func).join(JournalEntry)\ select: sa.Select = sa.Select(balance_func).join(JournalEntry)\
.filter(JournalEntryLineItem.currency_code == self.__currency.code, .where(JournalEntryLineItem.currency_code == self.__currency.code,
JournalEntryLineItem.account_id == self.__account.id, JournalEntryLineItem.account_id == self.__account.id,
JournalEntry.date < self.__period.start) JournalEntry.date < self.__period.start)
balance: int | None = db.session.scalar(select) balance: int | None = db.session.scalar(select)
if balance is None: if balance is None:
return None return None
@@ -144,15 +144,15 @@ class LineItemCollector:
conditions.append(JournalEntry.date >= self.__period.start) conditions.append(JournalEntry.date >= self.__period.start)
if self.__period.end is not None: if self.__period.end is not None:
conditions.append(JournalEntry.date <= self.__period.end) conditions.append(JournalEntry.date <= self.__period.end)
return [ReportLineItem(x) for x in JournalEntryLineItem.query return [ReportLineItem(x) for x in db.session.scalars(
.join(JournalEntry) sa.select(JournalEntryLineItem).join(JournalEntry)
.filter(*conditions) .where(*conditions)
.order_by(JournalEntry.date, .order_by(JournalEntry.date,
JournalEntry.no, JournalEntry.no,
JournalEntryLineItem.is_debit.desc(), JournalEntryLineItem.is_debit.desc(),
JournalEntryLineItem.no) JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.journal_entry)) .options(selectinload(JournalEntryLineItem.journal_entry)))
.all()] .unique()]
def __get_total(self) -> ReportLineItem | None: def __get_total(self) -> ReportLineItem | None:
"""Composes the total line item. """Composes the total line item.
@@ -308,12 +308,13 @@ class PageParams(BasePageParams):
:return: The account options. :return: The account options.
""" """
in_use: sa.Select = sa.Select(JournalEntryLineItem.account_id)\ in_use: sa.Select = sa.Select(JournalEntryLineItem.account_id)\
.filter(JournalEntryLineItem.currency_code == self.currency.code)\ .where(JournalEntryLineItem.currency_code == self.currency.code)\
.group_by(JournalEntryLineItem.account_id) .group_by(JournalEntryLineItem.account_id)
return [OptionLink(str(x), ledger_url(self.currency, x, self.period), return [OptionLink(str(x), ledger_url(self.currency, x, self.period),
x.id == self.account.id) x.id == self.account.id)
for x in Account.query.filter(Account.id.in_(in_use)) for x in db.session.scalars(
.order_by(Account.base_code, Account.no).all()] sa.select(Account).where(Account.id.in_(in_use))
.order_by(Account.base_code, Account.no)).unique()]
class Ledger(BaseReport): class Ledger(BaseReport):
+14 -12
View File
@@ -30,6 +30,7 @@ from ..utils.base_report import BaseReport
from ..utils.csv_export import csv_download from ..utils.csv_export import csv_download
from ..utils.report_chooser import ReportChooser from ..utils.report_chooser import ReportChooser
from ..utils.report_type import ReportType from ..utils.report_type import ReportType
from ... import db
from ...locale import gettext from ...locale import gettext
from ...models import Currency, CurrencyL10n, Account, AccountL10n, \ from ...models import Currency, CurrencyL10n, Account, AccountL10n, \
JournalEntry, JournalEntryLineItem JournalEntry, JournalEntryLineItem
@@ -69,15 +70,16 @@ class LineItemCollector:
except ArithmeticError: except ArithmeticError:
pass pass
conditions.append(sa.or_(*sub_conditions)) conditions.append(sa.or_(*sub_conditions))
return JournalEntryLineItem.query.join(JournalEntry)\ return db.session.scalars(
.filter(*conditions)\ sa.select(JournalEntryLineItem).join(JournalEntry)
.where(*conditions)
.order_by(JournalEntry.date, .order_by(JournalEntry.date,
JournalEntry.no, JournalEntry.no,
JournalEntryLineItem.is_debit, JournalEntryLineItem.is_debit,
JournalEntryLineItem.no)\ JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.account), .options(selectinload(JournalEntryLineItem.account),
selectinload(JournalEntryLineItem.currency), selectinload(JournalEntryLineItem.currency),
selectinload(JournalEntryLineItem.journal_entry)).all() selectinload(JournalEntryLineItem.journal_entry))).all()
@staticmethod @staticmethod
def __get_account_condition(k: str) -> sa.Select: def __get_account_condition(k: str) -> sa.Select:
@@ -91,7 +93,7 @@ class LineItemCollector:
sa.func.char_length(sa.cast(Account.no, sa.func.char_length(sa.cast(Account.no,
sa.String)) + 1) sa.String)) + 1)
select_l10n: sa.Select = sa.select(AccountL10n.account_id)\ select_l10n: sa.Select = sa.select(AccountL10n.account_id)\
.filter(AccountL10n.title.icontains(k)) .where(AccountL10n.title.icontains(k))
conditions: list[sa.ColumnElement[bool]] \ conditions: list[sa.ColumnElement[bool]] \
= [Account.base_code.contains(k), = [Account.base_code.contains(k),
Account.title_l10n.icontains(k), Account.title_l10n.icontains(k),
@@ -99,7 +101,7 @@ class LineItemCollector:
Account.id.in_(select_l10n)] Account.id.in_(select_l10n)]
if k in gettext("Needs Offset"): if k in gettext("Needs Offset"):
conditions.append(Account.is_need_offset) conditions.append(Account.is_need_offset)
return sa.select(Account.id).filter(sa.or_(*conditions)) return sa.select(Account.id).where(sa.or_(*conditions))
@staticmethod @staticmethod
def __get_currency_condition(k: str) -> sa.Select: def __get_currency_condition(k: str) -> sa.Select:
@@ -109,11 +111,11 @@ class LineItemCollector:
:return: The condition to filter the currency. :return: The condition to filter the currency.
""" """
select_l10n: sa.Select = sa.select(CurrencyL10n.currency_code)\ select_l10n: sa.Select = sa.select(CurrencyL10n.currency_code)\
.filter(CurrencyL10n.name.icontains(k)) .where(CurrencyL10n.name.icontains(k))
return sa.select(Currency.code).filter( return sa.select(Currency.code)\
sa.or_(Currency.code.icontains(k), .where(sa.or_(Currency.code.icontains(k),
Currency.name_l10n.icontains(k), Currency.name_l10n.icontains(k),
Currency.code.in_(select_l10n))) Currency.code.in_(select_l10n)))
@staticmethod @staticmethod
def __get_journal_entry_condition(k: str) -> sa.Select: def __get_journal_entry_condition(k: str) -> sa.Select:
@@ -153,7 +155,7 @@ class LineItemCollector:
sa.extract("day", JournalEntry.date) == date.day)) sa.extract("day", JournalEntry.date) == date.day))
except ValueError: except ValueError:
pass pass
return sa.select(JournalEntry.id).filter(sa.or_(*conditions)) return sa.select(JournalEntry.id).where(sa.or_(*conditions))
class PageParams(BasePageParams): class PageParams(BasePageParams):
@@ -187,14 +187,15 @@ class TrialBalance(BaseReport):
else_=-JournalEntryLineItem.amount)).label("balance") else_=-JournalEntryLineItem.amount)).label("balance")
select_balances: sa.Select = sa.select(Account.id, balance_func)\ select_balances: sa.Select = sa.select(Account.id, balance_func)\
.join(JournalEntry).join(Account)\ .join(JournalEntry).join(Account)\
.filter(*conditions)\ .where(*conditions)\
.group_by(Account.id)\ .group_by(Account.id)\
.having(balance_func != 0)\ .having(balance_func != 0)\
.order_by(Account.base_code, Account.no) .order_by(Account.base_code, Account.no)
balances: list[sa.Row] = db.session.execute(select_balances).all() balances: list[sa.Row] = db.session.execute(select_balances).all()
accounts: dict[int, Account] \ accounts: dict[int, Account] \
= {x.id: x for x in Account.query = {x.id: x for x in db.session.scalars(
.filter(Account.id.in_([x.id for x in balances])).all()} sa.select(Account)
.where(Account.id.in_([x.id for x in balances]))).unique()}
self.__accounts = [ReportAccount(account=accounts[x.id], self.__accounts = [ReportAccount(account=accounts[x.id],
amount=x.balance, amount=x.balance,
url=ledger_url(self.__currency, url=ledger_url(self.__currency,
+8 -5
View File
@@ -20,6 +20,7 @@
import datetime as dt import datetime as dt
from decimal import Decimal from decimal import Decimal
import sqlalchemy as sa
from flask import render_template, Response from flask import render_template, Response
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
@@ -31,6 +32,7 @@ from ..utils.report_chooser import ReportChooser
from ..utils.report_type import ReportType from ..utils.report_type import ReportType
from ..utils.unapplied import get_accounts_with_unapplied, get_net_balances from ..utils.unapplied import get_accounts_with_unapplied, get_net_balances
from ..utils.urls import unapplied_url from ..utils.urls import unapplied_url
from ... import db
from ...locale import gettext from ...locale import gettext
from ...models import Currency, Account, JournalEntry, JournalEntryLineItem from ...models import Currency, Account, JournalEntry, JournalEntryLineItem
from ...utils.pagination import Pagination from ...utils.pagination import Pagination
@@ -176,13 +178,14 @@ class UnappliedOriginalLineItems(BaseReport):
""" """
net_balances: dict[int, Decimal | None] \ net_balances: dict[int, Decimal | None] \
= get_net_balances(self.__currency, self.__account) = get_net_balances(self.__currency, self.__account)
line_items: list[JournalEntryLineItem] = JournalEntryLineItem.query \ line_items: list[JournalEntryLineItem] = db.session.scalars(
.join(Account).join(JournalEntry) \ sa.select(JournalEntryLineItem).join(Account).join(JournalEntry)
.filter(JournalEntryLineItem.id.in_(net_balances)) \ .where(JournalEntryLineItem.id.in_(net_balances))
.order_by(JournalEntry.date, JournalEntry.no, .order_by(JournalEntry.date, JournalEntry.no,
JournalEntryLineItem.is_debit, JournalEntryLineItem.no) \ JournalEntryLineItem.is_debit, JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.currency), .options(selectinload(JournalEntryLineItem.currency),
selectinload(JournalEntryLineItem.journal_entry)).all() selectinload(JournalEntryLineItem.journal_entry)))\
.unique().all()
for line_item in line_items: for line_item in line_items:
line_item.net_balance = line_item.amount \ line_item.net_balance = line_item.amount \
if net_balances[line_item.id] is None \ if net_balances[line_item.id] is None \
@@ -84,5 +84,6 @@ class BasePageParams(ABC):
sa.select(JournalEntryLineItem.currency_code) sa.select(JournalEntryLineItem.currency_code)
.group_by(JournalEntryLineItem.currency_code)).all()) .group_by(JournalEntryLineItem.currency_code)).all())
return [OptionLink(str(x), get_url(x), x.code == active_currency.code) return [OptionLink(str(x), get_url(x), x.code == active_currency.code)
for x in Currency.query.filter(Currency.code.in_(in_use)) for x in db.session.scalars(
.order_by(Currency.code).all()] sa.select(Currency).where(Currency.code.in_(in_use))
.order_by(Currency.code)).unique()]
@@ -24,6 +24,7 @@ from flask_babel import LazyString
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from ..utils.unapplied import get_net_balances from ..utils.unapplied import get_net_balances
from ... import db
from ...locale import lazy_gettext from ...locale import lazy_gettext
from ...models import Currency, Account, JournalEntry, JournalEntryLineItem from ...models import Currency, Account, JournalEntry, JournalEntryLineItem
@@ -113,14 +114,15 @@ class OffsetMatcher:
JournalEntryLineItem.is_debit), JournalEntryLineItem.is_debit),
sa.and_(Account.base_code.startswith("1"), sa.and_(Account.base_code.startswith("1"),
sa.not_(JournalEntryLineItem.is_debit)))) sa.not_(JournalEntryLineItem.is_debit))))
self.line_items = JournalEntryLineItem.query \ self.line_items = db.session.scalars(
.join(Account).join(JournalEntry) \ sa.select(JournalEntryLineItem).join(Account).join(JournalEntry)
.filter(sa.or_(JournalEntryLineItem.id.in_(net_balances), .where(sa.or_(JournalEntryLineItem.id.in_(net_balances),
unmatched_offset_condition)) \ unmatched_offset_condition))
.order_by(JournalEntry.date, JournalEntry.no, .order_by(JournalEntry.date, JournalEntry.no,
JournalEntryLineItem.is_debit, JournalEntryLineItem.no) \ JournalEntryLineItem.is_debit, JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.currency), .options(selectinload(JournalEntryLineItem.currency),
selectinload(JournalEntryLineItem.journal_entry)).all() selectinload(JournalEntryLineItem.journal_entry)))\
.unique().all()
for line_item in self.line_items: for line_item in self.line_items:
line_item.is_offset = line_item.id not in net_balances line_item.is_offset = line_item.id not in net_balances
self.unapplied = [x for x in self.line_items if not x.is_offset] self.unapplied = [x for x in self.line_items if not x.is_offset]
+17 -16
View File
@@ -45,12 +45,12 @@ def get_accounts_with_unapplied(currency: Currency) -> list[Account]:
.join(offset, .join(offset,
JournalEntryLineItem.id == offset.c.original_line_item_id, JournalEntryLineItem.id == offset.c.original_line_item_id,
isouter=True)\ isouter=True)\
.filter(Account.is_need_offset, .where(Account.is_need_offset,
JournalEntryLineItem.currency_code == currency.code, JournalEntryLineItem.currency_code == currency.code,
sa.or_(sa.and_(Account.base_code.startswith("2"), sa.or_(sa.and_(Account.base_code.startswith("2"),
sa.not_(JournalEntryLineItem.is_debit)), sa.not_(JournalEntryLineItem.is_debit)),
sa.and_(Account.base_code.startswith("1"), sa.and_(Account.base_code.startswith("1"),
JournalEntryLineItem.is_debit)))\ JournalEntryLineItem.is_debit)))\
.group_by(JournalEntryLineItem.id)\ .group_by(JournalEntryLineItem.id)\
.having(sa.or_(sa.func.count(offset.c.id) == 0, net_balance != 0)) .having(sa.or_(sa.func.count(offset.c.id) == 0, net_balance != 0))
@@ -58,13 +58,14 @@ def get_accounts_with_unapplied(currency: Currency) -> list[Account]:
= sa.func.count(JournalEntryLineItem.id).label("count") = sa.func.count(JournalEntryLineItem.id).label("count")
select: sa.Select = sa.select(Account.id, count_func)\ select: sa.Select = sa.select(Account.id, count_func)\
.join(JournalEntryLineItem, isouter=True)\ .join(JournalEntryLineItem, isouter=True)\
.filter(JournalEntryLineItem.id.in_(select_unapplied))\ .where(JournalEntryLineItem.id.in_(select_unapplied))\
.group_by(Account.id)\ .group_by(Account.id)\
.having(count_func > 0) .having(count_func > 0)
counts: dict[int, int] \ counts: dict[int, int] \
= {x.id: x.count for x in db.session.execute(select)} = {x.id: x.count for x in db.session.execute(select)}
accounts: list[Account] = Account.query.filter(Account.id.in_(counts))\ accounts: list[Account] = db.session.scalars(
.order_by(Account.base_code, Account.no).all() sa.select(Account).where(Account.id.in_(counts))
.order_by(Account.base_code, Account.no)).unique().all()
for account in accounts: for account in accounts:
account.count = counts[account.id] account.count = counts[account.id]
return accounts return accounts
@@ -91,13 +92,13 @@ def get_net_balances(currency: Currency, account: Account) \
.join(offset, .join(offset,
JournalEntryLineItem.id == offset.c.original_line_item_id, JournalEntryLineItem.id == offset.c.original_line_item_id,
isouter=True) \ isouter=True) \
.filter(Account.id == account.id, .where(Account.id == account.id,
JournalEntryLineItem.currency_code == currency.code, JournalEntryLineItem.currency_code == currency.code,
sa.or_(sa.and_(Account.base_code.startswith("2"), sa.or_(sa.and_(Account.base_code.startswith("2"),
sa.not_(JournalEntryLineItem.is_debit)), sa.not_(JournalEntryLineItem.is_debit)),
sa.and_(Account.base_code.startswith("1"), sa.and_(Account.base_code.startswith("1"),
JournalEntryLineItem.is_debit))) \ JournalEntryLineItem.is_debit))) \
.group_by(JournalEntryLineItem.id) \ .group_by(JournalEntryLineItem.id) \
.having(sa.or_(sa.func.count(offset.c.id) == 0, net_balance != 0)) .having(sa.or_(sa.func.count(offset.c.id) == 0, net_balance != 0))
return {x.id: x.net_balance return {x.id: x.net_balance
for x in db.session.execute(select_net_balances).all()} for x in db.session.execute(select_net_balances)}
+10 -9
View File
@@ -35,19 +35,20 @@ def get_accounts_with_unmatched(currency: Currency) -> list[Account]:
select: sa.Select = sa.select(Account.id, count_func)\ select: sa.Select = sa.select(Account.id, count_func)\
.select_from(Account)\ .select_from(Account)\
.join(JournalEntryLineItem, isouter=True).join(JournalEntry)\ .join(JournalEntryLineItem, isouter=True).join(JournalEntry)\
.filter(Account.is_need_offset, .where(Account.is_need_offset,
JournalEntryLineItem.currency_code == currency.code, JournalEntryLineItem.currency_code == currency.code,
JournalEntryLineItem.original_line_item_id.is_(None), JournalEntryLineItem.original_line_item_id.is_(None),
sa.or_(sa.and_(Account.base_code.startswith("2"), sa.or_(sa.and_(Account.base_code.startswith("2"),
JournalEntryLineItem.is_debit), JournalEntryLineItem.is_debit),
sa.and_(Account.base_code.startswith("1"), sa.and_(Account.base_code.startswith("1"),
sa.not_(JournalEntryLineItem.is_debit))))\ sa.not_(JournalEntryLineItem.is_debit))))\
.group_by(Account.id)\ .group_by(Account.id)\
.having(count_func > 0) .having(count_func > 0)
counts: dict[int, int] \ counts: dict[int, int] \
= {x.id: x.count for x in db.session.execute(select)} = {x.id: x.count for x in db.session.execute(select)}
accounts: list[Account] = Account.query.filter(Account.id.in_(counts))\ accounts: list[Account] = db.session.scalars(
.order_by(Account.base_code, Account.no).all() sa.select(Account).where(Account.id.in_(counts))
.order_by(Account.base_code, Account.no)).unique().all()
for account in accounts: for account in accounts:
account.count = counts[account.id] account.count = counts[account.id]
return accounts return accounts
+5 -1
View File
@@ -17,6 +17,9 @@
"""The template globals. """The template globals.
""" """
import sqlalchemy as sa
from . import db
from .models import Currency from .models import Currency
from .utils.options import options from .utils.options import options
@@ -26,7 +29,8 @@ def currency_options() -> list[Currency]:
:return: The currency options. :return: The currency options.
""" """
return Currency.query.order_by(Currency.code).all() return db.session.scalars(
sa.select(Currency).order_by(Currency.code)).unique().all()
def default_currency_code() -> str: def default_currency_code() -> str:
+5 -3
View File
@@ -21,6 +21,7 @@ from typing import Self
import sqlalchemy as sa import sqlalchemy as sa
from .. import db
from ..locale import gettext from ..locale import gettext
from ..models import Account from ..models import Account
@@ -74,9 +75,10 @@ class CurrentAccount:
""" """
accounts: list[cls] = [cls.current_assets_and_liabilities()] accounts: list[cls] = [cls.current_assets_and_liabilities()]
accounts.extend([cls(x) accounts.extend([cls(x)
for x in Account.query for x in db.session.scalars(
.filter(cls.sql_condition()) sa.select(Account).where(cls.sql_condition())
.order_by(Account.base_code, Account.no)]) .order_by(Account.base_code, Account.no))
.unique()])
return accounts return accounts
@classmethod @classmethod
+18 -10
View File
@@ -21,6 +21,7 @@ import datetime as dt
import unittest import unittest
import httpx import httpx
import sqlalchemy as sa
from flask import Flask from flask import Flask
from accounting.utils.next_uri import encode_next from accounting.utils.next_uri import encode_next
@@ -275,8 +276,10 @@ class AccountTestCase(unittest.TestCase):
response: httpx.Response response: httpx.Response
with self.__app.app_context(): with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()}, self.assertEqual(
{CASH.code, BANK.code}) {x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, BANK.code})
# Missing CSRF token # Missing CSRF token
response = self.__client.post(store_uri, response = self.__client.post(store_uri,
@@ -367,10 +370,11 @@ class AccountTestCase(unittest.TestCase):
f"{PREFIX}/{STOCK.base_code}-003") f"{PREFIX}/{STOCK.base_code}-003")
with self.__app.app_context(): with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()}, self.assertEqual(
{CASH.code, BANK.code, STOCK.code, {x.code
f"{STOCK.base_code}-002", for x in db.session.scalars(sa.select(Account)).unique()},
f"{STOCK.base_code}-003"}) {CASH.code, BANK.code, STOCK.code,
f"{STOCK.base_code}-002", f"{STOCK.base_code}-003"})
account: Account | None = Account.find_by_code(STOCK.code) account: Account | None = Account.find_by_code(STOCK.code)
self.assertIsNotNone(account) self.assertIsNotNone(account)
@@ -621,8 +625,10 @@ class AccountTestCase(unittest.TestCase):
"currency-1-credit-1-amount": "20"}) "currency-1-credit-1-amount": "20"})
with self.__app.app_context(): with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()}, self.assertEqual(
{CASH.code, PETTY.code, BANK.code}) {x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, PETTY.code, BANK.code})
# Cannot delete the cash account # Cannot delete the cash account
response = self.__client.post(f"{PREFIX}/{CASH.code}/delete", response = self.__client.post(f"{PREFIX}/{CASH.code}/delete",
@@ -645,8 +651,10 @@ class AccountTestCase(unittest.TestCase):
self.assertEqual(response.headers["Location"], list_uri) self.assertEqual(response.headers["Location"], list_uri)
with self.__app.app_context(): with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()}, self.assertEqual(
{CASH.code, BANK.code}) {x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, BANK.code})
response = self.__client.get(detail_uri) response = self.__client.get(detail_uri)
self.assertEqual(response.status_code, 404) self.assertEqual(response.status_code, 404)
+16 -10
View File
@@ -101,7 +101,8 @@ class ConsoleCommandTestCase(unittest.TestCase):
for x in rows} for x in rows}
with self.__app.app_context(): with self.__app.app_context():
accounts: list[BaseAccount] = BaseAccount.query.all() accounts: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)).unique().all()
self.assertEqual(len(accounts), len(data)) self.assertEqual(len(accounts), len(data))
for account in accounts: for account in accounts:
@@ -141,10 +142,14 @@ class ConsoleCommandTestCase(unittest.TestCase):
from accounting.models import BaseAccount, Account, AccountL10n from accounting.models import BaseAccount, Account, AccountL10n
with self.__app.app_context(): with self.__app.app_context():
bases: list[BaseAccount] = BaseAccount.query\ bases: list[BaseAccount] = db.session.scalars(
.filter(sa.func.char_length(BaseAccount.code) == 4).all() sa.select(BaseAccount)
accounts: list[Account] = Account.query.all() .where(sa.func.char_length(BaseAccount.code) == 4))\
l10n: list[AccountL10n] = AccountL10n.query.all() .unique().all()
accounts: list[Account] = db.session.scalars(
sa.select(Account)).unique().all()
l10n: list[AccountL10n] = db.session.scalars(
sa.select(AccountL10n)).all()
self.assertEqual({x.code for x in bases}, self.assertEqual({x.code for x in bases},
{x.base_code for x in accounts}) {x.base_code for x in accounts})
@@ -175,7 +180,8 @@ class ConsoleCommandTestCase(unittest.TestCase):
for x in csv.DictReader(fp)} for x in csv.DictReader(fp)}
with self.__app.app_context(): with self.__app.app_context():
currencies: list[Currency] = Currency.query.all() currencies: list[Currency] = db.session.scalars(
sa.select(Currency)).unique().all()
self.assertEqual(len(currencies), len(data)) self.assertEqual(len(currencies), len(data))
for currency in currencies: for currency in currencies:
@@ -216,9 +222,9 @@ class ConsoleCommandTestCase(unittest.TestCase):
result.output + str(result.exception)) result.output + str(result.exception))
# Turns the titles into lowercase. # Turns the titles into lowercase.
for base in BaseAccount.query: for base in db.session.scalars(sa.select(BaseAccount)).unique():
base.title_l10n = base.title_l10n.lower() base.title_l10n = base.title_l10n.lower()
for account in Account.query: for account in db.session.scalars(sa.select(Account)).unique():
account.title_l10n = account.title_l10n.lower() account.title_l10n = account.title_l10n.lower()
account.created_at \ account.created_at \
= account.created_at - dt.timedelta(seconds=5) = account.created_at - dt.timedelta(seconds=5)
@@ -242,9 +248,9 @@ class ConsoleCommandTestCase(unittest.TestCase):
args=["accounting-titleize", "-u", "editor"]) args=["accounting-titleize", "-u", "editor"])
self.assertEqual(result.exit_code, 0, self.assertEqual(result.exit_code, 0,
result.output + str(result.exception)) result.output + str(result.exception))
for base in BaseAccount.query: for base in db.session.scalars(sa.select(BaseAccount)).unique():
self.__test_title_case(base.title_l10n) self.__test_title_case(base.title_l10n)
for account in Account.query: for account in db.session.scalars(sa.select(Account)).unique():
if account.id != new_account.id: if account.id != new_account.id:
self.__test_title_case(account.title_l10n) self.__test_title_case(account.title_l10n)
self.assertNotEqual(account.created_at, account.updated_at) self.assertNotEqual(account.created_at, account.updated_at)
+17 -8
View File
@@ -21,6 +21,7 @@ import datetime as dt
import unittest import unittest
import httpx import httpx
import sqlalchemy as sa
from flask import Flask from flask import Flask
from accounting.utils.next_uri import encode_next from accounting.utils.next_uri import encode_next
@@ -221,8 +222,10 @@ class CurrencyTestCase(unittest.TestCase):
response: httpx.Response response: httpx.Response
with self.__app.app_context(): with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()}, self.assertEqual(
{USD.code, EUR.code}) {x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code})
# Missing CSRF token # Missing CSRF token
response = self.__client.post(store_uri, response = self.__client.post(store_uri,
@@ -287,8 +290,10 @@ class CurrencyTestCase(unittest.TestCase):
self.assertEqual(response.headers["Location"], create_uri) self.assertEqual(response.headers["Location"], create_uri)
with self.__app.app_context(): with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()}, self.assertEqual(
{USD.code, EUR.code, TWD.code}) {x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code, TWD.code})
currency: Currency = db.session.get(Currency, TWD.code) currency: Currency = db.session.get(Currency, TWD.code)
self.assertEqual(currency.code, TWD.code) self.assertEqual(currency.code, TWD.code)
@@ -554,8 +559,10 @@ class CurrencyTestCase(unittest.TestCase):
"currency-1-credit-1-amount": "20"}) "currency-1-credit-1-amount": "20"})
with self.__app.app_context(): with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()}, self.assertEqual(
{USD.code, EUR.code, JPY.code}) {x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code, JPY.code})
# Cannot delete the default currency # Cannot delete the default currency
response = self.__client.post(f"{PREFIX}/{USD.code}/delete", response = self.__client.post(f"{PREFIX}/{USD.code}/delete",
@@ -578,8 +585,10 @@ class CurrencyTestCase(unittest.TestCase):
self.assertEqual(response.headers["Location"], list_uri) self.assertEqual(response.headers["Location"], list_uri)
with self.__app.app_context(): with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()}, self.assertEqual(
{USD.code, EUR.code}) {x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code})
response = self.__client.get(detail_uri) response = self.__client.get(detail_uri)
self.assertEqual(response.status_code, 404) self.assertEqual(response.status_code, 404)
+6 -3
View File
@@ -20,6 +20,7 @@
import os import os
from secrets import token_urlsafe from secrets import token_urlsafe
import sqlalchemy as sa
from click.testing import Result from click.testing import Result
from flask import Flask, Blueprint, render_template, redirect, Response, \ from flask import Flask, Blueprint, render_template, redirect, Response, \
url_for url_for
@@ -112,8 +113,8 @@ def create_app(is_testing: bool = False, is_skip_accounts: bool = False,
return auth.current_user() return auth.current_user()
def get_by_username(self, username: str) -> auth.User | None: def get_by_username(self, username: str) -> auth.User | None:
return auth.User.query\ return db.session.scalar(
.filter(auth.User.username == username).first() sa.select(auth.User).where(auth.User.username == username))
def get_pk(self, user: auth.User) -> int: def get_pk(self, user: auth.User) -> int:
return user.id return user.id
@@ -140,7 +141,9 @@ def init_db(app: Flask, is_skip_accounts: bool,
db.create_all() db.create_all()
from .auth import User from .auth import User
for username in ["viewer", "editor", "admin", "nobody"]: for username in ["viewer", "editor", "admin", "nobody"]:
if User.query.filter(User.username == username).first() is None: user: User | None = db.session.scalar(
sa.select(User).where(User.username == username))
if user is None:
db.session.add(User(username=username)) db.session.add(User(username=username))
db.session.commit() db.session.commit()
runner: FlaskCliRunner = app.test_cli_runner() runner: FlaskCliRunner = app.test_cli_runner()
+3 -2
View File
@@ -19,6 +19,7 @@
""" """
from collections.abc import Callable from collections.abc import Callable
import sqlalchemy as sa
from flask import Blueprint, render_template, Flask, redirect, url_for, \ from flask import Blueprint, render_template, Flask, redirect, url_for, \
session, request, g, Response, abort session, request, g, Response, abort
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
@@ -91,8 +92,8 @@ def current_user() -> User | None:
if "user" not in session: if "user" not in session:
g.user = None g.user = None
else: else:
g.user = User.query.filter( g.user = db.session.scalar(
User.username == session["user"]).first() sa.select(User).where(User.username == session["user"]))
return g.user return g.user
+2 -2
View File
@@ -218,8 +218,8 @@ class BaseTestData(ABC):
self._app: Flask = app self._app: Flask = app
"""The Flask application.""" """The Flask application."""
with self._app.app_context(): with self._app.app_context():
current_user: User | None = User.query\ current_user: User | None = db.session.scalar(
.filter(User.username == username).first() sa.select(User).where(User.username == username))
assert current_user is not None assert current_user is not None
self.__current_user_id: int = current_user.id self.__current_user_id: int = current_user.id
"""The current user ID.""" """The current user ID."""
+9 -8
View File
@@ -19,6 +19,7 @@
""" """
import datetime as dt import datetime as dt
import sqlalchemy as sa
from flask import Flask, Blueprint, url_for, flash, redirect, session, \ from flask import Flask, Blueprint, url_for, flash, redirect, session, \
render_template, current_app, Response render_template, current_app, Response
from flask_babel import lazy_gettext from flask_babel import lazy_gettext
@@ -83,14 +84,14 @@ def __reset_database() -> None:
from accounting.account import init_accounts_command from accounting.account import init_accounts_command
from accounting.currency import init_currencies_command from accounting.currency import init_currencies_command
JournalEntryLineItem.query.delete() db.session.execute(sa.delete(JournalEntryLineItem))
JournalEntry.query.delete() db.session.execute(sa.delete(JournalEntry))
CurrencyL10n.query.delete() db.session.execute(sa.delete(CurrencyL10n))
Currency.query.delete() db.session.execute(sa.delete(Currency))
AccountL10n.query.delete() db.session.execute(sa.delete(AccountL10n))
Account.query.delete() db.session.execute(sa.delete(Account))
BaseAccountL10n.query.delete() db.session.execute(sa.delete(BaseAccountL10n))
BaseAccount.query.delete() db.session.execute(sa.delete(BaseAccount))
init_base_accounts_command() init_base_accounts_command()
init_accounts_command(session["user"]) init_accounts_command(session["user"])
init_currencies_command(session["user"]) init_currencies_command(session["user"])