summary refs log tree commit diff
diff options
context:
space:
mode:
authorQuentin Gliech <quenting@element.io>2022-10-25 16:25:02 +0200
committerGitHub <noreply@github.com>2022-10-25 14:25:02 +0000
commit9192d74b0bf2f87b00d3e106a18baa9ce27acda1 (patch)
tree08bc76abec65c3124686f19f03849e6ccb12c820
parentImplementation for MSC3664: Pushrules for relations (#11804) (diff)
downloadsynapse-9192d74b0bf2f87b00d3e106a18baa9ce27acda1.tar.xz
Refactor OIDC tests to better mimic an actual OIDC provider. (#13910)
This implements a fake OIDC server, which intercepts calls to the HTTP client.
Improves accuracy of tests by covering more internal methods.

One particular example was the ID token validation, which previously mocked.

This uncovered an incorrect dependency: Synapse actually requires at least
authlib 0.15.1, not 0.14.0.
-rw-r--r--changelog.d/13910.misc1
-rw-r--r--pyproject.toml2
-rw-r--r--synapse/handlers/oidc.py15
-rw-r--r--tests/federation/test_federation_client.py36
-rw-r--r--tests/handlers/test_oidc.py580
-rw-r--r--tests/rest/client/test_auth.py32
-rw-r--r--tests/rest/client/test_login.py40
-rw-r--r--tests/rest/client/utils.py136
-rw-r--r--tests/test_utils/__init__.py40
-rw-r--r--tests/test_utils/oidc.py325
10 files changed, 747 insertions, 460 deletions
diff --git a/changelog.d/13910.misc b/changelog.d/13910.misc
new file mode 100644
index 0000000000..e906952aab
--- /dev/null
+++ b/changelog.d/13910.misc
@@ -0,0 +1 @@
+Refactor OIDC tests to better mimic an actual OIDC provider.
diff --git a/pyproject.toml b/pyproject.toml
index 6ebac41ed1..7e0feb75aa 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -192,7 +192,7 @@ psycopg2 = { version = ">=2.8", markers = "platform_python_implementation != 'Py
 psycopg2cffi = { version = ">=2.8", markers = "platform_python_implementation == 'PyPy'", optional = true }
 psycopg2cffi-compat = { version = "==1.1", markers = "platform_python_implementation == 'PyPy'", optional = true }
 pysaml2 = { version = ">=4.5.0", optional = true }
-authlib = { version = ">=0.14.0", optional = true }
+authlib = { version = ">=0.15.1", optional = true }
 # systemd-python is necessary for logging to the systemd journal via
 # `systemd.journal.JournalHandler`, as is documented in
 # `contrib/systemd/log_config.yaml`.
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index d7a8226900..9759daf043 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -275,6 +275,7 @@ class OidcProvider:
         provider: OidcProviderConfig,
     ):
         self._store = hs.get_datastores().main
+        self._clock = hs.get_clock()
 
         self._macaroon_generaton = macaroon_generator
 
@@ -673,6 +674,13 @@ class OidcProvider:
         Returns:
             The decoded claims in the ID token.
         """
+        id_token = token.get("id_token")
+        logger.debug("Attempting to decode JWT id_token %r", id_token)
+
+        # That has been theoritically been checked by the caller, so even though
+        # assertion are not enabled in production, it is mainly here to appease mypy
+        assert id_token is not None
+
         metadata = await self.load_metadata()
         claims_params = {
             "nonce": nonce,
@@ -688,9 +696,6 @@ class OidcProvider:
 
         claim_options = {"iss": {"values": [metadata["issuer"]]}}
 
-        id_token = token["id_token"]
-        logger.debug("Attempting to decode JWT id_token %r", id_token)
-
         # Try to decode the keys in cache first, then retry by forcing the keys
         # to be reloaded
         jwk_set = await self.load_jwks()
@@ -715,7 +720,9 @@ class OidcProvider:
 
         logger.debug("Decoded id_token JWT %r; validating", claims)
 
-        claims.validate(leeway=120)  # allows 2 min of clock skew
+        claims.validate(
+            now=self._clock.time(), leeway=120
+        )  # allows 2 min of clock skew
 
         return claims
 
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index a538215931..51d3bb8fff 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -12,13 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 from unittest import mock
 
 import twisted.web.client
 from twisted.internet import defer
-from twisted.internet.protocol import Protocol
-from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.room_versions import RoomVersions
@@ -26,10 +23,9 @@ from synapse.events import EventBase
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
-from synapse.types import JsonDict
 from synapse.util import Clock
 
-from tests.test_utils import event_injection
+from tests.test_utils import FakeResponse, event_injection
 from tests.unittest import FederatingHomeserverTestCase
 
 
@@ -98,8 +94,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         # mock up the response, and have the agent return it
         self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
-            _mock_response(
-                {
+            FakeResponse.json(
+                payload={
                     "pdus": [
                         create_event_dict,
                         member_event_dict,
@@ -208,8 +204,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         # mock up the response, and have the agent return it
         self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
-            _mock_response(
-                {
+            FakeResponse.json(
+                payload={
                     "origin": "yet.another.server",
                     "origin_server_ts": 900,
                     "pdus": [
@@ -269,8 +265,8 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         # We expect an outbound request to /backfill, so stub that out
         self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
-            _mock_response(
-                {
+            FakeResponse.json(
+                payload={
                     "origin": "yet.another.server",
                     "origin_server_ts": 900,
                     # Mimic the other server returning our new `pulled_event`
@@ -305,21 +301,3 @@ class FederationClientTest(FederatingHomeserverTestCase):
         # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the
         # other from "yet.another.server"
         self.assertEqual(backfill_num_attempts, 2)
-
-
-def _mock_response(resp: JsonDict):
-    body = json.dumps(resp).encode("utf-8")
-
-    def deliver_body(p: Protocol):
-        p.dataReceived(body)
-        p.connectionLost(Failure(twisted.web.client.ResponseDone()))
-
-    response = mock.Mock(
-        code=200,
-        phrase=b"OK",
-        headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
-        length=len(body),
-        deliverBody=deliver_body,
-    )
-    mock.seal(response)
-    return response
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e6cd3af7b7..5955410524 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -11,9 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
 import os
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
 from unittest.mock import ANY, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
@@ -22,12 +21,15 @@ import pymacaroons
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.handlers.sso import MappingException
+from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import UserID
 from synapse.util import Clock
-from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon
+from synapse.util.macaroons import get_value_from_macaroon
+from synapse.util.stringutils import random_string
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
 from tests.unittest import HomeserverTestCase, override_config
 
 try:
@@ -46,12 +48,6 @@ BASE_URL = "https://synapse/"
 CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
 SCOPES = ["openid"]
 
-AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
-TOKEN_ENDPOINT = ISSUER + "token"
-USERINFO_ENDPOINT = ISSUER + "userinfo"
-WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
-JWKS_URI = ISSUER + ".well-known/jwks.json"
-
 # config for common cases
 DEFAULT_CONFIG = {
     "enabled": True,
@@ -66,9 +62,9 @@ DEFAULT_CONFIG = {
 EXPLICIT_ENDPOINT_CONFIG = {
     **DEFAULT_CONFIG,
     "discover": False,
-    "authorization_endpoint": AUTHORIZATION_ENDPOINT,
-    "token_endpoint": TOKEN_ENDPOINT,
-    "jwks_uri": JWKS_URI,
+    "authorization_endpoint": ISSUER + "authorize",
+    "token_endpoint": ISSUER + "token",
+    "jwks_uri": ISSUER + "jwks",
 }
 
 
@@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-async def get_json(url: str) -> JsonDict:
-    # Mock get_json calls to handle jwks & oidc discovery endpoints
-    if url == WELL_KNOWN:
-        # Minimal discovery document, as defined in OpenID.Discovery
-        # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
-        return {
-            "issuer": ISSUER,
-            "authorization_endpoint": AUTHORIZATION_ENDPOINT,
-            "token_endpoint": TOKEN_ENDPOINT,
-            "jwks_uri": JWKS_URI,
-            "userinfo_endpoint": USERINFO_ENDPOINT,
-            "response_types_supported": ["code"],
-            "subject_types_supported": ["public"],
-            "id_token_signing_alg_values_supported": ["RS256"],
-        }
-    elif url == JWKS_URI:
-        return {"keys": []}
-
-    return {}
-
-
 def _key_file_path() -> str:
     """path to a file containing the private half of a test key"""
 
@@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return config
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.http_client = Mock(spec=["get_json"])
-        self.http_client.get_json.side_effect = get_json
-        self.http_client.user_agent = b"Synapse Test"
+        self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER)
 
-        hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+        hs = self.setup_test_homeserver()
+        self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
+        self.hs_patcher.start()
 
         self.handler = hs.get_oidc_handler()
         self.provider = self.handler._providers["oidc"]
@@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # Reduce the number of attempts when generating MXIDs.
         sso_handler._MAP_USERNAME_RETRIES = 3
 
+        auth_handler = hs.get_auth_handler()
+        # Mock the complete SSO login method.
+        self.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = self.complete_sso_login  # type: ignore[assignment]
+
         return hs
 
+    def tearDown(self) -> None:
+        self.hs_patcher.stop()
+        return super().tearDown()
+
+    def reset_mocks(self):
+        """Reset all the Mocks."""
+        self.fake_server.reset_mocks()
+        self.render_error.reset_mock()
+        self.complete_sso_login.reset_mock()
+
     def metadata_edit(self, values):
         """Modify the result that will be returned by the well-known query"""
 
-        async def patched_get_json(uri):
-            res = await get_json(uri)
-            if uri == WELL_KNOWN:
-                res.update(values)
-            return res
+        metadata = self.fake_server.get_metadata()
+        metadata.update(values)
+        return patch.object(self.fake_server, "get_metadata", return_value=metadata)
 
-        return patch.object(self.http_client, "get_json", patched_get_json)
+    def start_authorization(
+        self,
+        userinfo: dict,
+        client_redirect_url: str = "http://client/redirect",
+        scope: str = "openid",
+        with_sid: bool = False,
+    ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]:
+        """Start an authorization request, and get the callback request back."""
+        nonce = random_string(10)
+        state = random_string(10)
+
+        code, grant = self.fake_server.start_authorization(
+            userinfo=userinfo,
+            scope=scope,
+            client_id=self.provider._client_auth.client_id,
+            redirect_uri=self.provider._callback_url,
+            nonce=nonce,
+            with_sid=with_sid,
+        )
+        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+        return _build_callback_request(code, state, session), grant
 
     def assertRenderedError(self, error, error_description=None):
         self.render_error.assert_called_once()
@@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
         metadata = self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.fake_server.get_metadata_handler.assert_called_once()
 
-        self.assertEqual(metadata.issuer, ISSUER)
-        self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
-        self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
-        self.assertEqual(metadata.jwks_uri, JWKS_URI)
-        # FIXME: it seems like authlib does not have that defined in its metadata models
-        # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+        self.assertEqual(metadata.issuer, self.fake_server.issuer)
+        self.assertEqual(
+            metadata.authorization_endpoint,
+            self.fake_server.authorization_endpoint,
+        )
+        self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint)
+        self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri)
+        # It seems like authlib does not have that defined in its metadata models
+        self.assertEqual(
+            metadata.get("userinfo_endpoint"),
+            self.fake_server.userinfo_endpoint,
+        )
 
         # subsequent calls should be cached
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_metadata_handler.assert_not_called()
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     def test_no_discovery(self) -> None:
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_metadata_handler.assert_not_called()
 
-    @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_load_jwks(self) -> None:
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
-        self.http_client.get_json.assert_called_once_with(JWKS_URI)
-        self.assertEqual(jwks, {"keys": []})
+        self.fake_server.get_jwks_handler.assert_called_once()
+        self.assertEqual(jwks, self.fake_server.get_jwks())
 
         # subsequent calls should be cached…
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_jwks())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_jwks_handler.assert_not_called()
 
         # …unless forced
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_jwks(force=True))
-        self.http_client.get_json.assert_called_once_with(JWKS_URI)
+        self.fake_server.get_jwks_handler.assert_called_once()
 
-        # Throw if the JWKS uri is missing
-        original = self.provider.load_metadata
-
-        async def patched_load_metadata():
-            m = (await original()).copy()
-            m.update({"jwks_uri": None})
-            return m
-
-        with patch.object(self.provider, "load_metadata", patched_load_metadata):
+        with self.metadata_edit({"jwks_uri": None}):
+            # If we don't do this, the load_metadata call will throw because of the
+            # missing jwks_uri
+            self.provider._user_profile_method = "userinfo_endpoint"
+            self.get_success(self.provider.load_metadata(force=True))
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
@@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 self.provider.handle_redirect_request(req, b"http://client/redirect")
             )
         )
-        auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+        auth_endpoint = urlparse(self.fake_server.authorization_endpoint)
 
         self.assertEqual(url.scheme, auth_endpoint.scheme)
         self.assertEqual(url.netloc, auth_endpoint.netloc)
@@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
         with self.assertRaises(AttributeError):
             _ = mapping_provider.get_extra_attributes
 
-        token = {
-            "type": "bearer",
-            "id_token": "id_token",
-            "access_token": "access_token",
-        }
         username = "bar"
         userinfo = {
             "sub": "foo",
             "username": username,
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
 
-        code = "code"
-        state = "state"
-        nonce = "nonce"
         client_redirect_url = "http://client/redirect"
-        ip_address = "10.0.0.1"
-        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
-        request = _build_callback_request(code, state, session, ip_address=ip_address)
-
+        request, _ = self.start_authorization(
+            userinfo, client_redirect_url=client_redirect_url
+        )
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
             client_redirect_url,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.provider._fetch_userinfo.assert_not_called()
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
+        request, _ = self.start_authorization(userinfo)
         with patch.object(
             self.provider,
             "_remote_id_from_userinfo",
@@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.provider._parse_id_token = simple_async_mock(raises=Exception())  # type: ignore[assignment]
-        self.get_success(self.handler.handle_oidc_callback(request))
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}):
+            self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
-        auth_handler.complete_sso_login.reset_mock()
-        self.provider._exchange_code.reset_mock()
-        self.provider._parse_id_token.reset_mock()
-        self.provider._fetch_userinfo.reset_mock()
+        self.reset_mocks()
 
         # With userinfo fetching
         self.provider._user_profile_method = "userinfo_endpoint"
-        token = {
-            "type": "bearer",
-            "access_token": "access_token",
-        }
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        # Without the "openid" scope, the FakeProvider does not generate an id_token
+        request, _ = self.start_authorization(userinfo, scope="")
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_not_called()
-        self.provider._fetch_userinfo.assert_called_once_with(token)
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_called_once()
         self.render_error.assert_not_called()
 
+        self.reset_mocks()
+
         # With an ID token, userinfo fetching and sid in the ID token
         self.provider._user_profile_method = "userinfo_endpoint"
-        token = {
-            "type": "bearer",
-            "access_token": "access_token",
-            "id_token": "id_token",
-        }
-        id_token = {
-            "sid": "abcdefgh",
-        }
-        self.provider._parse_id_token = simple_async_mock(return_value=id_token)  # type: ignore[assignment]
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        auth_handler.complete_sso_login.reset_mock()
-        self.provider._fetch_userinfo.reset_mock()
+        request, grant = self.start_authorization(userinfo, with_sid=True)
+        self.assertIsNotNone(grant.sid)
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             None,
             new_user=False,
-            auth_provider_session_id=id_token["sid"],
+            auth_provider_session_id=grant.sid,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.provider._fetch_userinfo.assert_called_once_with(token)
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_called_once()
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())  # type: ignore[assignment]
-        self.get_success(self.handler.handle_oidc_callback(request))
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.buggy_endpoint(userinfo=True):
+            self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
-        # Handle code exchange failure
-        from synapse.handlers.oidc import OidcError
-
-        self.provider._exchange_code = simple_async_mock(  # type: ignore[assignment]
-            raises=OidcError("invalid_request")
-        )
-        self.get_success(self.handler.handle_oidc_callback(request))
-        self.assertRenderedError("invalid_request")
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.buggy_endpoint(token=True):
+            self.get_success(self.handler.handle_oidc_callback(request))
+        self.assertRenderedError("server_error")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_session(self) -> None:
@@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_exchange_code(self) -> None:
         """Code exchange behaves correctly and handles various error scenarios."""
-        token = {"type": "bearer"}
-        token_json = json.dumps(token).encode("utf-8")
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
         ret = self.get_success(self.provider._exchange_code(code))
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
 
         self.assertEqual(ret, token)
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         args = parse_qs(kwargs["data"].decode("utf-8"))
         self.assertEqual(args["grant_type"], ["authorization_code"])
@@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
 
         # Test error handling
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=400,
-                phrase=b"Bad Request",
-                body=b'{"error": "foo", "error_description": "bar"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=400, payload={"error": "foo", "error_description": "bar"}
         )
         from synapse.handlers.oidc import OidcError
 
@@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(exc.value.error_description, "bar")
 
         # Internal server error with no JSON body
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=500,
-                phrase=b"Internal Server Error",
-                body=b"Not JSON",
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse(
+            code=500, body=b"Not JSON"
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=500,
-                phrase=b"Internal Server Error",
-                body=b'{"error": "internal_server_error"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=500, payload={"error": "internal_server_error"}
         )
 
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=400,
-                phrase=b"Bad request",
-                body=b"{}",
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=400, payload={}
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200,
-                phrase=b"OK",
-                body=b'{"error": "some_error"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=200, payload={"error": "some_error"}
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "some_error")
@@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Test that code exchange works with a JWK client secret."""
         from authlib.jose import jwt
 
-        token = {"type": "bearer"}
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
-            )
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
 
@@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(ret, token)
 
         # the request should have hit the token endpoint
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         # the client secret provided to the should be a jwt which can be checked with
         # the public key
@@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_exchange_code_no_auth(self) -> None:
         """Test that code exchange works with no client secret."""
-        token = {"type": "bearer"}
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
-            )
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
         ret = self.get_success(self.provider._exchange_code(code))
@@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(ret, token)
 
         # the request should have hit the token endpoint
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         # check the POSTed data
         args = parse_qs(kwargs["data"].decode("utf-8"))
@@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """
         Login while using a mapping provider that implements get_extra_attributes.
         """
-        token = {
-            "type": "bearer",
-            "id_token": "id_token",
-            "access_token": "access_token",
-        }
         userinfo = {
             "sub": "foo",
             "username": "foo",
             "phone": "1234567",
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
-        state = "state"
-        client_redirect_url = "http://client/redirect"
-        session = self._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-        )
-        request = _build_callback_request("code", state, session)
-
+        request, _ = self.start_authorization(userinfo)
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@foo:test",
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             {"phone": "1234567"},
             new_user=True,
             auth_provider_session_id=None,
@@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_user(self) -> None:
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         userinfo: dict = {
             "sub": "test_user",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@test_user:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Some providers return an integer ID.
         userinfo = {
             "sub": 1234,
             "username": "test_user_2",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@test_user_2:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Test if the mxid is already taken
         store = self.hs.get_datastores().main
@@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error",
             "Mapping provider does not support de-duplicating Matrix IDs",
@@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user.to_string(), password_hash=None)
         )
 
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         # Map a user via SSO.
         userinfo = {
             "sub": "test",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Subsequent calls should map to the same mxid.
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Note that a second SSO user can be mapped to the same Matrix ID. (This
         # requires a unique sub, but something that maps to the same matrix ID,
@@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Register some non-exact matching cases.
         user2 = UserID.from_string("@TEST_user_2:test")
@@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         args = self.assertRenderedError("mapping_error")
         self.assertTrue(
             args[2].startswith(
@@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@TEST_USER_2:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
@@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        self.get_success(
-            _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
-        )
+        userinfo = {"sub": "test2", "username": "föö"}
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
@@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_map_userinfo_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         store = self.hs.get_datastores().main
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
@@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # test_user is already taken, so test_user1 gets registered instead.
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@test_user1:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Register all of the potential mxids for a particular OIDC username.
         self.get_success(
@@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
@@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
     @override_config(
@@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": None,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
     @override_config(
@@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the OIDC userinfo response."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         # userinfo lacking "test": "foobar" attribute should fail.
         userinfo = {
             "sub": "tester",
             "username": "tester",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": "foobar" attribute should succeed.
         userinfo = {
@@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": "foobar",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # check that the auth handler got called as expected
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@tester:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
@@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_attribute_requirements_contains(self) -> None:
         """Test that auth succeeds if userinfo attribute CONTAINS required value"""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
         userinfo = {
             "sub": "tester",
             "username": "tester",
             "test": ["foobar", "foo", "bar"],
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # check that the auth handler got called as expected
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@tester:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
@@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
         Test that auth fails if attributes exist but don't match,
         or are non-string values.
         """
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": "not_foobar" attribute should fail
         userinfo: dict = {
             "sub": "tester",
             "username": "tester",
             "test": "not_foobar",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": ["foo", "bar"] attribute should fail
         userinfo = {
@@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": ["foo", "bar"],
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": False attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": False,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": None attribute should fail
         # a value of None breaks the OIDC spec, but it's important to not crash here
@@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": None,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": 1 attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": 1,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": 3.14 attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": 3.14,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
     def _generate_oidc_session_token(
         self,
@@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return self.handler._macaroon_generator.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
-                idp_id="oidc",
+                idp_id=self.provider.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url,
                 ui_auth_session_id=ui_auth_session_id,
@@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
 
-async def _make_callback_with_userinfo(
-    hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
-) -> None:
-    """Mock up an OIDC callback with the given userinfo dict
-
-    We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
-    and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
-
-    Args:
-        hs: the HomeServer impl to send the callback to.
-        userinfo: the OIDC userinfo dict
-        client_redirect_url: the URL to redirect to on success.
-    """
-
-    handler = hs.get_oidc_handler()
-    provider = handler._providers["oidc"]
-    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})  # type: ignore[assignment]
-    provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-
-    state = "state"
-    session = handler._macaroon_generator.generate_oidc_session_token(
-        state=state,
-        session_data=OidcSessionData(
-            idp_id="oidc",
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id="",
-        ),
-    )
-    request = _build_callback_request("code", state, session)
-
-    await handler.handle_oidc_callback(request)
-
-
 def _build_callback_request(
     code: str,
     state: str,
diff --git a/tests/rest/client/test_auth.py b/tests/rest/client/test_auth.py
index 090cef5216..ebf653d018 100644
--- a/tests/rest/client/test_auth.py
+++ b/tests/rest/client/test_auth.py
@@ -465,9 +465,11 @@ class UIAuthTests(unittest.HomeserverTestCase):
           * checking that the original operation succeeds
         """
 
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # log the user in
         remote_user_id = UserID.from_string(self.user).localpart
-        login_resp = self.helper.login_via_oidc(remote_user_id)
+        login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, remote_user_id)
         self.assertEqual(login_resp["user_id"], self.user)
 
         # initiate a UI Auth process by attempting to delete the device
@@ -481,8 +483,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
 
         # run the UIA-via-SSO flow
         session_id = channel.json_body["session"]
-        channel = self.helper.auth_via_oidc(
-            {"sub": remote_user_id}, ui_auth_session_id=session_id
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server, {"sub": remote_user_id}, ui_auth_session_id=session_id
         )
 
         # that should serve a confirmation page
@@ -499,7 +501,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @skip_unless(HAS_OIDC, "requires OIDC")
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_does_not_offer_password_for_sso_user(self) -> None:
-        login_resp = self.helper.login_via_oidc("username")
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(fake_oidc_server, "username")
         user_tok = login_resp["access_token"]
         device_id = login_resp["device_id"]
 
@@ -522,7 +525,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_offers_both_flows_for_upgraded_user(self) -> None:
         """A user that had a password and then logged in with SSO should get both flows"""
-        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, UserID.from_string(self.user).localpart
+        )
         self.assertEqual(login_resp["user_id"], self.user)
 
         channel = self.delete_device(
@@ -539,8 +545,13 @@ class UIAuthTests(unittest.HomeserverTestCase):
     @override_config({"oidc_config": TEST_OIDC_CONFIG})
     def test_ui_auth_fails_for_incorrect_sso_user(self) -> None:
         """If the user tries to authenticate with the wrong SSO user, they get an error"""
+
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # log the user in
-        login_resp = self.helper.login_via_oidc(UserID.from_string(self.user).localpart)
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, UserID.from_string(self.user).localpart
+        )
         self.assertEqual(login_resp["user_id"], self.user)
 
         # start a UI Auth flow by attempting to delete a device
@@ -553,8 +564,8 @@ class UIAuthTests(unittest.HomeserverTestCase):
         session_id = channel.json_body["session"]
 
         # do the OIDC auth, but auth as the wrong user
-        channel = self.helper.auth_via_oidc(
-            {"sub": "wrong_user"}, ui_auth_session_id=session_id
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server, {"sub": "wrong_user"}, ui_auth_session_id=session_id
         )
 
         # that should return a failure message
@@ -584,7 +595,10 @@ class UIAuthTests(unittest.HomeserverTestCase):
         """Tests that if we register a user via SSO while requiring approval for new
         accounts, we still raise the correct error before logging the user in.
         """
-        login_resp = self.helper.login_via_oidc("username", expected_status=403)
+        fake_oidc_server = self.helper.fake_oidc_server()
+        login_resp, _ = self.helper.login_via_oidc(
+            fake_oidc_server, "username", expected_status=403
+        )
 
         self.assertEqual(login_resp["errcode"], Codes.USER_AWAITING_APPROVAL)
         self.assertEqual(
diff --git a/tests/rest/client/test_login.py b/tests/rest/client/test_login.py
index e801ba8c8b..ff5baa9f0a 100644
--- a/tests/rest/client/test_login.py
+++ b/tests/rest/client/test_login.py
@@ -36,7 +36,7 @@ from synapse.util import Clock
 from tests import unittest
 from tests.handlers.test_oidc import HAS_OIDC
 from tests.handlers.test_saml import has_saml2
-from tests.rest.client.utils import TEST_OIDC_AUTH_ENDPOINT, TEST_OIDC_CONFIG
+from tests.rest.client.utils import TEST_OIDC_CONFIG
 from tests.server import FakeChannel
 from tests.test_utils.html_parsers import TestHtmlParser
 from tests.unittest import HomeserverTestCase, override_config, skip_unless
@@ -612,13 +612,16 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
     def test_login_via_oidc(self) -> None:
         """If OIDC is chosen, should redirect to the OIDC auth endpoint"""
 
-        # pick the default OIDC provider
-        channel = self.make_request(
-            "GET",
-            "/_synapse/client/pick_idp?redirectUrl="
-            + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
-            + "&idp=oidc",
-        )
+        fake_oidc_server = self.helper.fake_oidc_server()
+
+        with fake_oidc_server.patch_homeserver(hs=self.hs):
+            # pick the default OIDC provider
+            channel = self.make_request(
+                "GET",
+                "/_synapse/client/pick_idp?redirectUrl="
+                + urllib.parse.quote_plus(TEST_CLIENT_REDIRECT_URL)
+                + "&idp=oidc",
+            )
         self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
@@ -626,7 +629,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
 
         # it should redirect us to the auth page of the OIDC server
-        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+        self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
 
         # ... and should have set a cookie including the redirect url
         cookie_headers = channel.headers.getRawHeaders("Set-Cookie")
@@ -643,7 +646,9 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
             TEST_CLIENT_REDIRECT_URL,
         )
 
-        channel = self.helper.complete_oidc_auth(oidc_uri, cookies, {"sub": "user1"})
+        channel, _ = self.helper.complete_oidc_auth(
+            fake_oidc_server, oidc_uri, cookies, {"sub": "user1"}
+        )
 
         # that should serve a confirmation page
         self.assertEqual(channel.code, 200, channel.result)
@@ -693,7 +698,10 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
 
     def test_client_idp_redirect_to_oidc(self) -> None:
         """If the client pick a known IdP, redirect to it"""
-        channel = self._make_sso_redirect_request("oidc")
+        fake_oidc_server = self.helper.fake_oidc_server()
+
+        with fake_oidc_server.patch_homeserver(hs=self.hs):
+            channel = self._make_sso_redirect_request("oidc")
         self.assertEqual(channel.code, 302, channel.result)
         location_headers = channel.headers.getRawHeaders("Location")
         assert location_headers
@@ -701,7 +709,7 @@ class MultiSSOTestCase(unittest.HomeserverTestCase):
         oidc_uri_path, oidc_uri_query = oidc_uri.split("?", 1)
 
         # it should redirect us to the auth page of the OIDC server
-        self.assertEqual(oidc_uri_path, TEST_OIDC_AUTH_ENDPOINT)
+        self.assertEqual(oidc_uri_path, fake_oidc_server.authorization_endpoint)
 
     def _make_sso_redirect_request(self, idp_prov: Optional[str] = None) -> FakeChannel:
         """Send a request to /_matrix/client/r0/login/sso/redirect
@@ -1280,9 +1288,13 @@ class UsernamePickerTestCase(HomeserverTestCase):
     def test_username_picker(self) -> None:
         """Test the happy path of a username picker flow."""
 
+        fake_oidc_server = self.helper.fake_oidc_server()
+
         # do the start of the login flow
-        channel = self.helper.auth_via_oidc(
-            {"sub": "tester", "displayname": "Jonny"}, TEST_CLIENT_REDIRECT_URL
+        channel, _ = self.helper.auth_via_oidc(
+            fake_oidc_server,
+            {"sub": "tester", "displayname": "Jonny"},
+            TEST_CLIENT_REDIRECT_URL,
         )
 
         # that should redirect to the username picker
diff --git a/tests/rest/client/utils.py b/tests/rest/client/utils.py
index c249a42bb6..967d229223 100644
--- a/tests/rest/client/utils.py
+++ b/tests/rest/client/utils.py
@@ -31,7 +31,6 @@ from typing import (
     Tuple,
     overload,
 )
-from unittest.mock import patch
 from urllib.parse import urlencode
 
 import attr
@@ -46,8 +45,19 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 
 from tests.server import FakeChannel, FakeSite, make_request
-from tests.test_utils import FakeResponse
 from tests.test_utils.html_parsers import TestHtmlParser
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
+
+# an 'oidc_config' suitable for login_via_oidc.
+TEST_OIDC_ISSUER = "https://issuer.test/"
+TEST_OIDC_CONFIG = {
+    "enabled": True,
+    "issuer": TEST_OIDC_ISSUER,
+    "client_id": "test-client-id",
+    "client_secret": "test-client-secret",
+    "scopes": ["openid"],
+    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
+}
 
 
 @attr.s(auto_attribs=True)
@@ -543,12 +553,28 @@ class RestHelper:
 
         return channel.json_body
 
+    def fake_oidc_server(self, issuer: str = TEST_OIDC_ISSUER) -> FakeOidcServer:
+        """Create a ``FakeOidcServer``.
+
+        This can be used in conjuction with ``login_via_oidc``::
+
+            fake_oidc_server = self.helper.fake_oidc_server()
+            login_data, _ = self.helper.login_via_oidc(fake_oidc_server, "user")
+        """
+
+        return FakeOidcServer(
+            clock=self.hs.get_clock(),
+            issuer=issuer,
+        )
+
     def login_via_oidc(
         self,
+        fake_server: FakeOidcServer,
         remote_user_id: str,
+        with_sid: bool = False,
         expected_status: int = 200,
-    ) -> JsonDict:
-        """Log in via OIDC
+    ) -> Tuple[JsonDict, FakeAuthorizationGrant]:
+        """Log in (as a new user) via OIDC
 
         Returns the result of the final token login.
 
@@ -560,7 +586,10 @@ class RestHelper:
         the normal places.
         """
         client_redirect_url = "https://x"
-        channel = self.auth_via_oidc({"sub": remote_user_id}, client_redirect_url)
+        userinfo = {"sub": remote_user_id}
+        channel, grant = self.auth_via_oidc(
+            fake_server, userinfo, client_redirect_url, with_sid=with_sid
+        )
 
         # expect a confirmation page
         assert channel.code == HTTPStatus.OK, channel.result
@@ -585,14 +614,16 @@ class RestHelper:
         assert (
             channel.code == expected_status
         ), f"unexpected status in response: {channel.code}"
-        return channel.json_body
+        return channel.json_body, grant
 
     def auth_via_oidc(
         self,
+        fake_server: FakeOidcServer,
         user_info_dict: JsonDict,
         client_redirect_url: Optional[str] = None,
         ui_auth_session_id: Optional[str] = None,
-    ) -> FakeChannel:
+        with_sid: bool = False,
+    ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
         """Perform an OIDC authentication flow via a mock OIDC provider.
 
         This can be used for either login or user-interactive auth.
@@ -616,6 +647,7 @@ class RestHelper:
                 the login redirect endpoint
             ui_auth_session_id: if set, we will perform a UI Auth flow. The session id
                 of the UI auth.
+            with_sid: if True, generates a random `sid` (OIDC session ID)
 
         Returns:
             A FakeChannel containing the result of calling the OIDC callback endpoint.
@@ -625,14 +657,15 @@ class RestHelper:
 
         cookies: Dict[str, str] = {}
 
-        # if we're doing a ui auth, hit the ui auth redirect endpoint
-        if ui_auth_session_id:
-            # can't set the client redirect url for UI Auth
-            assert client_redirect_url is None
-            oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
-        else:
-            # otherwise, hit the login redirect endpoint
-            oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
+        with fake_server.patch_homeserver(hs=self.hs):
+            # if we're doing a ui auth, hit the ui auth redirect endpoint
+            if ui_auth_session_id:
+                # can't set the client redirect url for UI Auth
+                assert client_redirect_url is None
+                oauth_uri = self.initiate_sso_ui_auth(ui_auth_session_id, cookies)
+            else:
+                # otherwise, hit the login redirect endpoint
+                oauth_uri = self.initiate_sso_login(client_redirect_url, cookies)
 
         # we now have a URI for the OIDC IdP, but we skip that and go straight
         # back to synapse's OIDC callback resource. However, we do need the "state"
@@ -640,17 +673,21 @@ class RestHelper:
         # that synapse passes to the client.
 
         oauth_uri_path, _ = oauth_uri.split("?", 1)
-        assert oauth_uri_path == TEST_OIDC_AUTH_ENDPOINT, (
+        assert oauth_uri_path == fake_server.authorization_endpoint, (
             "unexpected SSO URI " + oauth_uri_path
         )
-        return self.complete_oidc_auth(oauth_uri, cookies, user_info_dict)
+        return self.complete_oidc_auth(
+            fake_server, oauth_uri, cookies, user_info_dict, with_sid=with_sid
+        )
 
     def complete_oidc_auth(
         self,
+        fake_serer: FakeOidcServer,
         oauth_uri: str,
         cookies: Mapping[str, str],
         user_info_dict: JsonDict,
-    ) -> FakeChannel:
+        with_sid: bool = False,
+    ) -> Tuple[FakeChannel, FakeAuthorizationGrant]:
         """Mock out an OIDC authentication flow
 
         Assumes that an OIDC auth has been initiated by one of initiate_sso_login or
@@ -661,50 +698,37 @@ class RestHelper:
         Requires the OIDC callback resource to be mounted at the normal place.
 
         Args:
+            fake_server: the fake OIDC server with which the auth should be done
             oauth_uri: the OIDC URI returned by synapse's redirect endpoint (ie,
                from initiate_sso_login or initiate_sso_ui_auth).
             cookies: the cookies set by synapse's redirect endpoint, which will be
                sent back to the callback endpoint.
             user_info_dict: the remote userinfo that the OIDC provider should present.
                 Typically this should be '{"sub": "<remote user id>"}'.
+            with_sid: if True, generates a random `sid` (OIDC session ID)
 
         Returns:
             A FakeChannel containing the result of calling the OIDC callback endpoint.
         """
         _, oauth_uri_qs = oauth_uri.split("?", 1)
         params = urllib.parse.parse_qs(oauth_uri_qs)
+
+        code, grant = fake_serer.start_authorization(
+            scope=params["scope"][0],
+            userinfo=user_info_dict,
+            client_id=params["client_id"][0],
+            redirect_uri=params["redirect_uri"][0],
+            nonce=params["nonce"][0],
+            with_sid=with_sid,
+        )
+        state = params["state"][0]
+
         callback_uri = "%s?%s" % (
             urllib.parse.urlparse(params["redirect_uri"][0]).path,
-            urllib.parse.urlencode({"state": params["state"][0], "code": "TEST_CODE"}),
+            urllib.parse.urlencode({"state": state, "code": code}),
         )
 
-        # before we hit the callback uri, stub out some methods in the http client so
-        # that we don't have to handle full HTTPS requests.
-        # (expected url, json response) pairs, in the order we expect them.
-        expected_requests = [
-            # first we get a hit to the token endpoint, which we tell to return
-            # a dummy OIDC access token
-            (TEST_OIDC_TOKEN_ENDPOINT, {"access_token": "TEST"}),
-            # and then one to the user_info endpoint, which returns our remote user id.
-            (TEST_OIDC_USERINFO_ENDPOINT, user_info_dict),
-        ]
-
-        async def mock_req(
-            method: str,
-            uri: str,
-            data: Optional[dict] = None,
-            headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
-        ):
-            (expected_uri, resp_obj) = expected_requests.pop(0)
-            assert uri == expected_uri
-            resp = FakeResponse(
-                code=HTTPStatus.OK,
-                phrase=b"OK",
-                body=json.dumps(resp_obj).encode("utf-8"),
-            )
-            return resp
-
-        with patch.object(self.hs.get_proxied_http_client(), "request", mock_req):
+        with fake_serer.patch_homeserver(hs=self.hs):
             # now hit the callback URI with the right params and a made-up code
             channel = make_request(
                 self.hs.get_reactor(),
@@ -715,7 +739,7 @@ class RestHelper:
                     ("Cookie", "%s=%s" % (k, v)) for (k, v) in cookies.items()
                 ],
             )
-        return channel
+        return channel, grant
 
     def initiate_sso_login(
         self, client_redirect_url: Optional[str], cookies: MutableMapping[str, str]
@@ -806,21 +830,3 @@ class RestHelper:
         assert len(p.links) == 1, "not exactly one link in confirmation page"
         oauth_uri = p.links[0]
         return oauth_uri
-
-
-# an 'oidc_config' suitable for login_via_oidc.
-TEST_OIDC_AUTH_ENDPOINT = "https://issuer.test/auth"
-TEST_OIDC_TOKEN_ENDPOINT = "https://issuer.test/token"
-TEST_OIDC_USERINFO_ENDPOINT = "https://issuer.test/userinfo"
-TEST_OIDC_CONFIG = {
-    "enabled": True,
-    "discover": False,
-    "issuer": "https://issuer.test",
-    "client_id": "test-client-id",
-    "client_secret": "test-client-secret",
-    "scopes": ["profile"],
-    "authorization_endpoint": TEST_OIDC_AUTH_ENDPOINT,
-    "token_endpoint": TEST_OIDC_TOKEN_ENDPOINT,
-    "userinfo_endpoint": TEST_OIDC_USERINFO_ENDPOINT,
-    "user_mapping_provider": {"config": {"localpart_template": "{{ user.sub }}"}},
-}
diff --git a/tests/test_utils/__init__.py b/tests/test_utils/__init__.py
index 0d0d6faf0d..e62ebcc6a5 100644
--- a/tests/test_utils/__init__.py
+++ b/tests/test_utils/__init__.py
@@ -15,17 +15,24 @@
 """
 Utilities for running the unit tests
 """
+import json
 import sys
 import warnings
 from asyncio import Future
 from binascii import unhexlify
-from typing import Awaitable, Callable, TypeVar
+from typing import Awaitable, Callable, Tuple, TypeVar
 from unittest.mock import Mock
 
 import attr
+import zope.interface
 
 from twisted.python.failure import Failure
 from twisted.web.client import ResponseDone
+from twisted.web.http import RESPONSES
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
+
+from synapse.types import JsonDict
 
 TV = TypeVar("TV")
 
@@ -97,27 +104,44 @@ def simple_async_mock(return_value=None, raises=None) -> Mock:
     return Mock(side_effect=cb)
 
 
-@attr.s
-class FakeResponse:
+# Type ignore: it does not fully implement IResponse, but is good enough for tests
+@zope.interface.implementer(IResponse)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FakeResponse:  # type: ignore[misc]
     """A fake twisted.web.IResponse object
 
     there is a similar class at treq.test.test_response, but it lacks a `phrase`
     attribute, and didn't support deliverBody until recently.
     """
 
-    # HTTP response code
-    code = attr.ib(type=int)
+    version: Tuple[bytes, int, int] = (b"HTTP", 1, 1)
 
-    # HTTP response phrase (eg b'OK' for a 200)
-    phrase = attr.ib(type=bytes)
+    # HTTP response code
+    code: int = 200
 
     # body of the response
-    body = attr.ib(type=bytes)
+    body: bytes = b""
+
+    headers: Headers = attr.Factory(Headers)
+
+    @property
+    def phrase(self):
+        return RESPONSES.get(self.code, b"Unknown Status")
+
+    @property
+    def length(self):
+        return len(self.body)
 
     def deliverBody(self, protocol):
         protocol.dataReceived(self.body)
         protocol.connectionLost(Failure(ResponseDone()))
 
+    @classmethod
+    def json(cls, *, code: int = 200, payload: JsonDict) -> "FakeResponse":
+        headers = Headers({"Content-Type": ["application/json"]})
+        body = json.dumps(payload).encode("utf-8")
+        return cls(code=code, body=body, headers=headers)
+
 
 # A small image used in some tests.
 #
diff --git a/tests/test_utils/oidc.py b/tests/test_utils/oidc.py
new file mode 100644
index 0000000000..de134bbc89
--- /dev/null
+++ b/tests/test_utils/oidc.py
@@ -0,0 +1,325 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import json
+from typing import Any, Dict, List, Optional, Tuple
+from unittest.mock import Mock, patch
+from urllib.parse import parse_qs
+
+import attr
+
+from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
+
+from synapse.server import HomeServer
+from synapse.util import Clock
+from synapse.util.stringutils import random_string
+
+from tests.test_utils import FakeResponse
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FakeAuthorizationGrant:
+    userinfo: dict
+    client_id: str
+    redirect_uri: str
+    scope: str
+    nonce: Optional[str]
+    sid: Optional[str]
+
+
+class FakeOidcServer:
+    """A fake OpenID Connect Provider."""
+
+    # All methods here are mocks, so we can track when they are called, and override
+    # their values
+    request: Mock
+    get_jwks_handler: Mock
+    get_metadata_handler: Mock
+    get_userinfo_handler: Mock
+    post_token_handler: Mock
+
+    def __init__(self, clock: Clock, issuer: str):
+        from authlib.jose import ECKey, KeySet
+
+        self._clock = clock
+        self.issuer = issuer
+
+        self.request = Mock(side_effect=self._request)
+        self.get_jwks_handler = Mock(side_effect=self._get_jwks_handler)
+        self.get_metadata_handler = Mock(side_effect=self._get_metadata_handler)
+        self.get_userinfo_handler = Mock(side_effect=self._get_userinfo_handler)
+        self.post_token_handler = Mock(side_effect=self._post_token_handler)
+
+        # A code -> grant mapping
+        self._authorization_grants: Dict[str, FakeAuthorizationGrant] = {}
+        # An access token -> grant mapping
+        self._sessions: Dict[str, FakeAuthorizationGrant] = {}
+
+        # We generate here an ECDSA key with the P-256 curve (ES256 algorithm) used for
+        # signing JWTs. ECDSA keys are really quick to generate compared to RSA.
+        self._key = ECKey.generate_key(crv="P-256", is_private=True)
+        self._jwks = KeySet([ECKey.import_key(self._key.as_pem(is_private=False))])
+
+        self._id_token_overrides: Dict[str, Any] = {}
+
+    def reset_mocks(self):
+        self.request.reset_mock()
+        self.get_jwks_handler.reset_mock()
+        self.get_metadata_handler.reset_mock()
+        self.get_userinfo_handler.reset_mock()
+        self.post_token_handler.reset_mock()
+
+    def patch_homeserver(self, hs: HomeServer):
+        """Patch the ``HomeServer`` HTTP client to handle requests through the ``FakeOidcServer``.
+
+        This patch should be used whenever the HS is expected to perform request to the
+        OIDC provider, e.g.::
+
+            fake_oidc_server = self.helper.fake_oidc_server()
+            with fake_oidc_server.patch_homeserver(hs):
+                self.make_request("GET", "/_matrix/client/r0/login/sso/redirect")
+        """
+        return patch.object(hs.get_proxied_http_client(), "request", self.request)
+
+    @property
+    def authorization_endpoint(self) -> str:
+        return self.issuer + "authorize"
+
+    @property
+    def token_endpoint(self) -> str:
+        return self.issuer + "token"
+
+    @property
+    def userinfo_endpoint(self) -> str:
+        return self.issuer + "userinfo"
+
+    @property
+    def metadata_endpoint(self) -> str:
+        return self.issuer + ".well-known/openid-configuration"
+
+    @property
+    def jwks_uri(self) -> str:
+        return self.issuer + "jwks"
+
+    def get_metadata(self) -> dict:
+        return {
+            "issuer": self.issuer,
+            "authorization_endpoint": self.authorization_endpoint,
+            "token_endpoint": self.token_endpoint,
+            "jwks_uri": self.jwks_uri,
+            "userinfo_endpoint": self.userinfo_endpoint,
+            "response_types_supported": ["code"],
+            "subject_types_supported": ["public"],
+            "id_token_signing_alg_values_supported": ["ES256"],
+        }
+
+    def get_jwks(self) -> dict:
+        return self._jwks.as_dict()
+
+    def get_userinfo(self, access_token: str) -> Optional[dict]:
+        """Given an access token, get the userinfo of the associated session."""
+        session = self._sessions.get(access_token, None)
+        if session is None:
+            return None
+        return session.userinfo
+
+    def _sign(self, payload: dict) -> str:
+        from authlib.jose import JsonWebSignature
+
+        jws = JsonWebSignature()
+        kid = self.get_jwks()["keys"][0]["kid"]
+        protected = {"alg": "ES256", "kid": kid}
+        json_payload = json.dumps(payload)
+        return jws.serialize_compact(protected, json_payload, self._key).decode("utf-8")
+
+    def generate_id_token(self, grant: FakeAuthorizationGrant) -> str:
+        now = self._clock.time()
+        id_token = {
+            **grant.userinfo,
+            "iss": self.issuer,
+            "aud": grant.client_id,
+            "iat": now,
+            "nbf": now,
+            "exp": now + 600,
+        }
+
+        if grant.nonce is not None:
+            id_token["nonce"] = grant.nonce
+
+        if grant.sid is not None:
+            id_token["sid"] = grant.sid
+
+        id_token.update(self._id_token_overrides)
+
+        return self._sign(id_token)
+
+    def id_token_override(self, overrides: dict):
+        """Temporarily patch the ID token generated by the token endpoint."""
+        return patch.object(self, "_id_token_overrides", overrides)
+
+    def start_authorization(
+        self,
+        client_id: str,
+        scope: str,
+        redirect_uri: str,
+        userinfo: dict,
+        nonce: Optional[str] = None,
+        with_sid: bool = False,
+    ) -> Tuple[str, FakeAuthorizationGrant]:
+        """Start an authorization request, and get back the code to use on the authorization endpoint."""
+        code = random_string(10)
+        sid = None
+        if with_sid:
+            sid = random_string(10)
+
+        grant = FakeAuthorizationGrant(
+            userinfo=userinfo,
+            scope=scope,
+            redirect_uri=redirect_uri,
+            nonce=nonce,
+            client_id=client_id,
+            sid=sid,
+        )
+        self._authorization_grants[code] = grant
+
+        return code, grant
+
+    def exchange_code(self, code: str) -> Optional[Dict[str, Any]]:
+        grant = self._authorization_grants.pop(code, None)
+        if grant is None:
+            return None
+
+        access_token = random_string(10)
+        self._sessions[access_token] = grant
+
+        token = {
+            "token_type": "Bearer",
+            "access_token": access_token,
+            "expires_in": 3600,
+            "scope": grant.scope,
+        }
+
+        if "openid" in grant.scope:
+            token["id_token"] = self.generate_id_token(grant)
+
+        return dict(token)
+
+    def buggy_endpoint(
+        self,
+        *,
+        jwks: bool = False,
+        metadata: bool = False,
+        token: bool = False,
+        userinfo: bool = False,
+    ):
+        """A context which makes a set of endpoints return a 500 error.
+
+        Args:
+            jwks: If True, makes the JWKS endpoint return a 500 error.
+            metadata: If True, makes the OIDC Discovery endpoint return a 500 error.
+            token: If True, makes the token endpoint return a 500 error.
+            userinfo: If True, makes the userinfo endpoint return a 500 error.
+        """
+        buggy = FakeResponse(code=500, body=b"Internal server error")
+
+        patches = {}
+        if jwks:
+            patches["get_jwks_handler"] = Mock(return_value=buggy)
+        if metadata:
+            patches["get_metadata_handler"] = Mock(return_value=buggy)
+        if token:
+            patches["post_token_handler"] = Mock(return_value=buggy)
+        if userinfo:
+            patches["get_userinfo_handler"] = Mock(return_value=buggy)
+
+        return patch.multiple(self, **patches)
+
+    async def _request(
+        self,
+        method: str,
+        uri: str,
+        data: Optional[bytes] = None,
+        headers: Optional[Headers] = None,
+    ) -> IResponse:
+        """The override of the SimpleHttpClient#request() method"""
+        access_token: Optional[str] = None
+
+        if headers is None:
+            headers = Headers()
+
+        # Try to find the access token in the headers if any
+        auth_headers = headers.getRawHeaders(b"Authorization")
+        if auth_headers:
+            parts = auth_headers[0].split(b" ")
+            if parts[0] == b"Bearer" and len(parts) == 2:
+                access_token = parts[1].decode("ascii")
+
+        if method == "POST":
+            # If the method is POST, assume it has an url-encoded body
+            if data is None or headers.getRawHeaders(b"Content-Type") != [
+                b"application/x-www-form-urlencoded"
+            ]:
+                return FakeResponse.json(code=400, payload={"error": "invalid_request"})
+
+            params = parse_qs(data.decode("utf-8"))
+
+            if uri == self.token_endpoint:
+                # Even though this endpoint should be protected, this does not check
+                # for client authentication. We're not checking it for simplicity,
+                # and because client authentication is tested in other standalone tests.
+                return self.post_token_handler(params)
+
+        elif method == "GET":
+            if uri == self.jwks_uri:
+                return self.get_jwks_handler()
+            elif uri == self.metadata_endpoint:
+                return self.get_metadata_handler()
+            elif uri == self.userinfo_endpoint:
+                return self.get_userinfo_handler(access_token=access_token)
+
+        return FakeResponse(code=404, body=b"404 not found")
+
+    # Request handlers
+    def _get_jwks_handler(self) -> IResponse:
+        """Handles requests to the JWKS URI."""
+        return FakeResponse.json(payload=self.get_jwks())
+
+    def _get_metadata_handler(self) -> IResponse:
+        """Handles requests to the OIDC well-known document."""
+        return FakeResponse.json(payload=self.get_metadata())
+
+    def _get_userinfo_handler(self, access_token: Optional[str]) -> IResponse:
+        """Handles requests to the userinfo endpoint."""
+        if access_token is None:
+            return FakeResponse(code=401)
+        user_info = self.get_userinfo(access_token)
+        if user_info is None:
+            return FakeResponse(code=401)
+
+        return FakeResponse.json(payload=user_info)
+
+    def _post_token_handler(self, params: Dict[str, List[str]]) -> IResponse:
+        """Handles requests to the token endpoint."""
+        code = params.get("code", [])
+
+        if len(code) != 1:
+            return FakeResponse.json(code=400, payload={"error": "invalid_request"})
+
+        grant = self.exchange_code(code=code[0])
+        if grant is None:
+            return FakeResponse.json(code=400, payload={"error": "invalid_grant"})
+
+        return FakeResponse.json(payload=grant)