summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
authorQuentin Gliech <quenting@element.io>2022-10-25 16:25:02 +0200
committerGitHub <noreply@github.com>2022-10-25 14:25:02 +0000
commit9192d74b0bf2f87b00d3e106a18baa9ce27acda1 (patch)
tree08bc76abec65c3124686f19f03849e6ccb12c820 /tests/handlers
parentImplementation for MSC3664: Pushrules for relations (#11804) (diff)
downloadsynapse-9192d74b0bf2f87b00d3e106a18baa9ce27acda1.tar.xz
Refactor OIDC tests to better mimic an actual OIDC provider. (#13910)
This implements a fake OIDC server, which intercepts calls to the HTTP client.
Improves accuracy of tests by covering more internal methods.

One particular example was the ID token validation, which previously mocked.

This uncovered an incorrect dependency: Synapse actually requires at least
authlib 0.15.1, not 0.14.0.
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_oidc.py580
1 files changed, 250 insertions, 330 deletions
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index e6cd3af7b7..5955410524 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -11,9 +11,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-import json
 import os
-from typing import Any, Dict
+from typing import Any, Dict, Tuple
 from unittest.mock import ANY, Mock, patch
 from urllib.parse import parse_qs, urlparse
 
@@ -22,12 +21,15 @@ import pymacaroons
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.handlers.sso import MappingException
+from synapse.http.site import SynapseRequest
 from synapse.server import HomeServer
-from synapse.types import JsonDict, UserID
+from synapse.types import UserID
 from synapse.util import Clock
-from synapse.util.macaroons import OidcSessionData, get_value_from_macaroon
+from synapse.util.macaroons import get_value_from_macaroon
+from synapse.util.stringutils import random_string
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
+from tests.test_utils.oidc import FakeAuthorizationGrant, FakeOidcServer
 from tests.unittest import HomeserverTestCase, override_config
 
 try:
@@ -46,12 +48,6 @@ BASE_URL = "https://synapse/"
 CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
 SCOPES = ["openid"]
 
-AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
-TOKEN_ENDPOINT = ISSUER + "token"
-USERINFO_ENDPOINT = ISSUER + "userinfo"
-WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
-JWKS_URI = ISSUER + ".well-known/jwks.json"
-
 # config for common cases
 DEFAULT_CONFIG = {
     "enabled": True,
@@ -66,9 +62,9 @@ DEFAULT_CONFIG = {
 EXPLICIT_ENDPOINT_CONFIG = {
     **DEFAULT_CONFIG,
     "discover": False,
-    "authorization_endpoint": AUTHORIZATION_ENDPOINT,
-    "token_endpoint": TOKEN_ENDPOINT,
-    "jwks_uri": JWKS_URI,
+    "authorization_endpoint": ISSUER + "authorize",
+    "token_endpoint": ISSUER + "token",
+    "jwks_uri": ISSUER + "jwks",
 }
 
 
@@ -102,27 +98,6 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-async def get_json(url: str) -> JsonDict:
-    # Mock get_json calls to handle jwks & oidc discovery endpoints
-    if url == WELL_KNOWN:
-        # Minimal discovery document, as defined in OpenID.Discovery
-        # https://openid.net/specs/openid-connect-discovery-1_0.html#ProviderMetadata
-        return {
-            "issuer": ISSUER,
-            "authorization_endpoint": AUTHORIZATION_ENDPOINT,
-            "token_endpoint": TOKEN_ENDPOINT,
-            "jwks_uri": JWKS_URI,
-            "userinfo_endpoint": USERINFO_ENDPOINT,
-            "response_types_supported": ["code"],
-            "subject_types_supported": ["public"],
-            "id_token_signing_alg_values_supported": ["RS256"],
-        }
-    elif url == JWKS_URI:
-        return {"keys": []}
-
-    return {}
-
-
 def _key_file_path() -> str:
     """path to a file containing the private half of a test key"""
 
@@ -159,11 +134,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return config
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        self.http_client = Mock(spec=["get_json"])
-        self.http_client.get_json.side_effect = get_json
-        self.http_client.user_agent = b"Synapse Test"
+        self.fake_server = FakeOidcServer(clock=clock, issuer=ISSUER)
 
-        hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
+        hs = self.setup_test_homeserver()
+        self.hs_patcher = self.fake_server.patch_homeserver(hs=hs)
+        self.hs_patcher.start()
 
         self.handler = hs.get_oidc_handler()
         self.provider = self.handler._providers["oidc"]
@@ -175,18 +150,51 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # Reduce the number of attempts when generating MXIDs.
         sso_handler._MAP_USERNAME_RETRIES = 3
 
+        auth_handler = hs.get_auth_handler()
+        # Mock the complete SSO login method.
+        self.complete_sso_login = simple_async_mock()
+        auth_handler.complete_sso_login = self.complete_sso_login  # type: ignore[assignment]
+
         return hs
 
+    def tearDown(self) -> None:
+        self.hs_patcher.stop()
+        return super().tearDown()
+
+    def reset_mocks(self):
+        """Reset all the Mocks."""
+        self.fake_server.reset_mocks()
+        self.render_error.reset_mock()
+        self.complete_sso_login.reset_mock()
+
     def metadata_edit(self, values):
         """Modify the result that will be returned by the well-known query"""
 
-        async def patched_get_json(uri):
-            res = await get_json(uri)
-            if uri == WELL_KNOWN:
-                res.update(values)
-            return res
+        metadata = self.fake_server.get_metadata()
+        metadata.update(values)
+        return patch.object(self.fake_server, "get_metadata", return_value=metadata)
 
-        return patch.object(self.http_client, "get_json", patched_get_json)
+    def start_authorization(
+        self,
+        userinfo: dict,
+        client_redirect_url: str = "http://client/redirect",
+        scope: str = "openid",
+        with_sid: bool = False,
+    ) -> Tuple[SynapseRequest, FakeAuthorizationGrant]:
+        """Start an authorization request, and get the callback request back."""
+        nonce = random_string(10)
+        state = random_string(10)
+
+        code, grant = self.fake_server.start_authorization(
+            userinfo=userinfo,
+            scope=scope,
+            client_id=self.provider._client_auth.client_id,
+            redirect_uri=self.provider._callback_url,
+            nonce=nonce,
+            with_sid=with_sid,
+        )
+        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+        return _build_callback_request(code, state, session), grant
 
     def assertRenderedError(self, error, error_description=None):
         self.render_error.assert_called_once()
@@ -210,52 +218,54 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
         metadata = self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
+        self.fake_server.get_metadata_handler.assert_called_once()
 
-        self.assertEqual(metadata.issuer, ISSUER)
-        self.assertEqual(metadata.authorization_endpoint, AUTHORIZATION_ENDPOINT)
-        self.assertEqual(metadata.token_endpoint, TOKEN_ENDPOINT)
-        self.assertEqual(metadata.jwks_uri, JWKS_URI)
-        # FIXME: it seems like authlib does not have that defined in its metadata models
-        # self.assertEqual(metadata.userinfo_endpoint, USERINFO_ENDPOINT)
+        self.assertEqual(metadata.issuer, self.fake_server.issuer)
+        self.assertEqual(
+            metadata.authorization_endpoint,
+            self.fake_server.authorization_endpoint,
+        )
+        self.assertEqual(metadata.token_endpoint, self.fake_server.token_endpoint)
+        self.assertEqual(metadata.jwks_uri, self.fake_server.jwks_uri)
+        # It seems like authlib does not have that defined in its metadata models
+        self.assertEqual(
+            metadata.get("userinfo_endpoint"),
+            self.fake_server.userinfo_endpoint,
+        )
 
         # subsequent calls should be cached
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_metadata_handler.assert_not_called()
 
     @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     def test_no_discovery(self) -> None:
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_metadata_handler.assert_not_called()
 
-    @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_load_jwks(self) -> None:
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
-        self.http_client.get_json.assert_called_once_with(JWKS_URI)
-        self.assertEqual(jwks, {"keys": []})
+        self.fake_server.get_jwks_handler.assert_called_once()
+        self.assertEqual(jwks, self.fake_server.get_jwks())
 
         # subsequent calls should be cached…
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_jwks())
-        self.http_client.get_json.assert_not_called()
+        self.fake_server.get_jwks_handler.assert_not_called()
 
         # …unless forced
-        self.http_client.reset_mock()
+        self.reset_mocks()
         self.get_success(self.provider.load_jwks(force=True))
-        self.http_client.get_json.assert_called_once_with(JWKS_URI)
+        self.fake_server.get_jwks_handler.assert_called_once()
 
-        # Throw if the JWKS uri is missing
-        original = self.provider.load_metadata
-
-        async def patched_load_metadata():
-            m = (await original()).copy()
-            m.update({"jwks_uri": None})
-            return m
-
-        with patch.object(self.provider, "load_metadata", patched_load_metadata):
+        with self.metadata_edit({"jwks_uri": None}):
+            # If we don't do this, the load_metadata call will throw because of the
+            # missing jwks_uri
+            self.provider._user_profile_method = "userinfo_endpoint"
+            self.get_success(self.provider.load_metadata(force=True))
             self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
@@ -359,7 +369,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 self.provider.handle_redirect_request(req, b"http://client/redirect")
             )
         )
-        auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
+        auth_endpoint = urlparse(self.fake_server.authorization_endpoint)
 
         self.assertEqual(url.scheme, auth_endpoint.scheme)
         self.assertEqual(url.netloc, auth_endpoint.netloc)
@@ -424,48 +434,34 @@ class OidcHandlerTestCase(HomeserverTestCase):
         with self.assertRaises(AttributeError):
             _ = mapping_provider.get_extra_attributes
 
-        token = {
-            "type": "bearer",
-            "id_token": "id_token",
-            "access_token": "access_token",
-        }
         username = "bar"
         userinfo = {
             "sub": "foo",
             "username": username,
         }
         expected_user_id = "@%s:%s" % (username, self.hs.hostname)
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
 
-        code = "code"
-        state = "state"
-        nonce = "nonce"
         client_redirect_url = "http://client/redirect"
-        ip_address = "10.0.0.1"
-        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
-        request = _build_callback_request(code, state, session, ip_address=ip_address)
-
+        request, _ = self.start_authorization(
+            userinfo, client_redirect_url=client_redirect_url
+        )
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
             client_redirect_url,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.provider._fetch_userinfo.assert_not_called()
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
+        request, _ = self.start_authorization(userinfo)
         with patch.object(
             self.provider,
             "_remote_id_from_userinfo",
@@ -475,81 +471,63 @@ class OidcHandlerTestCase(HomeserverTestCase):
             self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.provider._parse_id_token = simple_async_mock(raises=Exception())  # type: ignore[assignment]
-        self.get_success(self.handler.handle_oidc_callback(request))
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.id_token_override({"iss": "https://bad.issuer/"}):
+            self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
-        auth_handler.complete_sso_login.reset_mock()
-        self.provider._exchange_code.reset_mock()
-        self.provider._parse_id_token.reset_mock()
-        self.provider._fetch_userinfo.reset_mock()
+        self.reset_mocks()
 
         # With userinfo fetching
         self.provider._user_profile_method = "userinfo_endpoint"
-        token = {
-            "type": "bearer",
-            "access_token": "access_token",
-        }
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
+        # Without the "openid" scope, the FakeProvider does not generate an id_token
+        request, _ = self.start_authorization(userinfo, scope="")
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_not_called()
-        self.provider._fetch_userinfo.assert_called_once_with(token)
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_called_once()
         self.render_error.assert_not_called()
 
+        self.reset_mocks()
+
         # With an ID token, userinfo fetching and sid in the ID token
         self.provider._user_profile_method = "userinfo_endpoint"
-        token = {
-            "type": "bearer",
-            "access_token": "access_token",
-            "id_token": "id_token",
-        }
-        id_token = {
-            "sid": "abcdefgh",
-        }
-        self.provider._parse_id_token = simple_async_mock(return_value=id_token)  # type: ignore[assignment]
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        auth_handler.complete_sso_login.reset_mock()
-        self.provider._fetch_userinfo.reset_mock()
+        request, grant = self.start_authorization(userinfo, with_sid=True)
+        self.assertIsNotNone(grant.sid)
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             expected_user_id,
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             None,
             new_user=False,
-            auth_provider_session_id=id_token["sid"],
+            auth_provider_session_id=grant.sid,
         )
-        self.provider._exchange_code.assert_called_once_with(code)
-        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.provider._fetch_userinfo.assert_called_once_with(token)
+        self.fake_server.post_token_handler.assert_called_once()
+        self.fake_server.get_userinfo_handler.assert_called_once()
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())  # type: ignore[assignment]
-        self.get_success(self.handler.handle_oidc_callback(request))
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.buggy_endpoint(userinfo=True):
+            self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
-        # Handle code exchange failure
-        from synapse.handlers.oidc import OidcError
-
-        self.provider._exchange_code = simple_async_mock(  # type: ignore[assignment]
-            raises=OidcError("invalid_request")
-        )
-        self.get_success(self.handler.handle_oidc_callback(request))
-        self.assertRenderedError("invalid_request")
+        request, _ = self.start_authorization(userinfo)
+        with self.fake_server.buggy_endpoint(token=True):
+            self.get_success(self.handler.handle_oidc_callback(request))
+        self.assertRenderedError("server_error")
 
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_session(self) -> None:
@@ -599,18 +577,22 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_exchange_code(self) -> None:
         """Code exchange behaves correctly and handles various error scenarios."""
-        token = {"type": "bearer"}
-        token_json = json.dumps(token).encode("utf-8")
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
         ret = self.get_success(self.provider._exchange_code(code))
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
 
         self.assertEqual(ret, token)
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         args = parse_qs(kwargs["data"].decode("utf-8"))
         self.assertEqual(args["grant_type"], ["authorization_code"])
@@ -620,12 +602,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
 
         # Test error handling
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=400,
-                phrase=b"Bad Request",
-                body=b'{"error": "foo", "error_description": "bar"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=400, payload={"error": "foo", "error_description": "bar"}
         )
         from synapse.handlers.oidc import OidcError
 
@@ -634,46 +612,30 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(exc.value.error_description, "bar")
 
         # Internal server error with no JSON body
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=500,
-                phrase=b"Internal Server Error",
-                body=b"Not JSON",
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse(
+            code=500, body=b"Not JSON"
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=500,
-                phrase=b"Internal Server Error",
-                body=b'{"error": "internal_server_error"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=500, payload={"error": "internal_server_error"}
         )
 
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=400,
-                phrase=b"Bad request",
-                body=b"{}",
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=400, payload={}
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200,
-                phrase=b"OK",
-                body=b'{"error": "some_error"}',
-            )
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            code=200, payload={"error": "some_error"}
         )
         exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "some_error")
@@ -697,11 +659,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Test that code exchange works with a JWK client secret."""
         from authlib.jose import jwt
 
-        token = {"type": "bearer"}
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
-            )
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
 
@@ -714,9 +679,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(ret, token)
 
         # the request should have hit the token endpoint
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         # the client secret provided to the should be a jwt which can be checked with
         # the public key
@@ -750,11 +715,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_exchange_code_no_auth(self) -> None:
         """Test that code exchange works with no client secret."""
-        token = {"type": "bearer"}
-        self.http_client.request = simple_async_mock(
-            return_value=FakeResponse(
-                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
-            )
+        token = {
+            "type": "Bearer",
+            "access_token": "aabbcc",
+        }
+
+        self.fake_server.post_token_handler.side_effect = None
+        self.fake_server.post_token_handler.return_value = FakeResponse.json(
+            payload=token
         )
         code = "code"
         ret = self.get_success(self.provider._exchange_code(code))
@@ -762,9 +730,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(ret, token)
 
         # the request should have hit the token endpoint
-        kwargs = self.http_client.request.call_args[1]
+        kwargs = self.fake_server.request.call_args[1]
         self.assertEqual(kwargs["method"], "POST")
-        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+        self.assertEqual(kwargs["uri"], self.fake_server.token_endpoint)
 
         # check the POSTed data
         args = parse_qs(kwargs["data"].decode("utf-8"))
@@ -787,37 +755,19 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """
         Login while using a mapping provider that implements get_extra_attributes.
         """
-        token = {
-            "type": "bearer",
-            "id_token": "id_token",
-            "access_token": "access_token",
-        }
         userinfo = {
             "sub": "foo",
             "username": "foo",
             "phone": "1234567",
         }
-        self.provider._exchange_code = simple_async_mock(return_value=token)  # type: ignore[assignment]
-        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
-        state = "state"
-        client_redirect_url = "http://client/redirect"
-        session = self._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-        )
-        request = _build_callback_request("code", state, session)
-
+        request, _ = self.start_authorization(userinfo)
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@foo:test",
-            "oidc",
+            self.provider.idp_id,
             request,
-            client_redirect_url,
+            ANY,
             {"phone": "1234567"},
             new_user=True,
             auth_provider_session_id=None,
@@ -826,41 +776,40 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_user(self) -> None:
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         userinfo: dict = {
             "sub": "test_user",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@test_user:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Some providers return an integer ID.
         userinfo = {
             "sub": 1234,
             "username": "test_user_2",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@test_user_2:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Test if the mxid is already taken
         store = self.hs.get_datastores().main
@@ -869,8 +818,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error",
             "Mapping provider does not support de-duplicating Matrix IDs",
@@ -885,38 +835,37 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user.to_string(), password_hash=None)
         )
 
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         # Map a user via SSO.
         userinfo = {
             "sub": "test",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Subsequent calls should map to the same mxid.
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Note that a second SSO user can be mapped to the same Matrix ID. (This
         # requires a unique sub, but something that maps to the same matrix ID,
@@ -927,17 +876,18 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             user.to_string(),
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Register some non-exact matching cases.
         user2 = UserID.from_string("@TEST_user_2:test")
@@ -954,8 +904,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         args = self.assertRenderedError("mapping_error")
         self.assertTrue(
             args[2].startswith(
@@ -969,11 +920,12 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_called_once_with(
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_called_once_with(
             "@TEST_USER_2:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=False,
@@ -983,9 +935,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_invalid_localpart(self) -> None:
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        self.get_success(
-            _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
-        )
+        userinfo = {"sub": "test2", "username": "föö"}
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
@@ -1000,9 +952,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_map_userinfo_to_user_retries(self) -> None:
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         store = self.hs.get_datastores().main
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
@@ -1011,19 +960,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # test_user is already taken, so test_user1 gets registered instead.
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@test_user1:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
             auth_provider_session_id=None,
         )
-        auth_handler.complete_sso_login.reset_mock()
+        self.reset_mocks()
 
         # Register all of the potential mxids for a particular OIDC username.
         self.get_success(
@@ -1039,8 +989,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
         self.assertRenderedError(
             "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
@@ -1052,7 +1003,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
     @override_config(
@@ -1071,7 +1023,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": None,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("mapping_error", "localpart is invalid: ")
 
     @override_config(
@@ -1084,16 +1037,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_attribute_requirements(self) -> None:
         """The required attributes must be met from the OIDC userinfo response."""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
-
         # userinfo lacking "test": "foobar" attribute should fail.
         userinfo = {
             "sub": "tester",
             "username": "tester",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": "foobar" attribute should succeed.
         userinfo = {
@@ -1101,13 +1052,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": "foobar",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # check that the auth handler got called as expected
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@tester:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
@@ -1124,21 +1076,20 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_attribute_requirements_contains(self) -> None:
         """Test that auth succeeds if userinfo attribute CONTAINS required value"""
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": ["foobar", "foo", "bar"] attribute should succeed.
         userinfo = {
             "sub": "tester",
             "username": "tester",
             "test": ["foobar", "foo", "bar"],
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
 
         # check that the auth handler got called as expected
-        auth_handler.complete_sso_login.assert_called_once_with(
+        self.complete_sso_login.assert_called_once_with(
             "@tester:test",
-            "oidc",
-            ANY,
+            self.provider.idp_id,
+            request,
             ANY,
             None,
             new_user=True,
@@ -1158,16 +1109,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
         Test that auth fails if attributes exist but don't match,
         or are non-string values.
         """
-        auth_handler = self.hs.get_auth_handler()
-        auth_handler.complete_sso_login = simple_async_mock()
         # userinfo with "test": "not_foobar" attribute should fail
         userinfo: dict = {
             "sub": "tester",
             "username": "tester",
             "test": "not_foobar",
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": ["foo", "bar"] attribute should fail
         userinfo = {
@@ -1175,8 +1125,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": ["foo", "bar"],
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": False attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1185,8 +1136,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": False,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": None attribute should fail
         # a value of None breaks the OIDC spec, but it's important to not crash here
@@ -1195,8 +1147,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": None,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": 1 attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1205,8 +1158,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": 1,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
         # userinfo with "test": 3.14 attribute should fail
         # this is largely just to ensure we don't crash here
@@ -1215,8 +1169,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "username": "tester",
             "test": 3.14,
         }
-        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
-        auth_handler.complete_sso_login.assert_not_called()
+        request, _ = self.start_authorization(userinfo)
+        self.get_success(self.handler.handle_oidc_callback(request))
+        self.complete_sso_login.assert_not_called()
 
     def _generate_oidc_session_token(
         self,
@@ -1230,7 +1185,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return self.handler._macaroon_generator.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
-                idp_id="oidc",
+                idp_id=self.provider.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url,
                 ui_auth_session_id=ui_auth_session_id,
@@ -1238,41 +1193,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
         )
 
 
-async def _make_callback_with_userinfo(
-    hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
-) -> None:
-    """Mock up an OIDC callback with the given userinfo dict
-
-    We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
-    and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
-
-    Args:
-        hs: the HomeServer impl to send the callback to.
-        userinfo: the OIDC userinfo dict
-        client_redirect_url: the URL to redirect to on success.
-    """
-
-    handler = hs.get_oidc_handler()
-    provider = handler._providers["oidc"]
-    provider._exchange_code = simple_async_mock(return_value={"id_token": ""})  # type: ignore[assignment]
-    provider._parse_id_token = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)  # type: ignore[assignment]
-
-    state = "state"
-    session = handler._macaroon_generator.generate_oidc_session_token(
-        state=state,
-        session_data=OidcSessionData(
-            idp_id="oidc",
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id="",
-        ),
-    )
-    request = _build_callback_request("code", state, session)
-
-    await handler.handle_oidc_callback(request)
-
-
 def _build_callback_request(
     code: str,
     state: str,