Replace db.Model with DeclarativeBase from SQLAlchemy for Flask-SQLAlchemy-Lite migration

This commit is contained in:
2026-04-06 01:49:44 +08:00
parent e6d25882fc
commit 9c6cc1f3eb
8 changed files with 49 additions and 24 deletions
+5 -1
View File
@@ -13,7 +13,7 @@ The following is an example configuration for *Mia! Accounting*.
from flask import Response, redirect from flask import Response, redirect
from .auth import current_user() from .auth import current_user()
from .modules import User from .modules import Base, User
def create_app(test_config=None) -> Flask: def create_app(test_config=None) -> Flask:
app: Flask = Flask(__name__) app: Flask = Flask(__name__)
@@ -36,6 +36,10 @@ The following is an example configuration for *Mia! Accounting*.
def unauthorized(self) -> Response: def unauthorized(self) -> Response:
return redirect("/login") return redirect("/login")
@property
def base(self) -> type[DeclarativeBase]:
return Base
@property @property
def cls(self) -> type[User]: def cls(self) -> type[User]:
return User return User
+2 -2
View File
@@ -29,7 +29,7 @@ from .base_account import init_base_accounts_command
from .currency import init_currencies_command from .currency import init_currencies_command
from .models import BaseAccount, Account from .models import BaseAccount, Account
from .utils.title_case import title_case from .utils.title_case import title_case
from .utils.user import has_user, get_user_pk from .utils.user import base_cls, has_user, get_user_pk
def __validate_username(ctx: click.core.Context, param: click.core.Option, def __validate_username(ctx: click.core.Context, param: click.core.Option,
@@ -62,7 +62,7 @@ def __validate_username(ctx: click.core.Context, param: click.core.Option,
def init_db_command(username: str, skip_accounts: bool, def init_db_command(username: str, skip_accounts: bool,
skip_currencies: bool) -> None: skip_currencies: bool) -> None:
"""Initializes the accounting database.""" """Initializes the accounting database."""
db.create_all() base_cls.metadata.create_all(db.engine)
init_base_accounts_command() init_base_accounts_command()
if not skip_accounts: if not skip_accounts:
init_accounts_command(username) init_accounts_command(username)
+10 -10
View File
@@ -31,10 +31,10 @@ from sqlalchemy.orm import Mapped, mapped_column, relationship
from . import db from . import db
from .locale import gettext from .locale import gettext
from .utils.user import user_cls, user_pk_column from .utils.user import base_cls, user_cls, user_pk_column
class BaseAccount(db.Model): class BaseAccount(base_cls):
"""A base account.""" """A base account."""
__tablename__ = "accounting_base_accounts" __tablename__ = "accounting_base_accounts"
"""The table name.""" """The table name."""
@@ -78,7 +78,7 @@ class BaseAccount(db.Model):
return [self.code, self.title_l10n] + [x.title for x in self.l10n] return [self.code, self.title_l10n] + [x.title for x in self.l10n]
class BaseAccountL10n(db.Model): class BaseAccountL10n(base_cls):
"""A localized base account title.""" """A localized base account title."""
__tablename__ = "accounting_base_accounts_l10n" __tablename__ = "accounting_base_accounts_l10n"
"""The table name.""" """The table name."""
@@ -95,7 +95,7 @@ class BaseAccountL10n(db.Model):
"""The localized title.""" """The localized title."""
class Account(db.Model): class Account(base_cls):
"""An account.""" """An account."""
__tablename__ = "accounting_accounts" __tablename__ = "accounting_accounts"
"""The table name.""" """The table name."""
@@ -354,7 +354,7 @@ class Account(db.Model):
return account return account
class AccountL10n(db.Model): class AccountL10n(base_cls):
"""A localized account title.""" """A localized account title."""
__tablename__ = "accounting_accounts_l10n" __tablename__ = "accounting_accounts_l10n"
"""The table name.""" """The table name."""
@@ -371,7 +371,7 @@ class AccountL10n(db.Model):
"""The localized title.""" """The localized title."""
class Currency(db.Model): class Currency(base_cls):
"""A currency.""" """A currency."""
__tablename__ = "accounting_currencies" __tablename__ = "accounting_currencies"
"""The table name.""" """The table name."""
@@ -483,7 +483,7 @@ class Currency(db.Model):
db.session.delete(self) db.session.delete(self)
class CurrencyL10n(db.Model): class CurrencyL10n(base_cls):
"""A localized currency name.""" """A localized currency name."""
__tablename__ = "accounting_currencies_l10n" __tablename__ = "accounting_currencies_l10n"
"""The table name.""" """The table name."""
@@ -543,7 +543,7 @@ class JournalEntryCurrency:
return sum([x.amount for x in self.credit]) return sum([x.amount for x in self.credit])
class JournalEntry(db.Model): class JournalEntry(base_cls):
"""A journal entry.""" """A journal entry."""
__tablename__ = "accounting_journal_entries" __tablename__ = "accounting_journal_entries"
"""The table name.""" """The table name."""
@@ -661,7 +661,7 @@ class JournalEntry(db.Model):
db.session.delete(self) db.session.delete(self)
class JournalEntryLineItem(db.Model): class JournalEntryLineItem(base_cls):
"""A line item in the journal entry.""" """A line item in the journal entry."""
__tablename__ = "accounting_journal_entry_line_items" __tablename__ = "accounting_journal_entry_line_items"
"""The table name.""" """The table name."""
@@ -888,7 +888,7 @@ class JournalEntryLineItem(db.Model):
format_amount(self.amount)] format_amount(self.amount)]
class Option(db.Model): class Option(base_cls):
"""An option.""" """An option."""
__tablename__ = "accounting_options" __tablename__ = "accounting_options"
"""The table name.""" """The table name."""
+2 -1
View File
@@ -22,9 +22,10 @@ This module should not import any other module from the application.
from secrets import randbelow from secrets import randbelow
from .. import db from .. import db
from ..utils.user import base_cls
def new_id(cls: type[db.Model]): def new_id(cls: type[base_cls]):
"""Generates and returns a new, unused random ID for the data model. """Generates and returns a new, unused random ID for the data model.
:param cls: The data model. :param cls: The data model.
+15 -4
View File
@@ -23,10 +23,10 @@ from abc import ABC, abstractmethod
import sqlalchemy as sa import sqlalchemy as sa
from flask import g, Response from flask import g, Response
from flask_sqlalchemy.model import Model from sqlalchemy.orm import DeclarativeBase
class UserUtilityInterface[T: Model](ABC): class UserUtilityInterface[T: DeclarativeBase](ABC):
"""The interface for the user utilities.""" """The interface for the user utilities."""
@abstractmethod @abstractmethod
@@ -67,6 +67,14 @@ class UserUtilityInterface[T: Model](ABC):
:return: The response to require the user to log in. :return: The response to require the user to log in.
""" """
@property
@abstractmethod
def base(self) -> type[DeclarativeBase]:
"""Returns the base data model.
:return: The base data model.
"""
@property @property
@abstractmethod @abstractmethod
def cls(self) -> type[T]: def cls(self) -> type[T]:
@@ -109,7 +117,9 @@ class UserUtilityInterface[T: Model](ABC):
__user_utils: UserUtilityInterface __user_utils: UserUtilityInterface
"""The user utilities.""" """The user utilities."""
type user_cls = Model base_cls = DeclarativeBase
"""The base data model."""
type user_cls = DeclarativeBase
"""The user class.""" """The user class."""
user_pk_column: sa.Column = sa.Column(sa.Integer) user_pk_column: sa.Column = sa.Column(sa.Integer)
"""The primary key column of the user class.""" """The primary key column of the user class."""
@@ -121,8 +131,9 @@ def init_user_utils(utils: UserUtilityInterface) -> None:
:param utils: The user utilities. :param utils: The user utilities.
:return: None. :return: None.
""" """
global __user_utils, user_cls, user_pk_column global __user_utils, base_cls, user_cls, user_pk_column
__user_utils = utils __user_utils = utils
base_cls = utils.base
user_cls = utils.cls user_cls = utils.cls
user_pk_column = utils.pk_column user_pk_column = utils.pk_column
+3 -3
View File
@@ -29,7 +29,7 @@ from flask import Flask
from flask.testing import FlaskCliRunner from flask.testing import FlaskCliRunner
from sqlalchemy.sql.ddl import DropTable from sqlalchemy.sql.ddl import DropTable
from test_site import db from test_site import db, Base
from testlib import create_test_app from testlib import create_test_app
@@ -63,7 +63,7 @@ class ConsoleCommandTestCase(unittest.TestCase):
# Drop every accounting table, to see if accounting-init-db # Drop every accounting table, to see if accounting-init-db
# recreates them correctly. # recreates them correctly.
tables: list[sa.Table] \ tables: list[sa.Table] \
= [db.metadata.tables[x] for x in db.metadata.tables = [Base.metadata.tables[x] for x in Base.metadata.tables
if x.startswith("accounting_")] if x.startswith("accounting_")]
for table in tables: for table in tables:
db.session.execute(DropTable(table)) db.session.execute(DropTable(table))
@@ -207,7 +207,7 @@ class ConsoleCommandTestCase(unittest.TestCase):
with self.__app.app_context(): with self.__app.app_context():
# Resets the accounts. # Resets the accounts.
tables: list[sa.Table] \ tables: list[sa.Table] \
= [db.metadata.tables[x] for x in db.metadata.tables = [Base.metadata.tables[x] for x in Base.metadata.tables
if x.startswith("accounting_")] if x.startswith("accounting_")]
for table in tables: for table in tables:
db.session.execute(DropTable(table)) db.session.execute(DropTable(table))
+10 -1
View File
@@ -28,6 +28,7 @@ from flask.testing import FlaskCliRunner
from flask_babel_js import BabelJS from flask_babel_js import BabelJS
from flask_sqlalchemy import SQLAlchemy from flask_sqlalchemy import SQLAlchemy
from flask_wtf import CSRFProtect from flask_wtf import CSRFProtect
from sqlalchemy.orm import DeclarativeBase
bp: Blueprint = Blueprint("home", __name__) bp: Blueprint = Blueprint("home", __name__)
"""The global blueprint.""" """The global blueprint."""
@@ -39,6 +40,10 @@ db: SQLAlchemy = SQLAlchemy()
"""The database instance.""" """The database instance."""
class Base(DeclarativeBase):
"""The base class for all models."""
def create_app(is_testing: bool = False, is_skip_accounts: bool = False, def create_app(is_testing: bool = False, is_skip_accounts: bool = False,
is_skip_currencies: bool = False) -> Flask: is_skip_currencies: bool = False) -> Flask:
"""Create and configure the application. """Create and configure the application.
@@ -99,6 +104,10 @@ def create_app(is_testing: bool = False, is_skip_accounts: bool = False,
from accounting.utils.next_uri import append_next from accounting.utils.next_uri import append_next
return redirect(append_next(url_for("auth.login-form"))) return redirect(append_next(url_for("auth.login-form")))
@property
def base(self) -> type[DeclarativeBase]:
return Base
@property @property
def cls(self) -> type[auth.User]: def cls(self) -> type[auth.User]:
return auth.User return auth.User
@@ -137,7 +146,7 @@ def init_db(app: Flask, is_skip_accounts: bool,
otherwise. otherwise.
:return: None. :return: None.
""" """
db.create_all() Base.metadata.create_all(db.engine)
from .auth import User from .auth import User
for username in ["viewer", "editor", "admin", "nobody"]: for username in ["viewer", "editor", "admin", "nobody"]:
user: User | None = db.session.scalar( user: User | None = db.session.scalar(
+2 -2
View File
@@ -24,13 +24,13 @@ from flask import Blueprint, render_template, Flask, redirect, url_for, \
session, request, g, Response, abort session, request, g, Response, abort
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from . import db from . import db, Base
bp: Blueprint = Blueprint("auth", __name__, url_prefix="/") bp: Blueprint = Blueprint("auth", __name__, url_prefix="/")
"""The authentication blueprint.""" """The authentication blueprint."""
class User(db.Model): class User(Base):
"""A user.""" """A user."""
__tablename__ = "users" __tablename__ = "users"
"""The table name.""" """The table name."""