diff --git a/docs/source/examples.rst b/docs/source/examples.rst index 552a9b0..fb1bc6b 100644 --- a/docs/source/examples.rst +++ b/docs/source/examples.rst @@ -13,7 +13,7 @@ The following is an example configuration for *Mia! Accounting*. from flask import Response, redirect from .auth import current_user() - from .modules import User + from .modules import Base, User def create_app(test_config=None) -> Flask: app: Flask = Flask(__name__) @@ -36,6 +36,10 @@ The following is an example configuration for *Mia! Accounting*. def unauthorized(self) -> Response: return redirect("/login") + @property + def base(self) -> type[DeclarativeBase]: + return Base + @property def cls(self) -> type[User]: return User diff --git a/src/accounting/commands.py b/src/accounting/commands.py index f342d8e..aa60a94 100644 --- a/src/accounting/commands.py +++ b/src/accounting/commands.py @@ -29,7 +29,7 @@ from .base_account import init_base_accounts_command from .currency import init_currencies_command from .models import BaseAccount, Account from .utils.title_case import title_case -from .utils.user import has_user, get_user_pk +from .utils.user import base_cls, has_user, get_user_pk def __validate_username(ctx: click.core.Context, param: click.core.Option, @@ -62,7 +62,7 @@ def __validate_username(ctx: click.core.Context, param: click.core.Option, def init_db_command(username: str, skip_accounts: bool, skip_currencies: bool) -> None: """Initializes the accounting database.""" - db.create_all() + base_cls.metadata.create_all(db.engine) init_base_accounts_command() if not skip_accounts: init_accounts_command(username) diff --git a/src/accounting/models.py b/src/accounting/models.py index b8bb1ab..62dd295 100644 --- a/src/accounting/models.py +++ b/src/accounting/models.py @@ -31,10 +31,10 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship from . import db from .locale import gettext -from .utils.user import user_cls, user_pk_column +from .utils.user import base_cls, user_cls, user_pk_column -class BaseAccount(db.Model): +class BaseAccount(base_cls): """A base account.""" __tablename__ = "accounting_base_accounts" """The table name.""" @@ -78,7 +78,7 @@ class BaseAccount(db.Model): return [self.code, self.title_l10n] + [x.title for x in self.l10n] -class BaseAccountL10n(db.Model): +class BaseAccountL10n(base_cls): """A localized base account title.""" __tablename__ = "accounting_base_accounts_l10n" """The table name.""" @@ -95,7 +95,7 @@ class BaseAccountL10n(db.Model): """The localized title.""" -class Account(db.Model): +class Account(base_cls): """An account.""" __tablename__ = "accounting_accounts" """The table name.""" @@ -354,7 +354,7 @@ class Account(db.Model): return account -class AccountL10n(db.Model): +class AccountL10n(base_cls): """A localized account title.""" __tablename__ = "accounting_accounts_l10n" """The table name.""" @@ -371,7 +371,7 @@ class AccountL10n(db.Model): """The localized title.""" -class Currency(db.Model): +class Currency(base_cls): """A currency.""" __tablename__ = "accounting_currencies" """The table name.""" @@ -483,7 +483,7 @@ class Currency(db.Model): db.session.delete(self) -class CurrencyL10n(db.Model): +class CurrencyL10n(base_cls): """A localized currency name.""" __tablename__ = "accounting_currencies_l10n" """The table name.""" @@ -543,7 +543,7 @@ class JournalEntryCurrency: return sum([x.amount for x in self.credit]) -class JournalEntry(db.Model): +class JournalEntry(base_cls): """A journal entry.""" __tablename__ = "accounting_journal_entries" """The table name.""" @@ -661,7 +661,7 @@ class JournalEntry(db.Model): db.session.delete(self) -class JournalEntryLineItem(db.Model): +class JournalEntryLineItem(base_cls): """A line item in the journal entry.""" __tablename__ = "accounting_journal_entry_line_items" """The table name.""" @@ -888,7 +888,7 @@ class JournalEntryLineItem(db.Model): format_amount(self.amount)] -class Option(db.Model): +class Option(base_cls): """An option.""" __tablename__ = "accounting_options" """The table name.""" diff --git a/src/accounting/utils/random_id.py b/src/accounting/utils/random_id.py index e37ba80..9f806f2 100644 --- a/src/accounting/utils/random_id.py +++ b/src/accounting/utils/random_id.py @@ -22,9 +22,10 @@ This module should not import any other module from the application. from secrets import randbelow from .. import db +from ..utils.user import base_cls -def new_id(cls: type[db.Model]): +def new_id(cls: type[base_cls]): """Generates and returns a new, unused random ID for the data model. :param cls: The data model. diff --git a/src/accounting/utils/user.py b/src/accounting/utils/user.py index 8891d1b..3185c46 100644 --- a/src/accounting/utils/user.py +++ b/src/accounting/utils/user.py @@ -23,10 +23,10 @@ from abc import ABC, abstractmethod import sqlalchemy as sa from flask import g, Response -from flask_sqlalchemy.model import Model +from sqlalchemy.orm import DeclarativeBase -class UserUtilityInterface[T: Model](ABC): +class UserUtilityInterface[T: DeclarativeBase](ABC): """The interface for the user utilities.""" @abstractmethod @@ -67,6 +67,14 @@ class UserUtilityInterface[T: Model](ABC): :return: The response to require the user to log in. """ + @property + @abstractmethod + def base(self) -> type[DeclarativeBase]: + """Returns the base data model. + + :return: The base data model. + """ + @property @abstractmethod def cls(self) -> type[T]: @@ -109,7 +117,9 @@ class UserUtilityInterface[T: Model](ABC): __user_utils: UserUtilityInterface """The user utilities.""" -type user_cls = Model +base_cls = DeclarativeBase +"""The base data model.""" +type user_cls = DeclarativeBase """The user class.""" user_pk_column: sa.Column = sa.Column(sa.Integer) """The primary key column of the user class.""" @@ -121,8 +131,9 @@ def init_user_utils(utils: UserUtilityInterface) -> None: :param utils: The user utilities. :return: None. """ - global __user_utils, user_cls, user_pk_column + global __user_utils, base_cls, user_cls, user_pk_column __user_utils = utils + base_cls = utils.base user_cls = utils.cls user_pk_column = utils.pk_column diff --git a/tests/test_commands.py b/tests/test_commands.py index 0f25b71..3cd2685 100644 --- a/tests/test_commands.py +++ b/tests/test_commands.py @@ -29,7 +29,7 @@ from flask import Flask from flask.testing import FlaskCliRunner from sqlalchemy.sql.ddl import DropTable -from test_site import db +from test_site import db, Base from testlib import create_test_app @@ -63,7 +63,7 @@ class ConsoleCommandTestCase(unittest.TestCase): # Drop every accounting table, to see if accounting-init-db # recreates them correctly. tables: list[sa.Table] \ - = [db.metadata.tables[x] for x in db.metadata.tables + = [Base.metadata.tables[x] for x in Base.metadata.tables if x.startswith("accounting_")] for table in tables: db.session.execute(DropTable(table)) @@ -207,7 +207,7 @@ class ConsoleCommandTestCase(unittest.TestCase): with self.__app.app_context(): # Resets the accounts. tables: list[sa.Table] \ - = [db.metadata.tables[x] for x in db.metadata.tables + = [Base.metadata.tables[x] for x in Base.metadata.tables if x.startswith("accounting_")] for table in tables: db.session.execute(DropTable(table)) diff --git a/tests/test_site/__init__.py b/tests/test_site/__init__.py index 36cbf6b..6ee7598 100644 --- a/tests/test_site/__init__.py +++ b/tests/test_site/__init__.py @@ -28,6 +28,7 @@ from flask.testing import FlaskCliRunner from flask_babel_js import BabelJS from flask_sqlalchemy import SQLAlchemy from flask_wtf import CSRFProtect +from sqlalchemy.orm import DeclarativeBase bp: Blueprint = Blueprint("home", __name__) """The global blueprint.""" @@ -39,6 +40,10 @@ db: SQLAlchemy = SQLAlchemy() """The database instance.""" +class Base(DeclarativeBase): + """The base class for all models.""" + + def create_app(is_testing: bool = False, is_skip_accounts: bool = False, is_skip_currencies: bool = False) -> Flask: """Create and configure the application. @@ -99,6 +104,10 @@ def create_app(is_testing: bool = False, is_skip_accounts: bool = False, from accounting.utils.next_uri import append_next return redirect(append_next(url_for("auth.login-form"))) + @property + def base(self) -> type[DeclarativeBase]: + return Base + @property def cls(self) -> type[auth.User]: return auth.User @@ -137,7 +146,7 @@ def init_db(app: Flask, is_skip_accounts: bool, otherwise. :return: None. """ - db.create_all() + Base.metadata.create_all(db.engine) from .auth import User for username in ["viewer", "editor", "admin", "nobody"]: user: User | None = db.session.scalar( diff --git a/tests/test_site/auth.py b/tests/test_site/auth.py index 1baaf4c..a408ea3 100644 --- a/tests/test_site/auth.py +++ b/tests/test_site/auth.py @@ -24,13 +24,13 @@ from flask import Blueprint, render_template, Flask, redirect, url_for, \ session, request, g, Response, abort from sqlalchemy.orm import Mapped, mapped_column -from . import db +from . import db, Base bp: Blueprint = Blueprint("auth", __name__, url_prefix="/") """The authentication blueprint.""" -class User(db.Model): +class User(Base): """A user.""" __tablename__ = "users" """The table name."""