diff --git a/src/accounting/__init__.py b/src/accounting/__init__.py index 4c0cc09..628986e 100644 --- a/src/accounting/__init__.py +++ b/src/accounting/__init__.py @@ -18,55 +18,11 @@ """ import typing as t -from abc import ABC, abstractmethod -import sqlalchemy as sa from flask import Flask, Blueprint from flask_sqlalchemy.model import Model -T = t.TypeVar("T", bound=Model) - - -class AbstractUserUtils(t.Generic[T], ABC): - """The abstract user utilities.""" - - @property - @abstractmethod - def cls(self) -> t.Type[T]: - """Returns the user class. - - :return: The user class. - """ - - @property - @abstractmethod - def pk_column(self) -> sa.Column: - """Returns the primary key column. - - :return: The primary key column. - """ - - @property - @abstractmethod - def current_user(self) -> T: - """Returns the current user. - - :return: The current user. - """ - - @abstractmethod - def get_by_username(self, username: str) -> T | None: - """Returns the user by her username. - - :return: The user by her username, or None if the user was not found. - """ - - @abstractmethod - def get_pk(self, user: T) -> int: - """Returns the primary key of the user. - - :return: The primary key of the user. - """ +from accounting.utils.user import AbstractUserUtils def init_app(app: Flask, user_utils: AbstractUserUtils, @@ -87,7 +43,9 @@ def init_app(app: Flask, user_utils: AbstractUserUtils, # The database instance must be set before loading everything # in the application. from .database import set_db - set_db(app.extensions["sqlalchemy"], user_utils) + set_db(app.extensions["sqlalchemy"]) + from .utils.user import init_user_utils + init_user_utils(user_utils) bp: Blueprint = Blueprint("accounting", __name__, url_prefix=url_prefix, diff --git a/src/accounting/account/commands.py b/src/accounting/account/commands.py index a84175c..ea4785d 100644 --- a/src/accounting/account/commands.py +++ b/src/accounting/account/commands.py @@ -24,8 +24,9 @@ from secrets import randbelow import click from flask.cli import with_appcontext -from accounting.database import db, user_utils +from accounting.database import db from accounting.models import BaseAccount, Account, AccountL10n +from accounting.utils.user import has_user, get_user_pk AccountData = tuple[int, str, int, str, str, str, bool] """The format of the account data, as a list of (ID, base account code, number, @@ -45,8 +46,7 @@ def __validate_username(ctx: click.core.Context, param: click.core.Option, value = value.strip() if value == "": raise click.BadParameter("Username empty.") - user: user_utils.cls | None = user_utils.get_by_username(value) - if user is None: + if not has_user(value): raise click.BadParameter(f"User {value} does not exist.") return value @@ -58,7 +58,7 @@ def __validate_username(ctx: click.core.Context, param: click.core.Option, @with_appcontext def init_accounts_command(username: str) -> None: """Initializes the accounts.""" - creator_pk: int = user_utils.get_pk(user_utils.get_by_username(username)) + creator_pk: int = get_user_pk(username) bases: list[BaseAccount] = BaseAccount.query\ .filter(db.func.length(BaseAccount.code) == 4)\ diff --git a/src/accounting/account/forms.py b/src/accounting/account/forms.py index f8ff484..b9c392d 100644 --- a/src/accounting/account/forms.py +++ b/src/accounting/account/forms.py @@ -22,11 +22,12 @@ from flask_wtf import FlaskForm from wtforms import StringField, BooleanField from wtforms.validators import DataRequired, ValidationError -from accounting.database import db, user_utils +from accounting.database import db from accounting.locale import lazy_gettext from accounting.models import BaseAccount, Account from accounting.utils.random_id import new_id from accounting.utils.strip_text import strip_text +from accounting.utils.user import get_current_user_pk class BaseAccountExists: @@ -74,7 +75,7 @@ class AccountForm(FlaskForm): obj.title = self.title.data obj.is_offset_needed = self.is_offset_needed.data if is_new: - current_user_pk: int = user_utils.get_pk(user_utils.current_user) + current_user_pk: int = get_current_user_pk() obj.created_by_id = current_user_pk obj.updated_by_id = current_user_pk if prev_base_code is not None \ @@ -87,7 +88,7 @@ class AccountForm(FlaskForm): :return: None """ - current_user_pk: int = user_utils.get_pk(user_utils.current_user) + current_user_pk: int = get_current_user_pk() obj.updated_by_id = current_user_pk obj.updated_at = sa.func.now() if hasattr(self, "__post_update"): diff --git a/src/accounting/database.py b/src/accounting/database.py index 77423db..96ad6e6 100644 --- a/src/accounting/database.py +++ b/src/accounting/database.py @@ -25,21 +25,15 @@ time. from flask_sqlalchemy import SQLAlchemy -from accounting import AbstractUserUtils - db: SQLAlchemy """The database instance.""" -user_utils: AbstractUserUtils -"""The user utilities.""" -def set_db(new_db: SQLAlchemy, new_user_utils: AbstractUserUtils) -> None: +def set_db(new_db: SQLAlchemy) -> None: """Sets the database instance. :param new_db: The database instance. - :param new_user_utils: The user utilities. :return: None. """ - global db, user_utils + global db db = new_db - user_utils = new_user_utils diff --git a/src/accounting/models.py b/src/accounting/models.py index edaa90f..1261cf6 100644 --- a/src/accounting/models.py +++ b/src/accounting/models.py @@ -25,10 +25,8 @@ from flask import current_app from flask_babel import get_locale from sqlalchemy import text -from accounting.database import db, user_utils - -user_cls: db.Model = user_utils.cls -user_pk_column: db.Column = user_utils.pk_column +from accounting.database import db +from accounting.utils.user import user_cls, user_pk_column class BaseAccount(db.Model): diff --git a/src/accounting/utils/user.py b/src/accounting/utils/user.py new file mode 100644 index 0000000..dae6739 --- /dev/null +++ b/src/accounting/utils/user.py @@ -0,0 +1,116 @@ +# The Mia! Accounting Flask Project. +# Author: imacat@mail.imacat.idv.tw (imacat), 2023/2/1 + +# Copyright (c) 2023 imacat. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The user utilities. + +This module should not import any other module from the application. + +""" +import typing as t +from abc import ABC, abstractmethod + +import sqlalchemy as sa +from flask_sqlalchemy.model import Model + +T = t.TypeVar("T", bound=Model) + + +class AbstractUserUtils(t.Generic[T], ABC): + """The abstract user utilities.""" + + @property + @abstractmethod + def cls(self) -> t.Type[T]: + """Returns the user class. + + :return: The user class. + """ + + @property + @abstractmethod + def pk_column(self) -> sa.Column: + """Returns the primary key column. + + :return: The primary key column. + """ + + @property + @abstractmethod + def current_user(self) -> T: + """Returns the current user. + + :return: The current user. + """ + + @abstractmethod + def get_by_username(self, username: str) -> T | None: + """Returns the user by her username. + + :return: The user by her username, or None if the user was not found. + """ + + @abstractmethod + def get_pk(self, user: T) -> int: + """Returns the primary key of the user. + + :return: The primary key of the user. + """ + + +__user_utils: AbstractUserUtils +"""The user utilities.""" +user_cls: t.Type[Model] +"""The user class.""" +user_pk_column: sa.Column +"""The primary key column of the user class.""" + + +def init_user_utils(utils: AbstractUserUtils) -> None: + """Initializes the user utilities. + + :param utils: The user utilities. + :return: None. + """ + global __user_utils, user_cls, user_pk_column + __user_utils = utils + user_cls = utils.cls + user_pk_column = utils.pk_column + + +def get_current_user_pk() -> int: + """Returns the primary key value of the currently logged-in user. + + :return: The primary key value of the currently logged-in user. + """ + return __user_utils.get_pk(__user_utils.current_user) + + +def has_user(username: str) -> bool: + """Returns whether a user by the username exists. + + :param username: The username. + :return: True if the user by the username exists, or False otherwise. + """ + return __user_utils.get_by_username(username) is not None + + +def get_user_pk(username: str) -> int: + """Returns the primary key value of the user by the username. + + :param username: The username. + :return: The primary key value of the user by the username. + """ + return __user_utils.get_pk(__user_utils.get_by_username(username)) diff --git a/tests/testsite/__init__.py b/tests/testsite/__init__.py index aac97e4..edb0995 100644 --- a/tests/testsite/__init__.py +++ b/tests/testsite/__init__.py @@ -29,6 +29,8 @@ from flask_sqlalchemy import SQLAlchemy from flask_wtf import CSRFProtect from sqlalchemy import Column +import accounting.utils.user + bp: Blueprint = Blueprint("home", __name__) babel_js: BabelJS = BabelJS() csrf: CSRFProtect = CSRFProtect() @@ -68,7 +70,7 @@ def create_app(is_testing: bool = False) -> Flask: from . import auth auth.init_app(app) - class UserUtils(accounting.AbstractUserUtils[auth.User]): + class UserUtils(accounting.utils.user.AbstractUserUtils[auth.User]): @property def cls(self) -> t.Type[auth.User]: