diff --git a/tests/handlers/oidc_test_key.p8 b/tests/handlers/oidc_test_key.p8
new file mode 100644
 index 0000000000..bb92976333
--- /dev/null
+++ b/tests/handlers/oidc_test_key.p8
@@ -0,0 +1,5 @@
+-----BEGIN PRIVATE KEY-----
+MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgrHMvFcFjFhei6gHp
+Gfy4C8+6z7634MZbC7SSx4a17GahRANCAATp0YxEzGUXuqszggiFxczDdPgDpCJA
+P18rRuN7FLwZDuzYQPb8zVd8eGh4BqxjiVocICnVWyaSWD96N00I96SW
+-----END PRIVATE KEY-----
diff --git a/tests/handlers/oidc_test_key.pub.pem b/tests/handlers/oidc_test_key.pub.pem
new file mode 100644
 index 0000000000..176d4a4b4b
--- /dev/null
+++ b/tests/handlers/oidc_test_key.pub.pem
@@ -0,0 +1,4 @@
+-----BEGIN PUBLIC KEY-----
+MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAE6dGMRMxlF7qrM4IIhcXMw3T4A6Qi
+QD9fK0bjexS8GQ7s2ED2/M1XfHhoeAasY4laHCAp1Vsmklg/ejdNCPeklg==
+-----END PUBLIC KEY-----
diff --git a/tests/handlers/test_auth.py b/tests/handlers/test_auth.py
 index 0e42013bb9..c9f889b511 100644
--- a/tests/handlers/test_auth.py
+++ b/tests/handlers/test_auth.py
@@ -68,38 +68,45 @@ class AuthTestCase(unittest.HomeserverTestCase):
         v.verify(macaroon, self.hs.config.macaroon_secret_key)
 
     def test_short_term_login_token_gives_user_id(self):
-        token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
-        user_id = self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(token)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", "", 5000
         )
-        self.assertEqual("a_user", user_id)
+        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+        self.assertEqual("a_user", res.user_id)
+        self.assertEqual("", res.auth_provider_id)
 
         # when we advance the clock, the token should be rejected
         self.reactor.advance(6)
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(token),
+            self.auth_handler.validate_short_term_login_token(token),
             AuthError,
         )
 
+    def test_short_term_login_token_gives_auth_provider(self):
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", auth_provider_id="my_idp"
+        )
+        res = self.get_success(self.auth_handler.validate_short_term_login_token(token))
+        self.assertEqual("a_user", res.user_id)
+        self.assertEqual("my_idp", res.auth_provider_id)
+
     def test_short_term_login_token_cannot_replace_user_id(self):
-        token = self.macaroon_generator.generate_short_term_login_token("a_user", 5000)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "a_user", "", 5000
+        )
         macaroon = pymacaroons.Macaroon.deserialize(token)
 
-        user_id = self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                macaroon.serialize()
-            )
+        res = self.get_success(
+            self.auth_handler.validate_short_term_login_token(macaroon.serialize())
         )
-        self.assertEqual("a_user", user_id)
+        self.assertEqual("a_user", res.user_id)
 
         # add another "user_id" caveat, which might allow us to override the
         # user_id.
         macaroon.add_first_party_caveat("user_id = b_user")
 
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
-                macaroon.serialize()
-            ),
+            self.auth_handler.validate_short_term_login_token(macaroon.serialize()),
             AuthError,
         )
 
@@ -113,7 +120,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
@@ -135,7 +142,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             return_value=make_awaitable(self.large_number_of_users)
         )
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             ),
             ResourceLimitError,
@@ -159,7 +166,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             ResourceLimitError,
         )
         self.get_failure(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             ),
             ResourceLimitError,
@@ -175,7 +182,7 @@ class AuthTestCase(unittest.HomeserverTestCase):
             )
         )
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
@@ -197,11 +204,13 @@ class AuthTestCase(unittest.HomeserverTestCase):
             return_value=make_awaitable(self.small_number_of_users)
         )
         self.get_success(
-            self.auth_handler.validate_short_term_login_token_and_get_user_id(
+            self.auth_handler.validate_short_term_login_token(
                 self._get_macaroon().serialize()
             )
         )
 
     def _get_macaroon(self):
-        token = self.macaroon_generator.generate_short_term_login_token("user_a", 5000)
+        token = self.macaroon_generator.generate_short_term_login_token(
+            "user_a", "", 5000
+        )
         return pymacaroons.Macaroon.deserialize(token)
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
 index 6f992291b8..7975af243c 100644
--- a/tests/handlers/test_cas.py
+++ b/tests/handlers/test_cas.py
@@ -66,7 +66,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
     def test_map_cas_user_to_existing_user(self):
@@ -89,7 +89,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=False
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -98,7 +98,7 @@ class CasHandlerTestCase(HomeserverTestCase):
             self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=False
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=False
         )
 
     def test_map_cas_user_to_invalid_localpart(self):
@@ -116,7 +116,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+            "@f=c3=b6=c3=b6:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
     @override_config(
@@ -160,7 +160,7 @@ class CasHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "cas", request, "redirect_uri", None, new_user=True
         )
 
 
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
 index cf1de28fa9..5e9c9c2e88 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
-from typing import Optional
+import os
 from urllib.parse import parse_qs, urlparse
 
 from mock import ANY, Mock, patch
@@ -23,6 +23,7 @@ import pymacaroons
 from synapse.handlers.sso import MappingException
 from synapse.server import HomeServer
 from synapse.types import UserID
+from synapse.util.macaroons import get_value_from_macaroon
 
 from tests.test_utils import FakeResponse, get_awaitable_result, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
@@ -50,7 +51,18 @@ WELL_KNOWN = ISSUER + ".well-known/openid-configuration"
 JWKS_URI = ISSUER + ".well-known/jwks.json"
 
 # config for common cases
-COMMON_CONFIG = {
+DEFAULT_CONFIG = {
+    "enabled": True,
+    "client_id": CLIENT_ID,
+    "client_secret": CLIENT_SECRET,
+    "issuer": ISSUER,
+    "scopes": SCOPES,
+    "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
+}
+
+# extends the default config with explicit OAuth2 endpoints instead of using discovery
+EXPLICIT_ENDPOINT_CONFIG = {
+    **DEFAULT_CONFIG,
     "discover": False,
     "authorization_endpoint": AUTHORIZATION_ENDPOINT,
     "token_endpoint": TOKEN_ENDPOINT,
@@ -107,6 +119,32 @@ async def get_json(url):
         return {"keys": []}
 
 
+def _key_file_path() -> str:
+    """path to a file containing the private half of a test key"""
+
+    # this key was generated with:
+    #   openssl ecparam -name prime256v1 -genkey -noout |
+    #       openssl pkcs8 -topk8 -nocrypt -out oidc_test_key.p8
+    #
+    # we use PKCS8 rather than SEC-1 (which is what openssl ecparam spits out), because
+    # that's what Apple use, and we want to be sure that we work with Apple's keys.
+    #
+    # (For the record: both PKCS8 and SEC-1 specify (different) ways of representing
+    # keys using ASN.1. Both are then typically formatted using PEM, which says: use the
+    # base64-encoded DER encoding of ASN.1, with headers and footers. But we don't
+    # really need to care about any of that.)
+    return os.path.join(os.path.dirname(__file__), "oidc_test_key.p8")
+
+
+def _public_key_file_path() -> str:
+    """path to a file containing the public half of a test key"""
+    # this was generated with:
+    #    openssl ec -in oidc_test_key.p8 -pubout -out oidc_test_key.pub.pem
+    #
+    # See above about where oidc_test_key.p8 came from
+    return os.path.join(os.path.dirname(__file__), "oidc_test_key.pub.pem")
+
+
 class OidcHandlerTestCase(HomeserverTestCase):
     if not HAS_OIDC:
         skip = "requires OIDC"
@@ -114,20 +152,6 @@ class OidcHandlerTestCase(HomeserverTestCase):
     def default_config(self):
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
-        oidc_config = {
-            "enabled": True,
-            "client_id": CLIENT_ID,
-            "client_secret": CLIENT_SECRET,
-            "issuer": ISSUER,
-            "scopes": SCOPES,
-            "user_mapping_provider": {"module": __name__ + ".TestMappingProvider"},
-        }
-
-        # Update this config with what's in the default config so that
-        # override_config works as expected.
-        oidc_config.update(config.get("oidc_config", {}))
-        config["oidc_config"] = oidc_config
-
         return config
 
     def make_homeserver(self, reactor, clock):
@@ -170,13 +194,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.render_error.reset_mock()
         return args
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
         self.assertEqual(self.provider._callback_url, CALLBACK_URL)
         self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
         self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
-    @override_config({"oidc_config": {"discover": True}})
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "discover": True}})
     def test_discovery(self):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
@@ -195,13 +220,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
-    @override_config({"oidc_config": COMMON_CONFIG})
+    @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     def test_no_discovery(self):
         """When discovery is disabled, it should not try to load from discovery document."""
         self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
-    @override_config({"oidc_config": COMMON_CONFIG})
+    @override_config({"oidc_config": EXPLICIT_ENDPOINT_CONFIG})
     def test_load_jwks(self):
         """JWKS loading is done once (then cached) if used."""
         jwks = self.get_success(self.provider.load_jwks())
@@ -236,6 +261,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
         h = self.provider
@@ -318,13 +344,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             # Shouldn't raise with a valid userinfo, even without jwks
             force_load_metadata()
 
-    @override_config({"oidc_config": {"skip_verification": True}})
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "skip_verification": True}})
     def test_skip_verification(self):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
             get_awaitable_result(self.provider.load_metadata())
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["cookies"])
@@ -360,20 +387,15 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertEqual(name, b"oidc_session")
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "state"
-        )
-        nonce = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "nonce"
-        )
-        redirect = self.handler._token_generator._get_value_from_macaroon(
-            macaroon, "client_redirect_url"
-        )
+        state = get_value_from_macaroon(macaroon, "state")
+        nonce = get_value_from_macaroon(macaroon, "nonce")
+        redirect = get_value_from_macaroon(macaroon, "client_redirect_url")
 
         self.assertEqual(params["state"], [state])
         self.assertEqual(params["nonce"], [nonce])
         self.assertEqual(redirect, "http://client/redirect")
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_error(self):
         """Errors from the provider returned in the callback are displayed."""
         request = Mock(args={})
@@ -385,6 +407,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_client", "some description")
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback(self):
         """Code callback works and display errors if something went wrong.
 
@@ -434,7 +457,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None, new_user=True
+            expected_user_id, "oidc", request, client_redirect_url, None, new_user=True
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
@@ -465,7 +488,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
 
         auth_handler.complete_sso_login.assert_called_once_with(
-            expected_user_id, request, client_redirect_url, None, new_user=False
+            expected_user_id, "oidc", request, client_redirect_url, None, new_user=False
         )
         self.provider._exchange_code.assert_called_once_with(code)
         self.provider._parse_id_token.assert_not_called()
@@ -486,6 +509,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_callback_session(self):
         """The callback verifies the session presence and validity"""
         request = Mock(spec=["args", "getCookie", "cookies"])
@@ -528,7 +552,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_request")
 
-    @override_config({"oidc_config": {"client_auth_method": "client_secret_post"}})
+    @override_config(
+        {"oidc_config": {**DEFAULT_CONFIG, "client_auth_method": "client_secret_post"}}
+    )
     def test_exchange_code(self):
         """Code exchange behaves correctly and handles various error scenarios."""
         token = {"type": "bearer"}
@@ -613,9 +639,105 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {
             "oidc_config": {
+                "enabled": True,
+                "client_id": CLIENT_ID,
+                "issuer": ISSUER,
+                "client_auth_method": "client_secret_post",
+                "client_secret_jwt_key": {
+                    "key_file": _key_file_path(),
+                    "jwt_header": {"alg": "ES256", "kid": "ABC789"},
+                    "jwt_payload": {"iss": "DEFGHI"},
+                },
+            }
+        }
+    )
+    def test_exchange_code_jwt_key(self):
+        """Test that code exchange works with a JWK client secret."""
+        from authlib.jose import jwt
+
+        token = {"type": "bearer"}
+        self.http_client.request = simple_async_mock(
+            return_value=FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+            )
+        )
+        code = "code"
+
+        # advance the clock a bit before we start, so we aren't working with zero
+        # timestamps.
+        self.reactor.advance(1000)
+        start_time = self.reactor.seconds()
+        ret = self.get_success(self.provider._exchange_code(code))
+
+        self.assertEqual(ret, token)
+
+        # the request should have hit the token endpoint
+        kwargs = self.http_client.request.call_args[1]
+        self.assertEqual(kwargs["method"], "POST")
+        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+        # the client secret provided to the should be a jwt which can be checked with
+        # the public key
+        args = parse_qs(kwargs["data"].decode("utf-8"))
+        secret = args["client_secret"][0]
+        with open(_public_key_file_path()) as f:
+            key = f.read()
+        claims = jwt.decode(secret, key)
+        self.assertEqual(claims.header["kid"], "ABC789")
+        self.assertEqual(claims["aud"], ISSUER)
+        self.assertEqual(claims["iss"], "DEFGHI")
+        self.assertEqual(claims["sub"], CLIENT_ID)
+        self.assertEqual(claims["iat"], start_time)
+        self.assertGreater(claims["exp"], start_time)
+
+        # check the rest of the POSTed data
+        self.assertEqual(args["grant_type"], ["authorization_code"])
+        self.assertEqual(args["code"], [code])
+        self.assertEqual(args["client_id"], [CLIENT_ID])
+        self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+    @override_config(
+        {
+            "oidc_config": {
+                "enabled": True,
+                "client_id": CLIENT_ID,
+                "issuer": ISSUER,
+                "client_auth_method": "none",
+            }
+        }
+    )
+    def test_exchange_code_no_auth(self):
+        """Test that code exchange works with no client secret."""
+        token = {"type": "bearer"}
+        self.http_client.request = simple_async_mock(
+            return_value=FakeResponse(
+                code=200, phrase=b"OK", body=json.dumps(token).encode("utf-8")
+            )
+        )
+        code = "code"
+        ret = self.get_success(self.provider._exchange_code(code))
+
+        self.assertEqual(ret, token)
+
+        # the request should have hit the token endpoint
+        kwargs = self.http_client.request.call_args[1]
+        self.assertEqual(kwargs["method"], "POST")
+        self.assertEqual(kwargs["uri"], TOKEN_ENDPOINT)
+
+        # check the POSTed data
+        args = parse_qs(kwargs["data"].decode("utf-8"))
+        self.assertEqual(args["grant_type"], ["authorization_code"])
+        self.assertEqual(args["code"], [code])
+        self.assertEqual(args["client_id"], [CLIENT_ID])
+        self.assertEqual(args["redirect_uri"], [CALLBACK_URL])
+
+    @override_config(
+        {
+            "oidc_config": {
+                **DEFAULT_CONFIG,
                 "user_mapping_provider": {
                     "module": __name__ + ".TestMappingProviderExtra"
-                }
+                },
             }
         }
     )
@@ -651,12 +773,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         auth_handler.complete_sso_login.assert_called_once_with(
             "@foo:test",
+            "oidc",
             request,
             client_redirect_url,
             {"phone": "1234567"},
             new_user=True,
         )
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_user(self):
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
         auth_handler = self.hs.get_auth_handler()
@@ -668,7 +792,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", ANY, ANY, None, new_user=True
+            "@test_user:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -679,7 +803,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user_2:test", ANY, ANY, None, new_user=True
+            "@test_user_2:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -697,7 +821,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "Mapping provider does not support de-duplicating Matrix IDs",
         )
 
-    @override_config({"oidc_config": {"allow_existing_users": True}})
+    @override_config({"oidc_config": {**DEFAULT_CONFIG, "allow_existing_users": True}})
     def test_map_userinfo_to_existing_user(self):
         """Existing users can log in with OpenID Connect when allow_existing_users is True."""
         store = self.hs.get_datastore()
@@ -716,14 +840,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -738,7 +862,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            user.to_string(), ANY, ANY, None, new_user=False
+            user.to_string(), "oidc", ANY, ANY, None, new_user=False
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -774,9 +898,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
+            "@TEST_USER_2:test", "oidc", ANY, ANY, None, new_user=False
         )
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_map_userinfo_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
         self.get_success(
@@ -787,9 +912,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {
             "oidc_config": {
+                **DEFAULT_CONFIG,
                 "user_mapping_provider": {
                     "module": __name__ + ".TestMappingProviderFailures"
-                }
+                },
             }
         }
     )
@@ -810,7 +936,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", ANY, ANY, None, new_user=True
+            "@test_user1:test", "oidc", ANY, ANY, None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -834,6 +960,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
 
+    @override_config({"oidc_config": DEFAULT_CONFIG})
     def test_empty_localpart(self):
         """Attempts to map onto an empty localpart should be rejected."""
         userinfo = {
@@ -846,9 +973,10 @@ class OidcHandlerTestCase(HomeserverTestCase):
     @override_config(
         {
             "oidc_config": {
+                **DEFAULT_CONFIG,
                 "user_mapping_provider": {
                     "config": {"localpart_template": "{{ user.username }}"}
-                }
+                },
             }
         }
     )
@@ -866,7 +994,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         state: str,
         nonce: str,
         client_redirect_url: str,
-        ui_auth_session_id: Optional[str] = None,
+        ui_auth_session_id: str = "",
     ) -> str:
         from synapse.handlers.oidc_handler import OidcSessionData
 
@@ -909,6 +1037,7 @@ async def _make_callback_with_userinfo(
             idp_id="oidc",
             nonce="nonce",
             client_redirect_url=client_redirect_url,
+            ui_auth_session_id="",
         ),
     )
     request = _build_callback_request("code", state, session)
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
 index 029af2853e..30efd43b40 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -131,7 +131,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
         )
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
@@ -157,7 +157,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None, new_user=False
+            "@test_user:test", "saml", request, "", None, new_user=False
         )
 
         # Subsequent calls should map to the same mxid.
@@ -166,7 +166,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
             self.handler._handle_authn_response(request, saml_response, "")
         )
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "", None, new_user=False
+            "@test_user:test", "saml", request, "", None, new_user=False
         )
 
     def test_map_saml_response_to_invalid_localpart(self):
@@ -214,7 +214,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # test_user is already taken, so test_user1 gets registered instead.
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user1:test", request, "", None, new_user=True
+            "@test_user1:test", "saml", request, "", None, new_user=True
         )
         auth_handler.complete_sso_login.reset_mock()
 
@@ -310,7 +310,7 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # check that the auth handler got called as expected
         auth_handler.complete_sso_login.assert_called_once_with(
-            "@test_user:test", request, "redirect_uri", None, new_user=True
+            "@test_user:test", "saml", request, "redirect_uri", None, new_user=True
         )
 
 
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
 index f6a6aed35e..20940c8107 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -22,6 +22,7 @@ from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 from twisted.web.resource import Resource
+from twisted.web.server import Request, Site
 
 from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
@@ -32,7 +33,10 @@ from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.replication.tcp.resource import (
+    ReplicationStreamProtocolFactory,
+    ServerReplicationStreamProtocol,
+)
 from synapse.server import HomeServer
 from synapse.util import Clock
 
@@ -59,7 +63,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
         self.streamer = hs.get_replication_streamer()
-        self.server = server_factory.buildProtocol(None)
+        self.server = server_factory.buildProtocol(
+            None
+        )  # type: ServerReplicationStreamProtocol
 
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
@@ -155,9 +161,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         request_factory = OneShotRequestFactory()
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor)
-        channel.requestFactory = request_factory
-        channel.site = self.site
+        channel = _PushHTTPChannel(self.reactor, request_factory, self.site)
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -188,8 +192,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         fetching updates for given stream.
         """
 
+        path = request.path  # type: bytes  # type: ignore
         self.assertRegex(
-            request.path,
+            path,
             br"^/_synapse/replication/get_repl_stream_updates/%s/[^/]+$"
             % (stream_name.encode("ascii"),),
         )
@@ -390,9 +395,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         request_factory = OneShotRequestFactory()
 
         # Set up the server side protocol
-        channel = _PushHTTPChannel(self.reactor)
-        channel.requestFactory = request_factory
-        channel.site = self._hs_to_site[hs]
+        channel = _PushHTTPChannel(self.reactor, request_factory, self._hs_to_site[hs])
 
         # Connect client to server and vice versa.
         client_to_server_transport = FakeTransport(
@@ -475,9 +478,13 @@ class _PushHTTPChannel(HTTPChannel):
     makes it very hard to test.
     """
 
-    def __init__(self, reactor: IReactorTime):
+    def __init__(
+        self, reactor: IReactorTime, request_factory: Callable[..., Request], site: Site
+    ):
         super().__init__()
         self.reactor = reactor
+        self.requestFactory = request_factory
+        self.site = site
 
         self._pull_to_push_producer = None  # type: Optional[_PullToPushProducer]
 
diff --git a/tests/rest/media/v1/test_media_storage.py b/tests/rest/media/v1/test_media_storage.py
 index 36d1e6bc4a..9f77125fd4 100644
--- a/tests/rest/media/v1/test_media_storage.py
+++ b/tests/rest/media/v1/test_media_storage.py
@@ -105,7 +105,7 @@ class MediaStorageTests(unittest.HomeserverTestCase):
         self.assertEqual(test_body, body)
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class _TestImage:
     """An image for testing thumbnailing with the expected results
 
@@ -117,13 +117,15 @@ class _TestImage:
             test should just check for success.
         expected_scaled: The expected bytes from scaled thumbnailing, or None if
             test should just check for a valid image returned.
+        expected_found: True if the file should exist on the server, or False if
+            a 404 is expected.
     """
 
     data = attr.ib(type=bytes)
     content_type = attr.ib(type=bytes)
     extension = attr.ib(type=bytes)
-    expected_cropped = attr.ib(type=Optional[bytes])
-    expected_scaled = attr.ib(type=Optional[bytes])
+    expected_cropped = attr.ib(type=Optional[bytes], default=None)
+    expected_scaled = attr.ib(type=Optional[bytes], default=None)
     expected_found = attr.ib(default=True, type=bool)
 
 
@@ -153,6 +155,21 @@ class _TestImage:
                 ),
             ),
         ),
+        # small png with transparency.
+        (
+            _TestImage(
+                unhexlify(
+                    b"89504e470d0a1a0a0000000d49484452000000010000000101000"
+                    b"00000376ef9240000000274524e5300010194fdae0000000a4944"
+                    b"4154789c636800000082008177cd72b60000000049454e44ae426"
+                    b"082"
+                ),
+                b"image/png",
+                b".png",
+                # Note that we don't check the output since it varies across
+                # different versions of Pillow.
+            ),
+        ),
         # small lossless webp
         (
             _TestImage(
@@ -162,8 +179,6 @@ class _TestImage:
                 ),
                 b"image/webp",
                 b".webp",
-                None,
-                None,
             ),
         ),
         # an empty file
@@ -172,9 +187,7 @@ class _TestImage:
                 b"",
                 b"image/gif",
                 b".gif",
-                None,
-                None,
-                False,
+                expected_found=False,
             ),
         ),
     ],
diff --git a/tests/server.py b/tests/server.py
 index 939a0008ca..863f6da738 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -188,7 +188,7 @@ class FakeSite:
 
 def make_request(
     reactor,
-    site: Site,
+    site: Union[Site, FakeSite],
     method,
     path,
     content=b"",
diff --git a/tests/storage/test_purge.py b/tests/storage/test_purge.py
 index a06ad2c03e..41af8c4847 100644
--- a/tests/storage/test_purge.py
+++ b/tests/storage/test_purge.py
@@ -13,9 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from twisted.internet import defer
-
-from synapse.api.errors import NotFoundError
+from synapse.api.errors import NotFoundError, SynapseError
 from synapse.rest.client.v1 import room
 
 from tests.unittest import HomeserverTestCase
@@ -33,9 +31,12 @@ class PurgeTests(HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         self.room_id = self.helper.create_room_as(self.user_id)
 
-    def test_purge(self):
+        self.store = hs.get_datastore()
+        self.storage = self.hs.get_storage()
+
+    def test_purge_history(self):
         """
-        Purging a room will delete everything before the topological point.
+        Purging a room history will delete everything before the topological point.
         """
         # Send four messages to the room
         first = self.helper.send(self.room_id, body="test1")
@@ -43,30 +44,27 @@ class PurgeTests(HomeserverTestCase):
         third = self.helper.send(self.room_id, body="test3")
         last = self.helper.send(self.room_id, body="test4")
 
-        store = self.hs.get_datastore()
-        storage = self.hs.get_storage()
-
         # Get the topological token
         token = self.get_success(
-            store.get_topological_token_for_event(last["event_id"])
+            self.store.get_topological_token_for_event(last["event_id"])
         )
         token_str = self.get_success(token.to_string(self.hs.get_datastore()))
 
         # Purge everything before this topological token
         self.get_success(
-            storage.purge_events.purge_history(self.room_id, token_str, True)
+            self.storage.purge_events.purge_history(self.room_id, token_str, True)
         )
 
         # 1-3 should fail and last will succeed, meaning that 1-3 are deleted
         # and last is not.
-        self.get_failure(store.get_event(first["event_id"]), NotFoundError)
-        self.get_failure(store.get_event(second["event_id"]), NotFoundError)
-        self.get_failure(store.get_event(third["event_id"]), NotFoundError)
-        self.get_success(store.get_event(last["event_id"]))
+        self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
+        self.get_failure(self.store.get_event(second["event_id"]), NotFoundError)
+        self.get_failure(self.store.get_event(third["event_id"]), NotFoundError)
+        self.get_success(self.store.get_event(last["event_id"]))
 
-    def test_purge_wont_delete_extrems(self):
+    def test_purge_history_wont_delete_extrems(self):
         """
-        Purging a room will delete everything before the topological point.
+        Purging a room history will delete everything before the topological point.
         """
         # Send four messages to the room
         first = self.helper.send(self.room_id, body="test1")
@@ -74,22 +72,43 @@ class PurgeTests(HomeserverTestCase):
         third = self.helper.send(self.room_id, body="test3")
         last = self.helper.send(self.room_id, body="test4")
 
-        storage = self.hs.get_datastore()
-
         # Set the topological token higher than it should be
         token = self.get_success(
-            storage.get_topological_token_for_event(last["event_id"])
+            self.store.get_topological_token_for_event(last["event_id"])
         )
         event = "t{}-{}".format(token.topological + 1, token.stream + 1)
 
         # Purge everything before this topological token
-        purge = defer.ensureDeferred(storage.purge_history(self.room_id, event, True))
-        self.pump()
-        f = self.failureResultOf(purge)
+        f = self.get_failure(
+            self.storage.purge_events.purge_history(self.room_id, event, True),
+            SynapseError,
+        )
         self.assertIn("greater than forward", f.value.args[0])
 
         # Try and get the events
-        self.get_success(storage.get_event(first["event_id"]))
-        self.get_success(storage.get_event(second["event_id"]))
-        self.get_success(storage.get_event(third["event_id"]))
-        self.get_success(storage.get_event(last["event_id"]))
+        self.get_success(self.store.get_event(first["event_id"]))
+        self.get_success(self.store.get_event(second["event_id"]))
+        self.get_success(self.store.get_event(third["event_id"]))
+        self.get_success(self.store.get_event(last["event_id"]))
+
+    def test_purge_room(self):
+        """
+        Purging a room will delete everything about it.
+        """
+        # Send four messages to the room
+        first = self.helper.send(self.room_id, body="test1")
+
+        # Get the current room state.
+        state_handler = self.hs.get_state_handler()
+        create_event = self.get_success(
+            state_handler.get_current_state(self.room_id, "m.room.create", "")
+        )
+        self.assertIsNotNone(create_event)
+
+        # Purge everything before this topological token
+        self.get_success(self.storage.purge_events.purge_room(self.room_id))
+
+        # The events aren't found.
+        self.store._invalidate_get_event_cache(create_event.event_id)
+        self.get_failure(self.store.get_event(create_event.event_id), NotFoundError)
+        self.get_failure(self.store.get_event(first["event_id"]), NotFoundError)
diff --git a/tests/test_utils/logging_setup.py b/tests/test_utils/logging_setup.py
 index 52ae5c5713..74568b34f8 100644
--- a/tests/test_utils/logging_setup.py
+++ b/tests/test_utils/logging_setup.py
@@ -28,7 +28,7 @@ class ToTwistedHandler(logging.Handler):
     def emit(self, record):
         log_entry = self.format(record)
         log_level = record.levelname.lower().replace("warning", "warn")
-        self.tx_log.emit(
+        self.tx_log.emit(  # type: ignore
             twisted.logger.LogLevel.levelWithName(log_level), "{entry}", entry=log_entry
         )
 
diff --git a/tests/util/caches/test_responsecache.py b/tests/util/caches/test_responsecache.py
new file mode 100644
 index 0000000000..f9a187b8de
--- /dev/null
+++ b/tests/util/caches/test_responsecache.py
@@ -0,0 +1,131 @@
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.util.caches.response_cache import ResponseCache
+
+from tests.server import get_clock
+from tests.unittest import TestCase
+
+
+class DeferredCacheTestCase(TestCase):
+    """
+    A TestCase class for ResponseCache.
+
+    The test-case function naming has some logic to it in it's parts, here's some notes about it:
+        wait: Denotes tests that have an element of "waiting" before its wrapped result becomes available
+              (Generally these just use .delayed_return instead of .instant_return in it's wrapped call.)
+        expire: Denotes tests that test expiry after assured existence.
+                (These have cache with a short timeout_ms=, shorter than will be tested through advancing the clock)
+    """
+
+    def setUp(self):
+        self.reactor, self.clock = get_clock()
+
+    def with_cache(self, name: str, ms: int = 0) -> ResponseCache:
+        return ResponseCache(self.clock, name, timeout_ms=ms)
+
+    @staticmethod
+    async def instant_return(o: str) -> str:
+        return o
+
+    async def delayed_return(self, o: str) -> str:
+        await self.clock.sleep(1)
+        return o
+
+    def test_cache_hit(self):
+        cache = self.with_cache("keeping_cache", ms=9001)
+
+        expected_result = "howdy"
+
+        wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+        self.assertEqual(
+            expected_result,
+            self.successResultOf(wrap_d),
+            "initial wrap result should be the same",
+        )
+        self.assertEqual(
+            expected_result,
+            self.successResultOf(cache.get(0)),
+            "cache should have the result",
+        )
+
+    def test_cache_miss(self):
+        cache = self.with_cache("trashing_cache", ms=0)
+
+        expected_result = "howdy"
+
+        wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+        self.assertEqual(
+            expected_result,
+            self.successResultOf(wrap_d),
+            "initial wrap result should be the same",
+        )
+        self.assertIsNone(cache.get(0), "cache should not have the result now")
+
+    def test_cache_expire(self):
+        cache = self.with_cache("short_cache", ms=1000)
+
+        expected_result = "howdy"
+
+        wrap_d = cache.wrap(0, self.instant_return, expected_result)
+
+        self.assertEqual(expected_result, self.successResultOf(wrap_d))
+        self.assertEqual(
+            expected_result,
+            self.successResultOf(cache.get(0)),
+            "cache should still have the result",
+        )
+
+        # cache eviction timer is handled
+        self.reactor.pump((2,))
+
+        self.assertIsNone(cache.get(0), "cache should not have the result now")
+
+    def test_cache_wait_hit(self):
+        cache = self.with_cache("neutral_cache")
+
+        expected_result = "howdy"
+
+        wrap_d = cache.wrap(0, self.delayed_return, expected_result)
+        self.assertNoResult(wrap_d)
+
+        # function wakes up, returns result
+        self.reactor.pump((2,))
+
+        self.assertEqual(expected_result, self.successResultOf(wrap_d))
+
+    def test_cache_wait_expire(self):
+        cache = self.with_cache("medium_cache", ms=3000)
+
+        expected_result = "howdy"
+
+        wrap_d = cache.wrap(0, self.delayed_return, expected_result)
+        self.assertNoResult(wrap_d)
+
+        # stop at 1 second to callback cache eviction callLater at that time, then another to set time at 2
+        self.reactor.pump((1, 1))
+
+        self.assertEqual(expected_result, self.successResultOf(wrap_d))
+        self.assertEqual(
+            expected_result,
+            self.successResultOf(cache.get(0)),
+            "cache should still have the result",
+        )
+
+        # (1 + 1 + 2) > 3.0, cache eviction timer is handled
+        self.reactor.pump((2,))
+
+        self.assertIsNone(cache.get(0), "cache should not have the result now")
  |