summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/app/test_openid_listener.py2
-rw-r--r--tests/federation/test_federation_client.py36
-rw-r--r--tests/federation/transport/test_client.py7
-rw-r--r--tests/handlers/test_auth.py135
-rw-r--r--tests/handlers/test_oidc.py580
-rw-r--r--tests/push/test_bulk_push_rule_evaluator.py2
-rw-r--r--tests/push/test_push_rule_evaluator.py215
-rw-r--r--tests/replication/_base.py2
-rw-r--r--tests/rest/admin/test_user.py35
-rw-r--r--tests/rest/client/test_auth.py32
-rw-r--r--tests/rest/client/test_login.py40
-rw-r--r--tests/rest/client/utils.py136
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py4
-rw-r--r--tests/storage/test_room_search.py213
-rw-r--r--tests/test_utils/__init__.py40
-rw-r--r--tests/test_utils/oidc.py325
-rw-r--r--tests/util/caches/test_descriptors.py42
-rw-r--r--tests/util/test_macaroons.py28
18 files changed, 1286 insertions, 588 deletions
diff --git a/tests/app/test_openid_listener.py b/tests/app/test_openid_listener.py
index c7dae58eb5..8d03da7f96 100644
--- a/tests/app/test_openid_listener.py
+++ b/tests/app/test_openid_listener.py
@@ -79,7 +79,7 @@ class FederationReaderOpenIDListenerTests(HomeserverTestCase):
         self.assertEqual(channel.code, 401)
 
 
-@patch("synapse.app.homeserver.KeyApiV2Resource", new=Mock())
+@patch("synapse.app.homeserver.KeyResource", new=Mock())
 class SynapseHomeserverOpenIDListenerTests(HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index a538215931..51d3bb8fff 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -12,13 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 from unittest import mock
 
 import twisted.web.client
 from twisted.internet import defer
-from twisted.internet.protocol import Protocol
-from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.room_versions import RoomVersions
@@ -26,10 +23,9 @@ from synapse.events import EventBase
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
-from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.test_utils import event_injection
+from tests.test_utils import FakeResponse, event_injection
 from tests.unittest import FederatingHomeserverTestCase
 
 
@@ -98,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         # mock up the response, and have the agent return it
         self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
-            _mock_response(
-                {
+            FakeResponse.json(
+                payload={
                     "pdus": [
                         create_event_dict,
                         member_event_dict,
@@ -208,8 +204,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         # mock up the response, and have the agent return it
         self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
-            _mock_response(
-                {
+            FakeResponse.json(
+                payload={
                     "origin": "yet.another.server",
                     "origin_server_ts": 900,
                     "pdus": [
@@ -269,8 +265,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         # We expect an outbound request to /backfill, so stub that out
         self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
-            _mock_response(
-                {
+            FakeResponse.json(
+                payload={
                     "origin": "yet.another.server",
                     "origin_server_ts": 900,
                     # Mimic the other server returning our new `pulled_event`
@@ -305,21 +301,3 @@ class FederationClientTest(FederatingHomeserverTestCase):
         # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the
         # other from "yet.another.server"
         self.assertEqual(backfill_num_attempts, 2)
-
-
-def _mock_response(resp: JsonDict):
-    body = json.dumps(resp).encode("utf-8")
-
-    def deliver_body(p: Protocol):
-        p.dataReceived(body)
-        p.connectionLost(Failure(twisted.web.client.ResponseDone()))
-
-    response = mock.Mock(
-        code=200,
-        phrase=b"OK",
-        headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
-        length=len(body),
-        deliverBody=deliver_body,
-    )
-    mock.seal(response)
-    return response
diff --git a/tests/federation/transport/test_client.py b/tests/federation/transport/test_client.py
index 0926e0583d..dd4d1b56de 100644
--- a/tests/federation/transport/test_client.py
+++ b/tests/federation/transport/test_client.py
@@ -17,6 +17,7 @@ from unittest.mock import Mock
 
 from synapse.api.room_versions import RoomVersions
 from synapse.federation.transport.client import SendJoinParser
+from synapse.util import ExceptionBundle
 
 from tests.unittest import TestCase
 
@@ -121,10 +122,8 @@ class SendJoinParserTestCase(TestCase):
         # Send half of the data to the parser
         parser.write(serialisation[: len(serialisation) // 2])
 
-        # Close the parser. There should be _some_ kind of exception, but it need not
-        # be that RuntimeError directly. E.g. we might want to raise a wrapper
-        # encompassing multiple errors from multiple coroutines.
-        with self.assertRaises(Exception):
+        # Close the parser. There should be _some_ kind of exception.
+        with self.assertRaises(ExceptionBundle):
             parser.finish()
 
         # In any case, we should have tried to close both coros.
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
index 7106799d44..036dbbc45b 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Optional
 from unittest.mock import Mock
 
 import pymacaroons
@@ -19,6 +20,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import AuthError, ResourceLimitError
 from synapse.rest import admin
+from synapse.rest.client import login
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -29,6 +31,7 @@ from tests.test_utils import make_awaitable
 class AuthTestCase(unittest.HomeserverTestCase):
     servlets = [
         admin.register_servlets,
+        login.register_servlets,
     ]
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
@@ -46,6 +49,23 @@ class AuthTestCase(unittest.HomeserverTestCase):
 
         self.user1 = self.register_user("a_user", "pass")
 
+    def token_login(self, token: str) -> Optional[str]:
+        body = {
+            "type": "m.login.token",
+            "token": token,
+        }
+
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/v3/login",
+            body,
+        )
+
+        if channel.code == 200:
+            return channel.json_body["user_id"]
+
+        return None
+
     def test_macaroon_caveats(self) -> None:
         token = self.macaroon_generator.generate_guest_access_token("a_user")
         macaroon = pymacaroons.Macaroon.deserialize(token)
@@ -73,49 +93,62 @@ class AuthTestCase(unittest.HomeserverTestCase):
         v.satisfy_general(verify_guest)
         v.verify(macaroon, self.hs.config.key.macaroon_secret_key)
 
-    def test_short_term_login_token_gives_user_id(self) -> None:
-        token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, "", duration_in_ms=5000
+    def test_login_token_gives_user_id(self) -> None:
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(
+                self.user1,
+                duration_ms=(5 * 1000),
+            )
         )
-        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+
+        res = self.get_success(self.auth_handler.consume_login_token(token))
         self.assertEqual(self.user1, res.user_id)
-        self.assertEqual("", res.auth_provider_id)
+        self.assertEqual(None, res.auth_provider_id)
 
-        # when we advance the clock, the token should be rejected
-        self.reactor.advance(6)
-        self.get_failure(
-            self.auth_handler.validate_short_term_login_token(token),
-            AuthError,
+    def test_login_token_reuse_fails(self) -> None:
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(
+                self.user1,
+                duration_ms=(5 * 1000),
+            )
         )
 
-    def test_short_term_login_token_gives_auth_provider(self) -> None:
-        token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, auth_provider_id="my_idp"
-        )
-        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
-        self.assertEqual(self.user1, res.user_id)
-        self.assertEqual("my_idp", res.auth_provider_id)
+        self.get_success(self.auth_handler.consume_login_token(token))
 
-    def test_short_term_login_token_cannot_replace_user_id(self) -> None:
-        token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, "", duration_in_ms=5000
+        self.get_failure(
+            self.auth_handler.consume_login_token(token),
+            AuthError,
         )
-        macaroon = pymacaroons.Macaroon.deserialize(token)
 
-        res = self.get_success(
-            self.auth_handler.validate_short_term_login_token(macaroon.serialize())
+    def test_login_token_expires(self) -> None:
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(
+                self.user1,
+                duration_ms=(5 * 1000),
+            )
         )
-        self.assertEqual(self.user1, res.user_id)
-
-        # add another "user_id" caveat, which might allow us to override the
-        # user_id.
-        macaroon.add_first_party_caveat("user_id = b_user")
 
+        # when we advance the clock, the token should be rejected
+        self.reactor.advance(6)
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
+            self.auth_handler.consume_login_token(token),
             AuthError,
         )
 
+    def test_login_token_gives_auth_provider(self) -> None:
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(
+                self.user1,
+                auth_provider_id="my_idp",
+                auth_provider_session_id="11-22-33-44",
+                duration_ms=(5 * 1000),
+            )
+        )
+        res = self.get_success(self.auth_handler.consume_login_token(token))
+        self.assertEqual(self.user1, res.user_id)
+        self.assertEqual("my_idp", res.auth_provider_id)
+        self.assertEqual("11-22-33-44", res.auth_provider_session_id)
+
     def test_mau_limits_disabled(self) -> None:
         self.auth_blocking._limit_usage_by_mau = False
         # Ensure does not throw exception
@@ -125,12 +158,12 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.get_success(
-            self.auth_handler.validate_short_term_login_token(
-                self._get_macaroon().serialize()
-            )
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(self.user1)
         )
 
+        self.assertIsNotNone(self.token_login(token))
+
     def test_mau_limits_exceeded_large(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
         self.hs.get_datastores().main.get_monthly_active_count = Mock(
@@ -147,12 +180,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.hs.get_datastores().main.get_monthly_active_count = Mock(
             return_value=make_awaitable(self.large_number_of_users)
         )
-        self.get_failure(
-            self.auth_handler.validate_short_term_login_token(
-                self._get_macaroon().serialize()
-            ),
-            ResourceLimitError,
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(self.user1)
         )
+        self.assertIsNone(self.token_login(token))
 
     def test_mau_limits_parity(self) -> None:
         # Ensure we're not at the unix epoch.
@@ -171,12 +202,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ),
             ResourceLimitError,
         )
-        self.get_failure(
-            self.auth_handler.validate_short_term_login_token(
-                self._get_macaroon().serialize()
-            ),
-            ResourceLimitError,
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(self.user1)
         )
+        self.assertIsNone(self.token_login(token))
 
         # If in monthly active cohort
         self.hs.get_datastores().main.user_last_seen_monthly_active = Mock(
@@ -187,11 +216,10 @@ class AuthTestCase(unittest.HomeserverTestCase):
                 self.user1, device_id=None, valid_until_ms=None
             )
         )
-        self.get_success(
-            self.auth_handler.validate_short_term_login_token(
-                self._get_macaroon().serialize()
-            )
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(self.user1)
         )
+        self.assertIsNotNone(self.token_login(token))
 
     def test_mau_limits_not_exceeded(self) -> None:
         self.auth_blocking._limit_usage_by_mau = True
@@ -209,14 +237,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         self.hs.get_datastores().main.get_monthly_active_count = Mock(
             return_value=make_awaitable(self.small_number_of_users)
         )
-        self.get_success(
-            self.auth_handler.validate_short_term_login_token(
-                self._get_macaroon().serialize()
-            )
-        )
-
-    def _get_macaroon(self) -> pymacaroons.Macaroon:
-        token = self.macaroon_generator.generate_short_term_login_token(
-            self.user1, "", duration_in_ms=5000
+        token = self.get_success(
+            self.auth_handler.create_login_token_for_user_id(self.user1)
         )
-        return pymacaroons.Macaroon.deserialize(token)
+        self.assertIsNotNone(self.token_login(token))
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e6cd3af7b7..5955410524 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -11,9 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
 import os
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
 from unittest.mock import ANY, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
@@ -22,12 +21,15 @@ import pymacaroons
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.handlers.sso import MappingException
+from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import UserID
 from synapse.util import Clock
-from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon
+from synapse.util.macaroons import get_value_from_macaroon
+from synapse.util.stringutils import random_string
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
 from tests.unittest import HomeserverTestCase, override_config
 
 try:
@@ -46,12 +48,6 @@ BASE_URL = "https://synapse/"
 CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
 SCOPES = ["openid"]
 
-AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
-TOKEN_ENDPOINT = ISSUER + "token"
-USERINFO_ENDPOINT = ISSUER + "userinfo"
-WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
-JWKS_URI = ISSUER + ".well-known/jwks.json"
-
 # config for common cases
 DEFAULT_CONFIG = {
     "enabled": True,
@@ -66,9 +62,9 @@ DEFAULT_CONFIG = {
 EXPLICIT_ENDPOINT_CONFIG = {
     **DEFAULT_CONFIG,
     "discover": False,
-    "authorization_endpoint": AUTHORIZATION_ENDPOINT,
-    "token_endpoint": TOKEN_ENDPOINT,
-    "jwks_uri": JWKS_URI,
+    "authorization_endpoint": ISSUER + "authorize",
+    "token_endpoint": ISSUER + "token",
+    "jwks_uri": ISSUER + "jwks",
 }
 
 
@@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-async def get_json(url: str) -> JsonDict:
-    # Mock get_json calls to handle jwks & oidc discovery endpoints
-    if url == WELL_KNOWN:
-        # Minimal discovery document, as defined in OpenID.Discovery
-        # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
-        return {
-            "issuer": ISSUER,
-            "authorization_endpoint": AUTHORIZATION_ENDPOINT,
-            "token_endpoint": TOKEN_ENDPOINT,
-            "jwks_uri": JWKS_URI,
-            "userinfo_endpoint": USERINFO_ENDPOINT,
-            "response_types_supported": ["code"],
-            "subject_types_supported": ["public"],
-            "id_token_signing_alg_values_supported": ["RS256"],
-        }
-    elif url == JWKS_URI:
-        return {"keys": []}
-
-    return {}
-
-
 def _key_file_path() -> str:
     """path to a file containing the private half of a test key"""
 
@@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return config
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.http_client = Mock(spec=["get_json"])
-        self.http_client.get_json.side_effect = get_json
-        self.http_client.user_agent = b"Synapse Test"
+        self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER)
 
-        hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+        hs = self.setup_test_homeserver()
+        self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
+        self.hs_patcher.start()
 
         self.handler = hs.get_oidc_handler()
         self.provider = self.handler._providers["oidc"]
@@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # Reduce the number of attempts when generating MXIDs.
         sso_handler._MAP_USERNAME_RETRIES = 3
 
+        auth_handler = hs.get_auth_handler()
+        # Mock the complete SSO login method.
+        self.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = self.complete_sso_login  # type: ignore[assignment]
+
         return hs
 
+    def tearDown(self) -> None:
+        self.hs_patcher.stop()
+        return super().tearDown()
+
+    def reset_mocks(self):
+        """Reset all the Mocks."""
+        self.fake_server.reset_mocks()
+        self.render_error.reset_mock()
+        self.complete_sso_login.reset_mock()
+
     def metadata_edit(self, values):
         """Modify the result that will be returned by the well-known query"""
 
-        async def patched_get_json(uri):
-            res = await get_json(uri)
-            if uri == WELL_KNOWN:
-                res.update(values)
-            return res
+        metadata = self.fake_server.get_metadata()
+        metadata.update(values)
+        return patch.object(self.fake_server, "get_metadata", return_value=metadata)
 
-        return patch.object(self.http_client, "get_json", patched_get_json)
+    def start_authorization(
+        self,
+        userinfo: dict,
+        client_redirect_url: str = "http://client/redirect",
+        scope: str = "openid",
+        with_sid: bool = False,
+    ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]:
+        """Start an authorization request, and get the callback request back."""
+        nonce = random_string(10)
+        state = random_string(10)
+
+        code, grant = self.fake_server.start_authorization(
+            userinfo=userinfo,
+            scope=scope,
+            client_id=self.provider._client_auth.client_id,
+            redirect_uri=self.provider._callback_url,
+            nonce=nonce,
+            with_sid=with_sid,
+        )
+        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+        return _build_callback_request(code, state, session), grant
 
     def assertRenderedError(self, error, error_description=None):
         self.render_error.assert_called_once()
@@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
         metadata = self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.fake_server.get_metadata_handler.assert_called_once()
 
-        self.assertEqual(metadata.issuer, ISSUER)
-        self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
-        self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
-        self.assertEqual(metadata.jwks_uri, JWKS_URI)
-        # FIXME: it seems like authlib does not have that defined in its metadata models
-        # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+        self.assertEqual(metadata.issuer, self.fake_server.issuer)
+        self.assertEqual(
+            metadata.authorization_endpoint,
+            self.fake_server.authorization_endpoint,
+        )
+        self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint)
+        self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri)
+        # It seems like authlib does not have that defined in its metadata models
+        self.assertEqual(
+            metadata.get("userinfo_endpoint"),
+            self.fake_server.userinfo_endpoint,
+        )
 
         # subsequent calls should be cached
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_metadata_handler.assert_not_called()
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     def test_no_discovery(self) -> None:
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_metadata_handler.assert_not_called()
 
-    @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_load_jwks(self) -> None:
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
-        self.http_client.get_json.assert_called_once_with(JWKS_URI)
-        self.assertEqual(jwks, {"keys": []})
+        self.fake_server.get_jwks_handler.assert_called_once()
+        self.assertEqual(jwks, self.fake_server.get_jwks())
 
         # subsequent calls should be cached…
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_jwks())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_jwks_handler.assert_not_called()
 
         # …unless forced
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_jwks(force=True))
-        self.http_client.get_json.assert_called_once_with(JWKS_URI)
+        self.fake_server.get_jwks_handler.assert_called_once()
 
-        # Throw if the JWKS uri is missing
-        original = self.provider.load_metadata
-
-        async def patched_load_metadata():
-            m = (await original()).copy()
-            m.update({"jwks_uri": None})
-            return m
-
-        with patch.object(self.provider, "load_metadata", patched_load_metadata):
+        with self.metadata_edit({"jwks_uri": None}):
+            # If we don't do this, the load_metadata call will throw because of the
+            # missing jwks_uri
+            self.provider._user_profile_method = "userinfo_endpoint"
+            self.get_success(self.provider.load_metadata(force=True))
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
@@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 self.provider.handle_redirect_request(req, b"http://client/redirect")
             )
         )
-        auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+        auth_endpoint = urlparse(self.fake_server.authorization_endpoint)
 
         self.assertEqual(url.scheme, auth_endpoint.scheme)
         self.assertEqual(url.netloc, auth_endpoint.netloc)
@@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
         with self.assertRaises(AttributeError):
             _ = mapping_provider.get_extra_attributes
 
-        token = {
-            "type": "bearer",
-            "id_token": "id_token",
-            "access_token": "access_token",
-        }
         username = "bar"
         userinfo = {
             "sub": "foo",
             "username": username,
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
 
-        code = "code"
-        state = "state"
-        nonce = "nonce"
         client_redirect_url = "http://client/redirect"
-        ip_address = "10.0.0.1"
-        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
-        request = _build_callback_request(code, state, session, ip_address=ip_address)
-
+        request, _ = self.start_authorization(
+            userinfo, client_redirect_url=client_redirect_url
+        )
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
             client_redirect_url,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.provider._fetch_userinfo.assert_not_called()
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
+        request, _ = self.start_authorization(userinfo)
         with patch.object(
             self.provider,
             "_remote_id_from_userinfo",
@@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.provider._parse_id_token = simple_async_mock(raises=Exception())  # type: ignore[assignment]
-        self.get_success(self.handler.handle_oidc_callback(request))
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}):
+            self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
-        auth_handler.complete_sso_login.reset_mock()
-        self.provider._exchange_code.reset_mock()
-        self.provider._parse_id_token.reset_mock()
-        self.provider._fetch_userinfo.reset_mock()
+        self.reset_mocks()
 
         # With userinfo fetching
         self.provider._user_profile_method = "userinfo_endpoint"
-        token = {
-            "type": "bearer",
-            "access_token": "access_token",
-        }
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        # Without the "openid" scope, the FakeProvider does not generate an id_token
+        request, _ = self.start_authorization(userinfo, scope="")
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_not_called()
-        self.provider._fetch_userinfo.assert_called_once_with(token)
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_called_once()
         self.render_error.assert_not_called()
 
+        self.reset_mocks()
+
         # With an ID token, userinfo fetching and sid in the ID token
         self.provider._user_profile_method = "userinfo_endpoint"
-        token = {
-            "type": "bearer",
-            "access_token": "access_token",
-            "id_token": "id_token",
-        }
-        id_token = {
-            "sid": "abcdefgh",
-        }
-        self.provider._parse_id_token = simple_async_mock(return_value=id_token)  # type: ignore[assignment]
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        auth_handler.complete_sso_login.reset_mock()
-        self.provider._fetch_userinfo.reset_mock()
+        request, grant = self.start_authorization(userinfo, with_sid=True)
+        self.assertIsNotNone(grant.sid)
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             None,
             new_user=False,
-            auth_provider_session_id=id_token["sid"],
+            auth_provider_session_id=grant.sid,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.provider._fetch_userinfo.assert_called_once_with(token)
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_called_once()
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())  # type: ignore[assignment]
-        self.get_success(self.handler.handle_oidc_callback(request))
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.buggy_endpoint(userinfo=True):
+            self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
-        # Handle code exchange failure
-        from synapse.handlers.oidc import OidcError
-
-        self.provider._exchange_code = simple_async_mock(  # type: ignore[assignment]
-            raises=OidcError("invalid_request")
-        )
-        self.get_success(self.handler.handle_oidc_callback(request))
-        self.assertRenderedError("invalid_request")
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.buggy_endpoint(token=True):
+            self.get_success(self.handler.handle_oidc_callback(request))
+        self.assertRenderedError("server_error")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_session(self) -> None:
@@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_exchange_code(self) -> None:
         """Code exchange behaves correctly and handles various error scenarios."""
-        token = {"type": "bearer"}
-        token_json = json.dumps(token).encode("utf-8")
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
         ret = self.get_success(self.provider._exchange_code(code))
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
 
         self.assertEqual(ret, token)
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         args = parse_qs(kwargs["data"].decode("utf-8"))
         self.assertEqual(args["grant_type"], ["authorization_code"])
@@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
 
         # Test error handling
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=400,
-                phrase=b"Bad Request",
-                body=b'{"error": "foo", "error_description": "bar"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=400, payload={"error": "foo", "error_description": "bar"}
         )
         from synapse.handlers.oidc import OidcError
 
@@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(exc.value.error_description, "bar")
 
         # Internal server error with no JSON body
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=500,
-                phrase=b"Internal Server Error",
-                body=b"Not JSON",
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse(
+            code=500, body=b"Not JSON"
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=500,
-                phrase=b"Internal Server Error",
-                body=b'{"error": "internal_server_error"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=500, payload={"error": "internal_server_error"}
         )
 
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=400,
-                phrase=b"Bad request",
-                body=b"{}",
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=400, payload={}
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200,
-                phrase=b"OK",
-                body=b'{"error": "some_error"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=200, payload={"error": "some_error"}
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "some_error")
@@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Test that code exchange works with a JWK client secret."""
         from authlib.jose import jwt
 
-        token = {"type": "bearer"}
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
-            )
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
 
@@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(ret, token)
 
         # the request should have hit the token endpoint
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         # the client secret provided to the should be a jwt which can be checked with
         # the public key
@@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_exchange_code_no_auth(self) -> None:
         """Test that code exchange works with no client secret."""
-        token = {"type": "bearer"}
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
-            )
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
         ret = self.get_success(self.provider._exchange_code(code))
@@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(ret, token)
 
         # the request should have hit the token endpoint
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         # check the POSTed data
         args = parse_qs(kwargs["data"].decode("utf-8"))
@@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """
         Login while using a mapping provider that implements get_extra_attributes.
         """
-        token = {
-            "type": "bearer",
-            "id_token": "id_token",
-            "access_token": "access_token",
-        }
         userinfo = {
             "sub": "foo",
             "username": "foo",
             "phone": "1234567",
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
-        state = "state"
-        client_redirect_url = "http://client/redirect"
-        session = self._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-        )
-        request = _build_callback_request("code", state, session)
-
+        request, _ = self.start_authorization(userinfo)
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@foo:test",
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             {"phone": "1234567"},
             new_user=True,
             auth_provider_session_id=None,
@@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_user(self) -> None:
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         userinfo: dict = {
             "sub": "test_user",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@test_user:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Some providers return an integer ID.
         userinfo = {
             "sub": 1234,
             "username": "test_user_2",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@test_user_2:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Test if the mxid is already taken
         store = self.hs.get_datastores().main
@@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error",
             "Mapping provider does not support de-duplicating Matrix IDs",
@@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user.to_string(), password_hash=None)
         )
 
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         # Map a user via SSO.
         userinfo = {
             "sub": "test",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Subsequent calls should map to the same mxid.
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Note that a second SSO user can be mapped to the same Matrix ID. (This
         # requires a unique sub, but something that maps to the same matrix ID,
@@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Register some non-exact matching cases.
         user2 = UserID.from_string("@TEST_user_2:test")
@@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         args = self.assertRenderedError("mapping_error")
         self.assertTrue(
             args[2].startswith(
@@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@TEST_USER_2:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
@@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        self.get_success(
-            _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
-        )
+        userinfo = {"sub": "test2", "username": "föö"}
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
@@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_map_userinfo_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         store = self.hs.get_datastores().main
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
@@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # test_user is already taken, so test_user1 gets registered instead.
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@test_user1:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Register all of the potential mxids for a particular OIDC username.
         self.get_success(
@@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
@@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
     @override_config(
@@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": None,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
     @override_config(
@@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the OIDC userinfo response."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         # userinfo lacking "test": "foobar" attribute should fail.
         userinfo = {
             "sub": "tester",
             "username": "tester",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": "foobar" attribute should succeed.
         userinfo = {
@@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": "foobar",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # check that the auth handler got called as expected
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@tester:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
@@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_attribute_requirements_contains(self) -> None:
         """Test that auth succeeds if userinfo attribute CONTAINS required value"""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
         userinfo = {
             "sub": "tester",
             "username": "tester",
             "test": ["foobar", "foo", "bar"],
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # check that the auth handler got called as expected
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@tester:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
@@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
         Test that auth fails if attributes exist but don't match,
         or are non-string values.
         """
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": "not_foobar" attribute should fail
         userinfo: dict = {
             "sub": "tester",
             "username": "tester",
             "test": "not_foobar",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": ["foo", "bar"] attribute should fail
         userinfo = {
@@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": ["foo", "bar"],
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": False attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": False,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": None attribute should fail
         # a value of None breaks the OIDC spec, but it's important to not crash here
@@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": None,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": 1 attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": 1,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": 3.14 attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": 3.14,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
     def _generate_oidc_session_token(
         self,
@@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return self.handler._macaroon_generator.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
-                idp_id="oidc",
+                idp_id=self.provider.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url,
                 ui_auth_session_id=ui_auth_session_id,
@@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
 
-async def _make_callback_with_userinfo(
-    hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
-) -> None:
-    """Mock up an OIDC callback with the given userinfo dict
-
-    We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
-    and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
-
-    Args:
-        hs: the HomeServer impl to send the callback to.
-        userinfo: the OIDC userinfo dict
-        client_redirect_url: the URL to redirect to on success.
-    """
-
-    handler = hs.get_oidc_handler()
-    provider = handler._providers["oidc"]
-    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})  # type: ignore[assignment]
-    provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-
-    state = "state"
-    session = handler._macaroon_generator.generate_oidc_session_token(
-        state=state,
-        session_data=OidcSessionData(
-            idp_id="oidc",
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id="",
-        ),
-    )
-    request = _build_callback_request("code", state, session)
-
-    await handler.handle_oidc_callback(request)
-
-
 def _build_callback_request(
     code: str,
     state: str,
diff --git a/tests/push/test_bulk_push_rule_evaluator.py b/tests/push/test_bulk_push_rule_evaluator.py
index 675d7df2ac..594e7937a8 100644
--- a/tests/push/test_bulk_push_rule_evaluator.py
+++ b/tests/push/test_bulk_push_rule_evaluator.py
@@ -71,4 +71,4 @@ class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
 
         bulk_evaluator = BulkPushRuleEvaluator(self.hs)
         # should not raise
-        self.get_success(bulk_evaluator.action_for_event_by_user(event, context))
+        self.get_success(bulk_evaluator.action_for_events_by_user([(event, context)]))
diff --git a/tests/push/test_push_rule_evaluator.py b/tests/push/test_push_rule_evaluator.py
index decf619466..fe7c145840 100644
--- a/tests/push/test_push_rule_evaluator.py
+++ b/tests/push/test_push_rule_evaluator.py
@@ -38,7 +38,9 @@ from tests.test_utils.event_injection import create_event, inject_member_event
 
 
 class PushRuleEvaluatorTestCase(unittest.TestCase):
-    def _get_evaluator(self, content: JsonDict) -> PushRuleEvaluator:
+    def _get_evaluator(
+        self, content: JsonDict, related_events=None
+    ) -> PushRuleEvaluator:
         event = FrozenEvent(
             {
                 "event_id": "$event_id",
@@ -58,6 +60,8 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             room_member_count,
             sender_power_level,
             power_levels.get("notifications", {}),
+            {} if related_events is None else related_events,
+            True,
         )
 
     def test_display_name(self) -> None:
@@ -292,6 +296,215 @@ class PushRuleEvaluatorTestCase(unittest.TestCase):
             {"sound": "default", "highlight": True},
         )
 
+    def test_related_event_match(self):
+        evaluator = self._get_evaluator(
+            {
+                "m.relates_to": {
+                    "event_id": "$parent_event_id",
+                    "key": "😀",
+                    "rel_type": "m.annotation",
+                    "m.in_reply_to": {
+                        "event_id": "$parent_event_id",
+                    },
+                }
+            },
+            {
+                "m.in_reply_to": {
+                    "event_id": "$parent_event_id",
+                    "type": "m.room.message",
+                    "sender": "@other_user:test",
+                    "room_id": "!room:test",
+                    "content.msgtype": "m.text",
+                    "content.body": "Original message",
+                },
+                "m.annotation": {
+                    "event_id": "$parent_event_id",
+                    "type": "m.room.message",
+                    "sender": "@other_user:test",
+                    "room_id": "!room:test",
+                    "content.msgtype": "m.text",
+                    "content.body": "Original message",
+                },
+            },
+        )
+        self.assertTrue(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@other_user:test",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@user:test",
+                },
+                "@other_user:test",
+                "display_name",
+            )
+        )
+        self.assertTrue(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.annotation",
+                    "pattern": "@other_user:test",
+                },
+                "@other_user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertTrue(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "rel_type": "m.in_reply_to",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "rel_type": "m.replace",
+                },
+                "@other_user:test",
+                "display_name",
+            )
+        )
+
+    def test_related_event_match_with_fallback(self):
+        evaluator = self._get_evaluator(
+            {
+                "m.relates_to": {
+                    "event_id": "$parent_event_id",
+                    "key": "😀",
+                    "rel_type": "m.thread",
+                    "is_falling_back": True,
+                    "m.in_reply_to": {
+                        "event_id": "$parent_event_id",
+                    },
+                }
+            },
+            {
+                "m.in_reply_to": {
+                    "event_id": "$parent_event_id",
+                    "type": "m.room.message",
+                    "sender": "@other_user:test",
+                    "room_id": "!room:test",
+                    "content.msgtype": "m.text",
+                    "content.body": "Original message",
+                    "im.vector.is_falling_back": "",
+                },
+                "m.thread": {
+                    "event_id": "$parent_event_id",
+                    "type": "m.room.message",
+                    "sender": "@other_user:test",
+                    "room_id": "!room:test",
+                    "content.msgtype": "m.text",
+                    "content.body": "Original message",
+                },
+            },
+        )
+        self.assertTrue(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@other_user:test",
+                    "include_fallbacks": True,
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@other_user:test",
+                    "include_fallbacks": False,
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@other_user:test",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+
+    def test_related_event_match_no_related_event(self):
+        evaluator = self._get_evaluator(
+            {"msgtype": "m.text", "body": "Message without related event"}
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                    "pattern": "@other_user:test",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "key": "sender",
+                    "rel_type": "m.in_reply_to",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+        self.assertFalse(
+            evaluator.matches(
+                {
+                    "kind": "im.nheko.msc3664.related_event_match",
+                    "rel_type": "m.in_reply_to",
+                },
+                "@user:test",
+                "display_name",
+            )
+        )
+
 
 class TestBulkPushRuleEvaluator(unittest.HomeserverTestCase):
     """Tests for the bulk push rule evaluator"""
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index ce53f808db..121f3d8d65 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -371,7 +371,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             config=worker_hs.config.server.listeners[0],
             resource=resource,
             server_version_string="1",
-            max_request_body_size=4096,
+            max_request_body_size=8192,
             reactor=self.reactor,
         )
 
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index 4c1ce33463..63410ffdf1 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -31,7 +31,7 @@ from synapse.api.room_versions import RoomVersions
 from synapse.rest.client import devices, login, logout, profile, register, room, sync
 from synapse.rest.media.v1.filepath import MediaFilePaths
 from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -924,6 +924,36 @@ class UsersListTestCase(unittest.HomeserverTestCase):
         self.assertEqual(1, len(non_admin_user_ids), non_admin_user_ids)
         self.assertEqual(not_approved_user, non_admin_user_ids[0])
 
+    def test_erasure_status(self) -> None:
+        # Create a new user.
+        user_id = self.register_user("eraseme", "eraseme")
+
+        # They should appear in the list users API, marked as not erased.
+        channel = self.make_request(
+            "GET",
+            self.url + "?deactivated=true",
+            access_token=self.admin_user_tok,
+        )
+        users = {user["name"]: user for user in channel.json_body["users"]}
+        self.assertIs(users[user_id]["erased"], False)
+
+        # Deactivate that user, requesting erasure.
+        deactivate_account_handler = self.hs.get_deactivate_account_handler()
+        self.get_success(
+            deactivate_account_handler.deactivate_account(
+                user_id, erase_data=True, requester=create_requester(user_id)
+            )
+        )
+
+        # Repeat the list users query. They should now be marked as erased.
+        channel = self.make_request(
+            "GET",
+            self.url + "?deactivated=true",
+            access_token=self.admin_user_tok,
+        )
+        users = {user["name"]: user for user in channel.json_body["users"]}
+        self.assertIs(users[user_id]["erased"], True)
+
     def _order_test(
         self,
         expected_user_list: List[str],
@@ -1195,6 +1225,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
         self.assertEqual("foo@bar.com", channel.json_body["threepids"][0]["address"])
         self.assertEqual("mxc://servername/mediaid", channel.json_body["avatar_url"])
         self.assertEqual("User1", channel.json_body["displayname"])
+        self.assertFalse(channel.json_body["erased"])
 
         # Deactivate and erase user
         channel = self.make_request(
@@ -1219,6 +1250,7 @@ class DeactivateAccountTestCase(unittest.HomeserverTestCase):
         self.assertEqual(0, len(channel.json_body["threepids"]))
         self.assertIsNone(channel.json_body["avatar_url"])
         self.assertIsNone(channel.json_body["displayname"])
+        self.assertTrue(channel.json_body["erased"])
 
         self._is_erased("@user:test", True)
 
@@ -2757,6 +2789,7 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         self.assertIn("avatar_url", content)
         self.assertIn("admin", content)
         self.assertIn("deactivated", content)
+        self.assertIn("erased", content)
         self.assertIn("shadow_banned", content)
         self.assertIn("creation_ts", content)
         self.assertIn("appservice_id", content)
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 090cef5216..ebf653d018 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -465,9 +465,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
           * checking that the original operation succeeds
         """
 
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # log the user in
         remote_user_id = UserID.from_string(self.user).localpart
-        login_resp = self.helper.login_via_oidc(remote_user_id)
+        login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id)
         self.assertEqual(login_resp["user_id"], self.user)
 
         # initiate a UI Auth process by attempting to delete the device
@@ -481,8 +483,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # run the UIA-via-SSO flow
         session_id = channel.json_body["session"]
-        channel = self.helper.auth_via_oidc(
-            {"sub": remote_user_id}, ui_auth_session_id=session_id
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id
         )
 
         # that should serve a confirmation page
@@ -499,7 +501,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_does_not_offer_password_for_sso_user(self) -> None:
-        login_resp = self.helper.login_via_oidc("username")
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username")
         user_tok = login_resp["access_token"]
         device_id = login_resp["device_id"]
 
@@ -522,7 +525,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_offers_both_flows_for_upgraded_user(self) -> None:
         """A user that had a password and then logged in with SSO should get both flows"""
-        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, UserID.from_string(self.user).localpart
+        )
         self.assertEqual(login_resp["user_id"], self.user)
 
         channel = self.delete_device(
@@ -539,8 +545,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
         """If the user tries to authenticate with the wrong SSO user, they get an error"""
+
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # log the user in
-        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, UserID.from_string(self.user).localpart
+        )
         self.assertEqual(login_resp["user_id"], self.user)
 
         # start a UI Auth flow by attempting to delete a device
@@ -553,8 +564,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
         session_id = channel.json_body["session"]
 
         # do the OIDC auth, but auth as the wrong user
-        channel = self.helper.auth_via_oidc(
-            {"sub": "wrong_user"}, ui_auth_session_id=session_id
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id
         )
 
         # that should return a failure message
@@ -584,7 +595,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """Tests that if we register a user via SSO while requiring approval for new
         accounts, we still raise the correct error before logging the user in.
         """
-        login_resp = self.helper.login_via_oidc("username", expected_status=403)
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, "username", expected_status=403
+        )
 
         self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
         self.assertEqual(
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index e801ba8c8b..ff5baa9f0a 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -36,7 +36,7 @@ from synapse.util import Clock
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
-from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
 from tests.test_utils.html_parsers import TestHtmlParser
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
@@ -612,13 +612,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
     def test_login_via_oidc(self) -> None:
         """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
 
-        # pick the default OIDC provider
-        channel = self.make_request(
-            "GET",
-            "/_synapse/client/pick_idp?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
-            + "&idp=oidc",
-        )
+        fake_oidc_server = self.helper.fake_oidc_server()
+
+        with fake_oidc_server.patch_homeserver(hs=self.hs):
+            # pick the default OIDC provider
+            channel = self.make_request(
+                "GET",
+                "/_synapse/client/pick_idp?redirectUrl="
+                + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+                + "&idp=oidc",
+            )
         self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
@@ -626,7 +629,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
 
         # it should redirect us to the auth page of the OIDC server
-        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+        self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
 
         # ... and should have set a cookie including the redirect url
         cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
@@ -643,7 +646,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             TEST_CLIENT_REDIRECT_URL,
         )
 
-        channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+        channel, _ = self.helper.complete_oidc_auth(
+            fake_oidc_server, oidc_uri, cookies, {"sub": "user1"}
+        )
 
         # that should serve a confirmation page
         self.assertEqual(channel.code, 200, channel.result)
@@ -693,7 +698,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
 
     def test_client_idp_redirect_to_oidc(self) -> None:
         """If the client pick a known IdP, redirect to it"""
-        channel = self._make_sso_redirect_request("oidc")
+        fake_oidc_server = self.helper.fake_oidc_server()
+
+        with fake_oidc_server.patch_homeserver(hs=self.hs):
+            channel = self._make_sso_redirect_request("oidc")
         self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
@@ -701,7 +709,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
 
         # it should redirect us to the auth page of the OIDC server
-        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+        self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
 
     def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel:
         """Send a request to /_matrix/client/r0/login/sso/redirect
@@ -1280,9 +1288,13 @@ class UsernamePickerTestCase(HomeserverTestCase):
     def test_username_picker(self) -> None:
         """Test the happy path of a username picker flow."""
 
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # do the start of the login flow
-        channel = self.helper.auth_via_oidc(
-            {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server,
+            {"sub": "tester", "displayname": "Jonny"},
+            TEST_CLIENT_REDIRECT_URL,
         )
 
         # that should redirect to the username picker
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index c249a42bb6..967d229223 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,7 +31,6 @@ from typing import (
     Tuple,
     overload,
 )
-from unittest.mock import patch
 from urllib.parse import urlencode
 
 import attr
@@ -46,8 +45,19 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 
 from tests.server import FakeChannel, FakeSite, make_request
-from tests.test_utils import FakeResponse
 from tests.test_utils.html_parsers import TestHtmlParser
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_ISSUER = "https://issuer.test/"
+TEST_OIDC_CONFIG = {
+    "enabled": True,
+    "issuer": TEST_OIDC_ISSUER,
+    "client_id": "test-client-id",
+    "client_secret": "test-client-secret",
+    "scopes": ["openid"],
+    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
 
 
 @attr.s(auto_attribs=True)
@@ -543,12 +553,28 @@ class RestHelper:
 
         return channel.json_body
 
+    def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
+        """Create a ``FakeOidcServer``.
+
+        This can be used in conjuction with ``login_via_oidc``::
+
+            fake_oidc_server = self.helper.fake_oidc_server()
+            login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user")
+        """
+
+        return FakeOidcServer(
+            clock=self.hs.get_clock(),
+            issuer=issuer,
+        )
+
     def login_via_oidc(
         self,
+        fake_server: FakeOidcServer,
         remote_user_id: str,
+        with_sid: bool = False,
         expected_status: int = 200,
-    ) -> JsonDict:
-        """Log in via OIDC
+    ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
+        """Log in (as a new user) via OIDC
 
         Returns the result of the final token login.
 
@@ -560,7 +586,10 @@ class RestHelper:
         the normal places.
         """
         client_redirect_url = "https://x"
-        channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+        userinfo = {"sub": remote_user_id}
+        channel, grant = self.auth_via_oidc(
+            fake_server, userinfo, client_redirect_url, with_sid=with_sid
+        )
 
         # expect a confirmation page
         assert channel.code == HTTPStatus.OK, channel.result
@@ -585,14 +614,16 @@ class RestHelper:
         assert (
             channel.code == expected_status
         ), f"unexpected status in response: {channel.code}"
-        return channel.json_body
+        return channel.json_body, grant
 
     def auth_via_oidc(
         self,
+        fake_server: FakeOidcServer,
         user_info_dict: JsonDict,
         client_redirect_url: Optional[str] = None,
         ui_auth_session_id: Optional[str] = None,
-    ) -> FakeChannel:
+        with_sid: bool = False,
+    ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
         """Perform an OIDC authentication flow via a mock OIDC provider.
 
         This can be used for either login or user-interactive auth.
@@ -616,6 +647,7 @@ class RestHelper:
                 the login redirect endpoint
             ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
                 of the UI auth.
+            with_sid: if True, generates a random `sid` (OIDC session ID)
 
         Returns:
             A FakeChannel containing the result of calling the OIDC callback endpoint.
@@ -625,14 +657,15 @@ class RestHelper:
 
         cookies: Dict[str, str] = {}
 
-        # if we're doing a ui auth, hit the ui auth redirect endpoint
-        if ui_auth_session_id:
-            # can't set the client redirect url for UI Auth
-            assert client_redirect_url is None
-            oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
-        else:
-            # otherwise, hit the login redirect endpoint
-            oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+        with fake_server.patch_homeserver(hs=self.hs):
+            # if we're doing a ui auth, hit the ui auth redirect endpoint
+            if ui_auth_session_id:
+                # can't set the client redirect url for UI Auth
+                assert client_redirect_url is None
+                oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+            else:
+                # otherwise, hit the login redirect endpoint
+                oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
 
         # we now have a URI for the OIDC IdP, but we skip that and go straight
         # back to synapse's OIDC callback resource. However, we do need the "state"
@@ -640,17 +673,21 @@ class RestHelper:
         # that synapse passes to the client.
 
         oauth_uri_path, _ = oauth_uri.split("?", 1)
-        assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+        assert oauth_uri_path == fake_server.authorization_endpoint, (
             "unexpected SSO URI " + oauth_uri_path
         )
-        return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+        return self.complete_oidc_auth(
+            fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid
+        )
 
     def complete_oidc_auth(
         self,
+        fake_serer: FakeOidcServer,
         oauth_uri: str,
         cookies: Mapping[str, str],
         user_info_dict: JsonDict,
-    ) -> FakeChannel:
+        with_sid: bool = False,
+    ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
         """Mock out an OIDC authentication flow
 
         Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
@@ -661,50 +698,37 @@ class RestHelper:
         Requires the OIDC callback resource to be mounted at the normal place.
 
         Args:
+            fake_server: the fake OIDC server with which the auth should be done
             oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
                from initiate_sso_login or initiate_sso_ui_auth).
             cookies: the cookies set by synapse's redirect endpoint, which will be
                sent back to the callback endpoint.
             user_info_dict: the remote userinfo that the OIDC provider should present.
                 Typically this should be '{"sub": "<remote user id>"}'.
+            with_sid: if True, generates a random `sid` (OIDC session ID)
 
         Returns:
             A FakeChannel containing the result of calling the OIDC callback endpoint.
         """
         _, oauth_uri_qs = oauth_uri.split("?", 1)
         params = urllib.parse.parse_qs(oauth_uri_qs)
+
+        code, grant = fake_serer.start_authorization(
+            scope=params["scope"][0],
+            userinfo=user_info_dict,
+            client_id=params["client_id"][0],
+            redirect_uri=params["redirect_uri"][0],
+            nonce=params["nonce"][0],
+            with_sid=with_sid,
+        )
+        state = params["state"][0]
+
         callback_uri = "%s?%s" % (
             urllib.parse.urlparse(params["redirect_uri"][0]).path,
-            urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+            urllib.parse.urlencode({"state": state, "code": code}),
         )
 
-        # before we hit the callback uri, stub out some methods in the http client so
-        # that we don't have to handle full HTTPS requests.
-        # (expected url, json response) pairs, in the order we expect them.
-        expected_requests = [
-            # first we get a hit to the token endpoint, which we tell to return
-            # a dummy OIDC access token
-            (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
-            # and then one to the user_info endpoint, which returns our remote user id.
-            (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
-        ]
-
-        async def mock_req(
-            method: str,
-            uri: str,
-            data: Optional[dict] = None,
-            headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
-        ):
-            (expected_uri, resp_obj) = expected_requests.pop(0)
-            assert uri == expected_uri
-            resp = FakeResponse(
-                code=HTTPStatus.OK,
-                phrase=b"OK",
-                body=json.dumps(resp_obj).encode("utf-8"),
-            )
-            return resp
-
-        with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+        with fake_serer.patch_homeserver(hs=self.hs):
             # now hit the callback URI with the right params and a made-up code
             channel = make_request(
                 self.hs.get_reactor(),
@@ -715,7 +739,7 @@ class RestHelper:
                     ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
                 ],
             )
-        return channel
+        return channel, grant
 
     def initiate_sso_login(
         self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
@@ -806,21 +830,3 @@ class RestHelper:
         assert len(p.links) == 1, "not exactly one link in confirmation page"
         oauth_uri = p.links[0]
         return oauth_uri
-
-
-# an 'oidc_config' suitable for login_via_oidc.
-TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
-TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
-TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
-TEST_OIDC_CONFIG = {
-    "enabled": True,
-    "discover": False,
-    "issuer": "https://issuer.test",
-    "client_id": "test-client-id",
-    "client_secret": "test-client-secret",
-    "scopes": ["profile"],
-    "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
-    "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
-    "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
-    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
-}
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index ac0ac06b7e..7f1fba1086 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -26,7 +26,7 @@ from twisted.web.resource import NoResource, Resource
 
 from synapse.crypto.keyring import PerspectivesKeyFetcher
 from synapse.http.site import SynapseRequest
-from synapse.rest.key.v2 import KeyApiV2Resource
+from synapse.rest.key.v2 import KeyResource
 from synapse.server import HomeServer
 from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict
@@ -46,7 +46,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase):
 
     def create_test_resource(self) -> Resource:
         return create_resource_tree(
-            {"/_matrix/key/v2": KeyApiV2Resource(self.hs)}, root_resource=NoResource()
+            {"/_matrix/key/v2": KeyResource(self.hs)}, root_resource=NoResource()
         )
 
     def expect_outgoing_key_request(
diff --git a/tests/storage/test_room_search.py b/tests/storage/test_room_search.py
index e747c6b50e..9ddc19900a 100644
--- a/tests/storage/test_room_search.py
+++ b/tests/storage/test_room_search.py
@@ -12,11 +12,22 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+from typing import List, Tuple, Union
+from unittest.case import SkipTest
+from unittest.mock import PropertyMock, patch
+
+from twisted.test.proto_helpers import MemoryReactor
+
 import synapse.rest.admin
 from synapse.api.constants import EventTypes
 from synapse.api.errors import StoreError
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.storage.databases.main import DataStore
+from synapse.storage.databases.main.search import Phrase, SearchToken, _tokenize_query
 from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines.sqlite import Sqlite3Engine
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase, skip_unless
 from tests.utils import USE_POSTGRES_FOR_TESTS
@@ -187,3 +198,205 @@ class EventSearchInsertionTest(HomeserverTestCase):
             ),
         )
         self.assertCountEqual(values, ["hi", "2"])
+
+
+class MessageSearchTest(HomeserverTestCase):
+    """
+    Check message search.
+
+    A powerful way to check the behaviour is to run the following in Postgres >= 11:
+
+        # SELECT websearch_to_tsquery('english', <your string>);
+
+    The result can be compared to the tokenized version for SQLite and Postgres < 11.
+
+    """
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    PHRASE = "the quick brown fox jumps over the lazy dog"
+
+    # Each entry is a search query, followed by either a boolean of whether it is
+    # in the phrase OR a tuple of booleans: whether it matches using websearch
+    # and using plain search.
+    COMMON_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [
+        ("nope", False),
+        ("brown", True),
+        ("quick brown", True),
+        ("brown quick", True),
+        ("quick \t brown", True),
+        ("jump", True),
+        ("brown nope", False),
+        ('"brown quick"', (False, True)),
+        ('"jumps over"', True),
+        ('"quick fox"', (False, True)),
+        ("nope OR doublenope", False),
+        ("furphy OR fox", (True, False)),
+        ("fox -nope", (True, False)),
+        ("fox -brown", (False, True)),
+        ('"fox" quick', True),
+        ('"fox quick', True),
+        ('"quick brown', True),
+        ('" quick "', True),
+        ('" nope"', False),
+    ]
+    # TODO Test non-ASCII cases.
+
+    # Case that fail on SQLite.
+    POSTGRES_CASES: List[Tuple[str, Union[bool, Tuple[bool, bool]]]] = [
+        # SQLite treats NOT as a binary operator.
+        ("- fox", (False, True)),
+        ("- nope", (True, False)),
+        ('"-fox quick', (False, True)),
+        # PostgreSQL skips stop words.
+        ('"the quick brown"', True),
+        ('"over lazy"', True),
+    ]
+
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        # Register a user and create a room, create some messages
+        self.register_user("alice", "password")
+        self.access_token = self.login("alice", "password")
+        self.room_id = self.helper.create_room_as("alice", tok=self.access_token)
+
+        # Send the phrase as a message and check it was created
+        response = self.helper.send(self.room_id, self.PHRASE, tok=self.access_token)
+        self.assertIn("event_id", response)
+
+    def test_tokenize_query(self) -> None:
+        """Test the custom logic to tokenize a user's query."""
+        cases = (
+            ("brown", ["brown"]),
+            ("quick brown", ["quick", SearchToken.And, "brown"]),
+            ("quick \t brown", ["quick", SearchToken.And, "brown"]),
+            ('"brown quick"', [Phrase(["brown", "quick"])]),
+            ("furphy OR fox", ["furphy", SearchToken.Or, "fox"]),
+            ("fox -brown", ["fox", SearchToken.Not, "brown"]),
+            ("- fox", [SearchToken.Not, "fox"]),
+            ('"fox" quick', [Phrase(["fox"]), SearchToken.And, "quick"]),
+            # No trailing double quoe.
+            ('"fox quick', ["fox", SearchToken.And, "quick"]),
+            ('"-fox quick', [SearchToken.Not, "fox", SearchToken.And, "quick"]),
+            ('" quick "', [Phrase(["quick"])]),
+            (
+                'q"uick brow"n',
+                [
+                    "q",
+                    SearchToken.And,
+                    Phrase(["uick", "brow"]),
+                    SearchToken.And,
+                    "n",
+                ],
+            ),
+            (
+                '-"quick brown"',
+                [SearchToken.Not, Phrase(["quick", "brown"])],
+            ),
+        )
+
+        for query, expected in cases:
+            tokenized = _tokenize_query(query)
+            self.assertEqual(
+                tokenized, expected, f"{tokenized} != {expected} for {query}"
+            )
+
+    def _check_test_cases(
+        self,
+        store: DataStore,
+        cases: List[Tuple[str, Union[bool, Tuple[bool, bool]]]],
+        index=0,
+    ) -> None:
+        # Run all the test cases versus search_msgs
+        for query, expect_to_contain in cases:
+            if isinstance(expect_to_contain, tuple):
+                expect_to_contain = expect_to_contain[index]
+
+            result = self.get_success(
+                store.search_msgs([self.room_id], query, ["content.body"])
+            )
+            self.assertEquals(
+                result["count"],
+                1 if expect_to_contain else 0,
+                f"expected '{query}' to match '{self.PHRASE}'"
+                if expect_to_contain
+                else f"'{query}' unexpectedly matched '{self.PHRASE}'",
+            )
+            self.assertEquals(
+                len(result["results"]),
+                1 if expect_to_contain else 0,
+                "results array length should match count",
+            )
+
+        # Run them again versus search_rooms
+        for query, expect_to_contain in cases:
+            if isinstance(expect_to_contain, tuple):
+                expect_to_contain = expect_to_contain[index]
+
+            result = self.get_success(
+                store.search_rooms([self.room_id], query, ["content.body"], 10)
+            )
+            self.assertEquals(
+                result["count"],
+                1 if expect_to_contain else 0,
+                f"expected '{query}' to match '{self.PHRASE}'"
+                if expect_to_contain
+                else f"'{query}' unexpectedly matched '{self.PHRASE}'",
+            )
+            self.assertEquals(
+                len(result["results"]),
+                1 if expect_to_contain else 0,
+                "results array length should match count",
+            )
+
+    def test_postgres_web_search_for_phrase(self):
+        """
+        Test searching for phrases using typical web search syntax, as per postgres' websearch_to_tsquery.
+        This test is skipped unless the postgres instance supports websearch_to_tsquery.
+        """
+
+        store = self.hs.get_datastores().main
+        if not isinstance(store.database_engine, PostgresEngine):
+            raise SkipTest("Test only applies when postgres is used as the database")
+
+        if store.database_engine.tsquery_func != "websearch_to_tsquery":
+            raise SkipTest(
+                "Test only applies when postgres supporting websearch_to_tsquery is used as the database"
+            )
+
+        self._check_test_cases(store, self.COMMON_CASES + self.POSTGRES_CASES, index=0)
+
+    def test_postgres_non_web_search_for_phrase(self):
+        """
+        Test postgres searching for phrases without using web search, which is used when websearch_to_tsquery isn't
+        supported by the current postgres version.
+        """
+
+        store = self.hs.get_datastores().main
+        if not isinstance(store.database_engine, PostgresEngine):
+            raise SkipTest("Test only applies when postgres is used as the database")
+
+        # Patch supports_websearch_to_tsquery to always return False to ensure we're testing the plainto_tsquery path.
+        with patch(
+            "synapse.storage.engines.postgres.PostgresEngine.tsquery_func",
+            new_callable=PropertyMock,
+        ) as supports_websearch_to_tsquery:
+            supports_websearch_to_tsquery.return_value = "plainto_tsquery"
+            self._check_test_cases(
+                store, self.COMMON_CASES + self.POSTGRES_CASES, index=1
+            )
+
+    def test_sqlite_search(self):
+        """
+        Test sqlite searching for phrases.
+        """
+        store = self.hs.get_datastores().main
+        if not isinstance(store.database_engine, Sqlite3Engine):
+            raise SkipTest("Test only applies when sqlite is used as the database")
+
+        self._check_test_cases(store, self.COMMON_CASES, index=0)
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 0d0d6faf0d..e62ebcc6a5 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -15,17 +15,24 @@
 """
 Utilities for running the unit tests
 """
+import json
 import sys
 import warnings
 from asyncio import Future
 from binascii import unhexlify
-from typing import Awaitable, Callable, TypeVar
+from typing import Awaitable, Callable, Tuple, TypeVar
 from unittest.mock import Mock
 
 import attr
+import zope.interface
 
 from twisted.python.failure import Failure
 from twisted.web.client import ResponseDone
+from twisted.web.http import RESPONSES
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
+
+from synapse.types import JsonDict
 
 TV = TypeVar("TV")
 
@@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock:
     return Mock(side_effect=cb)
 
 
-@attr.s
-class FakeResponse:
+# Type ignore: it does not fully implement IResponse, but is good enough for tests
+@zope.interface.implementer(IResponse)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FakeResponse:  # type: ignore[misc]
     """A fake twisted.web.IResponse object
 
     there is a similar class at treq.test.test_response, but it lacks a `phrase`
     attribute, and didn't support deliverBody until recently.
     """
 
-    # HTTP response code
-    code = attr.ib(type=int)
+    version: Tuple[bytes, int, int] = (b"HTTP", 1, 1)
 
-    # HTTP response phrase (eg b'OK' for a 200)
-    phrase = attr.ib(type=bytes)
+    # HTTP response code
+    code: int = 200
 
     # body of the response
-    body = attr.ib(type=bytes)
+    body: bytes = b""
+
+    headers: Headers = attr.Factory(Headers)
+
+    @property
+    def phrase(self):
+        return RESPONSES.get(self.code, b"Unknown Status")
+
+    @property
+    def length(self):
+        return len(self.body)
 
     def deliverBody(self, protocol):
         protocol.dataReceived(self.body)
         protocol.connectionLost(Failure(ResponseDone()))
 
+    @classmethod
+    def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse":
+        headers = Headers({"Content-Type": ["application/json"]})
+        body = json.dumps(payload).encode("utf-8")
+        return cls(code=code, body=body, headers=headers)
+
 
 # A small image used in some tests.
 #
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
new file mode 100644
index 0000000000..de134bbc89
--- /dev/null
+++ b/tests/test_utils/oidc.py
@@ -0,0 +1,325 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+from typing import Any, Dict, List, Optional, Tuple
+from unittest.mock import Mock, patch
+from urllib.parse import parse_qs
+
+import attr
+
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
+
+from tests.test_utils import FakeResponse
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FakeAuthorizationGrant:
+    userinfo: dict
+    client_id: str
+    redirect_uri: str
+    scope: str
+    nonce: Optional[str]
+    sid: Optional[str]
+
+
+class FakeOidcServer:
+    """A fake OpenID Connect Provider."""
+
+    # All methods here are mocks, so we can track when they are called, and override
+    # their values
+    request: Mock
+    get_jwks_handler: Mock
+    get_metadata_handler: Mock
+    get_userinfo_handler: Mock
+    post_token_handler: Mock
+
+    def __init__(self, clock: Clock, issuer: str):
+        from authlib.jose import ECKey, KeySet
+
+        self._clock = clock
+        self.issuer = issuer
+
+        self.request = Mock(side_effect=self._request)
+        self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler)
+        self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler)
+        self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler)
+        self.post_token_handler = Mock(side_effect=self._post_token_handler)
+
+        # A code -> grant mapping
+        self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {}
+        # An access token -> grant mapping
+        self._sessions: Dict[str, FakeAuthorizationGrant] = {}
+
+        # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for
+        # signing JWTs. ECDSA keys are really quick to generate compared to RSA.
+        self._key = ECKey.generate_key(crv="P-256", is_private=True)
+        self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))])
+
+        self._id_token_overrides: Dict[str, Any] = {}
+
+    def reset_mocks(self):
+        self.request.reset_mock()
+        self.get_jwks_handler.reset_mock()
+        self.get_metadata_handler.reset_mock()
+        self.get_userinfo_handler.reset_mock()
+        self.post_token_handler.reset_mock()
+
+    def patch_homeserver(self, hs: HomeServer):
+        """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
+
+        This patch should be used whenever the HS is expected to perform request to the
+        OIDC provider, e.g.::
+
+            fake_oidc_server = self.helper.fake_oidc_server()
+            with fake_oidc_server.patch_homeserver(hs):
+                self.make_request("GET", "/_matrix/client/r0/login/sso/redirect")
+        """
+        return patch.object(hs.get_proxied_http_client(), "request", self.request)
+
+    @property
+    def authorization_endpoint(self) -> str:
+        return self.issuer + "authorize"
+
+    @property
+    def token_endpoint(self) -> str:
+        return self.issuer + "token"
+
+    @property
+    def userinfo_endpoint(self) -> str:
+        return self.issuer + "userinfo"
+
+    @property
+    def metadata_endpoint(self) -> str:
+        return self.issuer + ".well-known/openid-configuration"
+
+    @property
+    def jwks_uri(self) -> str:
+        return self.issuer + "jwks"
+
+    def get_metadata(self) -> dict:
+        return {
+            "issuer": self.issuer,
+            "authorization_endpoint": self.authorization_endpoint,
+            "token_endpoint": self.token_endpoint,
+            "jwks_uri": self.jwks_uri,
+            "userinfo_endpoint": self.userinfo_endpoint,
+            "response_types_supported": ["code"],
+            "subject_types_supported": ["public"],
+            "id_token_signing_alg_values_supported": ["ES256"],
+        }
+
+    def get_jwks(self) -> dict:
+        return self._jwks.as_dict()
+
+    def get_userinfo(self, access_token: str) -> Optional[dict]:
+        """Given an access token, get the userinfo of the associated session."""
+        session = self._sessions.get(access_token, None)
+        if session is None:
+            return None
+        return session.userinfo
+
+    def _sign(self, payload: dict) -> str:
+        from authlib.jose import JsonWebSignature
+
+        jws = JsonWebSignature()
+        kid = self.get_jwks()["keys"][0]["kid"]
+        protected = {"alg": "ES256", "kid": kid}
+        json_payload = json.dumps(payload)
+        return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
+
+    def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
+        now = self._clock.time()
+        id_token = {
+            **grant.userinfo,
+            "iss": self.issuer,
+            "aud": grant.client_id,
+            "iat": now,
+            "nbf": now,
+            "exp": now + 600,
+        }
+
+        if grant.nonce is not None:
+            id_token["nonce"] = grant.nonce
+
+        if grant.sid is not None:
+            id_token["sid"] = grant.sid
+
+        id_token.update(self._id_token_overrides)
+
+        return self._sign(id_token)
+
+    def id_token_override(self, overrides: dict):
+        """Temporarily patch the ID token generated by the token endpoint."""
+        return patch.object(self, "_id_token_overrides", overrides)
+
+    def start_authorization(
+        self,
+        client_id: str,
+        scope: str,
+        redirect_uri: str,
+        userinfo: dict,
+        nonce: Optional[str] = None,
+        with_sid: bool = False,
+    ) -> Tuple[str, FakeAuthorizationGrant]:
+        """Start an authorization request, and get back the code to use on the authorization endpoint."""
+        code = random_string(10)
+        sid = None
+        if with_sid:
+            sid = random_string(10)
+
+        grant = FakeAuthorizationGrant(
+            userinfo=userinfo,
+            scope=scope,
+            redirect_uri=redirect_uri,
+            nonce=nonce,
+            client_id=client_id,
+            sid=sid,
+        )
+        self._authorization_grants[code] = grant
+
+        return code, grant
+
+    def exchange_code(self, code: str) -> Optional[Dict[str, Any]]:
+        grant = self._authorization_grants.pop(code, None)
+        if grant is None:
+            return None
+
+        access_token = random_string(10)
+        self._sessions[access_token] = grant
+
+        token = {
+            "token_type": "Bearer",
+            "access_token": access_token,
+            "expires_in": 3600,
+            "scope": grant.scope,
+        }
+
+        if "openid" in grant.scope:
+            token["id_token"] = self.generate_id_token(grant)
+
+        return dict(token)
+
+    def buggy_endpoint(
+        self,
+        *,
+        jwks: bool = False,
+        metadata: bool = False,
+        token: bool = False,
+        userinfo: bool = False,
+    ):
+        """A context which makes a set of endpoints return a 500 error.
+
+        Args:
+            jwks: If True, makes the JWKS endpoint return a 500 error.
+            metadata: If True, makes the OIDC Discovery endpoint return a 500 error.
+            token: If True, makes the token endpoint return a 500 error.
+            userinfo: If True, makes the userinfo endpoint return a 500 error.
+        """
+        buggy = FakeResponse(code=500, body=b"Internal server error")
+
+        patches = {}
+        if jwks:
+            patches["get_jwks_handler"] = Mock(return_value=buggy)
+        if metadata:
+            patches["get_metadata_handler"] = Mock(return_value=buggy)
+        if token:
+            patches["post_token_handler"] = Mock(return_value=buggy)
+        if userinfo:
+            patches["get_userinfo_handler"] = Mock(return_value=buggy)
+
+        return patch.multiple(self, **patches)
+
+    async def _request(
+        self,
+        method: str,
+        uri: str,
+        data: Optional[bytes] = None,
+        headers: Optional[Headers] = None,
+    ) -> IResponse:
+        """The override of the SimpleHttpClient#request() method"""
+        access_token: Optional[str] = None
+
+        if headers is None:
+            headers = Headers()
+
+        # Try to find the access token in the headers if any
+        auth_headers = headers.getRawHeaders(b"Authorization")
+        if auth_headers:
+            parts = auth_headers[0].split(b" ")
+            if parts[0] == b"Bearer" and len(parts) == 2:
+                access_token = parts[1].decode("ascii")
+
+        if method == "POST":
+            # If the method is POST, assume it has an url-encoded body
+            if data is None or headers.getRawHeaders(b"Content-Type") != [
+                b"application/x-www-form-urlencoded"
+            ]:
+                return FakeResponse.json(code=400, payload={"error": "invalid_request"})
+
+            params = parse_qs(data.decode("utf-8"))
+
+            if uri == self.token_endpoint:
+                # Even though this endpoint should be protected, this does not check
+                # for client authentication. We're not checking it for simplicity,
+                # and because client authentication is tested in other standalone tests.
+                return self.post_token_handler(params)
+
+        elif method == "GET":
+            if uri == self.jwks_uri:
+                return self.get_jwks_handler()
+            elif uri == self.metadata_endpoint:
+                return self.get_metadata_handler()
+            elif uri == self.userinfo_endpoint:
+                return self.get_userinfo_handler(access_token=access_token)
+
+        return FakeResponse(code=404, body=b"404 not found")
+
+    # Request handlers
+    def _get_jwks_handler(self) -> IResponse:
+        """Handles requests to the JWKS URI."""
+        return FakeResponse.json(payload=self.get_jwks())
+
+    def _get_metadata_handler(self) -> IResponse:
+        """Handles requests to the OIDC well-known document."""
+        return FakeResponse.json(payload=self.get_metadata())
+
+    def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse:
+        """Handles requests to the userinfo endpoint."""
+        if access_token is None:
+            return FakeResponse(code=401)
+        user_info = self.get_userinfo(access_token)
+        if user_info is None:
+            return FakeResponse(code=401)
+
+        return FakeResponse.json(payload=user_info)
+
+    def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse:
+        """Handles requests to the token endpoint."""
+        code = params.get("code", [])
+
+        if len(code) != 1:
+            return FakeResponse.json(code=400, payload={"error": "invalid_request"})
+
+        grant = self.exchange_code(code=code[0])
+        if grant is None:
+            return FakeResponse.json(code=400, payload={"error": "invalid_grant"})
+
+        return FakeResponse.json(payload=grant)
diff --git a/tests/util/caches/test_descriptors.py b/tests/util/caches/test_descriptors.py
index 90861fe522..43475a307f 100644
--- a/tests/util/caches/test_descriptors.py
+++ b/tests/util/caches/test_descriptors.py
@@ -28,7 +28,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
 )
 from synapse.util.caches import descriptors
-from synapse.util.caches.descriptors import cached, cachedList, lru_cache
+from synapse.util.caches.descriptors import cached, cachedList
 
 from tests import unittest
 from tests.test_utils import get_awaitable_result
@@ -36,38 +36,6 @@ from tests.test_utils import get_awaitable_result
 logger = logging.getLogger(__name__)
 
 
-class LruCacheDecoratorTestCase(unittest.TestCase):
-    def test_base(self):
-        class Cls:
-            def __init__(self):
-                self.mock = mock.Mock()
-
-            @lru_cache()
-            def fn(self, arg1, arg2):
-                return self.mock(arg1, arg2)
-
-        obj = Cls()
-        obj.mock.return_value = "fish"
-        r = obj.fn(1, 2)
-        self.assertEqual(r, "fish")
-        obj.mock.assert_called_once_with(1, 2)
-        obj.mock.reset_mock()
-
-        # a call with different params should call the mock again
-        obj.mock.return_value = "chips"
-        r = obj.fn(1, 3)
-        self.assertEqual(r, "chips")
-        obj.mock.assert_called_once_with(1, 3)
-        obj.mock.reset_mock()
-
-        # the two values should now be cached
-        r = obj.fn(1, 2)
-        self.assertEqual(r, "fish")
-        r = obj.fn(1, 3)
-        self.assertEqual(r, "chips")
-        obj.mock.assert_not_called()
-
-
 def run_on_reactor():
     d = defer.Deferred()
     reactor.callLater(0, d.callback, 0)
@@ -478,10 +446,10 @@ class DescriptorTestCase(unittest.TestCase):
 
             @cached(cache_context=True)
             async def func2(self, key, cache_context):
-                return self.func3(key, on_invalidate=cache_context.invalidate)
+                return await self.func3(key, on_invalidate=cache_context.invalidate)
 
-            @lru_cache(cache_context=True)
-            def func3(self, key, cache_context):
+            @cached(cache_context=True)
+            async def func3(self, key, cache_context):
                 self.invalidate = cache_context.invalidate
                 return 42
 
@@ -1037,5 +1005,5 @@ class CachedListDescriptorTestCase(unittest.TestCase):
         obj = Cls()
 
         # Make sure this raises an error about the arg mismatch
-        with self.assertRaises(Exception):
+        with self.assertRaises(TypeError):
             obj.list_fn([("foo", "bar")])
diff --git a/tests/util/test_macaroons.py b/tests/util/test_macaroons.py
index 32125f7bb7..40754a4711 100644
--- a/tests/util/test_macaroons.py
+++ b/tests/util/test_macaroons.py
@@ -84,34 +84,6 @@ class MacaroonGeneratorTestCase(TestCase):
         )
         self.assertEqual(user_id, "@user:tesths")
 
-    def test_short_term_login_token(self):
-        """Test the generation and verification of short-term login tokens"""
-        token = self.macaroon_generator.generate_short_term_login_token(
-            user_id="@user:tesths",
-            auth_provider_id="oidc",
-            auth_provider_session_id="sid",
-            duration_in_ms=2 * 60 * 1000,
-        )
-
-        info = self.macaroon_generator.verify_short_term_login_token(token)
-        self.assertEqual(info.user_id, "@user:tesths")
-        self.assertEqual(info.auth_provider_id, "oidc")
-        self.assertEqual(info.auth_provider_session_id, "sid")
-
-        # Raises with another secret key
-        with self.assertRaises(MacaroonVerificationFailedException):
-            self.other_macaroon_generator.verify_short_term_login_token(token)
-
-        # Wait a minute
-        self.reactor.pump([60])
-        # Shouldn't raise
-        self.macaroon_generator.verify_short_term_login_token(token)
-        # Wait another minute
-        self.reactor.pump([60])
-        # Should raise since it expired
-        with self.assertRaises(MacaroonVerificationFailedException):
-            self.macaroon_generator.verify_short_term_login_token(token)
-
     def test_oidc_session_token(self):
         """Test the generation and verification of OIDC session cookies"""
         state = "arandomstate"