Migrate from SQLAlchemy 1.x legacy Query API to 2.x style select/delete statements

This commit is contained in:
2026-04-06 01:06:01 +08:00
parent 356950e2c7
commit 970c2e9946
39 changed files with 372 additions and 275 deletions
+2 -1
View File
@@ -49,7 +49,8 @@ The following is an example configuration for *Mia! Accounting*.
return current_user()
def get_by_username(self, username: str) -> User | None:
return User.query.filter(User.username == username).first()
return db.session.scalar(
sa.select(User).where(User.username == username))
def get_pk(self, user: User) -> int:
return user.id
+5 -4
View File
@@ -36,13 +36,14 @@ def init_accounts_command(username: str) -> None:
"""Initializes the accounts."""
creator_pk: int = get_user_pk(username)
bases: list[BaseAccount] = BaseAccount.query\
.filter(db.func.length(BaseAccount.code) == 4)\
.order_by(BaseAccount.code).all()
bases: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount).where(db.func.length(BaseAccount.code) == 4)
.order_by(BaseAccount.code)).unique().all()
if len(bases) == 0:
raise click.Abort
existing: list[Account] = Account.query.all()
existing: list[Account] = \
db.session.scalars(sa.select(Account)).unique().all()
existing_base_code: set[str] = {x.base_code for x in existing}
bases_to_add: list[BaseAccount] = [x for x in bases
+11 -9
View File
@@ -97,8 +97,9 @@ class AccountForm(FlaskForm):
if obj.base_code is not None:
sort_accounts_in(obj.base_code, obj.id)
sort_accounts_in(self.base_code.data, obj.id)
count: int = Account.query\
.filter(Account.base_code == self.base_code.data).count()
count: int = db.session.scalar(
sa.select(sa.func.count(Account.id))
.where(Account.base_code == self.base_code.data))
obj.base_code = self.base_code.data
obj.no = count + 1
obj.title = self.title.data
@@ -137,9 +138,10 @@ class AccountForm(FlaskForm):
:return: The selectable base accounts.
"""
return BaseAccount.query\
.filter(sa.func.char_length(BaseAccount.code) == 4)\
.order_by(BaseAccount.code).all()
return db.session.scalars(
sa.select(BaseAccount)
.where(sa.func.char_length(BaseAccount.code) == 4)
.order_by(BaseAccount.code)).unique()
def sort_accounts_in(base_code: str, exclude: int) -> None:
@@ -150,10 +152,10 @@ def sort_accounts_in(base_code: str, exclude: int) -> None:
:param exclude: The account ID to exclude.
:return: None.
"""
accounts: list[Account] = Account.query\
.filter(Account.base_code == base_code,
Account.id != exclude)\
.order_by(Account.no).all()
accounts: list[Account] = db.session.scalars(
sa.select(Account)
.where(Account.base_code == base_code, Account.id != exclude)
.order_by(Account.no)).unique().all()
for i in range(len(accounts)):
if accounts[i].no != i + 1:
accounts[i].no = i + 1
+10 -5
View File
@@ -20,6 +20,7 @@
import sqlalchemy as sa
from flask import request
from .. import db
from ..locale import gettext
from ..models import Account, AccountL10n
from ..utils.query import parse_query_keywords
@@ -32,15 +33,18 @@ def get_account_query() -> list[Account]:
"""
keywords: list[str] = parse_query_keywords(request.args.get("q"))
if len(keywords) == 0:
return Account.query.order_by(Account.base_code, Account.no).all()
return db.session.scalars(
sa.select(Account)
.order_by(Account.base_code, Account.no)).unique().all()
code: sa.ColumnElement[str] = Account.base_code + "-" \
+ sa.func.substr("000" + sa.cast(Account.no, sa.String),
sa.func.char_length(sa.cast(Account.no,
sa.String)) + 1)
conditions: list[sa.ColumnElement[bool]] = []
for k in keywords:
l10n: list[AccountL10n] = AccountL10n.query\
.filter(AccountL10n.title.icontains(k)).all()
l10n: list[AccountL10n] = db.session.scalars(
sa.select(AccountL10n)
.where(AccountL10n.title.icontains(k))).all()
l10n_matches: set[int] = {x.account_id for x in l10n}
sub_conditions: list[sa.ColumnElement[bool]] \
= [Account.base_code.contains(k),
@@ -51,5 +55,6 @@ def get_account_query() -> list[Account]:
sub_conditions.append(Account.is_need_offset)
conditions.append(sa.or_(*sub_conditions))
return Account.query.filter(*conditions)\
.order_by(Account.base_code, Account.no).all()
return db.session.scalars(
sa.select(Account).where(*conditions)
.order_by(Account.base_code, Account.no)).unique().all()
+1 -1
View File
@@ -28,7 +28,7 @@ from ..utils.title_case import title_case
def init_base_accounts_command() -> None:
"""Initializes the base accounts."""
if BaseAccount.query.first() is not None:
if db.session.scalar(sa.select(BaseAccount)) is not None:
return
with open(data_dir / "base_accounts.csv") as fp:
+9 -5
View File
@@ -20,6 +20,7 @@
import sqlalchemy as sa
from flask import request
from .. import db
from ..models import BaseAccount, BaseAccountL10n
from ..utils.query import parse_query_keywords
@@ -31,14 +32,17 @@ def get_base_account_query() -> list[BaseAccount]:
"""
keywords: list[str] = parse_query_keywords(request.args.get("q"))
if len(keywords) == 0:
return BaseAccount.query.order_by(BaseAccount.code).all()
return db.session.scalars(
sa.select(BaseAccount).order_by(BaseAccount.code)).unique().all()
conditions: list[sa.ColumnElement[bool]] = []
for k in keywords:
l10n: list[BaseAccountL10n] = BaseAccountL10n.query\
.filter(BaseAccountL10n.title.icontains(k)).all()
l10n: list[BaseAccountL10n] = db.session.scalars(
sa.select(BaseAccountL10n)
.where((BaseAccountL10n.title.icontains(k)))).all()
l10n_matches: set[str] = {x.account_code for x in l10n}
conditions.append(sa.or_(BaseAccount.code.contains(k),
BaseAccount.title_l10n.icontains(k),
BaseAccount.code.in_(l10n_matches)))
return BaseAccount.query.filter(*conditions)\
.order_by(BaseAccount.code).all()
return db.session.scalars(
sa.select(BaseAccount).where(*conditions)
.order_by(BaseAccount.code)).unique().all()
+4 -2
View File
@@ -66,8 +66,10 @@ def init_db_command(username: str, skip_accounts: bool,
init_base_accounts_command()
if not skip_accounts:
init_accounts_command(username)
print("OK 1")
if not skip_currencies:
init_currencies_command(username)
print("OK 2")
db.session.commit()
click.echo("Accounting database initialized.")
@@ -81,12 +83,12 @@ def titleize_command(username: str) -> None:
"""Capitalize the account titles."""
updater_pk: int = get_user_pk(username)
updated: int = 0
for base in BaseAccount.query:
for base in db.session.scalars(sa.select(BaseAccount)).unique():
new_title: str = title_case(base.title_l10n)
if base.title_l10n != new_title:
base.title_l10n = new_title
updated = updated + 1
for account in Account.query:
for account in db.session.scalars(sa.select(Account)).unique():
if account.title_l10n.lower() == account.base.title_l10n.lower():
new_title: str = title_case(account.title_l10n)
if account.title_l10n != new_title:
+2 -1
View File
@@ -29,7 +29,8 @@ from ..utils.user import get_user_pk
def init_currencies_command(username: str) -> None:
"""Initializes the currencies."""
existing_codes: set[str] = {x.code for x in Currency.query.all()}
existing_codes: set[str] = \
{x.code for x in db.session.scalars(sa.select(Currency)).unique()}
with open(data_dir / "currencies.csv") as fp:
data: list[dict[str, str]] = [x for x in csv.DictReader(fp)]
+9 -5
View File
@@ -20,6 +20,7 @@
import sqlalchemy as sa
from flask import request
from .. import db
from ..models import Currency, CurrencyL10n
from ..utils.query import parse_query_keywords
@@ -31,14 +32,17 @@ def get_currency_query() -> list[Currency]:
"""
keywords: list[str] = parse_query_keywords(request.args.get("q"))
if len(keywords) == 0:
return Currency.query.order_by(Currency.code).all()
return db.session.scalars(
sa.select(Currency).order_by(Currency.code)).unique().all()
conditions: list[sa.ColumnElement[bool]] = []
for k in keywords:
l10n: list[CurrencyL10n] = CurrencyL10n.query\
.filter(CurrencyL10n.name.icontains(k)).all()
l10n: list[CurrencyL10n] = db.session.scalars(
sa.select(CurrencyL10n)
.where(CurrencyL10n.name.icontains(k))).all()
l10n_matches: set[str] = {x.account_code for x in l10n}
conditions.append(sa.or_(Currency.code.icontains(k),
Currency.name_l10n.icontains(k),
Currency.code.in_(l10n_matches)))
return Currency.query.filter(*conditions)\
.order_by(Currency.code).all()
return db.session.scalars(
sa.select(Currency).where(*conditions)
.order_by(Currency.code)).unique().all()
@@ -55,7 +55,7 @@ class SameCurrencyAsOriginalLineItems:
return
original_line_item_currency_codes: set[str] = set(db.session.scalars(
sa.select(JournalEntryLineItem.currency_code)
.filter(JournalEntryLineItem.id.in_(original_line_item_id))).all())
.where(JournalEntryLineItem.id.in_(original_line_item_id))).all())
for currency_code in original_line_item_currency_codes:
if field.data != currency_code:
raise ValidationError(lazy_gettext(
@@ -72,17 +72,17 @@ class KeepCurrencyWhenHavingOffset:
if field.data is None:
return
offset: sa.Alias = offset_alias()
original_line_items: list[JournalEntryLineItem]\
= JournalEntryLineItem.query\
original_line_items: list[JournalEntryLineItem] = db.session.scalars(
sa.select(JournalEntryLineItem)
.join(offset,
JournalEntryLineItem.id == offset.c.original_line_item_id,
isouter=True)\
.filter(JournalEntryLineItem.id
isouter=True)
.where(JournalEntryLineItem.id
.in_({x.id.data for x in form.line_items
if x.id.data is not None}))\
if x.id.data is not None}))
.group_by(JournalEntryLineItem.id,
JournalEntryLineItem.currency_code)\
.having(sa.func.count(offset.c.id) > 0).all()
JournalEntryLineItem.currency_code)
.having(sa.func.count(offset.c.id) > 0)).unique().all()
for original_line_item in original_line_items:
if original_line_item.currency_code != field.data:
raise ValidationError(lazy_gettext(
@@ -152,7 +152,7 @@ class CurrencyForm(FlaskForm):
line_item_id: set[int] = {x.id.data for x in line_item_forms
if x.id.data is not None}
select: sa.Select = sa.select(sa.func.count(JournalEntryLineItem.id))\
.filter(JournalEntryLineItem.original_line_item_id
.where(JournalEntryLineItem.original_line_item_id
.in_(line_item_id))
return db.session.scalar(select) > 0
@@ -159,8 +159,9 @@ class JournalEntryForm(FlaskForm):
to_delete: set[int] = {x.id for x in obj.line_items
if x.id not in collector.to_keep}
if len(to_delete) > 0:
JournalEntryLineItem.query\
.filter(JournalEntryLineItem.id.in_(to_delete)).delete()
db.session.execute(
sa.delete(JournalEntryLineItem)
.where(JournalEntryLineItem.id.in_(to_delete)))
self.is_modified = True
if is_new or db.session.is_modified(obj):
@@ -195,7 +196,7 @@ class JournalEntryForm(FlaskForm):
if self.max_date is not None and new_date == self.max_date:
db_min_no: int | None = db.session.scalar(
sa.select(sa.func.min(JournalEntry.no))
.filter(JournalEntry.date == new_date))
.where(JournalEntry.date == new_date))
if db_min_no is None:
obj.date = new_date
obj.no = 1
@@ -205,8 +206,9 @@ class JournalEntryForm(FlaskForm):
sort_journal_entries_in(new_date)
else:
sort_journal_entries_in(new_date, obj.id)
count: int = JournalEntry.query\
.filter(JournalEntry.date == new_date).count()
count: int = db.session.scalar(
sa.select(sa.func.count(JournalEntry.id))
.where(JournalEntry.date == new_date))
obj.date = new_date
obj.no = count + 1
@@ -221,7 +223,7 @@ class JournalEntryForm(FlaskForm):
if not (x.code[0] == "2" and x.is_need_offset)]
in_use: set[int] = set(db.session.scalars(
sa.select(JournalEntryLineItem.account_id)
.filter(JournalEntryLineItem.is_debit)
.where(JournalEntryLineItem.is_debit)
.group_by(JournalEntryLineItem.account_id)).all())
for account in accounts:
account.is_in_use = account.id in in_use
@@ -238,7 +240,7 @@ class JournalEntryForm(FlaskForm):
if not (x.code[0] == "1" and x.is_need_offset)]
in_use: set[int] = set(db.session.scalars(
sa.select(JournalEntryLineItem.account_id)
.filter(sa.not_(JournalEntryLineItem.is_debit))
.where(sa.not_(JournalEntryLineItem.is_debit))
.group_by(JournalEntryLineItem.account_id)).all())
for account in accounts:
account.is_in_use = account.id in in_use
@@ -288,7 +290,7 @@ class JournalEntryForm(FlaskForm):
return None
select: sa.Select = sa.select(sa.func.max(JournalEntry.date))\
.join(JournalEntryLineItem)\
.filter(JournalEntryLineItem.id.in_(original_line_item_id))
.where(JournalEntryLineItem.id.in_(original_line_item_id))
return db.session.scalar(select)
@property
@@ -301,7 +303,7 @@ class JournalEntryForm(FlaskForm):
if x.id.data is not None}
select: sa.Select = sa.select(sa.func.min(JournalEntry.date))\
.join(JournalEntryLineItem)\
.filter(JournalEntryLineItem.original_line_item_id
.where(JournalEntryLineItem.original_line_item_id
.in_(line_item_id))
return db.session.scalar(select)
@@ -202,7 +202,7 @@ class NotExceedingOriginalLineItemNetBalance:
else_=-JournalEntryLineItem.amount))
offset_total_but_form: Decimal | None = db.session.scalar(
sa.select(offset_total_func)
.filter(JournalEntryLineItem.original_line_item_id
.where(JournalEntryLineItem.original_line_item_id
== original_line_item.id,
JournalEntryLineItem.id.not_in(existing_line_item_id)))
if offset_total_but_form is None:
@@ -231,7 +231,7 @@ class NotLessThanOffsetTotal:
(JournalEntryLineItem.is_debit != is_debit,
JournalEntryLineItem.amount),
else_=-JournalEntryLineItem.amount)))\
.filter(JournalEntryLineItem.original_line_item_id == form.id.data)
.where(JournalEntryLineItem.original_line_item_id == form.id.data)
offset_total: Decimal | None = db.session.scalar(select_offset_total)
if offset_total is not None and field.data < offset_total:
raise ValidationError(lazy_gettext(
@@ -353,13 +353,14 @@ class LineItemForm(FlaskForm):
def get_offsets() -> list[JournalEntryLineItem]:
if not self.is_need_offset or self.id.data is None:
return []
return JournalEntryLineItem.query.join(JournalEntry)\
.filter(JournalEntryLineItem.original_line_item_id
== self.id.data)\
return db.session.scalars(
sa.select(JournalEntryLineItem).join(JournalEntry)
.where(JournalEntryLineItem.original_line_item_id
== self.id.data)
.order_by(JournalEntry.date, JournalEntry.no,
JournalEntryLineItem.no)\
JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.journal_entry),
selectinload(JournalEntryLineItem.account)).all()
selectinload(JournalEntryLineItem.account))).all()
setattr(self, "__offsets", get_offsets())
return getattr(self, "__offsets")
@@ -37,9 +37,9 @@ def sort_journal_entries_in(date: dt.date, exclude: int | None = None) -> None:
conditions: list[sa.ColumnElement[bool]] = [JournalEntry.date == date]
if exclude is not None:
conditions.append(JournalEntry.id != exclude)
journal_entries: list[JournalEntry] = JournalEntry.query\
.filter(*conditions)\
.order_by(JournalEntry.no).all()
journal_entries: list[JournalEntry] = db.session.scalars(
sa.select(JournalEntry).where(*conditions)
.order_by(JournalEntry.no)).all()
for i in range(len(journal_entries)):
if journal_entries[i].no != i + 1:
journal_entries[i].no = i + 1
@@ -63,8 +63,9 @@ class JournalEntryReorderForm:
:return:
"""
journal_entries: list[JournalEntry] = JournalEntry.query\
.filter(JournalEntry.date == self.date).all()
journal_entries: list[JournalEntry] = db.session.scalars(
sa.select(JournalEntry)
.where(JournalEntry.date == self.date)).all()
# Collects the specified order.
orders: dict[JournalEntry, int] = {}
@@ -272,15 +272,17 @@ class DescriptionEditor:
select: sa.Select = sa.Select(debit_credit, tag_type, tag,
JournalEntryLineItem.account_id,
sa.func.count().label("freq"))\
.filter(JournalEntryLineItem.description.is_not(None),
.where(JournalEntryLineItem.description.is_not(None),
JournalEntryLineItem.description.like("_%—_%"),
JournalEntryLineItem.original_line_item_id.is_(None))\
.group_by(debit_credit, tag_type, tag,
JournalEntryLineItem.account_id)
result: list[sa.Row] = db.session.execute(select).all()
accounts: dict[int, Account] \
= {x.id: x for x in Account.query
.filter(Account.id.in_({x.account_id for x in result})).all()}
= {x.id: x for x in db.session.scalars(
sa.select(Account)
.where(Account.id.in_({x.account_id for x in result})))
.unique()}
debit_credit_dict: dict[Literal["debit", "credit"],
DescriptionDebitCredit] \
= {x.debit_credit: x for x in {self.debit, self.credit}}
@@ -326,7 +328,8 @@ class DescriptionEditor:
= [get_condition(x) for x in codes]
accounts: dict[str, Account] \
= {x.code: x for x in
Account.query.filter(sa.or_(*conditions)).all()}
db.session.scalars(
sa.select(Account).where(sa.or_(*conditions))).unique()}
for code in codes:
assert code in accounts, \
f"Unknown account \"{code}\" for regular transactions."
@@ -61,20 +61,21 @@ def get_selectable_original_line_items(
.join(offset,
JournalEntryLineItem.id == offset.c.original_line_item_id,
isouter=True)\
.filter(*conditions)\
.where(*conditions)\
.group_by(JournalEntryLineItem.id)\
.having(sa.or_(sa.func.count(offset.c.id) == 0, net_balance != 0))
net_balances: dict[int, Decimal] \
= {x.id: x.net_balance
for x in db.session.execute(select_net_balances).all()}
line_items: list[JournalEntryLineItem] = JournalEntryLineItem.query\
.filter(JournalEntryLineItem.id.in_({x for x in net_balances}))\
.join(JournalEntry)\
for x in db.session.execute(select_net_balances)}
line_items: list[JournalEntryLineItem] = db.session.scalars(
sa.select(JournalEntryLineItem)
.where(JournalEntryLineItem.id.in_({x for x in net_balances}))
.join(JournalEntry)
.order_by(JournalEntry.date, JournalEntry.no,
JournalEntryLineItem.is_debit, JournalEntryLineItem.no)\
JournalEntryLineItem.is_debit, JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.currency),
selectinload(JournalEntryLineItem.account),
selectinload(JournalEntryLineItem.journal_entry)).all()
selectinload(JournalEntryLineItem.journal_entry))).all()
line_items.reverse()
for line_item in line_items:
line_item.net_balance = line_item.amount \
+3 -3
View File
@@ -195,9 +195,9 @@ def show_journal_entry_order(date: dt.date) -> str:
:param date: The date.
:return: The order of the journal entries in the date.
"""
journal_entries: list[JournalEntry] = JournalEntry.query \
.filter(JournalEntry.date == date) \
.order_by(JournalEntry.no).all()
journal_entries: list[JournalEntry] = db.session.scalars(
sa.select(JournalEntry).where(JournalEntry.date == date)
.order_by(JournalEntry.no)).all()
return render_template("accounting/journal-entry/order.html",
date=date, list=journal_entries)
+27 -18
View File
@@ -268,9 +268,10 @@ class Account(db.Model):
:return: None.
"""
AccountL10n.query.filter(AccountL10n.account == self).delete()
db.session.execute(sa.delete(AccountL10n)
.where(AccountL10n.account == self))
cls: type[Self] = self.__class__
cls.query.filter(cls.id == self.id).delete()
db.session.execute(sa.delete(cls).where(cls.id == self.id))
@classmethod
def find_by_code(cls, code: str) -> Self | None:
@@ -282,8 +283,9 @@ class Account(db.Model):
m: re.Match[str] | None = re.match(r"^([1-9]{4})-(\d{3})$", code)
if m is None:
return None
return cls.query.filter(cls.base_code == m.group(1),
cls.no == int(m.group(2))).first()
return db.session.scalar(
sa.select(cls).where(cls.base_code == m.group(1),
cls.no == int(m.group(2))))
@classmethod
def selectable_debit(cls) -> list[Self]:
@@ -292,7 +294,9 @@ class Account(db.Model):
:return: The selectable debit accounts.
"""
return cls.query.filter(sa.or_(cls.base_code.startswith("1"),
return db.session.scalars(
sa.select(cls)
.where(sa.or_(cls.base_code.startswith("1"),
sa.and_(cls.base_code.startswith("2"),
sa.not_(cls.is_need_offset)),
cls.base_code.startswith("3"),
@@ -304,8 +308,8 @@ class Account(db.Model):
cls.base_code.startswith("78"),
cls.base_code.startswith("8"),
cls.base_code.startswith("9")),
cls.base_code != "3353")\
.order_by(cls.base_code, cls.no).all()
cls.base_code != "3353")
.order_by(cls.base_code, cls.no)).unique().all()
@classmethod
def selectable_credit(cls) -> list[Self]:
@@ -314,7 +318,9 @@ class Account(db.Model):
:return: The selectable debit accounts.
"""
return cls.query.filter(sa.or_(sa.and_(cls.base_code.startswith("1"),
return db.session.scalars(
sa.select(cls)
.where(sa.or_(sa.and_(cls.base_code.startswith("1"),
sa.not_(cls.is_need_offset)),
cls.base_code.startswith("2"),
cls.base_code.startswith("3"),
@@ -325,8 +331,8 @@ class Account(db.Model):
cls.base_code.startswith("74"),
cls.base_code.startswith("8"),
cls.base_code.startswith("9")),
cls.base_code != "3353")\
.order_by(cls.base_code, cls.no).all()
cls.base_code != "3353")
.order_by(cls.base_code, cls.no)).unique().all()
@classmethod
def cash(cls) -> Self:
@@ -472,9 +478,10 @@ class Currency(db.Model):
:return: None.
"""
CurrencyL10n.query.filter(CurrencyL10n.currency == self).delete()
cls: type[Self] = self.__class__
cls.query.filter(cls.code == self.code).delete()
db.session.execute(
sa.delete(CurrencyL10n)
.where(CurrencyL10n.currency_code == self.code))
db.session.delete(self)
class CurrencyL10n(db.Model):
@@ -649,8 +656,9 @@ class JournalEntry(db.Model):
:return: None.
"""
JournalEntryLineItem.query\
.filter(JournalEntryLineItem.journal_entry_id == self.id).delete()
db.session.execute(
sa.delete(JournalEntryLineItem)
.where(JournalEntryLineItem.journal_entry_id == self.id))
db.session.delete(self)
@@ -816,10 +824,11 @@ class JournalEntryLineItem(db.Model):
"""
if not hasattr(self, "__offsets"):
cls: type[Self] = self.__class__
offsets: list[Self] = cls.query.join(JournalEntry)\
.filter(JournalEntryLineItem.original_line_item_id == self.id)\
offsets: list[Self] = db.session.scalars(
sa.select(cls).join(JournalEntry)
.where(cls.original_line_item_id == self.id)
.order_by(JournalEntry.date, JournalEntry.no,
cls.is_debit, cls.no).all()
cls.is_debit, cls.no)).unique().all()
setattr(self, "__offsets", offsets)
return getattr(self, "__offsets")
+5 -2
View File
@@ -23,9 +23,12 @@ This file is largely taken from the NanoParma ERP project, first written in
import datetime as dt
from collections.abc import Callable
import sqlalchemy as sa
from .period import Period
from .shortcuts import ThisMonth, LastMonth, SinceLastMonth, ThisYear, \
LastYear, Today, Yesterday, AllTime, TemplatePeriod, YearPeriod
from ... import db
from ...models import JournalEntry
from ...utils.timezone import get_tz_today
@@ -62,8 +65,8 @@ class PeriodChooser:
self.url_template: str = get_url(TemplatePeriod())
"""The URL template."""
first: JournalEntry | None \
= JournalEntry.query.order_by(JournalEntry.date).first()
first: JournalEntry | None = db.session.scalar(
sa.select(JournalEntry).order_by(JournalEntry.date))
start: dt.date | None = None if first is None else first.date
# Attributes
+13 -10
View File
@@ -133,16 +133,17 @@ class AccountCollector:
= sa.select(Account.id, Account.base_code, Account.no,
balance_func)\
.join(JournalEntry).join(Account)\
.filter(*conditions)\
.where(*conditions)\
.group_by(Account.id, Account.base_code, Account.no)\
.having(balance_func != 0)\
.order_by(Account.base_code, Account.no)
account_balances: list[sa.Row] \
= db.session.execute(select_balance).all()
self.__all_accounts: list[Account] = Account.query\
.filter(sa.or_(Account.id.in_({x.id for x in account_balances}),
self.__all_accounts: list[Account] = db.session.scalars(
sa.select(Account)
.where(sa.or_(Account.id.in_({x.id for x in account_balances}),
Account.base_code == "3351",
Account.base_code == "3353")).all()
Account.base_code == "3353"))).unique().all()
"""The accounts."""
account_by_id: dict[int, Account] \
= {x.id: x for x in self.__all_accounts}
@@ -219,7 +220,7 @@ class AccountCollector:
(JournalEntryLineItem.is_debit, JournalEntryLineItem.amount),
else_=-JournalEntryLineItem.amount))
select_balance: sa.Select = sa.select(balance_func)\
.join(JournalEntry).join(Account).filter(*conditions)
.join(JournalEntry).join(Account).where(*conditions)
return db.session.scalar(select_balance)
def __add_owner_s_equity(self, code: str, amount: Decimal | None,
@@ -383,11 +384,13 @@ class BalanceSheet(BaseReport):
balances: list[ReportAccount] = AccountCollector(
self.__currency, self.__period).accounts
titles: list[BaseAccount] = BaseAccount.query\
.filter(BaseAccount.code.in_({"1", "2", "3"})).all()
subtitles: list[BaseAccount] = BaseAccount.query\
.filter(BaseAccount.code.in_({x.account.base_code[:2]
for x in balances})).all()
titles: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)
.where(BaseAccount.code.in_({"1", "2", "3"}))).unique().all()
subtitle_codes: set[str] = {x.account.base_code[:2] for x in balances}
subtitles: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)
.where(BaseAccount.code.in_(subtitle_codes))).unique().all()
sections: dict[str, Section] = {x.code: Section(x) for x in titles}
subsections: dict[str, Subsection] = {x.code: Subsection(x)
@@ -119,7 +119,7 @@ class LineItemCollector:
else_=-JournalEntryLineItem.amount))
select: sa.Select[tuple[Decimal]] = sa.Select(balance_func)\
.join(JournalEntry).join(Account)\
.filter(JournalEntryLineItem.currency_code == self.__currency.code,
.where(JournalEntryLineItem.currency_code == self.__currency.code,
self.__account_condition,
JournalEntry.date < self.__period.start)
balance: Decimal | None = db.session.scalar(select)
@@ -150,12 +150,12 @@ class LineItemCollector:
if self.__period.end is not None:
conditions.append(JournalEntry.date <= self.__period.end)
journal_entry_with_account: sa.Select = sa.Select(JournalEntry.id).\
join(JournalEntryLineItem).join(Account).filter(*conditions)
join(JournalEntryLineItem).join(Account).where(*conditions)
return [ReportLineItem(x)
for x in JournalEntryLineItem.query
return [ReportLineItem(x) for x in db.session.scalars(
sa.select(JournalEntryLineItem)
.join(JournalEntry).join(Account)
.filter(JournalEntryLineItem.journal_entry_id
.where(JournalEntryLineItem.journal_entry_id
.in_(journal_entry_with_account),
JournalEntryLineItem.currency_code
== self.__currency.code,
@@ -165,7 +165,7 @@ class LineItemCollector:
JournalEntryLineItem.is_debit,
JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.account),
selectinload(JournalEntryLineItem.journal_entry))]
selectinload(JournalEntryLineItem.journal_entry)))]
@property
def __account_condition(self) -> sa.ColumnElement[bool]:
@@ -343,7 +343,7 @@ class PageParams(BasePageParams):
self.account.id == 0)]
in_use: sa.Select = sa.Select(JournalEntryLineItem.account_id)\
.join(Account)\
.filter(JournalEntryLineItem.currency_code == self.currency.code,
.where(JournalEntryLineItem.currency_code == self.currency.code,
CurrentAccount.sql_condition())\
.group_by(JournalEntryLineItem.account_id)
options.extend([OptionLink(str(x),
@@ -352,8 +352,10 @@ class PageParams(BasePageParams):
CurrentAccount(x),
self.period),
x.id == self.account.id)
for x in Account.query.filter(Account.id.in_(in_use))
.order_by(Account.base_code, Account.no).all()])
for x in db.session.scalars(
sa.select(Account).where(Account.id.in_(in_use))
.order_by(Account.base_code, Account.no))
.unique()])
return options
@@ -218,11 +218,14 @@ class IncomeStatement(BaseReport):
"""
balances: list[ReportAccount] = self.__query_balances()
titles: list[BaseAccount] = BaseAccount.query\
.filter(BaseAccount.code.in_({"4", "5", "6", "7", "8", "9"})).all()
subtitles: list[BaseAccount] = BaseAccount.query\
.filter(BaseAccount.code.in_({x.account.base_code[:2]
for x in balances})).all()
title_codes: set[str] = {"4", "5", "6", "7", "8", "9"}
titles: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)
.where(BaseAccount.code.in_(title_codes))).unique().all()
subtitle_codes: set[str] = {x.account.base_code[:2] for x in balances}
subtitles: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)
.where(BaseAccount.code.in_(subtitle_codes))).unique().all()
total_titles: dict[str, str] \
= {"4": gettext("Total Operating Revenue"),
@@ -267,14 +270,15 @@ class IncomeStatement(BaseReport):
else_=JournalEntryLineItem.amount)).label("balance")
select_balances: sa.Select = sa.select(Account.id, balance_func)\
.join(JournalEntry).join(Account)\
.filter(*conditions)\
.where(*conditions)\
.group_by(Account.id)\
.having(balance_func != 0)\
.order_by(Account.base_code, Account.no)
balances: list[sa.Row] = db.session.execute(select_balances).all()
accounts: dict[int, Account] \
= {x.id: x for x in Account.query
.filter(Account.id.in_([x.id for x in balances])).all()}
= {x.id: x for x in db.session.scalars(
sa.select(Account)
.where(Account.id.in_([x.id for x in balances]))).unique()}
return [ReportAccount(account=accounts[x.id],
amount=x.balance,
url=ledger_url(self.__currency,
+6 -4
View File
@@ -31,6 +31,7 @@ from ..utils.csv_export import BaseCSVRow, csv_download, period_spec
from ..utils.report_chooser import ReportChooser
from ..utils.report_type import ReportType
from ..utils.urls import journal_url
from ... import db
from ...locale import gettext
from ...models import Currency, Account, JournalEntry, JournalEntryLineItem
from ...utils.pagination import Pagination
@@ -188,15 +189,16 @@ class Journal(BaseReport):
conditions.append(JournalEntry.date >= self.__period.start)
if self.__period.end is not None:
conditions.append(JournalEntry.date <= self.__period.end)
return JournalEntryLineItem.query.join(JournalEntry)\
.filter(*conditions)\
return db.session.scalars(
sa.select(JournalEntryLineItem).join(JournalEntry)
.where(*conditions)
.order_by(JournalEntry.date,
JournalEntry.no,
JournalEntryLineItem.is_debit.desc(),
JournalEntryLineItem.no)\
JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.account),
selectinload(JournalEntryLineItem.currency),
selectinload(JournalEntryLineItem.journal_entry)).all()
selectinload(JournalEntryLineItem.journal_entry))).all()
def csv(self) -> Response:
"""Returns the report as CSV for download.
+10 -9
View File
@@ -115,7 +115,7 @@ class LineItemCollector:
(JournalEntryLineItem.is_debit, JournalEntryLineItem.amount),
else_=-JournalEntryLineItem.amount))
select: sa.Select = sa.Select(balance_func).join(JournalEntry)\
.filter(JournalEntryLineItem.currency_code == self.__currency.code,
.where(JournalEntryLineItem.currency_code == self.__currency.code,
JournalEntryLineItem.account_id == self.__account.id,
JournalEntry.date < self.__period.start)
balance: int | None = db.session.scalar(select)
@@ -144,15 +144,15 @@ class LineItemCollector:
conditions.append(JournalEntry.date >= self.__period.start)
if self.__period.end is not None:
conditions.append(JournalEntry.date <= self.__period.end)
return [ReportLineItem(x) for x in JournalEntryLineItem.query
.join(JournalEntry)
.filter(*conditions)
return [ReportLineItem(x) for x in db.session.scalars(
sa.select(JournalEntryLineItem).join(JournalEntry)
.where(*conditions)
.order_by(JournalEntry.date,
JournalEntry.no,
JournalEntryLineItem.is_debit.desc(),
JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.journal_entry))
.all()]
.options(selectinload(JournalEntryLineItem.journal_entry)))
.unique()]
def __get_total(self) -> ReportLineItem | None:
"""Composes the total line item.
@@ -308,12 +308,13 @@ class PageParams(BasePageParams):
:return: The account options.
"""
in_use: sa.Select = sa.Select(JournalEntryLineItem.account_id)\
.filter(JournalEntryLineItem.currency_code == self.currency.code)\
.where(JournalEntryLineItem.currency_code == self.currency.code)\
.group_by(JournalEntryLineItem.account_id)
return [OptionLink(str(x), ledger_url(self.currency, x, self.period),
x.id == self.account.id)
for x in Account.query.filter(Account.id.in_(in_use))
.order_by(Account.base_code, Account.no).all()]
for x in db.session.scalars(
sa.select(Account).where(Account.id.in_(in_use))
.order_by(Account.base_code, Account.no)).unique()]
class Ledger(BaseReport):
+12 -10
View File
@@ -30,6 +30,7 @@ from ..utils.base_report import BaseReport
from ..utils.csv_export import csv_download
from ..utils.report_chooser import ReportChooser
from ..utils.report_type import ReportType
from ... import db
from ...locale import gettext
from ...models import Currency, CurrencyL10n, Account, AccountL10n, \
JournalEntry, JournalEntryLineItem
@@ -69,15 +70,16 @@ class LineItemCollector:
except ArithmeticError:
pass
conditions.append(sa.or_(*sub_conditions))
return JournalEntryLineItem.query.join(JournalEntry)\
.filter(*conditions)\
return db.session.scalars(
sa.select(JournalEntryLineItem).join(JournalEntry)
.where(*conditions)
.order_by(JournalEntry.date,
JournalEntry.no,
JournalEntryLineItem.is_debit,
JournalEntryLineItem.no)\
JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.account),
selectinload(JournalEntryLineItem.currency),
selectinload(JournalEntryLineItem.journal_entry)).all()
selectinload(JournalEntryLineItem.journal_entry))).all()
@staticmethod
def __get_account_condition(k: str) -> sa.Select:
@@ -91,7 +93,7 @@ class LineItemCollector:
sa.func.char_length(sa.cast(Account.no,
sa.String)) + 1)
select_l10n: sa.Select = sa.select(AccountL10n.account_id)\
.filter(AccountL10n.title.icontains(k))
.where(AccountL10n.title.icontains(k))
conditions: list[sa.ColumnElement[bool]] \
= [Account.base_code.contains(k),
Account.title_l10n.icontains(k),
@@ -99,7 +101,7 @@ class LineItemCollector:
Account.id.in_(select_l10n)]
if k in gettext("Needs Offset"):
conditions.append(Account.is_need_offset)
return sa.select(Account.id).filter(sa.or_(*conditions))
return sa.select(Account.id).where(sa.or_(*conditions))
@staticmethod
def __get_currency_condition(k: str) -> sa.Select:
@@ -109,9 +111,9 @@ class LineItemCollector:
:return: The condition to filter the currency.
"""
select_l10n: sa.Select = sa.select(CurrencyL10n.currency_code)\
.filter(CurrencyL10n.name.icontains(k))
return sa.select(Currency.code).filter(
sa.or_(Currency.code.icontains(k),
.where(CurrencyL10n.name.icontains(k))
return sa.select(Currency.code)\
.where(sa.or_(Currency.code.icontains(k),
Currency.name_l10n.icontains(k),
Currency.code.in_(select_l10n)))
@@ -153,7 +155,7 @@ class LineItemCollector:
sa.extract("day", JournalEntry.date) == date.day))
except ValueError:
pass
return sa.select(JournalEntry.id).filter(sa.or_(*conditions))
return sa.select(JournalEntry.id).where(sa.or_(*conditions))
class PageParams(BasePageParams):
@@ -187,14 +187,15 @@ class TrialBalance(BaseReport):
else_=-JournalEntryLineItem.amount)).label("balance")
select_balances: sa.Select = sa.select(Account.id, balance_func)\
.join(JournalEntry).join(Account)\
.filter(*conditions)\
.where(*conditions)\
.group_by(Account.id)\
.having(balance_func != 0)\
.order_by(Account.base_code, Account.no)
balances: list[sa.Row] = db.session.execute(select_balances).all()
accounts: dict[int, Account] \
= {x.id: x for x in Account.query
.filter(Account.id.in_([x.id for x in balances])).all()}
= {x.id: x for x in db.session.scalars(
sa.select(Account)
.where(Account.id.in_([x.id for x in balances]))).unique()}
self.__accounts = [ReportAccount(account=accounts[x.id],
amount=x.balance,
url=ledger_url(self.__currency,
+8 -5
View File
@@ -20,6 +20,7 @@
import datetime as dt
from decimal import Decimal
import sqlalchemy as sa
from flask import render_template, Response
from sqlalchemy.orm import selectinload
@@ -31,6 +32,7 @@ from ..utils.report_chooser import ReportChooser
from ..utils.report_type import ReportType
from ..utils.unapplied import get_accounts_with_unapplied, get_net_balances
from ..utils.urls import unapplied_url
from ... import db
from ...locale import gettext
from ...models import Currency, Account, JournalEntry, JournalEntryLineItem
from ...utils.pagination import Pagination
@@ -176,13 +178,14 @@ class UnappliedOriginalLineItems(BaseReport):
"""
net_balances: dict[int, Decimal | None] \
= get_net_balances(self.__currency, self.__account)
line_items: list[JournalEntryLineItem] = JournalEntryLineItem.query \
.join(Account).join(JournalEntry) \
.filter(JournalEntryLineItem.id.in_(net_balances)) \
line_items: list[JournalEntryLineItem] = db.session.scalars(
sa.select(JournalEntryLineItem).join(Account).join(JournalEntry)
.where(JournalEntryLineItem.id.in_(net_balances))
.order_by(JournalEntry.date, JournalEntry.no,
JournalEntryLineItem.is_debit, JournalEntryLineItem.no) \
JournalEntryLineItem.is_debit, JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.currency),
selectinload(JournalEntryLineItem.journal_entry)).all()
selectinload(JournalEntryLineItem.journal_entry)))\
.unique().all()
for line_item in line_items:
line_item.net_balance = line_item.amount \
if net_balances[line_item.id] is None \
@@ -84,5 +84,6 @@ class BasePageParams(ABC):
sa.select(JournalEntryLineItem.currency_code)
.group_by(JournalEntryLineItem.currency_code)).all())
return [OptionLink(str(x), get_url(x), x.code == active_currency.code)
for x in Currency.query.filter(Currency.code.in_(in_use))
.order_by(Currency.code).all()]
for x in db.session.scalars(
sa.select(Currency).where(Currency.code.in_(in_use))
.order_by(Currency.code)).unique()]
@@ -24,6 +24,7 @@ from flask_babel import LazyString
from sqlalchemy.orm import selectinload
from ..utils.unapplied import get_net_balances
from ... import db
from ...locale import lazy_gettext
from ...models import Currency, Account, JournalEntry, JournalEntryLineItem
@@ -113,14 +114,15 @@ class OffsetMatcher:
JournalEntryLineItem.is_debit),
sa.and_(Account.base_code.startswith("1"),
sa.not_(JournalEntryLineItem.is_debit))))
self.line_items = JournalEntryLineItem.query \
.join(Account).join(JournalEntry) \
.filter(sa.or_(JournalEntryLineItem.id.in_(net_balances),
unmatched_offset_condition)) \
self.line_items = db.session.scalars(
sa.select(JournalEntryLineItem).join(Account).join(JournalEntry)
.where(sa.or_(JournalEntryLineItem.id.in_(net_balances),
unmatched_offset_condition))
.order_by(JournalEntry.date, JournalEntry.no,
JournalEntryLineItem.is_debit, JournalEntryLineItem.no) \
JournalEntryLineItem.is_debit, JournalEntryLineItem.no)
.options(selectinload(JournalEntryLineItem.currency),
selectinload(JournalEntryLineItem.journal_entry)).all()
selectinload(JournalEntryLineItem.journal_entry)))\
.unique().all()
for line_item in self.line_items:
line_item.is_offset = line_item.id not in net_balances
self.unapplied = [x for x in self.line_items if not x.is_offset]
+7 -6
View File
@@ -45,7 +45,7 @@ def get_accounts_with_unapplied(currency: Currency) -> list[Account]:
.join(offset,
JournalEntryLineItem.id == offset.c.original_line_item_id,
isouter=True)\
.filter(Account.is_need_offset,
.where(Account.is_need_offset,
JournalEntryLineItem.currency_code == currency.code,
sa.or_(sa.and_(Account.base_code.startswith("2"),
sa.not_(JournalEntryLineItem.is_debit)),
@@ -58,13 +58,14 @@ def get_accounts_with_unapplied(currency: Currency) -> list[Account]:
= sa.func.count(JournalEntryLineItem.id).label("count")
select: sa.Select = sa.select(Account.id, count_func)\
.join(JournalEntryLineItem, isouter=True)\
.filter(JournalEntryLineItem.id.in_(select_unapplied))\
.where(JournalEntryLineItem.id.in_(select_unapplied))\
.group_by(Account.id)\
.having(count_func > 0)
counts: dict[int, int] \
= {x.id: x.count for x in db.session.execute(select)}
accounts: list[Account] = Account.query.filter(Account.id.in_(counts))\
.order_by(Account.base_code, Account.no).all()
accounts: list[Account] = db.session.scalars(
sa.select(Account).where(Account.id.in_(counts))
.order_by(Account.base_code, Account.no)).unique().all()
for account in accounts:
account.count = counts[account.id]
return accounts
@@ -91,7 +92,7 @@ def get_net_balances(currency: Currency, account: Account) \
.join(offset,
JournalEntryLineItem.id == offset.c.original_line_item_id,
isouter=True) \
.filter(Account.id == account.id,
.where(Account.id == account.id,
JournalEntryLineItem.currency_code == currency.code,
sa.or_(sa.and_(Account.base_code.startswith("2"),
sa.not_(JournalEntryLineItem.is_debit)),
@@ -100,4 +101,4 @@ def get_net_balances(currency: Currency, account: Account) \
.group_by(JournalEntryLineItem.id) \
.having(sa.or_(sa.func.count(offset.c.id) == 0, net_balance != 0))
return {x.id: x.net_balance
for x in db.session.execute(select_net_balances).all()}
for x in db.session.execute(select_net_balances)}
+4 -3
View File
@@ -35,7 +35,7 @@ def get_accounts_with_unmatched(currency: Currency) -> list[Account]:
select: sa.Select = sa.select(Account.id, count_func)\
.select_from(Account)\
.join(JournalEntryLineItem, isouter=True).join(JournalEntry)\
.filter(Account.is_need_offset,
.where(Account.is_need_offset,
JournalEntryLineItem.currency_code == currency.code,
JournalEntryLineItem.original_line_item_id.is_(None),
sa.or_(sa.and_(Account.base_code.startswith("2"),
@@ -46,8 +46,9 @@ def get_accounts_with_unmatched(currency: Currency) -> list[Account]:
.having(count_func > 0)
counts: dict[int, int] \
= {x.id: x.count for x in db.session.execute(select)}
accounts: list[Account] = Account.query.filter(Account.id.in_(counts))\
.order_by(Account.base_code, Account.no).all()
accounts: list[Account] = db.session.scalars(
sa.select(Account).where(Account.id.in_(counts))
.order_by(Account.base_code, Account.no)).unique().all()
for account in accounts:
account.count = counts[account.id]
return accounts
+5 -1
View File
@@ -17,6 +17,9 @@
"""The template globals.
"""
import sqlalchemy as sa
from . import db
from .models import Currency
from .utils.options import options
@@ -26,7 +29,8 @@ def currency_options() -> list[Currency]:
:return: The currency options.
"""
return Currency.query.order_by(Currency.code).all()
return db.session.scalars(
sa.select(Currency).order_by(Currency.code)).unique().all()
def default_currency_code() -> str:
+5 -3
View File
@@ -21,6 +21,7 @@ from typing import Self
import sqlalchemy as sa
from .. import db
from ..locale import gettext
from ..models import Account
@@ -74,9 +75,10 @@ class CurrentAccount:
"""
accounts: list[cls] = [cls.current_assets_and_liabilities()]
accounts.extend([cls(x)
for x in Account.query
.filter(cls.sql_condition())
.order_by(Account.base_code, Account.no)])
for x in db.session.scalars(
sa.select(Account).where(cls.sql_condition())
.order_by(Account.base_code, Account.no))
.unique()])
return accounts
@classmethod
+14 -6
View File
@@ -21,6 +21,7 @@ import datetime as dt
import unittest
import httpx
import sqlalchemy as sa
from flask import Flask
from accounting.utils.next_uri import encode_next
@@ -275,7 +276,9 @@ class AccountTestCase(unittest.TestCase):
response: httpx.Response
with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()},
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, BANK.code})
# Missing CSRF token
@@ -367,10 +370,11 @@ class AccountTestCase(unittest.TestCase):
f"{PREFIX}/{STOCK.base_code}-003")
with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()},
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, BANK.code, STOCK.code,
f"{STOCK.base_code}-002",
f"{STOCK.base_code}-003"})
f"{STOCK.base_code}-002", f"{STOCK.base_code}-003"})
account: Account | None = Account.find_by_code(STOCK.code)
self.assertIsNotNone(account)
@@ -621,7 +625,9 @@ class AccountTestCase(unittest.TestCase):
"currency-1-credit-1-amount": "20"})
with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()},
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, PETTY.code, BANK.code})
# Cannot delete the cash account
@@ -645,7 +651,9 @@ class AccountTestCase(unittest.TestCase):
self.assertEqual(response.headers["Location"], list_uri)
with self.__app.app_context():
self.assertEqual({x.code for x in Account.query.all()},
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Account)).unique()},
{CASH.code, BANK.code})
response = self.__client.get(detail_uri)
+16 -10
View File
@@ -101,7 +101,8 @@ class ConsoleCommandTestCase(unittest.TestCase):
for x in rows}
with self.__app.app_context():
accounts: list[BaseAccount] = BaseAccount.query.all()
accounts: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)).unique().all()
self.assertEqual(len(accounts), len(data))
for account in accounts:
@@ -141,10 +142,14 @@ class ConsoleCommandTestCase(unittest.TestCase):
from accounting.models import BaseAccount, Account, AccountL10n
with self.__app.app_context():
bases: list[BaseAccount] = BaseAccount.query\
.filter(sa.func.char_length(BaseAccount.code) == 4).all()
accounts: list[Account] = Account.query.all()
l10n: list[AccountL10n] = AccountL10n.query.all()
bases: list[BaseAccount] = db.session.scalars(
sa.select(BaseAccount)
.where(sa.func.char_length(BaseAccount.code) == 4))\
.unique().all()
accounts: list[Account] = db.session.scalars(
sa.select(Account)).unique().all()
l10n: list[AccountL10n] = db.session.scalars(
sa.select(AccountL10n)).all()
self.assertEqual({x.code for x in bases},
{x.base_code for x in accounts})
@@ -175,7 +180,8 @@ class ConsoleCommandTestCase(unittest.TestCase):
for x in csv.DictReader(fp)}
with self.__app.app_context():
currencies: list[Currency] = Currency.query.all()
currencies: list[Currency] = db.session.scalars(
sa.select(Currency)).unique().all()
self.assertEqual(len(currencies), len(data))
for currency in currencies:
@@ -216,9 +222,9 @@ class ConsoleCommandTestCase(unittest.TestCase):
result.output + str(result.exception))
# Turns the titles into lowercase.
for base in BaseAccount.query:
for base in db.session.scalars(sa.select(BaseAccount)).unique():
base.title_l10n = base.title_l10n.lower()
for account in Account.query:
for account in db.session.scalars(sa.select(Account)).unique():
account.title_l10n = account.title_l10n.lower()
account.created_at \
= account.created_at - dt.timedelta(seconds=5)
@@ -242,9 +248,9 @@ class ConsoleCommandTestCase(unittest.TestCase):
args=["accounting-titleize", "-u", "editor"])
self.assertEqual(result.exit_code, 0,
result.output + str(result.exception))
for base in BaseAccount.query:
for base in db.session.scalars(sa.select(BaseAccount)).unique():
self.__test_title_case(base.title_l10n)
for account in Account.query:
for account in db.session.scalars(sa.select(Account)).unique():
if account.id != new_account.id:
self.__test_title_case(account.title_l10n)
self.assertNotEqual(account.created_at, account.updated_at)
+13 -4
View File
@@ -21,6 +21,7 @@ import datetime as dt
import unittest
import httpx
import sqlalchemy as sa
from flask import Flask
from accounting.utils.next_uri import encode_next
@@ -221,7 +222,9 @@ class CurrencyTestCase(unittest.TestCase):
response: httpx.Response
with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()},
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code})
# Missing CSRF token
@@ -287,7 +290,9 @@ class CurrencyTestCase(unittest.TestCase):
self.assertEqual(response.headers["Location"], create_uri)
with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()},
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code, TWD.code})
currency: Currency = db.session.get(Currency, TWD.code)
@@ -554,7 +559,9 @@ class CurrencyTestCase(unittest.TestCase):
"currency-1-credit-1-amount": "20"})
with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()},
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code, JPY.code})
# Cannot delete the default currency
@@ -578,7 +585,9 @@ class CurrencyTestCase(unittest.TestCase):
self.assertEqual(response.headers["Location"], list_uri)
with self.__app.app_context():
self.assertEqual({x.code for x in Currency.query.all()},
self.assertEqual(
{x.code
for x in db.session.scalars(sa.select(Currency)).unique()},
{USD.code, EUR.code})
response = self.__client.get(detail_uri)
+6 -3
View File
@@ -20,6 +20,7 @@
import os
from secrets import token_urlsafe
import sqlalchemy as sa
from click.testing import Result
from flask import Flask, Blueprint, render_template, redirect, Response, \
url_for
@@ -112,8 +113,8 @@ def create_app(is_testing: bool = False, is_skip_accounts: bool = False,
return auth.current_user()
def get_by_username(self, username: str) -> auth.User | None:
return auth.User.query\
.filter(auth.User.username == username).first()
return db.session.scalar(
sa.select(auth.User).where(auth.User.username == username))
def get_pk(self, user: auth.User) -> int:
return user.id
@@ -140,7 +141,9 @@ def init_db(app: Flask, is_skip_accounts: bool,
db.create_all()
from .auth import User
for username in ["viewer", "editor", "admin", "nobody"]:
if User.query.filter(User.username == username).first() is None:
user: User | None = db.session.scalar(
sa.select(User).where(User.username == username))
if user is None:
db.session.add(User(username=username))
db.session.commit()
runner: FlaskCliRunner = app.test_cli_runner()
+3 -2
View File
@@ -19,6 +19,7 @@
"""
from collections.abc import Callable
import sqlalchemy as sa
from flask import Blueprint, render_template, Flask, redirect, url_for, \
session, request, g, Response, abort
from sqlalchemy.orm import Mapped, mapped_column
@@ -91,8 +92,8 @@ def current_user() -> User | None:
if "user" not in session:
g.user = None
else:
g.user = User.query.filter(
User.username == session["user"]).first()
g.user = db.session.scalar(
sa.select(User).where(User.username == session["user"]))
return g.user
+2 -2
View File
@@ -218,8 +218,8 @@ class BaseTestData(ABC):
self._app: Flask = app
"""The Flask application."""
with self._app.app_context():
current_user: User | None = User.query\
.filter(User.username == username).first()
current_user: User | None = db.session.scalar(
sa.select(User).where(User.username == username))
assert current_user is not None
self.__current_user_id: int = current_user.id
"""The current user ID."""
+9 -8
View File
@@ -19,6 +19,7 @@
"""
import datetime as dt
import sqlalchemy as sa
from flask import Flask, Blueprint, url_for, flash, redirect, session, \
render_template, current_app, Response
from flask_babel import lazy_gettext
@@ -83,14 +84,14 @@ def __reset_database() -> None:
from accounting.account import init_accounts_command
from accounting.currency import init_currencies_command
JournalEntryLineItem.query.delete()
JournalEntry.query.delete()
CurrencyL10n.query.delete()
Currency.query.delete()
AccountL10n.query.delete()
Account.query.delete()
BaseAccountL10n.query.delete()
BaseAccount.query.delete()
db.session.execute(sa.delete(JournalEntryLineItem))
db.session.execute(sa.delete(JournalEntry))
db.session.execute(sa.delete(CurrencyL10n))
db.session.execute(sa.delete(Currency))
db.session.execute(sa.delete(AccountL10n))
db.session.execute(sa.delete(Account))
db.session.execute(sa.delete(BaseAccountL10n))
db.session.execute(sa.delete(BaseAccount))
init_base_accounts_command()
init_accounts_command(session["user"])
init_currencies_command(session["user"])