Replaced "unittest.TestCase.assert*" methods with "assert" in the common test functions, for simplicity.

This commit is contained in:
2023-02-13 19:18:15 +08:00
parent 36f55900c7
commit 2ab60b2224
4 changed files with 27 additions and 32 deletions

View File

@ -19,36 +19,32 @@
"""
import typing as t
from html.parser import HTMLParser
from unittest import TestCase
import httpx
from flask import Flask
def get_client(test_case: TestCase, app: Flask, username: str) \
-> tuple[httpx.Client, str]:
def get_client(app: Flask, username: str) -> tuple[httpx.Client, str]:
"""Returns a user client.
:param test_case: The test case.
:param app: The Flask application.
:param username: The username.
:return: A tuple of the client and the CSRF token.
"""
client: httpx.Client = httpx.Client(app=app, base_url="https://testserver")
client.headers["Referer"] = "https://testserver"
csrf_token: str = get_csrf_token(test_case, client, "/login")
csrf_token: str = get_csrf_token(client, "/login")
response: httpx.Response = client.post("/login",
data={"csrf_token": csrf_token,
"username": username})
test_case.assertEqual(response.status_code, 302)
test_case.assertEqual(response.headers["Location"], "/")
assert response.status_code == 302
assert response.headers["Location"] == "/"
return client, csrf_token
def get_csrf_token(test_case: TestCase, client: httpx.Client, uri: str) -> str:
def get_csrf_token(client: httpx.Client, uri: str) -> str:
"""Returns the CSRF token from a form in a URI.
:param test_case: The test case.
:param client: The httpx client.
:param uri: The URI.
:return: The CSRF token.
@ -71,18 +67,17 @@ def get_csrf_token(test_case: TestCase, client: httpx.Client, uri: str) -> str:
self.csrf_token = attrs_dict["value"]
response: httpx.Response = client.get(uri)
test_case.assertEqual(response.status_code, 200)
assert response.status_code == 200
parser: CsrfParser = CsrfParser()
parser.feed(response.text)
test_case.assertIsNotNone(parser.csrf_token)
assert parser.csrf_token is not None
return parser.csrf_token
def set_locale(test_case: TestCase, client: httpx.Client, csrf_token: str,
def set_locale(client: httpx.Client, csrf_token: str,
locale: t.Literal["en", "zh_Hant", "zh_Hans"]) -> None:
"""Sets the current locale.
:param test_case: The test case.
:param client: The test client.
:param csrf_token: The CSRF token.
:param locale: The locale.
@ -92,5 +87,5 @@ def set_locale(test_case: TestCase, client: httpx.Client, csrf_token: str,
data={"csrf_token": csrf_token,
"locale": locale,
"next": "/next"})
test_case.assertEqual(response.status_code, 302)
test_case.assertEqual(response.headers["Location"], "/next")
assert response.status_code == 302
assert response.headers["Location"] == "/next"