From 591fb4a7ab441746213e88bcef6e0432be784298 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BE=9D=E7=91=AA=E8=B2=93?= Date: Mon, 6 Feb 2023 21:45:28 +0800 Subject: [PATCH] Replaced the UserClient class and the get_user_client function with the get_client function in the tests, for simplicity. --- tests/test_account.py | 86 +++++++++++++++++++------------------- tests/test_base_account.py | 20 ++++----- tests/testlib.py | 21 ++-------- 3 files changed, 56 insertions(+), 71 deletions(-) diff --git a/tests/test_account.py b/tests/test_account.py index 4f79f3e..343631f 100644 --- a/tests/test_account.py +++ b/tests/test_account.py @@ -25,7 +25,7 @@ from click.testing import Result from flask import Flask from flask.testing import FlaskCliRunner -from testlib import UserClient, get_user_client +from testlib import get_client from test_site import create_app @@ -108,9 +108,7 @@ class AccountTestCase(unittest.TestCase): Account.query.delete() db.session.commit() - editor: UserClient = get_user_client(self, self.app, "editor") - self.client: httpx.Client = editor.client - self.csrf_token: str = editor.csrf_token + self.client, self.csrf_token = get_client(self, self.app, "editor") response: httpx.Response response = self.client.post("/accounting/accounts/store", @@ -135,47 +133,47 @@ class AccountTestCase(unittest.TestCase): :return: None. """ from accounting.models import Account + client, csrf_token = get_client(self, self.app, "nobody") response: httpx.Response - nobody: UserClient = get_user_client(self, self.app, "nobody") - response = nobody.client.get("/accounting/accounts") + response = client.get("/accounting/accounts") self.assertEqual(response.status_code, 403) - response = nobody.client.get("/accounting/accounts/1111-001") + response = client.get("/accounting/accounts/1111-001") self.assertEqual(response.status_code, 403) - response = nobody.client.get("/accounting/accounts/create") + response = client.get("/accounting/accounts/create") self.assertEqual(response.status_code, 403) - response = nobody.client.post("/accounting/accounts/store", - data={"csrf_token": nobody.csrf_token, - "base_code": "1113", - "title": "1113 title"}) + response = client.post("/accounting/accounts/store", + data={"csrf_token": csrf_token, + "base_code": "1113", + "title": "1113 title"}) self.assertEqual(response.status_code, 403) - response = nobody.client.get("/accounting/accounts/1111-001/edit") + response = client.get("/accounting/accounts/1111-001/edit") self.assertEqual(response.status_code, 403) - response = nobody.client.post("/accounting/accounts/1111-001/update", - data={"csrf_token": nobody.csrf_token, - "base_code": "1111", - "title": "1111 title #2"}) + response = client.post("/accounting/accounts/1111-001/update", + data={"csrf_token": csrf_token, + "base_code": "1111", + "title": "1111 title #2"}) self.assertEqual(response.status_code, 403) - response = nobody.client.post("/accounting/accounts/1111-001/delete", - data={"csrf_token": nobody.csrf_token}) + response = client.post("/accounting/accounts/1111-001/delete", + data={"csrf_token": csrf_token}) self.assertEqual(response.status_code, 403) - response = nobody.client.get("/accounting/accounts/bases/1111") + response = client.get("/accounting/accounts/bases/1111") self.assertEqual(response.status_code, 403) with self.app.app_context(): account_id: int = Account.find_by_code("1112-001").id - response = nobody.client.post("/accounting/accounts/bases/1112", - data={"csrf_token": nobody.csrf_token, - "next": "/next", - f"{account_id}-no": "5"}) + response = client.post("/accounting/accounts/bases/1112", + data={"csrf_token": csrf_token, + "next": "/next", + f"{account_id}-no": "5"}) self.assertEqual(response.status_code, 403) def test_viewer(self) -> None: @@ -184,47 +182,47 @@ class AccountTestCase(unittest.TestCase): :return: None. """ from accounting.models import Account + client, csrf_token = get_client(self, self.app, "viewer") response: httpx.Response - viewer: UserClient = get_user_client(self, self.app, "viewer") - response = viewer.client.get("/accounting/accounts") + response = client.get("/accounting/accounts") self.assertEqual(response.status_code, 200) - response = viewer.client.get("/accounting/accounts/1111-001") + response = client.get("/accounting/accounts/1111-001") self.assertEqual(response.status_code, 200) - response = viewer.client.get("/accounting/accounts/create") + response = client.get("/accounting/accounts/create") self.assertEqual(response.status_code, 403) - response = viewer.client.post("/accounting/accounts/store", - data={"csrf_token": viewer.csrf_token, - "base_code": "1113", - "title": "1113 title"}) + response = client.post("/accounting/accounts/store", + data={"csrf_token": csrf_token, + "base_code": "1113", + "title": "1113 title"}) self.assertEqual(response.status_code, 403) - response = viewer.client.get("/accounting/accounts/1111-001/edit") + response = client.get("/accounting/accounts/1111-001/edit") self.assertEqual(response.status_code, 403) - response = viewer.client.post("/accounting/accounts/1111-001/update", - data={"csrf_token": viewer.csrf_token, - "base_code": "1111", - "title": "1111 title #2"}) + response = client.post("/accounting/accounts/1111-001/update", + data={"csrf_token": csrf_token, + "base_code": "1111", + "title": "1111 title #2"}) self.assertEqual(response.status_code, 403) - response = viewer.client.post("/accounting/accounts/1111-001/delete", - data={"csrf_token": viewer.csrf_token}) + response = client.post("/accounting/accounts/1111-001/delete", + data={"csrf_token": csrf_token}) self.assertEqual(response.status_code, 403) - response = viewer.client.get("/accounting/accounts/bases/1111") + response = client.get("/accounting/accounts/bases/1111") self.assertEqual(response.status_code, 200) with self.app.app_context(): account_id: int = Account.find_by_code("1112-001").id - response = viewer.client.post("/accounting/accounts/bases/1112", - data={"csrf_token": viewer.csrf_token, - "next": "/next", - f"{account_id}-no": "5"}) + response = client.post("/accounting/accounts/bases/1112", + data={"csrf_token": csrf_token, + "next": "/next", + f"{account_id}-no": "5"}) self.assertEqual(response.status_code, 403) def test_editor(self) -> None: diff --git a/tests/test_base_account.py b/tests/test_base_account.py index 1fff9ea..800c57c 100644 --- a/tests/test_base_account.py +++ b/tests/test_base_account.py @@ -24,7 +24,7 @@ from click.testing import Result from flask import Flask from flask.testing import FlaskCliRunner -from testlib import UserClient, get_user_client +from testlib import get_client from test_site import create_app @@ -92,13 +92,13 @@ class BaseAccountTestCase(unittest.TestCase): :return: None. """ + client, csrf_token = get_client(self, self.app, "nobody") response: httpx.Response - nobody: UserClient = get_user_client(self, self.app, "nobody") - response = nobody.client.get("/accounting/base-accounts") + response = client.get("/accounting/base-accounts") self.assertEqual(response.status_code, 403) - response = nobody.client.get("/accounting/base-accounts/1111") + response = client.get("/accounting/base-accounts/1111") self.assertEqual(response.status_code, 403) def test_viewer(self) -> None: @@ -106,13 +106,13 @@ class BaseAccountTestCase(unittest.TestCase): :return: None. """ + client, csrf_token = get_client(self, self.app, "viewer") response: httpx.Response - viewer: UserClient = get_user_client(self, self.app, "viewer") - response = viewer.client.get("/accounting/base-accounts") + response = client.get("/accounting/base-accounts") self.assertEqual(response.status_code, 200) - response = viewer.client.get("/accounting/base-accounts/1111") + response = client.get("/accounting/base-accounts/1111") self.assertEqual(response.status_code, 200) def test_editor(self) -> None: @@ -120,11 +120,11 @@ class BaseAccountTestCase(unittest.TestCase): :return: None. """ + client, csrf_token = get_client(self, self.app, "editor") response: httpx.Response - editor: UserClient = get_user_client(self, self.app, "editor") - response = editor.client.get("/accounting/base-accounts") + response = client.get("/accounting/base-accounts") self.assertEqual(response.status_code, 200) - response = editor.client.get("/accounting/base-accounts/1111") + response = client.get("/accounting/base-accounts/1111") self.assertEqual(response.status_code, 200) diff --git a/tests/testlib.py b/tests/testlib.py index 51438be..95f5ff9 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -24,27 +24,14 @@ import httpx from flask import Flask -class UserClient: - """A user client.""" - - def __init__(self, client: httpx.Client, csrf_token: str): - """Constructs a user client. - - :param client: The client. - :param csrf_token: The CSRF token. - """ - self.client: httpx.Client = client - self.csrf_token: str = csrf_token - - -def get_user_client(test_case: TestCase, app: Flask, username: str) \ - -> UserClient: +def get_client(test_case: TestCase, app: Flask, username: str) \ + -> tuple[httpx.Client, str]: """Returns a user client. :param test_case: The test case. :param app: The Flask application. :param username: The username. - :return: The user client. + :return: A tuple of the client and the CSRF token. """ client: httpx.Client = httpx.Client(app=app, base_url="https://testserver") client.headers["Referer"] = "https://testserver" @@ -54,7 +41,7 @@ def get_user_client(test_case: TestCase, app: Flask, username: str) \ "username": username}) test_case.assertEqual(response.status_code, 302) test_case.assertEqual(response.headers["Location"], "/") - return UserClient(client, csrf_token) + return client, csrf_token def get_csrf_token(test_case: TestCase, client: httpx.Client, uri: str) -> str: