summary refs log tree commit diff
path: root/tests/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client')
-rw-r--r--tests/rest/client/test_account.py10
-rw-r--r--tests/rest/client/test_filter.py14
-rw-r--r--tests/rest/client/test_login.py127
-rw-r--r--tests/rest/client/test_models.py53
-rw-r--r--tests/rest/client/test_redactions.py4
-rw-r--r--tests/rest/client/test_register.py94
-rw-r--r--tests/rest/client/test_relations.py6
-rw-r--r--tests/rest/client/test_report_event.py4
-rw-r--r--tests/rest/client/test_retention.py4
-rw-r--r--tests/rest/client/test_rooms.py7
-rw-r--r--tests/rest/client/test_shadow_banned.py6
-rw-r--r--tests/rest/client/test_sync.py79
-rw-r--r--tests/rest/client/test_third_party_rules.py27
-rw-r--r--tests/rest/client/utils.py28
14 files changed, 274 insertions, 189 deletions
diff --git a/tests/rest/client/test_account.py b/tests/rest/client/test_account.py
index 7ae926dc9c..c1a7fb2f8a 100644
--- a/tests/rest/client/test_account.py
+++ b/tests/rest/client/test_account.py
@@ -488,7 +488,7 @@ class DeactivateTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST", "account/deactivate", request_data, access_token=tok
         )
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, 200, channel.json_body)
 
 
 class WhoamiTestCase(unittest.HomeserverTestCase):
@@ -641,21 +641,21 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
     def test_add_email_no_at(self) -> None:
         self._request_token_invalid_email(
             "address-without-at.bar",
-            expected_errcode=Codes.UNKNOWN,
+            expected_errcode=Codes.BAD_JSON,
             expected_error="Unable to parse email address",
         )
 
     def test_add_email_two_at(self) -> None:
         self._request_token_invalid_email(
             "foo@foo@test.bar",
-            expected_errcode=Codes.UNKNOWN,
+            expected_errcode=Codes.BAD_JSON,
             expected_error="Unable to parse email address",
         )
 
     def test_add_email_bad_format(self) -> None:
         self._request_token_invalid_email(
             "user@bad.example.net@good.example.com",
-            expected_errcode=Codes.UNKNOWN,
+            expected_errcode=Codes.BAD_JSON,
             expected_error="Unable to parse email address",
         )
 
@@ -1001,7 +1001,7 @@ class ThreepidEmailRestTestCase(unittest.HomeserverTestCase):
             HTTPStatus.BAD_REQUEST, channel.code, msg=channel.result["body"]
         )
         self.assertEqual(expected_errcode, channel.json_body["errcode"])
-        self.assertEqual(expected_error, channel.json_body["error"])
+        self.assertIn(expected_error, channel.json_body["error"])
 
     def _validate_token(self, link: str) -> None:
         # Remove the host
diff --git a/tests/rest/client/test_filter.py b/tests/rest/client/test_filter.py
index 823e8ab8c4..afc8d641be 100644
--- a/tests/rest/client/test_filter.py
+++ b/tests/rest/client/test_filter.py
@@ -43,7 +43,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
             self.EXAMPLE_FILTER_JSON,
         )
 
-        self.assertEqual(channel.result["code"], b"200")
+        self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, {"filter_id": "0"})
         filter = self.get_success(
             self.store.get_user_filter(user_localpart="apple", filter_id=0)
@@ -58,7 +58,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
             self.EXAMPLE_FILTER_JSON,
         )
 
-        self.assertEqual(channel.result["code"], b"403")
+        self.assertEqual(channel.code, 403)
         self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
 
     def test_add_filter_non_local_user(self) -> None:
@@ -71,7 +71,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
         )
 
         self.hs.is_mine = _is_mine
-        self.assertEqual(channel.result["code"], b"403")
+        self.assertEqual(channel.code, 403)
         self.assertEqual(channel.json_body["errcode"], Codes.FORBIDDEN)
 
     def test_get_filter(self) -> None:
@@ -85,7 +85,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
             "GET", "/_matrix/client/r0/user/%s/filter/%s" % (self.user_id, filter_id)
         )
 
-        self.assertEqual(channel.result["code"], b"200")
+        self.assertEqual(channel.code, 200)
         self.assertEqual(channel.json_body, self.EXAMPLE_FILTER)
 
     def test_get_filter_non_existant(self) -> None:
@@ -93,7 +93,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
             "GET", "/_matrix/client/r0/user/%s/filter/12382148321" % (self.user_id)
         )
 
-        self.assertEqual(channel.result["code"], b"404")
+        self.assertEqual(channel.code, 404)
         self.assertEqual(channel.json_body["errcode"], Codes.NOT_FOUND)
 
     # Currently invalid params do not have an appropriate errcode
@@ -103,7 +103,7 @@ class FilterTestCase(unittest.HomeserverTestCase):
             "GET", "/_matrix/client/r0/user/%s/filter/foobar" % (self.user_id)
         )
 
-        self.assertEqual(channel.result["code"], b"400")
+        self.assertEqual(channel.code, 400)
 
     # No ID also returns an invalid_id error
     def test_get_filter_no_id(self) -> None:
@@ -111,4 +111,4 @@ class FilterTestCase(unittest.HomeserverTestCase):
             "GET", "/_matrix/client/r0/user/%s/filter/" % (self.user_id)
         )
 
-        self.assertEqual(channel.result["code"], b"400")
+        self.assertEqual(channel.code, 400)
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index a2958f6959..e2a4d98275 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -13,7 +13,6 @@
 # limitations under the License.
 import time
 import urllib.parse
-from http import HTTPStatus
 from typing import Any, Dict, List, Optional
 from unittest.mock import Mock
 from urllib.parse import urlencode
@@ -134,10 +133,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
-                self.assertEqual(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEqual(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
@@ -152,7 +151,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     @override_config(
         {
@@ -179,10 +178,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
-                self.assertEqual(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEqual(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
@@ -197,7 +196,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     @override_config(
         {
@@ -224,10 +223,10 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", LOGIN_URL, params)
 
             if i == 5:
-                self.assertEqual(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEqual(channel.result["code"], b"403", channel.result)
+                self.assertEqual(channel.code, 403, msg=channel.result)
 
         # Since we're ratelimiting at 1 request/min, retry_after_ms should be lower
         # than 1min.
@@ -242,7 +241,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
 
     @override_config({"session_lifetime": "24h"})
     def test_soft_logout(self) -> None:
@@ -250,7 +249,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # we shouldn't be able to make requests without an access token
         channel = self.make_request(b"GET", TEST_URL)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_MISSING_TOKEN")
 
         # log in as normal
@@ -261,20 +260,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         access_token = channel.json_body["access_token"]
         device_id = channel.json_body["device_id"]
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
@@ -288,7 +287,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         # more requests with the expired token should still return a soft-logout
         self.reactor.advance(3600)
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
@@ -296,7 +295,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         self._delete_device(access_token_2, "kermit", "monkey", device_id)
 
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], False)
 
@@ -307,7 +306,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             b"DELETE", "devices/" + device_id, access_token=access_token
         )
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         # check it's a UI-Auth fail
         self.assertEqual(
             set(channel.json_body.keys()),
@@ -330,7 +329,7 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
             access_token=access_token,
             content={"auth": auth},
         )
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
     @override_config({"session_lifetime": "24h"})
     def test_session_can_hard_logout_after_being_soft_logged_out(self) -> None:
@@ -341,20 +340,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
         # Now try to hard logout this session
         channel = self.make_request(b"POST", "/logout", access_token=access_token)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     @override_config({"session_lifetime": "24h"})
     def test_session_can_hard_logout_all_sessions_after_being_soft_logged_out(
@@ -367,20 +366,20 @@ class LoginRestServletTestCase(unittest.HomeserverTestCase):
 
         # we should now be able to make requests with the access token
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # time passes
         self.reactor.advance(24 * 3600)
 
         # ... and we should be soft-logouted
         channel = self.make_request(b"GET", TEST_URL, access_token=access_token)
-        self.assertEqual(channel.code, HTTPStatus.UNAUTHORIZED, channel.result)
+        self.assertEqual(channel.code, 401, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_UNKNOWN_TOKEN")
         self.assertEqual(channel.json_body["soft_logout"], True)
 
         # Now try to hard log out all of the user's sessions
         channel = self.make_request(b"POST", "/logout/all", access_token=access_token)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     def test_login_with_overly_long_device_id_fails(self) -> None:
         self.register_user("mickey", "cheese")
@@ -466,7 +465,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
     def test_get_login_flows(self) -> None:
         """GET /login should return password and SSO flows"""
         channel = self.make_request("GET", "/_matrix/client/r0/login")
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         expected_flow_types = [
             "m.login.cas",
@@ -494,14 +493,14 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         """/login/sso/redirect should redirect to an identity picker"""
         # first hit the redirect url, which should redirect to our idp picker
         channel = self._make_sso_redirect_request(None)
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         uri = location_headers[0]
 
         # hitting that picker should give us some HTML
         channel = self.make_request("GET", uri)
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         # parse the form to check it has fields assumed elsewhere in this class
         html = channel.result["body"].decode("utf-8")
@@ -530,7 +529,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + "&idp=cas",
             shorthand=False,
         )
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         cas_uri = location_headers[0]
@@ -555,7 +554,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
             + "&idp=saml",
         )
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         saml_uri = location_headers[0]
@@ -579,7 +578,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
             + "&idp=oidc",
         )
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         oidc_uri = location_headers[0]
@@ -606,7 +605,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
 
         # that should serve a confirmation page
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         content_type_headers = channel.headers.getRawHeaders("Content-Type")
         assert content_type_headers
         self.assertTrue(content_type_headers[-1].startswith("text/html"))
@@ -634,7 +633,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             "/login",
             content={"type": "m.login.token", "token": login_token},
         )
-        self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
+        self.assertEqual(chan.code, 200, chan.result)
         self.assertEqual(chan.json_body["user_id"], "@user1:test")
 
     def test_multi_sso_redirect_to_unknown(self) -> None:
@@ -643,18 +642,18 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             "GET",
             "/_synapse/client/pick_idp?redirectUrl=http://x&idp=xyz",
         )
-        self.assertEqual(channel.code, HTTPStatus.BAD_REQUEST, channel.result)
+        self.assertEqual(channel.code, 400, channel.result)
 
     def test_client_idp_redirect_to_unknown(self) -> None:
         """If the client tries to pick an unknown IdP, return a 404"""
         channel = self._make_sso_redirect_request("xxx")
-        self.assertEqual(channel.code, HTTPStatus.NOT_FOUND, channel.result)
+        self.assertEqual(channel.code, 404, channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_NOT_FOUND")
 
     def test_client_idp_redirect_to_oidc(self) -> None:
         """If the client pick a known IdP, redirect to it"""
         channel = self._make_sso_redirect_request("oidc")
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         oidc_uri = location_headers[0]
@@ -765,7 +764,7 @@ class CASTestCase(unittest.HomeserverTestCase):
         channel = self.make_request("GET", cas_ticket_url)
 
         # Test that the response is HTML.
-        self.assertEqual(channel.code, HTTPStatus.OK, channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         content_type_header_value = ""
         for header in channel.result.get("headers", []):
             if header[0] == b"Content-Type":
@@ -878,17 +877,17 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_jwt_valid_registered(self) -> None:
         self.register_user("kermit", "monkey")
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
     def test_login_jwt_valid_unregistered(self) -> None:
         channel = self.jwt_login({"sub": "frog"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@frog:test")
 
     def test_login_jwt_invalid_signature(self) -> None:
         channel = self.jwt_login({"sub": "frog"}, "notsecret")
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -897,7 +896,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     def test_login_jwt_expired(self) -> None:
         channel = self.jwt_login({"sub": "frog", "exp": 864000})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -907,7 +906,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_jwt_not_before(self) -> None:
         now = int(time.time())
         channel = self.jwt_login({"sub": "frog", "nbf": now + 3600})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -916,7 +915,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
     def test_login_no_sub(self) -> None:
         channel = self.jwt_login({"username": "root"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Invalid JWT")
 
@@ -925,12 +924,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
         """Test validating the issuer claim."""
         # A valid issuer.
         channel = self.jwt_login({"sub": "kermit", "iss": "test-issuer"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
         # An invalid issuer.
         channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -939,7 +938,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
         # Not providing an issuer.
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -949,7 +948,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_iss_no_config(self) -> None:
         """Test providing an issuer claim without requiring it in the configuration."""
         channel = self.jwt_login({"sub": "kermit", "iss": "invalid"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
     @override_config({"jwt_config": {**base_config, "audiences": ["test-audience"]}})
@@ -957,12 +956,12 @@ class JWTTestCase(unittest.HomeserverTestCase):
         """Test validating the audience claim."""
         # A valid audience.
         channel = self.jwt_login({"sub": "kermit", "aud": "test-audience"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
         # An invalid audience.
         channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -971,7 +970,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
 
         # Not providing an audience.
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -981,7 +980,7 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_aud_no_config(self) -> None:
         """Test providing an audience without requiring it in the configuration."""
         channel = self.jwt_login({"sub": "kermit", "aud": "invalid"})
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -991,20 +990,20 @@ class JWTTestCase(unittest.HomeserverTestCase):
     def test_login_default_sub(self) -> None:
         """Test reading user ID from the default subject claim."""
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
     @override_config({"jwt_config": {**base_config, "subject_claim": "username"}})
     def test_login_custom_sub(self) -> None:
         """Test reading user ID from a custom subject claim."""
         channel = self.jwt_login({"username": "frog"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@frog:test")
 
     def test_login_no_token(self) -> None:
         params = {"type": "org.matrix.login.jwt"}
         channel = self.make_request(b"POST", LOGIN_URL, params)
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(channel.json_body["error"], "Token field for JWT is missing")
 
@@ -1086,12 +1085,12 @@ class JWTPubKeyTestCase(unittest.HomeserverTestCase):
 
     def test_login_jwt_valid(self) -> None:
         channel = self.jwt_login({"sub": "kermit"})
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["user_id"], "@kermit:test")
 
     def test_login_jwt_invalid_signature(self) -> None:
         channel = self.jwt_login({"sub": "frog"}, self.bad_privatekey)
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
         self.assertEqual(
             channel.json_body["error"],
@@ -1152,7 +1151,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     def test_login_appservice_user_bot(self) -> None:
         """Test that the appservice bot can use /login"""
@@ -1166,7 +1165,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     def test_login_appservice_wrong_user(self) -> None:
         """Test that non-as users cannot login with the as token"""
@@ -1180,7 +1179,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.service.token
         )
 
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
 
     def test_login_appservice_wrong_as(self) -> None:
         """Test that as users cannot login with wrong as token"""
@@ -1194,7 +1193,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", LOGIN_URL, params, access_token=self.another_service.token
         )
 
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
 
     def test_login_appservice_no_token(self) -> None:
         """Test that users must provide a token when using the appservice
@@ -1208,7 +1207,7 @@ class AppserviceLoginRestServletTestCase(unittest.HomeserverTestCase):
         }
         channel = self.make_request(b"POST", LOGIN_URL, params)
 
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
 
 
 @skip_unless(HAS_OIDC, "requires OIDC")
@@ -1246,7 +1245,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
         )
 
         # that should redirect to the username picker
-        self.assertEqual(channel.code, HTTPStatus.FOUND, channel.result)
+        self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
         picker_url = location_headers[0]
@@ -1290,7 +1289,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
                 ("Content-Length", str(len(content))),
             ],
         )
-        self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
+        self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
         assert location_headers
 
@@ -1300,7 +1299,7 @@ class UsernamePickerTestCase(HomeserverTestCase):
             path=location_headers[0],
             custom_headers=[("Cookie", "username_mapping_session=" + session_id)],
         )
-        self.assertEqual(chan.code, HTTPStatus.FOUND, chan.result)
+        self.assertEqual(chan.code, 302, chan.result)
         location_headers = chan.headers.getRawHeaders("Location")
         assert location_headers
 
@@ -1325,5 +1324,5 @@ class UsernamePickerTestCase(HomeserverTestCase):
             "/login",
             content={"type": "m.login.token", "token": login_token},
         )
-        self.assertEqual(chan.code, HTTPStatus.OK, chan.result)
+        self.assertEqual(chan.code, 200, chan.result)
         self.assertEqual(chan.json_body["user_id"], "@bobby:test")
diff --git a/tests/rest/client/test_models.py b/tests/rest/client/test_models.py
new file mode 100644
index 0000000000..a9da00665e
--- /dev/null
+++ b/tests/rest/client/test_models.py
@@ -0,0 +1,53 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import unittest
+
+from pydantic import ValidationError
+
+from synapse.rest.client.models import EmailRequestTokenBody
+
+
+class EmailRequestTokenBodyTestCase(unittest.TestCase):
+    base_request = {
+        "client_secret": "hunter2",
+        "email": "alice@wonderland.com",
+        "send_attempt": 1,
+    }
+
+    def test_token_required_if_id_server_provided(self) -> None:
+        with self.assertRaises(ValidationError):
+            EmailRequestTokenBody.parse_obj(
+                {
+                    **self.base_request,
+                    "id_server": "identity.wonderland.com",
+                }
+            )
+        with self.assertRaises(ValidationError):
+            EmailRequestTokenBody.parse_obj(
+                {
+                    **self.base_request,
+                    "id_server": "identity.wonderland.com",
+                    "id_access_token": None,
+                }
+            )
+
+    def test_token_typechecked_when_id_server_provided(self) -> None:
+        with self.assertRaises(ValidationError):
+            EmailRequestTokenBody.parse_obj(
+                {
+                    **self.base_request,
+                    "id_server": "identity.wonderland.com",
+                    "id_access_token": 1337,
+                }
+            )
diff --git a/tests/rest/client/test_redactions.py b/tests/rest/client/test_redactions.py
index 7401b5e0c0..be4c67d68e 100644
--- a/tests/rest/client/test_redactions.py
+++ b/tests/rest/client/test_redactions.py
@@ -76,12 +76,12 @@ class RedactionsTestCase(HomeserverTestCase):
         path = "/_matrix/client/r0/rooms/%s/redact/%s" % (room_id, event_id)
 
         channel = self.make_request("POST", path, content={}, access_token=access_token)
-        self.assertEqual(int(channel.result["code"]), expect_code)
+        self.assertEqual(channel.code, expect_code)
         return channel.json_body
 
     def _sync_room_timeline(self, access_token: str, room_id: str) -> List[JsonDict]:
         channel = self.make_request("GET", "sync", access_token=self.mod_access_token)
-        self.assertEqual(channel.result["code"], b"200")
+        self.assertEqual(channel.code, 200)
         room_sync = channel.json_body["rooms"]["join"][room_id]
         return room_sync["timeline"]["events"]
 
diff --git a/tests/rest/client/test_register.py b/tests/rest/client/test_register.py
index 071b488cc0..b781875d52 100644
--- a/tests/rest/client/test_register.py
+++ b/tests/rest/client/test_register.py
@@ -70,7 +70,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         det_data = {"user_id": user_id, "home_server": self.hs.hostname}
         self.assertDictContainsSubset(det_data, channel.json_body)
 
@@ -91,7 +91,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
 
-        self.assertEqual(channel.result["code"], b"400", channel.result)
+        self.assertEqual(channel.code, 400, msg=channel.result)
 
     def test_POST_appservice_registration_invalid(self) -> None:
         self.appservice = None  # no application service exists
@@ -100,20 +100,20 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             b"POST", self.url + b"?access_token=i_am_an_app_service", request_data
         )
 
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
 
     def test_POST_bad_password(self) -> None:
         request_data = {"username": "kermit", "password": 666}
         channel = self.make_request(b"POST", self.url, request_data)
 
-        self.assertEqual(channel.result["code"], b"400", channel.result)
+        self.assertEqual(channel.code, 400, msg=channel.result)
         self.assertEqual(channel.json_body["error"], "Invalid password")
 
     def test_POST_bad_username(self) -> None:
         request_data = {"username": 777, "password": "monkey"}
         channel = self.make_request(b"POST", self.url, request_data)
 
-        self.assertEqual(channel.result["code"], b"400", channel.result)
+        self.assertEqual(channel.code, 400, msg=channel.result)
         self.assertEqual(channel.json_body["error"], "Invalid username")
 
     def test_POST_user_valid(self) -> None:
@@ -132,7 +132,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "home_server": self.hs.hostname,
             "device_id": device_id,
         }
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
 
     @override_config({"enable_registration": False})
@@ -142,7 +142,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         channel = self.make_request(b"POST", self.url, request_data)
 
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["error"], "Registration has been disabled")
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
 
@@ -153,7 +153,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
         det_data = {"home_server": self.hs.hostname, "device_id": "guest_device"}
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
 
     def test_POST_disabled_guest_registration(self) -> None:
@@ -161,7 +161,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(channel.json_body["error"], "Guest access is disabled")
 
     @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
@@ -171,16 +171,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", url, b"{}")
 
             if i == 5:
-                self.assertEqual(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEqual(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.code, 200, msg=channel.result)
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     @override_config({"rc_registration": {"per_second": 0.17, "burst_count": 5}})
     def test_POST_ratelimiting(self) -> None:
@@ -194,16 +194,16 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             channel = self.make_request(b"POST", self.url, request_data)
 
             if i == 5:
-                self.assertEqual(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEqual(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.code, 200, msg=channel.result)
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
         channel = self.make_request(b"POST", self.url + b"?kind=guest", b"{}")
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     @override_config({"registration_requires_token": True})
     def test_POST_registration_requires_token(self) -> None:
@@ -231,7 +231,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         # Request without auth to get flows and session
         channel = self.make_request(b"POST", self.url, params)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         flows = channel.json_body["flows"]
         # Synapse adds a dummy stage to differentiate flows where otherwise one
         # flow would be a subset of another flow.
@@ -248,7 +248,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "session": session,
         }
         channel = self.make_request(b"POST", self.url, params)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         completed = channel.json_body["completed"]
         self.assertCountEqual([LoginType.REGISTRATION_TOKEN], completed)
 
@@ -263,7 +263,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "home_server": self.hs.hostname,
             "device_id": device_id,
         }
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertDictContainsSubset(det_data, channel.json_body)
 
         # Check the `completed` counter has been incremented and pending is 0
@@ -293,21 +293,21 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "session": session,
         }
         channel = self.make_request(b"POST", self.url, params)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], Codes.MISSING_PARAM)
         self.assertEqual(channel.json_body["completed"], [])
 
         # Test with non-string (invalid)
         params["auth"]["token"] = 1234
         channel = self.make_request(b"POST", self.url, params)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], Codes.INVALID_PARAM)
         self.assertEqual(channel.json_body["completed"], [])
 
         # Test with unknown token (invalid)
         params["auth"]["token"] = "1234"
         channel = self.make_request(b"POST", self.url, params)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
         self.assertEqual(channel.json_body["completed"], [])
 
@@ -361,7 +361,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "session": session2,
         }
         channel = self.make_request(b"POST", self.url, params2)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
         self.assertEqual(channel.json_body["completed"], [])
 
@@ -381,7 +381,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
         # Check auth still fails when using token with session2
         channel = self.make_request(b"POST", self.url, params2)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
         self.assertEqual(channel.json_body["completed"], [])
 
@@ -415,7 +415,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
             "session": session,
         }
         channel = self.make_request(b"POST", self.url, params)
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         self.assertEqual(channel.json_body["errcode"], Codes.UNAUTHORIZED)
         self.assertEqual(channel.json_body["completed"], [])
 
@@ -570,7 +570,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
 
     def test_advertised_flows(self) -> None:
         channel = self.make_request(b"POST", self.url, b"{}")
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         flows = channel.json_body["flows"]
 
         # with the stock config, we only expect the dummy flow
@@ -593,7 +593,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     )
     def test_advertised_flows_captcha_and_terms_and_3pids(self) -> None:
         channel = self.make_request(b"POST", self.url, b"{}")
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         flows = channel.json_body["flows"]
 
         self.assertCountEqual(
@@ -625,7 +625,7 @@ class RegisterRestServletTestCase(unittest.HomeserverTestCase):
     )
     def test_advertised_flows_no_msisdn_email_required(self) -> None:
         channel = self.make_request(b"POST", self.url, b"{}")
-        self.assertEqual(channel.result["code"], b"401", channel.result)
+        self.assertEqual(channel.code, 401, msg=channel.result)
         flows = channel.json_body["flows"]
 
         # with the stock config, we expect all four combinations of 3pid
@@ -797,13 +797,13 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         # endpoint.
         channel = self.make_request(b"GET", "/sync", access_token=tok)
 
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
         self.reactor.advance(datetime.timedelta(weeks=1).total_seconds())
 
         channel = self.make_request(b"GET", "/sync", access_token=tok)
 
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(
             channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
         )
@@ -823,12 +823,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
         url = "/_synapse/admin/v1/account_validity/validity"
         request_data = {"user_id": user_id}
         channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
         # The specific endpoint doesn't matter, all we need is an authenticated
         # endpoint.
         channel = self.make_request(b"GET", "/sync", access_token=tok)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     def test_manual_expire(self) -> None:
         user_id = self.register_user("kermit", "monkey")
@@ -844,12 +844,12 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
             "enable_renewal_emails": False,
         }
         channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
         # The specific endpoint doesn't matter, all we need is an authenticated
         # endpoint.
         channel = self.make_request(b"GET", "/sync", access_token=tok)
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, msg=channel.result)
         self.assertEqual(
             channel.json_body["errcode"], Codes.EXPIRED_ACCOUNT, channel.result
         )
@@ -868,18 +868,18 @@ class AccountValidityTestCase(unittest.HomeserverTestCase):
             "enable_renewal_emails": False,
         }
         channel = self.make_request(b"POST", url, request_data, access_token=admin_tok)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Try to log the user out
         channel = self.make_request(b"POST", "/logout", access_token=tok)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Log the user in again (allowed for expired accounts)
         tok = self.login("kermit", "monkey")
 
         # Try to log out all of the user's sessions
         channel = self.make_request(b"POST", "/logout/all", access_token=tok)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
 
 class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
@@ -954,7 +954,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         renewal_token = self.get_success(self.store.get_renewal_token_for_user(user_id))
         url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
         channel = self.make_request(b"GET", url)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Check that we're getting HTML back.
         content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -972,7 +972,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         # Move 1 day forward. Try to renew with the same token again.
         url = "/_matrix/client/unstable/account_validity/renew?token=%s" % renewal_token
         channel = self.make_request(b"GET", url)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
         # Check that we're getting HTML back.
         content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -992,14 +992,14 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
         # succeed.
         self.reactor.advance(datetime.timedelta(days=3).total_seconds())
         channel = self.make_request(b"GET", "/sync", access_token=tok)
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
     def test_renewal_invalid_token(self) -> None:
         # Hit the renewal endpoint with an invalid token and check that it behaves as
         # expected, i.e. that it responds with 404 Not Found and the correct HTML.
         url = "/_matrix/client/unstable/account_validity/renew?token=123"
         channel = self.make_request(b"GET", url)
-        self.assertEqual(channel.result["code"], b"404", channel.result)
+        self.assertEqual(channel.code, 404, msg=channel.result)
 
         # Check that we're getting HTML back.
         content_type = channel.headers.getRawHeaders(b"Content-Type")
@@ -1023,7 +1023,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/unstable/account_validity/send_mail",
             access_token=tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
         self.assertEqual(len(self.email_attempts), 1)
 
@@ -1096,7 +1096,7 @@ class AccountValidityRenewalByEmailTestCase(unittest.HomeserverTestCase):
             "/_matrix/client/unstable/account_validity/send_mail",
             access_token=tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
 
         self.assertEqual(len(self.email_attempts), 1)
 
@@ -1176,7 +1176,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
             b"GET",
             f"{self.url}?token={token}",
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["valid"], True)
 
     def test_GET_token_invalid(self) -> None:
@@ -1185,7 +1185,7 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
             b"GET",
             f"{self.url}?token={token}",
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         self.assertEqual(channel.json_body["valid"], False)
 
     @override_config(
@@ -1201,10 +1201,10 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
             )
 
             if i == 5:
-                self.assertEqual(channel.result["code"], b"429", channel.result)
+                self.assertEqual(channel.code, 429, msg=channel.result)
                 retry_after_ms = int(channel.json_body["retry_after_ms"])
             else:
-                self.assertEqual(channel.result["code"], b"200", channel.result)
+                self.assertEqual(channel.code, 200, msg=channel.result)
 
         self.reactor.advance(retry_after_ms / 1000.0 + 1.0)
 
@@ -1212,4 +1212,4 @@ class RegistrationTokenValidityRestServletTestCase(unittest.HomeserverTestCase):
             b"GET",
             f"{self.url}?token={token}",
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, msg=channel.result)
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index ad03eee17b..d589f07314 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1060,6 +1060,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                     participated, bundled_aggregations.get("current_user_participated")
                 )
                 # The latest thread event has some fields that don't matter.
+                self.assertIn("latest_event", bundled_aggregations)
                 self.assert_dict(
                     {
                         "content": {
@@ -1072,7 +1073,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                         "sender": self.user2_id,
                         "type": "m.room.test",
                     },
-                    bundled_aggregations.get("latest_event"),
+                    bundled_aggregations["latest_event"],
                 )
 
             return assert_thread
@@ -1112,6 +1113,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
             self.assertEqual(2, bundled_aggregations.get("count"))
             self.assertTrue(bundled_aggregations.get("current_user_participated"))
             # The latest thread event has some fields that don't matter.
+            self.assertIn("latest_event", bundled_aggregations)
             self.assert_dict(
                 {
                     "content": {
@@ -1124,7 +1126,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                     "sender": self.user_id,
                     "type": "m.room.test",
                 },
-                bundled_aggregations.get("latest_event"),
+                bundled_aggregations["latest_event"],
             )
             # Check the unsigned field on the latest event.
             self.assert_dict(
diff --git a/tests/rest/client/test_report_event.py b/tests/rest/client/test_report_event.py
index ad0d0209f7..7cb1017a4a 100644
--- a/tests/rest/client/test_report_event.py
+++ b/tests/rest/client/test_report_event.py
@@ -77,6 +77,4 @@ class ReportEventTestCase(unittest.HomeserverTestCase):
         channel = self.make_request(
             "POST", self.report_path, data, access_token=self.other_user_tok
         )
-        self.assertEqual(
-            response_status, int(channel.result["code"]), msg=channel.result["body"]
-        )
+        self.assertEqual(response_status, channel.code, msg=channel.result["body"])
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index ac9c113354..9c8c1889d3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -20,7 +20,7 @@ from synapse.api.constants import EventTypes
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 from synapse.visibility import filter_events_for_client
 
@@ -188,7 +188,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         message_handler = self.hs.get_message_handler()
         create_event = self.get_success(
             message_handler.get_room_data(
-                self.user_id, room_id, EventTypes.Create, state_key=""
+                create_requester(self.user_id), room_id, EventTypes.Create, state_key=""
             )
         )
 
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index c45cb32090..aa2f578441 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -496,7 +496,7 @@ class RoomStateTestCase(RoomBase):
 
         self.assertEqual(HTTPStatus.OK, channel.code, msg=channel.result["body"])
         self.assertCountEqual(
-            [state_event["type"] for state_event in channel.json_body],
+            [state_event["type"] for state_event in channel.json_list],
             {
                 "m.room.create",
                 "m.room.power_levels",
@@ -2070,7 +2070,6 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
 
         config = self.default_config()
         config["allow_public_rooms_without_auth"] = True
-        config["experimental_features"] = {"msc3827_enabled": True}
         self.hs = self.setup_test_homeserver(config=config)
         self.url = b"/_matrix/client/r0/publicRooms"
 
@@ -2123,13 +2122,13 @@ class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
         chunk, count = self.make_public_rooms_request([None])
 
         self.assertEqual(count, 1)
-        self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), None)
+        self.assertEqual(chunk[0].get("room_type", None), None)
 
     def test_returns_only_space_based_on_filter(self) -> None:
         chunk, count = self.make_public_rooms_request(["m.space"])
 
         self.assertEqual(count, 1)
-        self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), "m.space")
+        self.assertEqual(chunk[0].get("room_type", None), "m.space")
 
     def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
         chunk, count = self.make_public_rooms_request(["m.space", None])
diff --git a/tests/rest/client/test_shadow_banned.py b/tests/rest/client/test_shadow_banned.py
index d9bd8c4a28..c50f034b34 100644
--- a/tests/rest/client/test_shadow_banned.py
+++ b/tests/rest/client/test_shadow_banned.py
@@ -26,7 +26,7 @@ from synapse.rest.client import (
     room_upgrade_rest_servlet,
 )
 from synapse.server import HomeServer
-from synapse.types import UserID
+from synapse.types import UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -275,7 +275,7 @@ class ProfileTestCase(_ShadowBannedBase):
         message_handler = self.hs.get_message_handler()
         event = self.get_success(
             message_handler.get_room_data(
-                self.banned_user_id,
+                create_requester(self.banned_user_id),
                 room_id,
                 "m.room.member",
                 self.banned_user_id,
@@ -310,7 +310,7 @@ class ProfileTestCase(_ShadowBannedBase):
         message_handler = self.hs.get_message_handler()
         event = self.get_success(
             message_handler.get_room_data(
-                self.banned_user_id,
+                create_requester(self.banned_user_id),
                 room_id,
                 "m.room.member",
                 self.banned_user_id,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index b085c50356..de0dec8539 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -38,7 +38,6 @@ from tests.federation.transport.test_knocking import (
     KnockingStrippedStateEventHelperMixin,
 )
 from tests.server import TimedOutException
-from tests.unittest import override_config
 
 
 class FilterTestCase(unittest.HomeserverTestCase):
@@ -390,6 +389,12 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         sync.register_servlets,
     ]
 
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+        config["experimental_features"] = {"msc2285_enabled": True}
+
+        return self.setup_test_homeserver(config=config)
+
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.url = "/sync?since=%s"
         self.next_batch = "s0"
@@ -408,15 +413,17 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         # Join the second user
         self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
 
-    @override_config({"experimental_features": {"msc2285_enabled": True}})
-    def test_private_read_receipts(self) -> None:
+    @parameterized.expand(
+        [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
+    )
+    def test_private_read_receipts(self, receipt_type: str) -> None:
         # Send a message as the first user
         res = self.helper.send(self.room_id, body="hello", tok=self.tok)
 
         # Send a private read receipt to tell the server the first user's message was read
         channel = self.make_request(
             "POST",
-            f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+            f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
             {},
             access_token=self.tok2,
         )
@@ -425,8 +432,10 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         # Test that the first user can't see the other user's private read receipt
         self.assertIsNone(self._get_read_receipt())
 
-    @override_config({"experimental_features": {"msc2285_enabled": True}})
-    def test_public_receipt_can_override_private(self) -> None:
+    @parameterized.expand(
+        [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
+    )
+    def test_public_receipt_can_override_private(self, receipt_type: str) -> None:
         """
         Sending a public read receipt to the same event which has a private read
         receipt should cause that receipt to become public.
@@ -437,7 +446,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         # Send a private read receipt
         channel = self.make_request(
             "POST",
-            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
+            f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
             {},
             access_token=self.tok2,
         )
@@ -456,8 +465,10 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         # Test that we did override the private read receipt
         self.assertNotEqual(self._get_read_receipt(), None)
 
-    @override_config({"experimental_features": {"msc2285_enabled": True}})
-    def test_private_receipt_cannot_override_public(self) -> None:
+    @parameterized.expand(
+        [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
+    )
+    def test_private_receipt_cannot_override_public(self, receipt_type: str) -> None:
         """
         Sending a private read receipt to the same event which has a public read
         receipt should cause no change.
@@ -478,7 +489,7 @@ class ReadReceiptsTestCase(unittest.HomeserverTestCase):
         # Send a private read receipt
         channel = self.make_request(
             "POST",
-            f"/rooms/{self.room_id}/receipt/{ReceiptTypes.READ_PRIVATE}/{res['event_id']}",
+            f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
             {},
             access_token=self.tok2,
         )
@@ -590,7 +601,10 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
             tok=self.tok,
         )
 
-    def test_unread_counts(self) -> None:
+    @parameterized.expand(
+        [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE]
+    )
+    def test_unread_counts(self, receipt_type: str) -> None:
         """Tests that /sync returns the right value for the unread count (MSC2654)."""
 
         # Check that our own messages don't increase the unread count.
@@ -624,7 +638,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         # Send a read receipt to tell the server we've read the latest event.
         channel = self.make_request(
             "POST",
-            f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res['event_id']}",
+            f"/rooms/{self.room_id}/receipt/{receipt_type}/{res['event_id']}",
             {},
             access_token=self.tok,
         )
@@ -700,7 +714,7 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         self._check_unread_count(5)
         res2 = self.helper.send(self.room_id, "hello", tok=self.tok2)
 
-        # Make sure both m.read and org.matrix.msc2285.read.private advance
+        # Make sure both m.read and m.read.private advance
         channel = self.make_request(
             "POST",
             f"/rooms/{self.room_id}/receipt/m.read/{res1['event_id']}",
@@ -712,16 +726,22 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
 
         channel = self.make_request(
             "POST",
-            f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res2['event_id']}",
+            f"/rooms/{self.room_id}/receipt/{receipt_type}/{res2['event_id']}",
             {},
             access_token=self.tok,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
         self._check_unread_count(0)
 
-    # We test for both receipt types that influence notification counts
-    @parameterized.expand([ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE])
-    def test_read_receipts_only_go_down(self, receipt_type: ReceiptTypes) -> None:
+    # We test for all three receipt types that influence notification counts
+    @parameterized.expand(
+        [
+            ReceiptTypes.READ,
+            ReceiptTypes.READ_PRIVATE,
+            ReceiptTypes.UNSTABLE_READ_PRIVATE,
+        ]
+    )
+    def test_read_receipts_only_go_down(self, receipt_type: str) -> None:
         # Join the new user
         self.helper.join(room=self.room_id, user=self.user2, tok=self.tok2)
 
@@ -739,11 +759,11 @@ class UnreadMessagesTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.json_body)
         self._check_unread_count(0)
 
-        # Make sure neither m.read nor org.matrix.msc2285.read.private make the
+        # Make sure neither m.read nor m.read.private make the
         # read receipt go up to an older event
         channel = self.make_request(
             "POST",
-            f"/rooms/{self.room_id}/receipt/org.matrix.msc2285.read.private/{res1['event_id']}",
+            f"/rooms/{self.room_id}/receipt/{receipt_type}/{res1['event_id']}",
             {},
             access_token=self.tok,
         )
@@ -948,3 +968,24 @@ class ExcludeRoomTestCase(unittest.HomeserverTestCase):
 
         self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["invite"])
         self.assertIn(self.included_room_id, channel.json_body["rooms"]["invite"])
+
+    def test_incremental_sync(self) -> None:
+        """Tests that activity in the room is properly filtered out of incremental
+        syncs.
+        """
+        channel = self.make_request("GET", "/sync", access_token=self.tok)
+        self.assertEqual(channel.code, 200, channel.result)
+        next_batch = channel.json_body["next_batch"]
+
+        self.helper.send(self.excluded_room_id, tok=self.tok)
+        self.helper.send(self.included_room_id, tok=self.tok)
+
+        channel = self.make_request(
+            "GET",
+            f"/sync?since={next_batch}",
+            access_token=self.tok,
+        )
+        self.assertEqual(channel.code, 200, channel.result)
+
+        self.assertNotIn(self.excluded_room_id, channel.json_body["rooms"]["join"])
+        self.assertIn(self.included_room_id, channel.json_body["rooms"]["join"])
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 9a48e9286f..3325d43a2f 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -20,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
 from synapse.api.constants import EventTypes, LoginType, Membership
 from synapse.api.errors import SynapseError
 from synapse.api.room_versions import RoomVersion
+from synapse.config.homeserver import HomeServerConfig
 from synapse.events import EventBase
 from synapse.events.third_party_rules import load_legacy_third_party_event_rules
 from synapse.rest import admin
@@ -154,7 +155,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             {},
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         callback.assert_called_once()
 
@@ -172,7 +173,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             {},
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"403", channel.result)
+        self.assertEqual(channel.code, 403, channel.result)
 
     def test_third_party_rules_workaround_synapse_errors_pass_through(self) -> None:
         """
@@ -185,12 +186,12 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         """
 
         class NastyHackException(SynapseError):
-            def error_dict(self) -> JsonDict:
+            def error_dict(self, config: Optional[HomeServerConfig]) -> JsonDict:
                 """
                 This overrides SynapseError's `error_dict` to nastily inject
                 JSON into the error response.
                 """
-                result = super().error_dict()
+                result = super().error_dict(config)
                 result["nasty"] = "very"
                 return result
 
@@ -210,7 +211,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             access_token=self.tok,
         )
         # Check the error code
-        self.assertEqual(channel.result["code"], b"429", channel.result)
+        self.assertEqual(channel.code, 429, channel.result)
         # Check the JSON body has had the `nasty` key injected
         self.assertEqual(
             channel.json_body,
@@ -259,7 +260,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             {"x": "x"},
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         event_id = channel.json_body["event_id"]
 
         # ... and check that it got modified
@@ -268,7 +269,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         ev = channel.json_body
         self.assertEqual(ev["content"]["x"], "y")
 
@@ -297,7 +298,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             },
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         orig_event_id = channel.json_body["event_id"]
 
         channel = self.make_request(
@@ -314,7 +315,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             },
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         edited_event_id = channel.json_body["event_id"]
 
         # ... and check that they both got modified
@@ -323,7 +324,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, orig_event_id),
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         ev = channel.json_body
         self.assertEqual(ev["content"]["body"], "ORIGINAL BODY")
 
@@ -332,7 +333,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, edited_event_id),
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
         ev = channel.json_body
         self.assertEqual(ev["content"]["body"], "EDITED BODY")
 
@@ -378,7 +379,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             },
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         event_id = channel.json_body["event_id"]
 
@@ -387,7 +388,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             "/_matrix/client/r0/rooms/%s/event/%s" % (self.room_id, event_id),
             access_token=self.tok,
         )
-        self.assertEqual(channel.result["code"], b"200", channel.result)
+        self.assertEqual(channel.code, 200, channel.result)
 
         self.assertIn("foo", channel.json_body["content"].keys())
         self.assertEqual(channel.json_body["content"]["foo"], "bar")
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index 105d418698..dd26145bf8 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -140,7 +140,7 @@ class RestHelper:
             custom_headers=custom_headers,
         )
 
-        assert channel.result["code"] == b"%d" % expect_code, channel.result
+        assert channel.code == expect_code, channel.result
         self.auth_user_id = temp_id
 
         if expect_code == HTTPStatus.OK:
@@ -213,11 +213,9 @@ class RestHelper:
             data,
         )
 
-        assert (
-            int(channel.result["code"]) == expect_code
-        ), "Expected: %d, got: %d, resp: %r" % (
+        assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
             expect_code,
-            int(channel.result["code"]),
+            channel.code,
             channel.result["body"],
         )
 
@@ -312,11 +310,9 @@ class RestHelper:
             data,
         )
 
-        assert (
-            int(channel.result["code"]) == expect_code
-        ), "Expected: %d, got: %d, resp: %r" % (
+        assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
             expect_code,
-            int(channel.result["code"]),
+            channel.code,
             channel.result["body"],
         )
 
@@ -396,11 +392,9 @@ class RestHelper:
             custom_headers=custom_headers,
         )
 
-        assert (
-            int(channel.result["code"]) == expect_code
-        ), "Expected: %d, got: %d, resp: %r" % (
+        assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
             expect_code,
-            int(channel.result["code"]),
+            channel.code,
             channel.result["body"],
         )
 
@@ -449,11 +443,9 @@ class RestHelper:
 
         channel = make_request(self.hs.get_reactor(), self.site, method, path, content)
 
-        assert (
-            int(channel.result["code"]) == expect_code
-        ), "Expected: %d, got: %d, resp: %r" % (
+        assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
             expect_code,
-            int(channel.result["code"]),
+            channel.code,
             channel.result["body"],
         )
 
@@ -545,7 +537,7 @@ class RestHelper:
 
         assert channel.code == expect_code, "Expected: %d, got: %d, resp: %r" % (
             expect_code,
-            int(channel.result["code"]),
+            channel.code,
             channel.result["body"],
         )