diff --git a/src/flask_digest_auth/auth.py b/src/flask_digest_auth/auth.py index bb53c8a..9d0b744 100644 --- a/src/flask_digest_auth/auth.py +++ b/src/flask_digest_auth/auth.py @@ -23,6 +23,7 @@ from __future__ import annotations import sys import typing as t +from abc import ABC, abstractmethod from functools import wraps from random import random from secrets import token_urlsafe @@ -35,6 +36,36 @@ from flask_digest_auth.algo import calc_response from flask_digest_auth.exception import UnauthorizedException +class BasePasswordHashGetter(ABC): + """The base password hash getter.""" + + @staticmethod + @abstractmethod + def __call__(username: str) -> t.Optional[str]: + """Returns the password hash of a user. + + :param username: The username. + :return: The password hash, or None if the user does not exist. + :raise UnboundLocalError: When the password hash getter function is + not registered yet. + """ + + +class BaseUserGetter(ABC): + """The base user getter.""" + + @staticmethod + @abstractmethod + def __call__(username: str) -> t.Optional[t.Any]: + """Returns a user. + + :param username: The username. + :return: The user, or None if the user does not exist. + :raise UnboundLocalError: When the user getter function is not + registered yet. + """ + + class DigestAuth: """The HTTP digest authentication.""" @@ -51,11 +82,43 @@ class DigestAuth: self.use_opaque: bool = True self.domain: t.List[str] = [] self.qop: t.List[str] = ["auth", "auth-int"] - self.__get_password_hash: t.Callable[[str], t.Optional[str]] \ - = lambda x: None - self.__get_user: t.Callable[[str], t.Optional] = lambda x: None self.app: t.Optional[Flask] = None + class DummyPasswordHashGetter(BasePasswordHashGetter): + """The dummy password hash getter.""" + + @staticmethod + def __call__(username: str) -> t.Optional[str]: + """Returns the password hash of a user. + + :param username: The username. + :return: The password hash, or None if the user does not exist. + :raise UnboundLocalError: When the password hash getter function + is not registered yet. + """ + raise UnboundLocalError("The function to return the password" + " hash was not registered yet.") + + self.__get_password_hash: BasePasswordHashGetter \ + = DummyPasswordHashGetter() + + class DummyUserGetter(BaseUserGetter): + """The dummy user getter.""" + + @staticmethod + def __call__(username: str) -> t.Optional[t.Any]: + """Returns a user. + + :param username: The username. + :return: The user, or None if the user does not exist. + :raise UnboundLocalError: When the user getter function is not + registered yet. + """ + raise UnboundLocalError("The function to return the user" + " was not registered yet.") + + self.__get_user: BaseUserGetter = DummyUserGetter() + def login_required(self, view) -> t.Callable: """The view decorator for HTTP digest authentication. @@ -127,8 +190,8 @@ class DigestAuth: except BadData: raise UnauthorizedException("Invalid opaque") state.opaque = authorization.opaque - password_hash: t.Optional[str] = self.__get_password_hash( - authorization.username) + password_hash: t.Optional[str] \ + = self.__get_password_hash(authorization.username) if password_hash is None: raise UnauthorizedException( f"No such user \"{authorization.username}\"") @@ -187,7 +250,20 @@ class DigestAuth: hash, or None if the user does not exist. :return: None. """ - self.__get_password_hash = func + + class PasswordHashGetter(BasePasswordHashGetter): + """The base password hash getter.""" + + @staticmethod + def __call__(username: str) -> t.Optional[str]: + """Returns the password hash of a user. + + :param username: The username. + :return: The password hash, or None if the user does not exist. + """ + return func(username) + + self.__get_password_hash = PasswordHashGetter() def register_get_user(self, func: t.Callable[[str], t.Optional[t.Any]])\ -> None: @@ -197,7 +273,20 @@ class DigestAuth: or None if the user does not exist. :return: None. """ - self.__get_user = func + + class UserGetter(BaseUserGetter): + """The user getter.""" + + @staticmethod + def __call__(username: str) -> t.Optional[t.Any]: + """Returns a user. + + :param username: The username. + :return: The user, or None if the user does not exist. + """ + return func(username) + + self.__get_user = UserGetter() def init_app(self, app: Flask) -> None: """Initializes the Flask application.