Replaced the UserClient class and the get_user_client function with the get_client function in the tests, for simplicity.

This commit is contained in:
依瑪貓 2023-02-06 21:45:28 +08:00
parent 2a6c5de6d6
commit 591fb4a7ab
3 changed files with 56 additions and 71 deletions

View File

@ -25,7 +25,7 @@ from click.testing import Result
from flask import Flask from flask import Flask
from flask.testing import FlaskCliRunner from flask.testing import FlaskCliRunner
from testlib import UserClient, get_user_client from testlib import get_client
from test_site import create_app from test_site import create_app
@ -108,9 +108,7 @@ class AccountTestCase(unittest.TestCase):
Account.query.delete() Account.query.delete()
db.session.commit() db.session.commit()
editor: UserClient = get_user_client(self, self.app, "editor") self.client, self.csrf_token = get_client(self, self.app, "editor")
self.client: httpx.Client = editor.client
self.csrf_token: str = editor.csrf_token
response: httpx.Response response: httpx.Response
response = self.client.post("/accounting/accounts/store", response = self.client.post("/accounting/accounts/store",
@ -135,45 +133,45 @@ class AccountTestCase(unittest.TestCase):
:return: None. :return: None.
""" """
from accounting.models import Account from accounting.models import Account
client, csrf_token = get_client(self, self.app, "nobody")
response: httpx.Response response: httpx.Response
nobody: UserClient = get_user_client(self, self.app, "nobody")
response = nobody.client.get("/accounting/accounts") response = client.get("/accounting/accounts")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = nobody.client.get("/accounting/accounts/1111-001") response = client.get("/accounting/accounts/1111-001")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = nobody.client.get("/accounting/accounts/create") response = client.get("/accounting/accounts/create")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = nobody.client.post("/accounting/accounts/store", response = client.post("/accounting/accounts/store",
data={"csrf_token": nobody.csrf_token, data={"csrf_token": csrf_token,
"base_code": "1113", "base_code": "1113",
"title": "1113 title"}) "title": "1113 title"})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = nobody.client.get("/accounting/accounts/1111-001/edit") response = client.get("/accounting/accounts/1111-001/edit")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = nobody.client.post("/accounting/accounts/1111-001/update", response = client.post("/accounting/accounts/1111-001/update",
data={"csrf_token": nobody.csrf_token, data={"csrf_token": csrf_token,
"base_code": "1111", "base_code": "1111",
"title": "1111 title #2"}) "title": "1111 title #2"})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = nobody.client.post("/accounting/accounts/1111-001/delete", response = client.post("/accounting/accounts/1111-001/delete",
data={"csrf_token": nobody.csrf_token}) data={"csrf_token": csrf_token})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = nobody.client.get("/accounting/accounts/bases/1111") response = client.get("/accounting/accounts/bases/1111")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
with self.app.app_context(): with self.app.app_context():
account_id: int = Account.find_by_code("1112-001").id account_id: int = Account.find_by_code("1112-001").id
response = nobody.client.post("/accounting/accounts/bases/1112", response = client.post("/accounting/accounts/bases/1112",
data={"csrf_token": nobody.csrf_token, data={"csrf_token": csrf_token,
"next": "/next", "next": "/next",
f"{account_id}-no": "5"}) f"{account_id}-no": "5"})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
@ -184,45 +182,45 @@ class AccountTestCase(unittest.TestCase):
:return: None. :return: None.
""" """
from accounting.models import Account from accounting.models import Account
client, csrf_token = get_client(self, self.app, "viewer")
response: httpx.Response response: httpx.Response
viewer: UserClient = get_user_client(self, self.app, "viewer")
response = viewer.client.get("/accounting/accounts") response = client.get("/accounting/accounts")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response = viewer.client.get("/accounting/accounts/1111-001") response = client.get("/accounting/accounts/1111-001")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response = viewer.client.get("/accounting/accounts/create") response = client.get("/accounting/accounts/create")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = viewer.client.post("/accounting/accounts/store", response = client.post("/accounting/accounts/store",
data={"csrf_token": viewer.csrf_token, data={"csrf_token": csrf_token,
"base_code": "1113", "base_code": "1113",
"title": "1113 title"}) "title": "1113 title"})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = viewer.client.get("/accounting/accounts/1111-001/edit") response = client.get("/accounting/accounts/1111-001/edit")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = viewer.client.post("/accounting/accounts/1111-001/update", response = client.post("/accounting/accounts/1111-001/update",
data={"csrf_token": viewer.csrf_token, data={"csrf_token": csrf_token,
"base_code": "1111", "base_code": "1111",
"title": "1111 title #2"}) "title": "1111 title #2"})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = viewer.client.post("/accounting/accounts/1111-001/delete", response = client.post("/accounting/accounts/1111-001/delete",
data={"csrf_token": viewer.csrf_token}) data={"csrf_token": csrf_token})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = viewer.client.get("/accounting/accounts/bases/1111") response = client.get("/accounting/accounts/bases/1111")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
with self.app.app_context(): with self.app.app_context():
account_id: int = Account.find_by_code("1112-001").id account_id: int = Account.find_by_code("1112-001").id
response = viewer.client.post("/accounting/accounts/bases/1112", response = client.post("/accounting/accounts/bases/1112",
data={"csrf_token": viewer.csrf_token, data={"csrf_token": csrf_token,
"next": "/next", "next": "/next",
f"{account_id}-no": "5"}) f"{account_id}-no": "5"})
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)

View File

@ -24,7 +24,7 @@ from click.testing import Result
from flask import Flask from flask import Flask
from flask.testing import FlaskCliRunner from flask.testing import FlaskCliRunner
from testlib import UserClient, get_user_client from testlib import get_client
from test_site import create_app from test_site import create_app
@ -92,13 +92,13 @@ class BaseAccountTestCase(unittest.TestCase):
:return: None. :return: None.
""" """
client, csrf_token = get_client(self, self.app, "nobody")
response: httpx.Response response: httpx.Response
nobody: UserClient = get_user_client(self, self.app, "nobody")
response = nobody.client.get("/accounting/base-accounts") response = client.get("/accounting/base-accounts")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
response = nobody.client.get("/accounting/base-accounts/1111") response = client.get("/accounting/base-accounts/1111")
self.assertEqual(response.status_code, 403) self.assertEqual(response.status_code, 403)
def test_viewer(self) -> None: def test_viewer(self) -> None:
@ -106,13 +106,13 @@ class BaseAccountTestCase(unittest.TestCase):
:return: None. :return: None.
""" """
client, csrf_token = get_client(self, self.app, "viewer")
response: httpx.Response response: httpx.Response
viewer: UserClient = get_user_client(self, self.app, "viewer")
response = viewer.client.get("/accounting/base-accounts") response = client.get("/accounting/base-accounts")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response = viewer.client.get("/accounting/base-accounts/1111") response = client.get("/accounting/base-accounts/1111")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
def test_editor(self) -> None: def test_editor(self) -> None:
@ -120,11 +120,11 @@ class BaseAccountTestCase(unittest.TestCase):
:return: None. :return: None.
""" """
client, csrf_token = get_client(self, self.app, "editor")
response: httpx.Response response: httpx.Response
editor: UserClient = get_user_client(self, self.app, "editor")
response = editor.client.get("/accounting/base-accounts") response = client.get("/accounting/base-accounts")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)
response = editor.client.get("/accounting/base-accounts/1111") response = client.get("/accounting/base-accounts/1111")
self.assertEqual(response.status_code, 200) self.assertEqual(response.status_code, 200)

View File

@ -24,27 +24,14 @@ import httpx
from flask import Flask from flask import Flask
class UserClient: def get_client(test_case: TestCase, app: Flask, username: str) \
"""A user client.""" -> tuple[httpx.Client, str]:
def __init__(self, client: httpx.Client, csrf_token: str):
"""Constructs a user client.
:param client: The client.
:param csrf_token: The CSRF token.
"""
self.client: httpx.Client = client
self.csrf_token: str = csrf_token
def get_user_client(test_case: TestCase, app: Flask, username: str) \
-> UserClient:
"""Returns a user client. """Returns a user client.
:param test_case: The test case. :param test_case: The test case.
:param app: The Flask application. :param app: The Flask application.
:param username: The username. :param username: The username.
:return: The user client. :return: A tuple of the client and the CSRF token.
""" """
client: httpx.Client = httpx.Client(app=app, base_url="https://testserver") client: httpx.Client = httpx.Client(app=app, base_url="https://testserver")
client.headers["Referer"] = "https://testserver" client.headers["Referer"] = "https://testserver"
@ -54,7 +41,7 @@ def get_user_client(test_case: TestCase, app: Flask, username: str) \
"username": username}) "username": username})
test_case.assertEqual(response.status_code, 302) test_case.assertEqual(response.status_code, 302)
test_case.assertEqual(response.headers["Location"], "/") test_case.assertEqual(response.headers["Location"], "/")
return UserClient(client, csrf_token) return client, csrf_token
def get_csrf_token(test_case: TestCase, client: httpx.Client, uri: str) -> str: def get_csrf_token(test_case: TestCase, client: httpx.Client, uri: str) -> str: