diff --git a/src/flask_digest_auth/auth.py b/src/flask_digest_auth/auth.py index a5c7858..a60e89b 100644 --- a/src/flask_digest_auth/auth.py +++ b/src/flask_digest_auth/auth.py @@ -110,6 +110,36 @@ class DigestAuth: class NoLogInException(Exception): """The exception thrown when the user is not authorized.""" + def get_logged_in_user() -> t.Optional[t.Any]: + """Returns the currently logged-in user. + + :return: The currently logged-in user. + :raise NoLogInException: When the user is not logged in. + """ + if "user" not in session: + raise NoLogInException + user: t.Optional[t.Any] = self.__get_user(session["user"]) + if user is None: + raise NoLogInException + return user + + def auth_user(state: AuthState) -> t.Any: + """Authenticates a user. + + :param state: The authentication state. + :return: The user. + :raise UnauthorizedException: When the authentication fails. + """ + authorization: Authorization = request.authorization + if authorization is None: + raise UnauthorizedException + if authorization.type != "digest": + raise UnauthorizedException( + "Not an HTTP digest authorization") + self.authenticate(state) + session["user"] = authorization.username + return self.__get_user(authorization.username) + @wraps(view) def login_required_view(*args, **kwargs) -> t.Any: """The login-protected view. @@ -119,36 +149,24 @@ class DigestAuth: :return: The response. """ try: - if "user" not in session: - raise NoLogInException - user: t.Optional[t.Any] = self.__get_user(session["user"]) - if user is None: - raise NoLogInException - g.user = user + g.user = get_logged_in_user() return view(*args, **kwargs) except NoLogInException: - state: AuthState = AuthState() - authorization: Authorization = request.authorization - try: - if authorization is None: - raise UnauthorizedException - if authorization.type != "digest": - raise UnauthorizedException( - "Not an HTTP digest authorization") - self.authenticate(state) - session["user"] = authorization.username - user = self.__get_user(authorization.username) - g.user = user - self.__on_login(user) - return view(*args, **kwargs) - except UnauthorizedException as e: - if len(e.args) > 0: - sys.stderr.write(e.args[0] + "\n") - response: Response = Response() - response.status = 401 - response.headers["WWW-Authenticate"] \ - = self.make_response_header(state) - abort(response) + pass + + state: AuthState = AuthState() + try: + g.user = auth_user(state) + self.__on_login(g.user) + return view(*args, **kwargs) + except UnauthorizedException as e: + if len(e.args) > 0: + sys.stderr.write(e.args[0] + "\n") + response: Response = Response() + response.status = 401 + response.headers["WWW-Authenticate"] \ + = self.make_response_header(state) + abort(response) return login_required_view