diff --git a/tests/test_account.py b/tests/test_account.py index 88c6432..68b2bf7 100644 --- a/tests/test_account.py +++ b/tests/test_account.py @@ -141,7 +141,7 @@ class AccountTestCase(unittest.TestCase): Account.query.delete() db.session.commit() - self.client, self.csrf_token = get_client(self, self.app, "editor") + self.client, self.csrf_token = get_client(self.app, "editor") response: httpx.Response response = self.client.post(f"{PREFIX}/store", @@ -166,7 +166,7 @@ class AccountTestCase(unittest.TestCase): :return: None. """ from accounting.models import Account - client, csrf_token = get_client(self, self.app, "nobody") + client, csrf_token = get_client(self.app, "nobody") response: httpx.Response response = client.get(PREFIX) @@ -215,7 +215,7 @@ class AccountTestCase(unittest.TestCase): :return: None. """ from accounting.models import Account - client, csrf_token = get_client(self, self.app, "viewer") + client, csrf_token = get_client(self.app, "viewer") response: httpx.Response response = client.get(PREFIX) @@ -527,7 +527,7 @@ class AccountTestCase(unittest.TestCase): """ from accounting.models import Account editor_username, editor2_username = "editor", "editor2" - client, csrf_token = get_client(self, self.app, editor2_username) + client, csrf_token = get_client(self.app, editor2_username) detail_uri: str = f"{PREFIX}/{cash.code}" update_uri: str = f"{PREFIX}/{cash.code}/update" response: httpx.Response @@ -566,7 +566,7 @@ class AccountTestCase(unittest.TestCase): self.assertEqual(cash_account.title_l10n, cash.title) self.assertEqual(cash_account.l10n, []) - set_locale(self, self.client, self.csrf_token, "zh_Hant") + set_locale(self.client, self.csrf_token, "zh_Hant") response = self.client.post(update_uri, data={"csrf_token": self.csrf_token, @@ -581,7 +581,7 @@ class AccountTestCase(unittest.TestCase): self.assertEqual({(x.locale, x.title) for x in cash_account.l10n}, {("zh_Hant", f"{cash.title}-zh_Hant")}) - set_locale(self, self.client, self.csrf_token, "en") + set_locale(self.client, self.csrf_token, "en") response = self.client.post(update_uri, data={"csrf_token": self.csrf_token, @@ -596,7 +596,7 @@ class AccountTestCase(unittest.TestCase): self.assertEqual({(x.locale, x.title) for x in cash_account.l10n}, {("zh_Hant", f"{cash.title}-zh_Hant")}) - set_locale(self, self.client, self.csrf_token, "zh_Hant") + set_locale(self.client, self.csrf_token, "zh_Hant") response = self.client.post(update_uri, data={"csrf_token": self.csrf_token, diff --git a/tests/test_base_account.py b/tests/test_base_account.py index 4d808e1..d77f5f5 100644 --- a/tests/test_base_account.py +++ b/tests/test_base_account.py @@ -108,7 +108,7 @@ class BaseAccountTestCase(unittest.TestCase): :return: None. """ - client, csrf_token = get_client(self, self.app, "nobody") + client, csrf_token = get_client(self.app, "nobody") response: httpx.Response response = client.get("/accounting/base-accounts") @@ -122,7 +122,7 @@ class BaseAccountTestCase(unittest.TestCase): :return: None. """ - client, csrf_token = get_client(self, self.app, "viewer") + client, csrf_token = get_client(self.app, "viewer") response: httpx.Response response = client.get("/accounting/base-accounts") @@ -136,7 +136,7 @@ class BaseAccountTestCase(unittest.TestCase): :return: None. """ - client, csrf_token = get_client(self, self.app, "editor") + client, csrf_token = get_client(self.app, "editor") response: httpx.Response response = client.get("/accounting/base-accounts") diff --git a/tests/test_currency.py b/tests/test_currency.py index f9dcfeb..a657462 100644 --- a/tests/test_currency.py +++ b/tests/test_currency.py @@ -137,7 +137,7 @@ class CurrencyTestCase(unittest.TestCase): Currency.query.delete() db.session.commit() - self.client, self.csrf_token = get_client(self, self.app, "editor") + self.client, self.csrf_token = get_client(self.app, "editor") response: httpx.Response response = self.client.post(f"{PREFIX}/store", @@ -159,7 +159,7 @@ class CurrencyTestCase(unittest.TestCase): :return: None. """ - client, csrf_token = get_client(self, self.app, "nobody") + client, csrf_token = get_client(self.app, "nobody") response: httpx.Response response = client.get(PREFIX) @@ -195,7 +195,7 @@ class CurrencyTestCase(unittest.TestCase): :return: None. """ - client, csrf_token = get_client(self, self.app, "viewer") + client, csrf_token = get_client(self.app, "viewer") response: httpx.Response response = client.get(PREFIX) @@ -474,7 +474,7 @@ class CurrencyTestCase(unittest.TestCase): from accounting.models import Currency from test_site import db editor_username, editor2_username = "editor", "editor2" - client, csrf_token = get_client(self, self.app, editor2_username) + client, csrf_token = get_client(self.app, editor2_username) detail_uri: str = f"{PREFIX}/{zza.code}" update_uri: str = f"{PREFIX}/{zza.code}/update" response: httpx.Response @@ -533,7 +533,7 @@ class CurrencyTestCase(unittest.TestCase): self.assertEqual(zza_currency.name_l10n, zza.name) self.assertEqual(zza_currency.l10n, []) - set_locale(self, self.client, self.csrf_token, "zh_Hant") + set_locale(self.client, self.csrf_token, "zh_Hant") response = self.client.post(update_uri, data={"csrf_token": self.csrf_token, @@ -548,7 +548,7 @@ class CurrencyTestCase(unittest.TestCase): self.assertEqual({(x.locale, x.name) for x in zza_currency.l10n}, {("zh_Hant", f"{zza.name}-zh_Hant")}) - set_locale(self, self.client, self.csrf_token, "en") + set_locale(self.client, self.csrf_token, "en") response = self.client.post(update_uri, data={"csrf_token": self.csrf_token, @@ -563,7 +563,7 @@ class CurrencyTestCase(unittest.TestCase): self.assertEqual({(x.locale, x.name) for x in zza_currency.l10n}, {("zh_Hant", f"{zza.name}-zh_Hant")}) - set_locale(self, self.client, self.csrf_token, "zh_Hant") + set_locale(self.client, self.csrf_token, "zh_Hant") response = self.client.post(update_uri, data={"csrf_token": self.csrf_token, diff --git a/tests/testlib.py b/tests/testlib.py index 8b812f1..3f0a97f 100644 --- a/tests/testlib.py +++ b/tests/testlib.py @@ -19,36 +19,32 @@ """ import typing as t from html.parser import HTMLParser -from unittest import TestCase import httpx from flask import Flask -def get_client(test_case: TestCase, app: Flask, username: str) \ - -> tuple[httpx.Client, str]: +def get_client(app: Flask, username: str) -> tuple[httpx.Client, str]: """Returns a user client. - :param test_case: The test case. :param app: The Flask application. :param username: The username. :return: A tuple of the client and the CSRF token. """ client: httpx.Client = httpx.Client(app=app, base_url="https://testserver") client.headers["Referer"] = "https://testserver" - csrf_token: str = get_csrf_token(test_case, client, "/login") + csrf_token: str = get_csrf_token(client, "/login") response: httpx.Response = client.post("/login", data={"csrf_token": csrf_token, "username": username}) - test_case.assertEqual(response.status_code, 302) - test_case.assertEqual(response.headers["Location"], "/") + assert response.status_code == 302 + assert response.headers["Location"] == "/" return client, csrf_token -def get_csrf_token(test_case: TestCase, client: httpx.Client, uri: str) -> str: +def get_csrf_token(client: httpx.Client, uri: str) -> str: """Returns the CSRF token from a form in a URI. - :param test_case: The test case. :param client: The httpx client. :param uri: The URI. :return: The CSRF token. @@ -71,18 +67,17 @@ def get_csrf_token(test_case: TestCase, client: httpx.Client, uri: str) -> str: self.csrf_token = attrs_dict["value"] response: httpx.Response = client.get(uri) - test_case.assertEqual(response.status_code, 200) + assert response.status_code == 200 parser: CsrfParser = CsrfParser() parser.feed(response.text) - test_case.assertIsNotNone(parser.csrf_token) + assert parser.csrf_token is not None return parser.csrf_token -def set_locale(test_case: TestCase, client: httpx.Client, csrf_token: str, +def set_locale(client: httpx.Client, csrf_token: str, locale: t.Literal["en", "zh_Hant", "zh_Hans"]) -> None: """Sets the current locale. - :param test_case: The test case. :param client: The test client. :param csrf_token: The CSRF token. :param locale: The locale. @@ -92,5 +87,5 @@ def set_locale(test_case: TestCase, client: httpx.Client, csrf_token: str, data={"csrf_token": csrf_token, "locale": locale, "next": "/next"}) - test_case.assertEqual(response.status_code, 302) - test_case.assertEqual(response.headers["Location"], "/next") + assert response.status_code == 302 + assert response.headers["Location"] == "/next"