Replaced "unittest.TestCase.assert*" methods with "assert" in the common test functions, for simplicity.
This commit is contained in:
@ -19,36 +19,32 @@
|
||||
"""
|
||||
import typing as t
|
||||
from html.parser import HTMLParser
|
||||
from unittest import TestCase
|
||||
|
||||
import httpx
|
||||
from flask import Flask
|
||||
|
||||
|
||||
def get_client(test_case: TestCase, app: Flask, username: str) \
|
||||
-> tuple[httpx.Client, str]:
|
||||
def get_client(app: Flask, username: str) -> tuple[httpx.Client, str]:
|
||||
"""Returns a user client.
|
||||
|
||||
:param test_case: The test case.
|
||||
:param app: The Flask application.
|
||||
:param username: The username.
|
||||
:return: A tuple of the client and the CSRF token.
|
||||
"""
|
||||
client: httpx.Client = httpx.Client(app=app, base_url="https://testserver")
|
||||
client.headers["Referer"] = "https://testserver"
|
||||
csrf_token: str = get_csrf_token(test_case, client, "/login")
|
||||
csrf_token: str = get_csrf_token(client, "/login")
|
||||
response: httpx.Response = client.post("/login",
|
||||
data={"csrf_token": csrf_token,
|
||||
"username": username})
|
||||
test_case.assertEqual(response.status_code, 302)
|
||||
test_case.assertEqual(response.headers["Location"], "/")
|
||||
assert response.status_code == 302
|
||||
assert response.headers["Location"] == "/"
|
||||
return client, csrf_token
|
||||
|
||||
|
||||
def get_csrf_token(test_case: TestCase, client: httpx.Client, uri: str) -> str:
|
||||
def get_csrf_token(client: httpx.Client, uri: str) -> str:
|
||||
"""Returns the CSRF token from a form in a URI.
|
||||
|
||||
:param test_case: The test case.
|
||||
:param client: The httpx client.
|
||||
:param uri: The URI.
|
||||
:return: The CSRF token.
|
||||
@ -71,18 +67,17 @@ def get_csrf_token(test_case: TestCase, client: httpx.Client, uri: str) -> str:
|
||||
self.csrf_token = attrs_dict["value"]
|
||||
|
||||
response: httpx.Response = client.get(uri)
|
||||
test_case.assertEqual(response.status_code, 200)
|
||||
assert response.status_code == 200
|
||||
parser: CsrfParser = CsrfParser()
|
||||
parser.feed(response.text)
|
||||
test_case.assertIsNotNone(parser.csrf_token)
|
||||
assert parser.csrf_token is not None
|
||||
return parser.csrf_token
|
||||
|
||||
|
||||
def set_locale(test_case: TestCase, client: httpx.Client, csrf_token: str,
|
||||
def set_locale(client: httpx.Client, csrf_token: str,
|
||||
locale: t.Literal["en", "zh_Hant", "zh_Hans"]) -> None:
|
||||
"""Sets the current locale.
|
||||
|
||||
:param test_case: The test case.
|
||||
:param client: The test client.
|
||||
:param csrf_token: The CSRF token.
|
||||
:param locale: The locale.
|
||||
@ -92,5 +87,5 @@ def set_locale(test_case: TestCase, client: httpx.Client, csrf_token: str,
|
||||
data={"csrf_token": csrf_token,
|
||||
"locale": locale,
|
||||
"next": "/next"})
|
||||
test_case.assertEqual(response.status_code, 302)
|
||||
test_case.assertEqual(response.headers["Location"], "/next")
|
||||
assert response.status_code == 302
|
||||
assert response.headers["Location"] == "/next"
|
||||
|
Reference in New Issue
Block a user