diff --git a/tests/test_utils.py b/tests/test_utils.py index 607fd50..73f5110 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -21,12 +21,12 @@ import unittest from urllib.parse import quote_plus import httpx -from flask import Flask, request +from flask import Flask, request, render_template_string from accounting.utils.next_uri import append_next, inherit_next, or_next from accounting.utils.pagination import Pagination, DEFAULT_PAGE_SIZE from accounting.utils.query import parse_query_keywords -from test_site import create_app, csrf +from test_site import create_app class NextUriTestCase(unittest.TestCase): @@ -40,8 +40,12 @@ class NextUriTestCase(unittest.TestCase): app: Flask = create_app(is_testing=True) target: str = "/target" + @app.get("/test-csrf") + def test_csrf() -> str: + """The test view to return the CSRF token.""" + return render_template_string("{{csrf_token()}}") + @app.route("/test-next", methods=["GET", "POST"]) - @csrf.exempt def test_next_view() -> str: """The test view with the next URI.""" current_uri: str = request.full_path if request.query_string \ @@ -56,7 +60,6 @@ class NextUriTestCase(unittest.TestCase): return "" @app.route("/test-no-next", methods=["GET", "POST"]) - @csrf.exempt def test_no_next_view() -> str: """The test view without the next URI.""" current_uri: str = request.full_path if request.query_string \ @@ -70,19 +73,22 @@ class NextUriTestCase(unittest.TestCase): client: httpx.Client = httpx.Client(app=app, base_url="https://testserver") client.headers["Referer"] = "https://testserver" + csrf_token: str = client.get("/test-csrf").text response: httpx.Response # With the next URI response = client.get("/test-next?next=/next&q=abc&page-no=4") self.assertEqual(response.status_code, 200) - response = client.post("/test-next", data={"next": "/next", + response = client.post("/test-next", data={"csrf_token": csrf_token, + "next": "/next", "name": "viewer"}) self.assertEqual(response.status_code, 200) # Without the next URI response = client.get("/test-no-next?q=abc&page-no=4") self.assertEqual(response.status_code, 200) - response = client.post("/test-no-next", data={"name": "viewer"}) + response = client.post("/test-no-next", data={"csrf_token": csrf_token, + "name": "viewer"}) self.assertEqual(response.status_code, 200)