diff --git a/src/accounting/utils/next_uri.py b/src/accounting/utils/next_uri.py index 9d276dd..48bf47f 100644 --- a/src/accounting/utils/next_uri.py +++ b/src/accounting/utils/next_uri.py @@ -41,11 +41,8 @@ def inherit_next(uri: str) -> str: :param uri: The URI. :return: The URI with the current next URI added at the query argument. """ - next_uri: str | None = request.form.get("next") \ - if request.method == "POST" else request.args.get("next") - if next_uri is None: - return uri - return __set_next(uri, next_uri) + next_uri: str | None = __get_next_uri() + return uri if next_uri is None else __set_next(uri, next_uri) def or_next(uri: str) -> str: @@ -54,9 +51,22 @@ def or_next(uri: str) -> str: :param uri: The URI. :return: The next URI or the supplied URI. """ + next_uri: str | None = __get_next_uri() + return uri if next_uri is None else next_uri + + +def __get_next_uri() -> str | None: + """Returns the valid next URI. + + :return: The valid next URI. + """ next_uri: str | None = request.form.get("next") \ if request.method == "POST" else request.args.get("next") - return uri if next_uri is None else next_uri + if next_uri is None or not next_uri.startswith("/"): + return None + if len(next_uri) > 512: + return next_uri[:512] + return next_uri def __set_next(uri: str, next_uri: str) -> str: diff --git a/tests/test_utils.py b/tests/test_utils.py index 8d8eb07..307d8e9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -101,6 +101,60 @@ class NextUriTestCase(unittest.TestCase): "name": "viewer"}) self.assertEqual(response.status_code, 200) + def test_invalid(self) -> None: + """Tests the next URI utilities without an invalid next URI. + + :return: None. + """ + def test_invalid_next_uri_view() -> str: + """The test view without the next URI.""" + self.assertEqual(inherit_next(self.TARGET), + request.args.get("inherit-expected")) + self.assertEqual(or_next(self.TARGET), + request.args.get("or-expected")) + return "" + + self.app.add_url_rule("/test-invalid-next", + view_func=test_invalid_next_uri_view, + methods=["GET", "POST"]) + client: httpx.Client = httpx.Client(app=self.app, base_url=TEST_SERVER) + client.headers["Referer"] = TEST_SERVER + csrf_token: str = get_csrf_token(client) + next_uri: str + expected1: str + expected2: str + response: httpx.Response + + # A foreign URI + next_uri = "https://example.com" + expected1 = self.TARGET + expected2 = self.TARGET + response = client.get(f"/test-invalid-next?next={quote_plus(next_uri)}" + f"&inherit-expected={quote_plus(expected1)}" + f"&or-expected={quote_plus(expected2)}") + self.assertEqual(response.status_code, 200) + response = client.post("/test-invalid-next" + f"?inherit-expected={quote_plus(expected1)}" + f"&or-expected={quote_plus(expected2)}", + data={"csrf_token": csrf_token, + "next": next_uri}) + self.assertEqual(response.status_code, 200) + + # An extremely-long URI to trigger the error + next_uri = "/" + "x" * 1024 + expected2 = next_uri[:512] + expected1 = f"{self.TARGET}?next={quote_plus(expected2)}" + response = client.get(f"/test-invalid-next?next={quote_plus(next_uri)}" + f"&inherit-expected={quote_plus(expected1)}" + f"&or-expected={quote_plus(expected2)}") + self.assertEqual(response.status_code, 200) + response = client.post("/test-invalid-next" + f"?inherit-expected={quote_plus(expected1)}" + f"&or-expected={quote_plus(expected2)}", + data={"csrf_token": csrf_token, + "next": next_uri}) + self.assertEqual(response.status_code, 200) + class QueryKeywordParserTestCase(unittest.TestCase): """The test case for the query keyword parser."""