Added the on-login callback for the log in bookkeeping.

This commit is contained in:
依瑪貓 2022-11-27 06:58:29 +11:00
parent 2aaaa9f47f
commit be163d35fb
4 changed files with 110 additions and 12 deletions

View File

@ -66,6 +66,12 @@ Flask-Digest-Auth supports log out. The user will be prompted for
new username and password.
Log In Bookkeeping
##################
You can register a callback to run when the user logs in.
.. _HTTP Digest Authentication: https://en.wikipedia.org/wiki/Digest_access_authentication
.. _RFC 2617: https://www.rfc-editor.org/rfc/rfc2617
.. _Flask: https://flask.palletsprojects.com
@ -336,6 +342,19 @@ the next browser automatic authentication to fail, forcing the browser
to ask the user for the username and password again.
Log In Bookkeeping
=================#
You can register a callback to run when the user logs in, for ex.,
logging the log in event, adding the log in counter, etc.
::
@auth.register_on_login
def on_login(user: User) -> None:
user.visits = user.visits + 1
Writing Tests
=============

View File

@ -67,6 +67,18 @@ class BaseUserGetter:
" was not registered yet.")
class BaseOnLogInCallback:
"""The base callback when the user logs in."""
@staticmethod
def __call__(user: t.Any) -> None:
"""Runs the callback when the user logs in.
:param user: The logged-in user.
:return: None.
"""
class DigestAuth:
"""The HTTP digest authentication."""
@ -87,6 +99,7 @@ class DigestAuth:
self.__get_password_hash: BasePasswordHashGetter \
= BasePasswordHashGetter()
self.__get_user: BaseUserGetter = BaseUserGetter()
self.__on_login: BaseOnLogInCallback = BaseOnLogInCallback()
def login_required(self, view) -> t.Callable:
"""The view decorator for HTTP digest authentication.
@ -125,7 +138,9 @@ class DigestAuth:
"Not an HTTP digest authorization")
self.authenticate(state)
session["user"] = authorization.username
g.user = self.__get_user(authorization.username)
user = self.__get_user(authorization.username)
g.user = user
self.__on_login(user)
return view(*args, **kwargs)
except UnauthorizedException as e:
if len(e.args) > 0:
@ -257,6 +272,27 @@ class DigestAuth:
self.__get_user = UserGetter()
def register_on_login(self, func: t.Callable[[t.Any], None]) -> None:
"""Registers the callback when the user logs in.
:param func: The callback given the logged-in user.
:return: None.
"""
class OnLogInCallback:
"""The callback when the user logs in."""
@staticmethod
def __call__(user: t.Any) -> None:
"""Runs the callback when the user logs in.
:param user: The logged-in user.
:return: None.
"""
func(user)
self.__on_login = OnLogInCallback()
def init_app(self, app: Flask) -> None:
"""Initializes the Flask application.
@ -303,6 +339,7 @@ class DigestAuth:
user = login_manager.user_callback(
authorization.username)
login_user(user)
self.__on_login(user)
return user
except UnauthorizedException as e:
if str(e) != "":

View File

@ -20,7 +20,6 @@
"""
import typing as t
from secrets import token_urlsafe
from types import SimpleNamespace
from flask import Response, Flask, g, redirect, request
from flask_testing import TestCase
@ -33,6 +32,20 @@ _USERNAME: str = "Mufasa"
_PASSWORD: str = "Circle Of Life"
class User:
"""A dummy user"""
def __init__(self, username: str, password_hash: str):
"""Constructs a dummy user.
:param username: The username.
:param password_hash: The password hash.
"""
self.username: str = username
self.password_hash: str = password_hash
self.visits: int = 0
class AuthenticationTestCase(TestCase):
"""The test case for the HTTP digest authentication."""
@ -50,8 +63,9 @@ class AuthenticationTestCase(TestCase):
auth: DigestAuth = DigestAuth(realm=_REALM)
auth.init_app(app)
user_db: t.Dict[str, str] \
= {_USERNAME: make_password_hash(_REALM, _USERNAME, _PASSWORD)}
user_db: t.Dict[str, User] \
= {_USERNAME: User(
_USERNAME, make_password_hash(_REALM, _USERNAME, _PASSWORD))}
@auth.register_get_password
def get_password_hash(username: str) -> t.Optional[str]:
@ -60,7 +74,8 @@ class AuthenticationTestCase(TestCase):
:param username: The username.
:return: The password hash, or None if the user does not exist.
"""
return user_db[username] if username in user_db else None
return user_db[username].password_hash if username in user_db \
else None
@auth.register_get_user
def get_user(username: str) -> t.Optional[t.Any]:
@ -69,8 +84,16 @@ class AuthenticationTestCase(TestCase):
:param username: The username.
:return: The user, or None if the user does not exist.
"""
return SimpleNamespace(username=username) if username in user_db \
else None
return user_db[username] if username in user_db else None
@auth.register_on_login
def on_login(user: User):
"""The callback when the user logs in.
:param user: The logged-in user.
:return: None.
"""
user.visits = user.visits + 1
@app.get("/admin-1/auth", endpoint="admin-1")
@auth.login_required
@ -118,6 +141,7 @@ class AuthenticationTestCase(TestCase):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.decode("UTF-8"),
f"Hello, {_USERNAME}! #2")
self.assertEqual(g.user.visits, 1)
def test_stale_opaque(self) -> None:
"""Tests the stale and opaque value.
@ -194,3 +218,4 @@ class AuthenticationTestCase(TestCase):
response = self.client.get(admin_uri)
self.assertEqual(response.status_code, 200)
self.assertEqual(g.user.visits, 2)

View File

@ -21,6 +21,7 @@
import typing as t
from secrets import token_urlsafe
import flask_login
from flask import Response, Flask, g, redirect, request
from flask_testing import TestCase
from werkzeug.datastructures import WWWAuthenticate, Authorization
@ -35,12 +36,15 @@ _PASSWORD: str = "Circle Of Life"
class User:
"""A dummy user."""
def __init__(self, username: str):
def __init__(self, username: str, password_hash: str):
"""Constructs a dummy user.
:param username: The username.
:param password_hash: The password hash.
"""
self.username: str = username
self.password_hash: str = password_hash
self.visits: int = 0
self.is_authenticated: bool = True
self.is_active: bool = True
self.is_anonymous: bool = False
@ -82,8 +86,9 @@ class FlaskLoginTestCase(TestCase):
auth: DigestAuth = DigestAuth(realm=_REALM)
auth.init_app(app)
user_db: t.Dict[str, str] \
= {_USERNAME: make_password_hash(_REALM, _USERNAME, _PASSWORD)}
user_db: t.Dict[str, User] \
= {_USERNAME: User(
_USERNAME, make_password_hash(_REALM, _USERNAME, _PASSWORD))}
@auth.register_get_password
def get_password_hash(username: str) -> t.Optional[str]:
@ -92,7 +97,17 @@ class FlaskLoginTestCase(TestCase):
:param username: The username.
:return: The password hash, or None if the user does not exist.
"""
return user_db[username] if username in user_db else None
return user_db[username].password_hash if username in user_db \
else None
@auth.register_on_login
def on_login(user: User):
"""The callback when the user logs in.
:param user: The logged-in user.
:return: None.
"""
user.visits = user.visits + 1
@login_manager.user_loader
def load_user(user_id: str) -> t.Optional[User]:
@ -101,7 +116,7 @@ class FlaskLoginTestCase(TestCase):
:param user_id: The username.
:return: The user, or None if the user does not exist.
"""
return User(user_id) if user_id in user_db else None
return user_db[user_id] if user_id in user_db else None
@app.get("/admin-1/auth", endpoint="admin-1")
@flask_login.login_required
@ -152,6 +167,7 @@ class FlaskLoginTestCase(TestCase):
self.assertEqual(response.status_code, 200)
self.assertEqual(response.data.decode("UTF-8"),
f"Hello, {_USERNAME}! #2")
self.assertEqual(flask_login.current_user.visits, 1)
def test_stale_opaque(self) -> None:
"""Tests the stale and opaque value.
@ -237,3 +253,4 @@ class FlaskLoginTestCase(TestCase):
response = self.client.get(admin_uri)
self.assertEqual(response.status_code, 200)
self.assertEqual(flask_login.current_user.visits, 2)