diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index a43a137273..c1a7fb2f8a 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-import json
import os
import re
from email.parser import Parser
+from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
@@ -95,10 +95,8 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
"""
body = {"type": "m.login.password", "user": username, "password": password}
- channel = self.make_request(
- "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
- )
- self.assertEqual(channel.code, 403, channel.result)
+ channel = self.make_request("POST", "/_matrix/client/r0/login", body)
+ self.assertEqual(channel.code, HTTPStatus.FORBIDDEN, channel.result)
def test_basic_password_reset(self) -> None:
"""Test basic password reset flow"""
@@ -347,7 +345,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
# Now POST to the same endpoint, mimicking the same behaviour as clicking the
# password reset confirm button
@@ -362,7 +360,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
shorthand=False,
content_is_form=True,
)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@@ -390,7 +388,7 @@ class PasswordResetTestCase(unittest.HomeserverTestCase):
new_password: str,
session_id: str,
client_secret: str,
- expected_code: int = 200,
+ expected_code: int = HTTPStatus.OK,
) -> None:
channel = self.make_request(
"POST",
@@ -479,20 +477,18 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
self.assertEqual(memberships[0].room_id, room_id, memberships)
def deactivate(self, user_id: str, tok: str) -> None:
- request_data = json.dumps(
- {
- "auth": {
- "type": "m.login.password",
- "user": user_id,
- "password": "test",
- },
- "erase": False,
- }
- )
+ request_data = {
+ "auth": {
+ "type": "m.login.password",
+ "user": user_id,
+ "password": "test",
+ },
+ "erase": False,
+ }
channel = self.make_request(
"POST", "account/deactivate", request_data, access_token=tok
)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.json_body)
class WhoamiTestCase(unittest.HomeserverTestCase):
@@ -645,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
def test_add_email_no_at(self) -> None:
self._request_token_invalid_email(
"address-without-at.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_two_at(self) -> None:
self._request_token_invalid_email(
"foo@foo@test.bar",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
def test_add_email_bad_format(self) -> None:
self._request_token_invalid_email(
"user@bad.example.net@good.example.com",
- expected_errcode=Codes.UNKNOWN,
+ expected_errcode=Codes.BAD_JSON,
expected_error="Unable to parse email address",
)
@@ -715,7 +711,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -725,7 +723,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email(self) -> None:
@@ -747,7 +745,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
{"medium": "email", "address": self.email},
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -756,7 +754,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_delete_email_if_disabled(self) -> None:
@@ -781,7 +779,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.FORBIDDEN, channel.json_body["errcode"])
# Get user
@@ -791,7 +791,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
self.assertEqual(self.email, channel.json_body["threepids"][0]["address"])
@@ -817,7 +817,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -827,7 +829,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
def test_no_valid_token(self) -> None:
@@ -852,7 +854,9 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
},
access_token=self.user_id_tok,
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(Codes.THREEPID_AUTH_FAILED, channel.json_body["errcode"])
# Get user
@@ -862,7 +866,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertFalse(channel.json_body["threepids"])
@override_config({"next_link_domain_whitelist": None})
@@ -872,7 +876,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/good/site",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
@override_config({"next_link_domain_whitelist": None})
@@ -884,7 +888,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="some-protocol://abcdefghijklmopqrstuvwxyz",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
@override_config({"next_link_domain_whitelist": None})
@@ -895,7 +899,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="file:///host/path",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
@override_config({"next_link_domain_whitelist": ["example.com", "example.org"]})
@@ -907,28 +911,28 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link=None,
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.com/some/good/page",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://example.org/some/also/good/page",
- expect_code=200,
+ expect_code=HTTPStatus.OK,
)
self._request_token(
"something@example.com",
"some_secret",
next_link="https://bad.example.org/some/bad/page",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
@override_config({"next_link_domain_whitelist": []})
@@ -940,7 +944,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
"something@example.com",
"some_secret",
next_link="https://example.com/a/page",
- expect_code=400,
+ expect_code=HTTPStatus.BAD_REQUEST,
)
def _request_token(
@@ -948,8 +952,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
email: str,
client_secret: str,
next_link: Optional[str] = None,
- expect_code: int = 200,
- ) -> str:
+ expect_code: int = HTTPStatus.OK,
+ ) -> Optional[str]:
"""Request a validation token to add an email address to a user's account
Args:
@@ -959,7 +963,8 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
expect_code: Expected return code of the call
Returns:
- The ID of the new threepid validation session
+ The ID of the new threepid validation session, or None if the response
+ did not contain a session ID.
"""
body = {"client_secret": client_secret, "email": email, "send_attempt": 1}
if next_link:
@@ -992,16 +997,18 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
b"account/3pid/email/requestToken",
{"client_secret": client_secret, "email": email, "send_attempt": 1},
)
- self.assertEqual(400, channel.code, msg=channel.result["body"])
+ self.assertEqual(
+ HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
+ )
self.assertEqual(expected_errcode, channel.json_body["errcode"])
- self.assertEqual(expected_error, channel.json_body["error"])
+ self.assertIn(expected_error, channel.json_body["error"])
def _validate_token(self, link: str) -> None:
# Remove the host
path = link.replace("https://example.com", "")
channel = self.make_request("GET", path, shorthand=False)
- self.assertEqual(200, channel.code, channel.result)
+ self.assertEqual(HTTPStatus.OK, channel.code, channel.result)
def _get_link_from_email(self) -> str:
assert self.email_attempts, "No emails have been sent"
@@ -1051,7 +1058,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
# Get user
channel = self.make_request(
@@ -1060,7 +1067,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
access_token=self.user_id_tok,
)
- self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
self.assertEqual("email", channel.json_body["threepids"][0]["medium"])
threepids = {threepid["address"] for threepid in channel.json_body["threepids"]}
@@ -1091,7 +1098,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that not providing any MXID raises an error."""
self._test_status(
users=None,
- expected_status_code=400,
+ expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.MISSING_PARAM,
)
@@ -1099,7 +1106,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
"""Tests that providing an invalid MXID raises an error."""
self._test_status(
users=["bad:test"],
- expected_status_code=400,
+ expected_status_code=HTTPStatus.BAD_REQUEST,
expected_errcode=Codes.INVALID_PARAM,
)
@@ -1285,7 +1292,7 @@ class AccountStatusTestCase(unittest.HomeserverTestCase):
def _test_status(
self,
users: Optional[List[str]],
- expected_status_code: int = 200,
+ expected_status_code: int = HTTPStatus.OK,
expected_statuses: Optional[Dict[str, Dict[str, bool]]] = None,
expected_failures: Optional[List[str]] = None,
expected_errcode: Optional[str] = None,
|