summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/13910.misc1
-rw-r--r--pyproject.toml2
-rw-r--r--synapse/handlers/oidc.py15
-rw-r--r--tests/federation/test_federation_client.py36
-rw-r--r--tests/handlers/test_oidc.py580
-rw-r--r--tests/rest/client/test_auth.py32
-rw-r--r--tests/rest/client/test_login.py40
-rw-r--r--tests/rest/client/utils.py136
-rw-r--r--tests/test_utils/__init__.py40
-rw-r--r--tests/test_utils/oidc.py325
10 files changed, 747 insertions, 460 deletions
diff --git a/changelog.d/13910.misc b/changelog.d/13910.misc
new file mode 100644
index 0000000000..e906952aab
--- /dev/null
+++ b/changelog.d/13910.misc
@@ -0,0 +1 @@
+Refactor OIDC tests to better mimic an actual OIDC provider.
diff --git a/pyproject.toml b/pyproject.toml
index 6ebac41ed1..7e0feb75aa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -192,7 +192,7 @@ psycopg2 = { version = ">=2.8", markers = "platform_python_implementation != 'Py
 psycopg2cffi = { version = ">=2.8", markers = "platform_python_implementation == 'PyPy'", optional = true }
 psycopg2cffi-compat = { version = "==1.1", markers = "platform_python_implementation == 'PyPy'", optional = true }
 pysaml2 = { version = ">=4.5.0", optional = true }
-authlib = { version = ">=0.14.0", optional = true }
+authlib = { version = ">=0.15.1", optional = true }
 # systemd-python is necessary for logging to the systemd journal via
 # `systemd.journal.JournalHandler`, as is documented in
 # `contrib/systemd/log_config.yaml`.
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index d7a8226900..9759daf043 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -275,6 +275,7 @@ class OidcProvider:
         provider: OidcProviderConfig,
     ):
         self._store = hs.get_datastores().main
+        self._clock = hs.get_clock()
 
         self._macaroon_generaton = macaroon_generator
 
@@ -673,6 +674,13 @@ class OidcProvider:
         Returns:
             The decoded claims in the ID token.
         """
+        id_token = token.get("id_token")
+        logger.debug("Attempting to decode JWT id_token %r", id_token)
+
+        # That has been theoritically been checked by the caller, so even though
+        # assertion are not enabled in production, it is mainly here to appease mypy
+        assert id_token is not None
+
         metadata = await self.load_metadata()
         claims_params = {
             "nonce": nonce,
@@ -688,9 +696,6 @@ class OidcProvider:
 
         claim_options = {"iss": {"values": [metadata["issuer"]]}}
 
-        id_token = token["id_token"]
-        logger.debug("Attempting to decode JWT id_token %r", id_token)
-
         # Try to decode the keys in cache first, then retry by forcing the keys
         # to be reloaded
         jwk_set = await self.load_jwks()
@@ -715,7 +720,9 @@ class OidcProvider:
 
         logger.debug("Decoded id_token JWT %r; validating", claims)
 
-        claims.validate(leeway=120)  # allows 2 min of clock skew
+        claims.validate(
+            now=self._clock.time(), leeway=120
+        )  # allows 2 min of clock skew
 
         return claims
 
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index a538215931..51d3bb8fff 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -12,13 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 from unittest import mock
 
 import twisted.web.client
 from twisted.internet import defer
-from twisted.internet.protocol import Protocol
-from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.room_versions import RoomVersions
@@ -26,10 +23,9 @@ from synapse.events import EventBase
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
-from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.test_utils import event_injection
+from tests.test_utils import FakeResponse, event_injection
 from tests.unittest import FederatingHomeserverTestCase
 
 
@@ -98,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         # mock up the response, and have the agent return it
         self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
-            _mock_response(
-                {
+            FakeResponse.json(
+                payload={
                     "pdus": [
                         create_event_dict,
                         member_event_dict,
@@ -208,8 +204,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         # mock up the response, and have the agent return it
         self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
-            _mock_response(
-                {
+            FakeResponse.json(
+                payload={
                     "origin": "yet.another.server",
                     "origin_server_ts": 900,
                     "pdus": [
@@ -269,8 +265,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         # We expect an outbound request to /backfill, so stub that out
         self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
-            _mock_response(
-                {
+            FakeResponse.json(
+                payload={
                     "origin": "yet.another.server",
                     "origin_server_ts": 900,
                     # Mimic the other server returning our new `pulled_event`
@@ -305,21 +301,3 @@ class FederationClientTest(FederatingHomeserverTestCase):
         # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the
         # other from "yet.another.server"
         self.assertEqual(backfill_num_attempts, 2)
-
-
-def _mock_response(resp: JsonDict):
-    body = json.dumps(resp).encode("utf-8")
-
-    def deliver_body(p: Protocol):
-        p.dataReceived(body)
-        p.connectionLost(Failure(twisted.web.client.ResponseDone()))
-
-    response = mock.Mock(
-        code=200,
-        phrase=b"OK",
-        headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
-        length=len(body),
-        deliverBody=deliver_body,
-    )
-    mock.seal(response)
-    return response
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e6cd3af7b7..5955410524 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -11,9 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
 import os
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
 from unittest.mock import ANY, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
@@ -22,12 +21,15 @@ import pymacaroons
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.handlers.sso import MappingException
+from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import UserID
 from synapse.util import Clock
-from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon
+from synapse.util.macaroons import get_value_from_macaroon
+from synapse.util.stringutils import random_string
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
 from tests.unittest import HomeserverTestCase, override_config
 
 try:
@@ -46,12 +48,6 @@ BASE_URL = "https://synapse/"
 CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
 SCOPES = ["openid"]
 
-AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
-TOKEN_ENDPOINT = ISSUER + "token"
-USERINFO_ENDPOINT = ISSUER + "userinfo"
-WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
-JWKS_URI = ISSUER + ".well-known/jwks.json"
-
 # config for common cases
 DEFAULT_CONFIG = {
     "enabled": True,
@@ -66,9 +62,9 @@ DEFAULT_CONFIG = {
 EXPLICIT_ENDPOINT_CONFIG = {
     **DEFAULT_CONFIG,
     "discover": False,
-    "authorization_endpoint": AUTHORIZATION_ENDPOINT,
-    "token_endpoint": TOKEN_ENDPOINT,
-    "jwks_uri": JWKS_URI,
+    "authorization_endpoint": ISSUER + "authorize",
+    "token_endpoint": ISSUER + "token",
+    "jwks_uri": ISSUER + "jwks",
 }
 
 
@@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-async def get_json(url: str) -> JsonDict:
-    # Mock get_json calls to handle jwks & oidc discovery endpoints
-    if url == WELL_KNOWN:
-        # Minimal discovery document, as defined in OpenID.Discovery
-        # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
-        return {
-            "issuer": ISSUER,
-            "authorization_endpoint": AUTHORIZATION_ENDPOINT,
-            "token_endpoint": TOKEN_ENDPOINT,
-            "jwks_uri": JWKS_URI,
-            "userinfo_endpoint": USERINFO_ENDPOINT,
-            "response_types_supported": ["code"],
-            "subject_types_supported": ["public"],
-            "id_token_signing_alg_values_supported": ["RS256"],
-        }
-    elif url == JWKS_URI:
-        return {"keys": []}
-
-    return {}
-
-
 def _key_file_path() -> str:
     """path to a file containing the private half of a test key"""
 
@@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return config
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.http_client = Mock(spec=["get_json"])
-        self.http_client.get_json.side_effect = get_json
-        self.http_client.user_agent = b"Synapse Test"
+        self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER)
 
-        hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+        hs = self.setup_test_homeserver()
+        self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
+        self.hs_patcher.start()
 
         self.handler = hs.get_oidc_handler()
         self.provider = self.handler._providers["oidc"]
@@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # Reduce the number of attempts when generating MXIDs.
         sso_handler._MAP_USERNAME_RETRIES = 3
 
+        auth_handler = hs.get_auth_handler()
+        # Mock the complete SSO login method.
+        self.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = self.complete_sso_login  # type: ignore[assignment]
+
         return hs
 
+    def tearDown(self) -> None:
+        self.hs_patcher.stop()
+        return super().tearDown()
+
+    def reset_mocks(self):
+        """Reset all the Mocks."""
+        self.fake_server.reset_mocks()
+        self.render_error.reset_mock()
+        self.complete_sso_login.reset_mock()
+
     def metadata_edit(self, values):
         """Modify the result that will be returned by the well-known query"""
 
-        async def patched_get_json(uri):
-            res = await get_json(uri)
-            if uri == WELL_KNOWN:
-                res.update(values)
-            return res
+        metadata = self.fake_server.get_metadata()
+        metadata.update(values)
+        return patch.object(self.fake_server, "get_metadata", return_value=metadata)
 
-        return patch.object(self.http_client, "get_json", patched_get_json)
+    def start_authorization(
+        self,
+        userinfo: dict,
+        client_redirect_url: str = "http://client/redirect",
+        scope: str = "openid",
+        with_sid: bool = False,
+    ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]:
+        """Start an authorization request, and get the callback request back."""
+        nonce = random_string(10)
+        state = random_string(10)
+
+        code, grant = self.fake_server.start_authorization(
+            userinfo=userinfo,
+            scope=scope,
+            client_id=self.provider._client_auth.client_id,
+            redirect_uri=self.provider._callback_url,
+            nonce=nonce,
+            with_sid=with_sid,
+        )
+        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+        return _build_callback_request(code, state, session), grant
 
     def assertRenderedError(self, error, error_description=None):
         self.render_error.assert_called_once()
@@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
         metadata = self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.fake_server.get_metadata_handler.assert_called_once()
 
-        self.assertEqual(metadata.issuer, ISSUER)
-        self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
-        self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
-        self.assertEqual(metadata.jwks_uri, JWKS_URI)
-        # FIXME: it seems like authlib does not have that defined in its metadata models
-        # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+        self.assertEqual(metadata.issuer, self.fake_server.issuer)
+        self.assertEqual(
+            metadata.authorization_endpoint,
+            self.fake_server.authorization_endpoint,
+        )
+        self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint)
+        self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri)
+        # It seems like authlib does not have that defined in its metadata models
+        self.assertEqual(
+            metadata.get("userinfo_endpoint"),
+            self.fake_server.userinfo_endpoint,
+        )
 
         # subsequent calls should be cached
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_metadata_handler.assert_not_called()
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     def test_no_discovery(self) -> None:
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_metadata_handler.assert_not_called()
 
-    @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_load_jwks(self) -> None:
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
-        self.http_client.get_json.assert_called_once_with(JWKS_URI)
-        self.assertEqual(jwks, {"keys": []})
+        self.fake_server.get_jwks_handler.assert_called_once()
+        self.assertEqual(jwks, self.fake_server.get_jwks())
 
         # subsequent calls should be cached…
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_jwks())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_jwks_handler.assert_not_called()
 
         # …unless forced
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_jwks(force=True))
-        self.http_client.get_json.assert_called_once_with(JWKS_URI)
+        self.fake_server.get_jwks_handler.assert_called_once()
 
-        # Throw if the JWKS uri is missing
-        original = self.provider.load_metadata
-
-        async def patched_load_metadata():
-            m = (await original()).copy()
-            m.update({"jwks_uri": None})
-            return m
-
-        with patch.object(self.provider, "load_metadata", patched_load_metadata):
+        with self.metadata_edit({"jwks_uri": None}):
+            # If we don't do this, the load_metadata call will throw because of the
+            # missing jwks_uri
+            self.provider._user_profile_method = "userinfo_endpoint"
+            self.get_success(self.provider.load_metadata(force=True))
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
@@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 self.provider.handle_redirect_request(req, b"http://client/redirect")
             )
         )
-        auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+        auth_endpoint = urlparse(self.fake_server.authorization_endpoint)
 
         self.assertEqual(url.scheme, auth_endpoint.scheme)
         self.assertEqual(url.netloc, auth_endpoint.netloc)
@@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
         with self.assertRaises(AttributeError):
             _ = mapping_provider.get_extra_attributes
 
-        token = {
-            "type": "bearer",
-            "id_token": "id_token",
-            "access_token": "access_token",
-        }
         username = "bar"
         userinfo = {
             "sub": "foo",
             "username": username,
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
 
-        code = "code"
-        state = "state"
-        nonce = "nonce"
         client_redirect_url = "http://client/redirect"
-        ip_address = "10.0.0.1"
-        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
-        request = _build_callback_request(code, state, session, ip_address=ip_address)
-
+        request, _ = self.start_authorization(
+            userinfo, client_redirect_url=client_redirect_url
+        )
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
             client_redirect_url,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.provider._fetch_userinfo.assert_not_called()
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
+        request, _ = self.start_authorization(userinfo)
         with patch.object(
             self.provider,
             "_remote_id_from_userinfo",
@@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.provider._parse_id_token = simple_async_mock(raises=Exception())  # type: ignore[assignment]
-        self.get_success(self.handler.handle_oidc_callback(request))
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}):
+            self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
-        auth_handler.complete_sso_login.reset_mock()
-        self.provider._exchange_code.reset_mock()
-        self.provider._parse_id_token.reset_mock()
-        self.provider._fetch_userinfo.reset_mock()
+        self.reset_mocks()
 
         # With userinfo fetching
         self.provider._user_profile_method = "userinfo_endpoint"
-        token = {
-            "type": "bearer",
-            "access_token": "access_token",
-        }
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        # Without the "openid" scope, the FakeProvider does not generate an id_token
+        request, _ = self.start_authorization(userinfo, scope="")
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_not_called()
-        self.provider._fetch_userinfo.assert_called_once_with(token)
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_called_once()
         self.render_error.assert_not_called()
 
+        self.reset_mocks()
+
         # With an ID token, userinfo fetching and sid in the ID token
         self.provider._user_profile_method = "userinfo_endpoint"
-        token = {
-            "type": "bearer",
-            "access_token": "access_token",
-            "id_token": "id_token",
-        }
-        id_token = {
-            "sid": "abcdefgh",
-        }
-        self.provider._parse_id_token = simple_async_mock(return_value=id_token)  # type: ignore[assignment]
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        auth_handler.complete_sso_login.reset_mock()
-        self.provider._fetch_userinfo.reset_mock()
+        request, grant = self.start_authorization(userinfo, with_sid=True)
+        self.assertIsNotNone(grant.sid)
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             None,
             new_user=False,
-            auth_provider_session_id=id_token["sid"],
+            auth_provider_session_id=grant.sid,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.provider._fetch_userinfo.assert_called_once_with(token)
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_called_once()
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())  # type: ignore[assignment]
-        self.get_success(self.handler.handle_oidc_callback(request))
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.buggy_endpoint(userinfo=True):
+            self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
-        # Handle code exchange failure
-        from synapse.handlers.oidc import OidcError
-
-        self.provider._exchange_code = simple_async_mock(  # type: ignore[assignment]
-            raises=OidcError("invalid_request")
-        )
-        self.get_success(self.handler.handle_oidc_callback(request))
-        self.assertRenderedError("invalid_request")
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.buggy_endpoint(token=True):
+            self.get_success(self.handler.handle_oidc_callback(request))
+        self.assertRenderedError("server_error")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_session(self) -> None:
@@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_exchange_code(self) -> None:
         """Code exchange behaves correctly and handles various error scenarios."""
-        token = {"type": "bearer"}
-        token_json = json.dumps(token).encode("utf-8")
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
         ret = self.get_success(self.provider._exchange_code(code))
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
 
         self.assertEqual(ret, token)
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         args = parse_qs(kwargs["data"].decode("utf-8"))
         self.assertEqual(args["grant_type"], ["authorization_code"])
@@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
 
         # Test error handling
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=400,
-                phrase=b"Bad Request",
-                body=b'{"error": "foo", "error_description": "bar"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=400, payload={"error": "foo", "error_description": "bar"}
         )
         from synapse.handlers.oidc import OidcError
 
@@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(exc.value.error_description, "bar")
 
         # Internal server error with no JSON body
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=500,
-                phrase=b"Internal Server Error",
-                body=b"Not JSON",
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse(
+            code=500, body=b"Not JSON"
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=500,
-                phrase=b"Internal Server Error",
-                body=b'{"error": "internal_server_error"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=500, payload={"error": "internal_server_error"}
         )
 
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=400,
-                phrase=b"Bad request",
-                body=b"{}",
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=400, payload={}
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200,
-                phrase=b"OK",
-                body=b'{"error": "some_error"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=200, payload={"error": "some_error"}
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "some_error")
@@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Test that code exchange works with a JWK client secret."""
         from authlib.jose import jwt
 
-        token = {"type": "bearer"}
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
-            )
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
 
@@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(ret, token)
 
         # the request should have hit the token endpoint
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         # the client secret provided to the should be a jwt which can be checked with
         # the public key
@@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_exchange_code_no_auth(self) -> None:
         """Test that code exchange works with no client secret."""
-        token = {"type": "bearer"}
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
-            )
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
         ret = self.get_success(self.provider._exchange_code(code))
@@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(ret, token)
 
         # the request should have hit the token endpoint
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         # check the POSTed data
         args = parse_qs(kwargs["data"].decode("utf-8"))
@@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """
         Login while using a mapping provider that implements get_extra_attributes.
         """
-        token = {
-            "type": "bearer",
-            "id_token": "id_token",
-            "access_token": "access_token",
-        }
         userinfo = {
             "sub": "foo",
             "username": "foo",
             "phone": "1234567",
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
-        state = "state"
-        client_redirect_url = "http://client/redirect"
-        session = self._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-        )
-        request = _build_callback_request("code", state, session)
-
+        request, _ = self.start_authorization(userinfo)
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@foo:test",
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             {"phone": "1234567"},
             new_user=True,
             auth_provider_session_id=None,
@@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_user(self) -> None:
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         userinfo: dict = {
             "sub": "test_user",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@test_user:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Some providers return an integer ID.
         userinfo = {
             "sub": 1234,
             "username": "test_user_2",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@test_user_2:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Test if the mxid is already taken
         store = self.hs.get_datastores().main
@@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error",
             "Mapping provider does not support de-duplicating Matrix IDs",
@@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user.to_string(), password_hash=None)
         )
 
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         # Map a user via SSO.
         userinfo = {
             "sub": "test",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Subsequent calls should map to the same mxid.
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Note that a second SSO user can be mapped to the same Matrix ID. (This
         # requires a unique sub, but something that maps to the same matrix ID,
@@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Register some non-exact matching cases.
         user2 = UserID.from_string("@TEST_user_2:test")
@@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         args = self.assertRenderedError("mapping_error")
         self.assertTrue(
             args[2].startswith(
@@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@TEST_USER_2:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
@@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        self.get_success(
-            _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
-        )
+        userinfo = {"sub": "test2", "username": "föö"}
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
@@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_map_userinfo_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         store = self.hs.get_datastores().main
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
@@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # test_user is already taken, so test_user1 gets registered instead.
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@test_user1:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Register all of the potential mxids for a particular OIDC username.
         self.get_success(
@@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
@@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
     @override_config(
@@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": None,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
     @override_config(
@@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the OIDC userinfo response."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         # userinfo lacking "test": "foobar" attribute should fail.
         userinfo = {
             "sub": "tester",
             "username": "tester",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": "foobar" attribute should succeed.
         userinfo = {
@@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": "foobar",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # check that the auth handler got called as expected
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@tester:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
@@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_attribute_requirements_contains(self) -> None:
         """Test that auth succeeds if userinfo attribute CONTAINS required value"""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
         userinfo = {
             "sub": "tester",
             "username": "tester",
             "test": ["foobar", "foo", "bar"],
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # check that the auth handler got called as expected
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@tester:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
@@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
         Test that auth fails if attributes exist but don't match,
         or are non-string values.
         """
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": "not_foobar" attribute should fail
         userinfo: dict = {
             "sub": "tester",
             "username": "tester",
             "test": "not_foobar",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": ["foo", "bar"] attribute should fail
         userinfo = {
@@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": ["foo", "bar"],
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": False attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": False,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": None attribute should fail
         # a value of None breaks the OIDC spec, but it's important to not crash here
@@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": None,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": 1 attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": 1,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": 3.14 attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": 3.14,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
     def _generate_oidc_session_token(
         self,
@@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return self.handler._macaroon_generator.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
-                idp_id="oidc",
+                idp_id=self.provider.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url,
                 ui_auth_session_id=ui_auth_session_id,
@@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
 
-async def _make_callback_with_userinfo(
-    hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
-) -> None:
-    """Mock up an OIDC callback with the given userinfo dict
-
-    We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
-    and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
-
-    Args:
-        hs: the HomeServer impl to send the callback to.
-        userinfo: the OIDC userinfo dict
-        client_redirect_url: the URL to redirect to on success.
-    """
-
-    handler = hs.get_oidc_handler()
-    provider = handler._providers["oidc"]
-    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})  # type: ignore[assignment]
-    provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-
-    state = "state"
-    session = handler._macaroon_generator.generate_oidc_session_token(
-        state=state,
-        session_data=OidcSessionData(
-            idp_id="oidc",
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id="",
-        ),
-    )
-    request = _build_callback_request("code", state, session)
-
-    await handler.handle_oidc_callback(request)
-
-
 def _build_callback_request(
     code: str,
     state: str,
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 090cef5216..ebf653d018 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -465,9 +465,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
           * checking that the original operation succeeds
         """
 
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # log the user in
         remote_user_id = UserID.from_string(self.user).localpart
-        login_resp = self.helper.login_via_oidc(remote_user_id)
+        login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id)
         self.assertEqual(login_resp["user_id"], self.user)
 
         # initiate a UI Auth process by attempting to delete the device
@@ -481,8 +483,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # run the UIA-via-SSO flow
         session_id = channel.json_body["session"]
-        channel = self.helper.auth_via_oidc(
-            {"sub": remote_user_id}, ui_auth_session_id=session_id
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id
         )
 
         # that should serve a confirmation page
@@ -499,7 +501,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_does_not_offer_password_for_sso_user(self) -> None:
-        login_resp = self.helper.login_via_oidc("username")
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username")
         user_tok = login_resp["access_token"]
         device_id = login_resp["device_id"]
 
@@ -522,7 +525,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_offers_both_flows_for_upgraded_user(self) -> None:
         """A user that had a password and then logged in with SSO should get both flows"""
-        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, UserID.from_string(self.user).localpart
+        )
         self.assertEqual(login_resp["user_id"], self.user)
 
         channel = self.delete_device(
@@ -539,8 +545,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
         """If the user tries to authenticate with the wrong SSO user, they get an error"""
+
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # log the user in
-        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, UserID.from_string(self.user).localpart
+        )
         self.assertEqual(login_resp["user_id"], self.user)
 
         # start a UI Auth flow by attempting to delete a device
@@ -553,8 +564,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
         session_id = channel.json_body["session"]
 
         # do the OIDC auth, but auth as the wrong user
-        channel = self.helper.auth_via_oidc(
-            {"sub": "wrong_user"}, ui_auth_session_id=session_id
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id
         )
 
         # that should return a failure message
@@ -584,7 +595,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """Tests that if we register a user via SSO while requiring approval for new
         accounts, we still raise the correct error before logging the user in.
         """
-        login_resp = self.helper.login_via_oidc("username", expected_status=403)
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, "username", expected_status=403
+        )
 
         self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
         self.assertEqual(
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index e801ba8c8b..ff5baa9f0a 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -36,7 +36,7 @@ from synapse.util import Clock
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
-from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
 from tests.test_utils.html_parsers import TestHtmlParser
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
@@ -612,13 +612,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
     def test_login_via_oidc(self) -> None:
         """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
 
-        # pick the default OIDC provider
-        channel = self.make_request(
-            "GET",
-            "/_synapse/client/pick_idp?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
-            + "&idp=oidc",
-        )
+        fake_oidc_server = self.helper.fake_oidc_server()
+
+        with fake_oidc_server.patch_homeserver(hs=self.hs):
+            # pick the default OIDC provider
+            channel = self.make_request(
+                "GET",
+                "/_synapse/client/pick_idp?redirectUrl="
+                + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+                + "&idp=oidc",
+            )
         self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
@@ -626,7 +629,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
 
         # it should redirect us to the auth page of the OIDC server
-        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+        self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
 
         # ... and should have set a cookie including the redirect url
         cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
@@ -643,7 +646,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             TEST_CLIENT_REDIRECT_URL,
         )
 
-        channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+        channel, _ = self.helper.complete_oidc_auth(
+            fake_oidc_server, oidc_uri, cookies, {"sub": "user1"}
+        )
 
         # that should serve a confirmation page
         self.assertEqual(channel.code, 200, channel.result)
@@ -693,7 +698,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
 
     def test_client_idp_redirect_to_oidc(self) -> None:
         """If the client pick a known IdP, redirect to it"""
-        channel = self._make_sso_redirect_request("oidc")
+        fake_oidc_server = self.helper.fake_oidc_server()
+
+        with fake_oidc_server.patch_homeserver(hs=self.hs):
+            channel = self._make_sso_redirect_request("oidc")
         self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
@@ -701,7 +709,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
 
         # it should redirect us to the auth page of the OIDC server
-        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+        self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
 
     def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel:
         """Send a request to /_matrix/client/r0/login/sso/redirect
@@ -1280,9 +1288,13 @@ class UsernamePickerTestCase(HomeserverTestCase):
     def test_username_picker(self) -> None:
         """Test the happy path of a username picker flow."""
 
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # do the start of the login flow
-        channel = self.helper.auth_via_oidc(
-            {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server,
+            {"sub": "tester", "displayname": "Jonny"},
+            TEST_CLIENT_REDIRECT_URL,
         )
 
         # that should redirect to the username picker
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index c249a42bb6..967d229223 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,7 +31,6 @@ from typing import (
     Tuple,
     overload,
 )
-from unittest.mock import patch
 from urllib.parse import urlencode
 
 import attr
@@ -46,8 +45,19 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 
 from tests.server import FakeChannel, FakeSite, make_request
-from tests.test_utils import FakeResponse
 from tests.test_utils.html_parsers import TestHtmlParser
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_ISSUER = "https://issuer.test/"
+TEST_OIDC_CONFIG = {
+    "enabled": True,
+    "issuer": TEST_OIDC_ISSUER,
+    "client_id": "test-client-id",
+    "client_secret": "test-client-secret",
+    "scopes": ["openid"],
+    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
 
 
 @attr.s(auto_attribs=True)
@@ -543,12 +553,28 @@ class RestHelper:
 
         return channel.json_body
 
+    def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
+        """Create a ``FakeOidcServer``.
+
+        This can be used in conjuction with ``login_via_oidc``::
+
+            fake_oidc_server = self.helper.fake_oidc_server()
+            login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user")
+        """
+
+        return FakeOidcServer(
+            clock=self.hs.get_clock(),
+            issuer=issuer,
+        )
+
     def login_via_oidc(
         self,
+        fake_server: FakeOidcServer,
         remote_user_id: str,
+        with_sid: bool = False,
         expected_status: int = 200,
-    ) -> JsonDict:
-        """Log in via OIDC
+    ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
+        """Log in (as a new user) via OIDC
 
         Returns the result of the final token login.
 
@@ -560,7 +586,10 @@ class RestHelper:
         the normal places.
         """
         client_redirect_url = "https://x"
-        channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+        userinfo = {"sub": remote_user_id}
+        channel, grant = self.auth_via_oidc(
+            fake_server, userinfo, client_redirect_url, with_sid=with_sid
+        )
 
         # expect a confirmation page
         assert channel.code == HTTPStatus.OK, channel.result
@@ -585,14 +614,16 @@ class RestHelper:
         assert (
             channel.code == expected_status
         ), f"unexpected status in response: {channel.code}"
-        return channel.json_body
+        return channel.json_body, grant
 
     def auth_via_oidc(
         self,
+        fake_server: FakeOidcServer,
         user_info_dict: JsonDict,
         client_redirect_url: Optional[str] = None,
         ui_auth_session_id: Optional[str] = None,
-    ) -> FakeChannel:
+        with_sid: bool = False,
+    ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
         """Perform an OIDC authentication flow via a mock OIDC provider.
 
         This can be used for either login or user-interactive auth.
@@ -616,6 +647,7 @@ class RestHelper:
                 the login redirect endpoint
             ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
                 of the UI auth.
+            with_sid: if True, generates a random `sid` (OIDC session ID)
 
         Returns:
             A FakeChannel containing the result of calling the OIDC callback endpoint.
@@ -625,14 +657,15 @@ class RestHelper:
 
         cookies: Dict[str, str] = {}
 
-        # if we're doing a ui auth, hit the ui auth redirect endpoint
-        if ui_auth_session_id:
-            # can't set the client redirect url for UI Auth
-            assert client_redirect_url is None
-            oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
-        else:
-            # otherwise, hit the login redirect endpoint
-            oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+        with fake_server.patch_homeserver(hs=self.hs):
+            # if we're doing a ui auth, hit the ui auth redirect endpoint
+            if ui_auth_session_id:
+                # can't set the client redirect url for UI Auth
+                assert client_redirect_url is None
+                oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+            else:
+                # otherwise, hit the login redirect endpoint
+                oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
 
         # we now have a URI for the OIDC IdP, but we skip that and go straight
         # back to synapse's OIDC callback resource. However, we do need the "state"
@@ -640,17 +673,21 @@ class RestHelper:
         # that synapse passes to the client.
 
         oauth_uri_path, _ = oauth_uri.split("?", 1)
-        assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+        assert oauth_uri_path == fake_server.authorization_endpoint, (
             "unexpected SSO URI " + oauth_uri_path
         )
-        return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+        return self.complete_oidc_auth(
+            fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid
+        )
 
     def complete_oidc_auth(
         self,
+        fake_serer: FakeOidcServer,
         oauth_uri: str,
         cookies: Mapping[str, str],
         user_info_dict: JsonDict,
-    ) -> FakeChannel:
+        with_sid: bool = False,
+    ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
         """Mock out an OIDC authentication flow
 
         Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
@@ -661,50 +698,37 @@ class RestHelper:
         Requires the OIDC callback resource to be mounted at the normal place.
 
         Args:
+            fake_server: the fake OIDC server with which the auth should be done
             oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
                from initiate_sso_login or initiate_sso_ui_auth).
             cookies: the cookies set by synapse's redirect endpoint, which will be
                sent back to the callback endpoint.
             user_info_dict: the remote userinfo that the OIDC provider should present.
                 Typically this should be '{"sub": "<remote user id>"}'.
+            with_sid: if True, generates a random `sid` (OIDC session ID)
 
         Returns:
             A FakeChannel containing the result of calling the OIDC callback endpoint.
         """
         _, oauth_uri_qs = oauth_uri.split("?", 1)
         params = urllib.parse.parse_qs(oauth_uri_qs)
+
+        code, grant = fake_serer.start_authorization(
+            scope=params["scope"][0],
+            userinfo=user_info_dict,
+            client_id=params["client_id"][0],
+            redirect_uri=params["redirect_uri"][0],
+            nonce=params["nonce"][0],
+            with_sid=with_sid,
+        )
+        state = params["state"][0]
+
         callback_uri = "%s?%s" % (
             urllib.parse.urlparse(params["redirect_uri"][0]).path,
-            urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+            urllib.parse.urlencode({"state": state, "code": code}),
         )
 
-        # before we hit the callback uri, stub out some methods in the http client so
-        # that we don't have to handle full HTTPS requests.
-        # (expected url, json response) pairs, in the order we expect them.
-        expected_requests = [
-            # first we get a hit to the token endpoint, which we tell to return
-            # a dummy OIDC access token
-            (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
-            # and then one to the user_info endpoint, which returns our remote user id.
-            (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
-        ]
-
-        async def mock_req(
-            method: str,
-            uri: str,
-            data: Optional[dict] = None,
-            headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
-        ):
-            (expected_uri, resp_obj) = expected_requests.pop(0)
-            assert uri == expected_uri
-            resp = FakeResponse(
-                code=HTTPStatus.OK,
-                phrase=b"OK",
-                body=json.dumps(resp_obj).encode("utf-8"),
-            )
-            return resp
-
-        with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+        with fake_serer.patch_homeserver(hs=self.hs):
             # now hit the callback URI with the right params and a made-up code
             channel = make_request(
                 self.hs.get_reactor(),
@@ -715,7 +739,7 @@ class RestHelper:
                     ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
                 ],
             )
-        return channel
+        return channel, grant
 
     def initiate_sso_login(
         self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
@@ -806,21 +830,3 @@ class RestHelper:
         assert len(p.links) == 1, "not exactly one link in confirmation page"
         oauth_uri = p.links[0]
         return oauth_uri
-
-
-# an 'oidc_config' suitable for login_via_oidc.
-TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
-TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
-TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
-TEST_OIDC_CONFIG = {
-    "enabled": True,
-    "discover": False,
-    "issuer": "https://issuer.test",
-    "client_id": "test-client-id",
-    "client_secret": "test-client-secret",
-    "scopes": ["profile"],
-    "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
-    "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
-    "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
-    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
-}
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 0d0d6faf0d..e62ebcc6a5 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -15,17 +15,24 @@
 """
 Utilities for running the unit tests
 """
+import json
 import sys
 import warnings
 from asyncio import Future
 from binascii import unhexlify
-from typing import Awaitable, Callable, TypeVar
+from typing import Awaitable, Callable, Tuple, TypeVar
 from unittest.mock import Mock
 
 import attr
+import zope.interface
 
 from twisted.python.failure import Failure
 from twisted.web.client import ResponseDone
+from twisted.web.http import RESPONSES
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
+
+from synapse.types import JsonDict
 
 TV = TypeVar("TV")
 
@@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock:
     return Mock(side_effect=cb)
 
 
-@attr.s
-class FakeResponse:
+# Type ignore: it does not fully implement IResponse, but is good enough for tests
+@zope.interface.implementer(IResponse)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FakeResponse:  # type: ignore[misc]
     """A fake twisted.web.IResponse object
 
     there is a similar class at treq.test.test_response, but it lacks a `phrase`
     attribute, and didn't support deliverBody until recently.
     """
 
-    # HTTP response code
-    code = attr.ib(type=int)
+    version: Tuple[bytes, int, int] = (b"HTTP", 1, 1)
 
-    # HTTP response phrase (eg b'OK' for a 200)
-    phrase = attr.ib(type=bytes)
+    # HTTP response code
+    code: int = 200
 
     # body of the response
-    body = attr.ib(type=bytes)
+    body: bytes = b""
+
+    headers: Headers = attr.Factory(Headers)
+
+    @property
+    def phrase(self):
+        return RESPONSES.get(self.code, b"Unknown Status")
+
+    @property
+    def length(self):
+        return len(self.body)
 
     def deliverBody(self, protocol):
         protocol.dataReceived(self.body)
         protocol.connectionLost(Failure(ResponseDone()))
 
+    @classmethod
+    def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse":
+        headers = Headers({"Content-Type": ["application/json"]})
+        body = json.dumps(payload).encode("utf-8")
+        return cls(code=code, body=body, headers=headers)
+
 
 # A small image used in some tests.
 #
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
new file mode 100644
index 0000000000..de134bbc89
--- /dev/null
+++ b/tests/test_utils/oidc.py
@@ -0,0 +1,325 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+from typing import Any, Dict, List, Optional, Tuple
+from unittest.mock import Mock, patch
+from urllib.parse import parse_qs
+
+import attr
+
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
+
+from tests.test_utils import FakeResponse
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FakeAuthorizationGrant:
+    userinfo: dict
+    client_id: str
+    redirect_uri: str
+    scope: str
+    nonce: Optional[str]
+    sid: Optional[str]
+
+
+class FakeOidcServer:
+    """A fake OpenID Connect Provider."""
+
+    # All methods here are mocks, so we can track when they are called, and override
+    # their values
+    request: Mock
+    get_jwks_handler: Mock
+    get_metadata_handler: Mock
+    get_userinfo_handler: Mock
+    post_token_handler: Mock
+
+    def __init__(self, clock: Clock, issuer: str):
+        from authlib.jose import ECKey, KeySet
+
+        self._clock = clock
+        self.issuer = issuer
+
+        self.request = Mock(side_effect=self._request)
+        self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler)
+        self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler)
+        self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler)
+        self.post_token_handler = Mock(side_effect=self._post_token_handler)
+
+        # A code -> grant mapping
+        self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {}
+        # An access token -> grant mapping
+        self._sessions: Dict[str, FakeAuthorizationGrant] = {}
+
+        # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for
+        # signing JWTs. ECDSA keys are really quick to generate compared to RSA.
+        self._key = ECKey.generate_key(crv="P-256", is_private=True)
+        self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))])
+
+        self._id_token_overrides: Dict[str, Any] = {}
+
+    def reset_mocks(self):
+        self.request.reset_mock()
+        self.get_jwks_handler.reset_mock()
+        self.get_metadata_handler.reset_mock()
+        self.get_userinfo_handler.reset_mock()
+        self.post_token_handler.reset_mock()
+
+    def patch_homeserver(self, hs: HomeServer):
+        """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
+
+        This patch should be used whenever the HS is expected to perform request to the
+        OIDC provider, e.g.::
+
+            fake_oidc_server = self.helper.fake_oidc_server()
+            with fake_oidc_server.patch_homeserver(hs):
+                self.make_request("GET", "/_matrix/client/r0/login/sso/redirect")
+        """
+        return patch.object(hs.get_proxied_http_client(), "request", self.request)
+
+    @property
+    def authorization_endpoint(self) -> str:
+        return self.issuer + "authorize"
+
+    @property
+    def token_endpoint(self) -> str:
+        return self.issuer + "token"
+
+    @property
+    def userinfo_endpoint(self) -> str:
+        return self.issuer + "userinfo"
+
+    @property
+    def metadata_endpoint(self) -> str:
+        return self.issuer + ".well-known/openid-configuration"
+
+    @property
+    def jwks_uri(self) -> str:
+        return self.issuer + "jwks"
+
+    def get_metadata(self) -> dict:
+        return {
+            "issuer": self.issuer,
+            "authorization_endpoint": self.authorization_endpoint,
+            "token_endpoint": self.token_endpoint,
+            "jwks_uri": self.jwks_uri,
+            "userinfo_endpoint": self.userinfo_endpoint,
+            "response_types_supported": ["code"],
+            "subject_types_supported": ["public"],
+            "id_token_signing_alg_values_supported": ["ES256"],
+        }
+
+    def get_jwks(self) -> dict:
+        return self._jwks.as_dict()
+
+    def get_userinfo(self, access_token: str) -> Optional[dict]:
+        """Given an access token, get the userinfo of the associated session."""
+        session = self._sessions.get(access_token, None)
+        if session is None:
+            return None
+        return session.userinfo
+
+    def _sign(self, payload: dict) -> str:
+        from authlib.jose import JsonWebSignature
+
+        jws = JsonWebSignature()
+        kid = self.get_jwks()["keys"][0]["kid"]
+        protected = {"alg": "ES256", "kid": kid}
+        json_payload = json.dumps(payload)
+        return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
+
+    def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
+        now = self._clock.time()
+        id_token = {
+            **grant.userinfo,
+            "iss": self.issuer,
+            "aud": grant.client_id,
+            "iat": now,
+            "nbf": now,
+            "exp": now + 600,
+        }
+
+        if grant.nonce is not None:
+            id_token["nonce"] = grant.nonce
+
+        if grant.sid is not None:
+            id_token["sid"] = grant.sid
+
+        id_token.update(self._id_token_overrides)
+
+        return self._sign(id_token)
+
+    def id_token_override(self, overrides: dict):
+        """Temporarily patch the ID token generated by the token endpoint."""
+        return patch.object(self, "_id_token_overrides", overrides)
+
+    def start_authorization(
+        self,
+        client_id: str,
+        scope: str,
+        redirect_uri: str,
+        userinfo: dict,
+        nonce: Optional[str] = None,
+        with_sid: bool = False,
+    ) -> Tuple[str, FakeAuthorizationGrant]:
+        """Start an authorization request, and get back the code to use on the authorization endpoint."""
+        code = random_string(10)
+        sid = None
+        if with_sid:
+            sid = random_string(10)
+
+        grant = FakeAuthorizationGrant(
+            userinfo=userinfo,
+            scope=scope,
+            redirect_uri=redirect_uri,
+            nonce=nonce,
+            client_id=client_id,
+            sid=sid,
+        )
+        self._authorization_grants[code] = grant
+
+        return code, grant
+
+    def exchange_code(self, code: str) -> Optional[Dict[str, Any]]:
+        grant = self._authorization_grants.pop(code, None)
+        if grant is None:
+            return None
+
+        access_token = random_string(10)
+        self._sessions[access_token] = grant
+
+        token = {
+            "token_type": "Bearer",
+            "access_token": access_token,
+            "expires_in": 3600,
+            "scope": grant.scope,
+        }
+
+        if "openid" in grant.scope:
+            token["id_token"] = self.generate_id_token(grant)
+
+        return dict(token)
+
+    def buggy_endpoint(
+        self,
+        *,
+        jwks: bool = False,
+        metadata: bool = False,
+        token: bool = False,
+        userinfo: bool = False,
+    ):
+        """A context which makes a set of endpoints return a 500 error.
+
+        Args:
+            jwks: If True, makes the JWKS endpoint return a 500 error.
+            metadata: If True, makes the OIDC Discovery endpoint return a 500 error.
+            token: If True, makes the token endpoint return a 500 error.
+            userinfo: If True, makes the userinfo endpoint return a 500 error.
+        """
+        buggy = FakeResponse(code=500, body=b"Internal server error")
+
+        patches = {}
+        if jwks:
+            patches["get_jwks_handler"] = Mock(return_value=buggy)
+        if metadata:
+            patches["get_metadata_handler"] = Mock(return_value=buggy)
+        if token:
+            patches["post_token_handler"] = Mock(return_value=buggy)
+        if userinfo:
+            patches["get_userinfo_handler"] = Mock(return_value=buggy)
+
+        return patch.multiple(self, **patches)
+
+    async def _request(
+        self,
+        method: str,
+        uri: str,
+        data: Optional[bytes] = None,
+        headers: Optional[Headers] = None,
+    ) -> IResponse:
+        """The override of the SimpleHttpClient#request() method"""
+        access_token: Optional[str] = None
+
+        if headers is None:
+            headers = Headers()
+
+        # Try to find the access token in the headers if any
+        auth_headers = headers.getRawHeaders(b"Authorization")
+        if auth_headers:
+            parts = auth_headers[0].split(b" ")
+            if parts[0] == b"Bearer" and len(parts) == 2:
+                access_token = parts[1].decode("ascii")
+
+        if method == "POST":
+            # If the method is POST, assume it has an url-encoded body
+            if data is None or headers.getRawHeaders(b"Content-Type") != [
+                b"application/x-www-form-urlencoded"
+            ]:
+                return FakeResponse.json(code=400, payload={"error": "invalid_request"})
+
+            params = parse_qs(data.decode("utf-8"))
+
+            if uri == self.token_endpoint:
+                # Even though this endpoint should be protected, this does not check
+                # for client authentication. We're not checking it for simplicity,
+                # and because client authentication is tested in other standalone tests.
+                return self.post_token_handler(params)
+
+        elif method == "GET":
+            if uri == self.jwks_uri:
+                return self.get_jwks_handler()
+            elif uri == self.metadata_endpoint:
+                return self.get_metadata_handler()
+            elif uri == self.userinfo_endpoint:
+                return self.get_userinfo_handler(access_token=access_token)
+
+        return FakeResponse(code=404, body=b"404 not found")
+
+    # Request handlers
+    def _get_jwks_handler(self) -> IResponse:
+        """Handles requests to the JWKS URI."""
+        return FakeResponse.json(payload=self.get_jwks())
+
+    def _get_metadata_handler(self) -> IResponse:
+        """Handles requests to the OIDC well-known document."""
+        return FakeResponse.json(payload=self.get_metadata())
+
+    def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse:
+        """Handles requests to the userinfo endpoint."""
+        if access_token is None:
+            return FakeResponse(code=401)
+        user_info = self.get_userinfo(access_token)
+        if user_info is None:
+            return FakeResponse(code=401)
+
+        return FakeResponse.json(payload=user_info)
+
+    def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse:
+        """Handles requests to the token endpoint."""
+        code = params.get("code", [])
+
+        if len(code) != 1:
+            return FakeResponse.json(code=400, payload={"error": "invalid_request"})
+
+        grant = self.exchange_code(code=code[0])
+        if grant is None:
+            return FakeResponse.json(code=400, payload={"error": "invalid_grant"})
+
+        return FakeResponse.json(payload=grant)