summary refs log tree commit diff
path: root/tests/rest/client/test_login.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-08-24 09:58:29 +0100
committerGitHub <noreply@github.com>2022-08-24 09:58:29 +0100
commitcbb157548676865793f39b4da0b7f3fa5ee01058 (patch)
tree77ff41ddcb9b01938b4efdd109cf90981c86a23e /tests/rest/client/test_login.py
parentNewsfile (diff)
parentInstrument `_check_sigs_and_hash_and_fetch` to trace time spent in child conc... (diff)
downloadsynapse-github/erikj/less_state_membership.tar.xz
Merge branch 'develop' into erikj/less_state_membership github/erikj/less_state_membership erikj/less_state_membership
Diffstat (limited to 'tests/rest/client/test_login.py')
-rw-r--r--tests/rest/client/test_login.py127
1 files changed, 63 insertions, 64 deletions
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a2958f6959..e2a4d98275 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 import time
 import urllib.parse
-from http import HTTPStatus
 from typing import Any, Dict, List, Optional
 from unittest.mock import Mock
 from urllib.parse import urlencode
@@ -134,10 +133,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
-                self.assertEqual(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEqual(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
@@ -152,7 +151,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     @override_config(
         {
@@ -179,10 +178,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
-                self.assertEqual(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEqual(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
@@ -197,7 +196,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     @override_config(
         {
@@ -224,10 +223,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
-                self.assertEqual(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEqual(channel.result["code"], b"403", channel.result)
+                self.assertEqual(channel.code, 403, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
@@ -242,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
 
     @override_config({"session_lifetime": "24h"})
     def test_soft_logout(self) -> None:
@@ -250,7 +249,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # we shouldn't be able to make requests without an access token
         channel = self.make_request(b"GET", TEST_URL)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
 
         # log in as normal
@@ -261,20 +260,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         access_token = channel.json_body["access_token"]
         device_id = channel.json_body["device_id"]
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
@@ -288,7 +287,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # more requests with the expired token should still return a soft-logout
         self.reactor.advance(3600)
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
@@ -296,7 +295,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         self._delete_device(access_token_2, "kermit", "monkey", device_id)
 
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], False)
 
@@ -307,7 +306,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             b"DELETE", "devices/" + device_id, access_token=access_token
         )
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         # check it's a UI-Auth fail
         self.assertEqual(
             set(channel.json_body.keys()),
@@ -330,7 +329,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             access_token=access_token,
             content={"auth": auth},
         )
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
     @override_config({"session_lifetime": "24h"})
     def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@@ -341,20 +340,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
         # Now try to hard logout this session
         channel = self.make_request(b"POST", "/logout", access_token=access_token)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     @override_config({"session_lifetime": "24h"})
     def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
@@ -367,20 +366,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
         # Now try to hard log out all of the user's sessions
         channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     def test_login_with_overly_long_device_id_fails(self) -> None:
         self.register_user("mickey", "cheese")
@@ -466,7 +465,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
     def test_get_login_flows(self) -> None:
         """GET /login should return password and SSO flows"""
         channel = self.make_request("GET", "/_matrix/client/r0/login")
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         expected_flow_types = [
             "m.login.cas",
@@ -494,14 +493,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         """/login/sso/redirect should redirect to an identity picker"""
         # first hit the redirect url, which should redirect to our idp picker
         channel = self._make_sso_redirect_request(None)
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         uri = location_headers[0]
 
         # hitting that picker should give us some HTML
         channel = self.make_request("GET", uri)
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # parse the form to check it has fields assumed elsewhere in this class
         html = channel.result["body"].decode("utf-8")
@@ -530,7 +529,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + "&idp=cas",
             shorthand=False,
         )
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         cas_uri = location_headers[0]
@@ -555,7 +554,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
             + "&idp=saml",
         )
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         saml_uri = location_headers[0]
@@ -579,7 +578,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
             + "&idp=oidc",
         )
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         oidc_uri = location_headers[0]
@@ -606,7 +605,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
 
         # that should serve a confirmation page
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         content_type_headers = channel.headers.getRawHeaders("Content-Type")
         assert content_type_headers
         self.assertTrue(content_type_headers[-1].startswith("text/html"))
@@ -634,7 +633,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             "/login",
             content={"type": "m.login.token", "token": login_token},
         )
-        self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
+        self.assertEqual(chan.code, 200, chan.result)
         self.assertEqual(chan.json_body["user_id"], "@user1:test")
 
     def test_multi_sso_redirect_to_unknown(self) -> None:
@@ -643,18 +642,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             "GET",
             "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
         )
-        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(channel.code, 400, channel.result)
 
     def test_client_idp_redirect_to_unknown(self) -> None:
         """If the client tries to pick an unknown IdP, return a 404"""
         channel = self._make_sso_redirect_request("xxx")
-        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+        self.assertEqual(channel.code, 404, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
 
     def test_client_idp_redirect_to_oidc(self) -> None:
         """If the client pick a known IdP, redirect to it"""
         channel = self._make_sso_redirect_request("oidc")
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         oidc_uri = location_headers[0]
@@ -765,7 +764,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", cas_ticket_url)
 
         # Test that the response is HTML.
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         content_type_header_value = ""
         for header in channel.result.get("headers", []):
             if header[0] == b"Content-Type":
@@ -878,17 +877,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_jwt_valid_registered(self) -> None:
         self.register_user("kermit", "monkey")
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
     def test_login_jwt_valid_unregistered(self) -> None:
         channel = self.jwt_login({"sub": "frog"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@frog:test")
 
     def test_login_jwt_invalid_signature(self) -> None:
         channel = self.jwt_login({"sub": "frog"}, "notsecret")
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -897,7 +896,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     def test_login_jwt_expired(self) -> None:
         channel = self.jwt_login({"sub": "frog", "exp": 864000})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -907,7 +906,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_jwt_not_before(self) -> None:
         now = int(time.time())
         channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -916,7 +915,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     def test_login_no_sub(self) -> None:
         channel = self.jwt_login({"username": "root"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Invalid JWT")
 
@@ -925,12 +924,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
         """Test validating the issuer claim."""
         # A valid issuer.
         channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
         # An invalid issuer.
         channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -939,7 +938,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
         # Not providing an issuer.
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -949,7 +948,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_iss_no_config(self) -> None:
         """Test providing an issuer claim without requiring it in the configuration."""
         channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
     @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
@@ -957,12 +956,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
         """Test validating the audience claim."""
         # A valid audience.
         channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
         # An invalid audience.
         channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -971,7 +970,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
         # Not providing an audience.
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -981,7 +980,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_aud_no_config(self) -> None:
         """Test providing an audience without requiring it in the configuration."""
         channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -991,20 +990,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_default_sub(self) -> None:
         """Test reading user ID from the default subject claim."""
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
     @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
     def test_login_custom_sub(self) -> None:
         """Test reading user ID from a custom subject claim."""
         channel = self.jwt_login({"username": "frog"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@frog:test")
 
     def test_login_no_token(self) -> None:
         params = {"type": "org.matrix.login.jwt"}
         channel = self.make_request(b"POST", LOGIN_URL, params)
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
 
@@ -1086,12 +1085,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
 
     def test_login_jwt_valid(self) -> None:
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
     def test_login_jwt_invalid_signature(self) -> None:
         channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -1152,7 +1151,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     def test_login_appservice_user_bot(self) -> None:
         """Test that the appservice bot can use /login"""
@@ -1166,7 +1165,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     def test_login_appservice_wrong_user(self) -> None:
         """Test that non-as users cannot login with the as token"""
@@ -1180,7 +1179,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
 
     def test_login_appservice_wrong_as(self) -> None:
         """Test that as users cannot login with wrong as token"""
@@ -1194,7 +1193,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.another_service.token
         )
 
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
 
     def test_login_appservice_no_token(self) -> None:
         """Test that users must provide a token when using the appservice
@@ -1208,7 +1207,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
 
 
 @skip_unless(HAS_OIDC, "requires OIDC")
@@ -1246,7 +1245,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
 
         # that should redirect to the username picker
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         picker_url = location_headers[0]
@@ -1290,7 +1289,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
                 ("Content-Length", str(len(content))),
             ],
         )
-        self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
+        self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
         assert location_headers
 
@@ -1300,7 +1299,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
             path=location_headers[0],
             custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
         )
-        self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
+        self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
         assert location_headers
 
@@ -1325,5 +1324,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
             "/login",
             content={"type": "m.login.token", "token": login_token},
         )
-        self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
+        self.assertEqual(chan.code, 200, chan.result)
         self.assertEqual(chan.json_body["user_id"], "@bobby:test")