diff --git a/README.rst b/README.rst index 2ac6347..3fb2f55 100644 --- a/README.rst +++ b/README.rst @@ -66,6 +66,12 @@ Flask-Digest-Auth supports log out. The user will be prompted for new username and password. +Log In Bookkeeping +################## + +You can register a callback to run when the user logs in. + + .. _HTTP Digest Authentication: https://en.wikipedia.org/wiki/Digest_access_authentication .. _RFC 2617: https://www.rfc-editor.org/rfc/rfc2617 .. _Flask: https://flask.palletsprojects.com @@ -336,6 +342,19 @@ the next browser automatic authentication to fail, forcing the browser to ask the user for the username and password again. +Log In Bookkeeping +=================# + +You can register a callback to run when the user logs in, for ex., +logging the log in event, adding the log in counter, etc. + +:: + + @auth.register_on_login + def on_login(user: User) -> None: + user.visits = user.visits + 1 + + Writing Tests ============= diff --git a/src/flask_digest_auth/auth.py b/src/flask_digest_auth/auth.py index 167d8df..a0b4c0e 100644 --- a/src/flask_digest_auth/auth.py +++ b/src/flask_digest_auth/auth.py @@ -67,6 +67,18 @@ class BaseUserGetter: " was not registered yet.") +class BaseOnLogInCallback: + """The base callback when the user logs in.""" + + @staticmethod + def __call__(user: t.Any) -> None: + """Runs the callback when the user logs in. + + :param user: The logged-in user. + :return: None. + """ + + class DigestAuth: """The HTTP digest authentication.""" @@ -87,6 +99,7 @@ class DigestAuth: self.__get_password_hash: BasePasswordHashGetter \ = BasePasswordHashGetter() self.__get_user: BaseUserGetter = BaseUserGetter() + self.__on_login: BaseOnLogInCallback = BaseOnLogInCallback() def login_required(self, view) -> t.Callable: """The view decorator for HTTP digest authentication. @@ -125,7 +138,9 @@ class DigestAuth: "Not an HTTP digest authorization") self.authenticate(state) session["user"] = authorization.username - g.user = self.__get_user(authorization.username) + user = self.__get_user(authorization.username) + g.user = user + self.__on_login(user) return view(*args, **kwargs) except UnauthorizedException as e: if len(e.args) > 0: @@ -257,6 +272,27 @@ class DigestAuth: self.__get_user = UserGetter() + def register_on_login(self, func: t.Callable[[t.Any], None]) -> None: + """Registers the callback when the user logs in. + + :param func: The callback given the logged-in user. + :return: None. + """ + + class OnLogInCallback: + """The callback when the user logs in.""" + + @staticmethod + def __call__(user: t.Any) -> None: + """Runs the callback when the user logs in. + + :param user: The logged-in user. + :return: None. + """ + func(user) + + self.__on_login = OnLogInCallback() + def init_app(self, app: Flask) -> None: """Initializes the Flask application. @@ -303,6 +339,7 @@ class DigestAuth: user = login_manager.user_callback( authorization.username) login_user(user) + self.__on_login(user) return user except UnauthorizedException as e: if str(e) != "": diff --git a/tests/test_auth.py b/tests/test_auth.py index fcb6751..241a211 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -20,7 +20,6 @@ """ import typing as t from secrets import token_urlsafe -from types import SimpleNamespace from flask import Response, Flask, g, redirect, request from flask_testing import TestCase @@ -33,6 +32,20 @@ _USERNAME: str = "Mufasa" _PASSWORD: str = "Circle Of Life" +class User: + """A dummy user""" + + def __init__(self, username: str, password_hash: str): + """Constructs a dummy user. + + :param username: The username. + :param password_hash: The password hash. + """ + self.username: str = username + self.password_hash: str = password_hash + self.visits: int = 0 + + class AuthenticationTestCase(TestCase): """The test case for the HTTP digest authentication.""" @@ -50,8 +63,9 @@ class AuthenticationTestCase(TestCase): auth: DigestAuth = DigestAuth(realm=_REALM) auth.init_app(app) - user_db: t.Dict[str, str] \ - = {_USERNAME: make_password_hash(_REALM, _USERNAME, _PASSWORD)} + user_db: t.Dict[str, User] \ + = {_USERNAME: User( + _USERNAME, make_password_hash(_REALM, _USERNAME, _PASSWORD))} @auth.register_get_password def get_password_hash(username: str) -> t.Optional[str]: @@ -60,7 +74,8 @@ class AuthenticationTestCase(TestCase): :param username: The username. :return: The password hash, or None if the user does not exist. """ - return user_db[username] if username in user_db else None + return user_db[username].password_hash if username in user_db \ + else None @auth.register_get_user def get_user(username: str) -> t.Optional[t.Any]: @@ -69,8 +84,16 @@ class AuthenticationTestCase(TestCase): :param username: The username. :return: The user, or None if the user does not exist. """ - return SimpleNamespace(username=username) if username in user_db \ - else None + return user_db[username] if username in user_db else None + + @auth.register_on_login + def on_login(user: User): + """The callback when the user logs in. + + :param user: The logged-in user. + :return: None. + """ + user.visits = user.visits + 1 @app.get("/admin-1/auth", endpoint="admin-1") @auth.login_required @@ -118,6 +141,7 @@ class AuthenticationTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.data.decode("UTF-8"), f"Hello, {_USERNAME}! #2") + self.assertEqual(g.user.visits, 1) def test_stale_opaque(self) -> None: """Tests the stale and opaque value. @@ -194,3 +218,4 @@ class AuthenticationTestCase(TestCase): response = self.client.get(admin_uri) self.assertEqual(response.status_code, 200) + self.assertEqual(g.user.visits, 2) diff --git a/tests/test_flask_login.py b/tests/test_flask_login.py index f1348f5..679c3af 100644 --- a/tests/test_flask_login.py +++ b/tests/test_flask_login.py @@ -21,6 +21,7 @@ import typing as t from secrets import token_urlsafe +import flask_login from flask import Response, Flask, g, redirect, request from flask_testing import TestCase from werkzeug.datastructures import WWWAuthenticate, Authorization @@ -35,12 +36,15 @@ _PASSWORD: str = "Circle Of Life" class User: """A dummy user.""" - def __init__(self, username: str): + def __init__(self, username: str, password_hash: str): """Constructs a dummy user. :param username: The username. + :param password_hash: The password hash. """ self.username: str = username + self.password_hash: str = password_hash + self.visits: int = 0 self.is_authenticated: bool = True self.is_active: bool = True self.is_anonymous: bool = False @@ -82,8 +86,9 @@ class FlaskLoginTestCase(TestCase): auth: DigestAuth = DigestAuth(realm=_REALM) auth.init_app(app) - user_db: t.Dict[str, str] \ - = {_USERNAME: make_password_hash(_REALM, _USERNAME, _PASSWORD)} + user_db: t.Dict[str, User] \ + = {_USERNAME: User( + _USERNAME, make_password_hash(_REALM, _USERNAME, _PASSWORD))} @auth.register_get_password def get_password_hash(username: str) -> t.Optional[str]: @@ -92,7 +97,17 @@ class FlaskLoginTestCase(TestCase): :param username: The username. :return: The password hash, or None if the user does not exist. """ - return user_db[username] if username in user_db else None + return user_db[username].password_hash if username in user_db \ + else None + + @auth.register_on_login + def on_login(user: User): + """The callback when the user logs in. + + :param user: The logged-in user. + :return: None. + """ + user.visits = user.visits + 1 @login_manager.user_loader def load_user(user_id: str) -> t.Optional[User]: @@ -101,7 +116,7 @@ class FlaskLoginTestCase(TestCase): :param user_id: The username. :return: The user, or None if the user does not exist. """ - return User(user_id) if user_id in user_db else None + return user_db[user_id] if user_id in user_db else None @app.get("/admin-1/auth", endpoint="admin-1") @flask_login.login_required @@ -152,6 +167,7 @@ class FlaskLoginTestCase(TestCase): self.assertEqual(response.status_code, 200) self.assertEqual(response.data.decode("UTF-8"), f"Hello, {_USERNAME}! #2") + self.assertEqual(flask_login.current_user.visits, 1) def test_stale_opaque(self) -> None: """Tests the stale and opaque value. @@ -237,3 +253,4 @@ class FlaskLoginTestCase(TestCase): response = self.client.get(admin_uri) self.assertEqual(response.status_code, 200) + self.assertEqual(flask_login.current_user.visits, 2)