summary refs log tree commit diff
path: root/tests/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers')
-rw-r--r--tests/handlers/test_cas.py121
-rw-r--r--tests/handlers/test_device.py4
-rw-r--r--tests/handlers/test_directory.py14
-rw-r--r--tests/handlers/test_federation.py54
-rw-r--r--tests/handlers/test_message.py2
-rw-r--r--tests/handlers/test_oidc.py519
-rw-r--r--tests/handlers/test_password_providers.py29
-rw-r--r--tests/handlers/test_presence.py2
-rw-r--r--tests/handlers/test_profile.py32
-rw-r--r--tests/handlers/test_saml.py155
-rw-r--r--tests/handlers/test_typing.py19
-rw-r--r--tests/handlers/test_user_directory.py48
12 files changed, 693 insertions, 306 deletions
diff --git a/tests/handlers/test_cas.py b/tests/handlers/test_cas.py
new file mode 100644
index 0000000000..7baf224f7e
--- /dev/null
+++ b/tests/handlers/test_cas.py
@@ -0,0 +1,121 @@
+#  Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+#  Licensed under the Apache License, Version 2.0 (the "License");
+#  you may not use this file except in compliance with the License.
+#  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+#  limitations under the License.
+from mock import Mock
+
+from synapse.handlers.cas_handler import CasResponse
+
+from tests.test_utils import simple_async_mock
+from tests.unittest import HomeserverTestCase
+
+# These are a few constants that are used as config parameters in the tests.
+BASE_URL = "https://synapse/"
+SERVER_URL = "https://issuer/"
+
+
+class CasHandlerTestCase(HomeserverTestCase):
+    def default_config(self):
+        config = super().default_config()
+        config["public_baseurl"] = BASE_URL
+        cas_config = {
+            "enabled": True,
+            "server_url": SERVER_URL,
+            "service_url": BASE_URL,
+        }
+        config["cas_config"] = cas_config
+
+        return config
+
+    def make_homeserver(self, reactor, clock):
+        hs = self.setup_test_homeserver()
+
+        self.handler = hs.get_cas_handler()
+
+        # Reduce the number of attempts when generating MXIDs.
+        sso_handler = hs.get_sso_handler()
+        sso_handler._MAP_USERNAME_RETRIES = 3
+
+        return hs
+
+    def test_map_cas_user_to_user(self):
+        """Ensure that mapping the CAS user returned from a provider to an MXID works properly."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
+        )
+
+    def test_map_cas_user_to_existing_user(self):
+        """Existing users can log in with CAS account."""
+        store = self.hs.get_datastore()
+        self.get_success(
+            store.register_user(user_id="@test_user:test", password_hash=None)
+        )
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # Map a user via SSO.
+        cas_response = CasResponse("test_user", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=False
+        )
+
+        # Subsequent calls should map to the same mxid.
+        auth_handler.complete_sso_login.reset_mock()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=False
+        )
+
+    def test_map_cas_user_to_invalid_localpart(self):
+        """CAS automaps invalid characters to base-64 encoding."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        cas_response = CasResponse("föö", {})
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_cas_response(request, cas_response, "redirect_uri", "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@f=c3=b6=c3=b6:test", request, "redirect_uri", None, new_user=True
+        )
+
+
+def _mock_request():
+    """Returns a mock which will stand in as a SynapseRequest"""
+    return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 875aaec2c6..5dfeccfeb6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -27,7 +27,7 @@ user2 = "@theresa:bbb"
 
 class DeviceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
         self.store = hs.get_datastore()
         return hs
@@ -229,7 +229,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver("server", http_client=None)
+        hs = self.setup_test_homeserver("server", federation_http_client=None)
         self.handler = hs.get_device_handler()
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py
index ee6ef5e6fa..a39f898608 100644
--- a/tests/handlers/test_directory.py
+++ b/tests/handlers/test_directory.py
@@ -42,8 +42,6 @@ class DirectoryTestCase(unittest.HomeserverTestCase):
         self.mock_registry.register_query_handler = register_query_handler
 
         hs = self.setup_test_homeserver(
-            http_client=None,
-            resource_for_federation=Mock(),
             federation_client=self.mock_federation,
             federation_registry=self.mock_registry,
         )
@@ -407,7 +405,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
     def test_denied(self):
         room_id = self.helper.create_room_as(self.user_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             b"directory/room/%23test%3Atest",
             ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -417,7 +415,7 @@ class TestCreateAliasACL(unittest.HomeserverTestCase):
     def test_allowed(self):
         room_id = self.helper.create_room_as(self.user_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT",
             b"directory/room/%23unofficial_test%3Atest",
             ('{"room_id":"%s"}' % (room_id,)).encode("ascii"),
@@ -433,7 +431,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
     def prepare(self, reactor, clock, hs):
         room_id = self.helper.create_room_as(self.user_id)
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
         )
         self.assertEquals(200, channel.code, channel.result)
@@ -448,7 +446,7 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         self.directory_handler.enable_room_list_search = True
 
         # Room list is enabled so we should get some results
-        request, channel = self.make_request("GET", b"publicRooms")
+        channel = self.make_request("GET", b"publicRooms")
         self.assertEquals(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["chunk"]) > 0)
 
@@ -456,13 +454,13 @@ class TestRoomListSearchDisabled(unittest.HomeserverTestCase):
         self.directory_handler.enable_room_list_search = False
 
         # Room list disabled so we should get no results
-        request, channel = self.make_request("GET", b"publicRooms")
+        channel = self.make_request("GET", b"publicRooms")
         self.assertEquals(200, channel.code, channel.result)
         self.assertTrue(len(channel.json_body["chunk"]) == 0)
 
         # Room list disabled so we shouldn't be allowed to publish rooms
         room_id = self.helper.create_room_as(self.user_id)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "PUT", b"directory/list/room/%s" % (room_id.encode("ascii"),), b"{}"
         )
         self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bf866dacf3..983e368592 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -16,7 +16,7 @@ import logging
 from unittest import TestCase
 
 from synapse.api.constants import EventTypes
-from synapse.api.errors import AuthError, Codes, SynapseError
+from synapse.api.errors import AuthError, Codes, LimitExceededError, SynapseError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
 from synapse.federation.federation_base import event_from_pdu_json
@@ -37,7 +37,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor, clock):
-        hs = self.setup_test_homeserver(http_client=None)
+        hs = self.setup_test_homeserver(federation_http_client=None)
         self.handler = hs.get_federation_handler()
         self.store = hs.get_datastore()
         return hs
@@ -126,7 +126,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             room_version,
         )
 
-        with LoggingContext(request="send_rejected"):
+        with LoggingContext("send_rejected"):
             d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
         self.get_success(d)
 
@@ -178,7 +178,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
             room_version,
         )
 
-        with LoggingContext(request="send_rejected"):
+        with LoggingContext("send_rejected"):
             d = run_in_background(self.handler.on_receive_pdu, OTHER_SERVER, ev)
         self.get_success(d)
 
@@ -191,6 +191,50 @@ class FederationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(sg, sg2)
 
+    @unittest.override_config(
+        {"rc_invites": {"per_user": {"per_second": 0.5, "burst_count": 3}}}
+    )
+    def test_invite_by_user_ratelimit(self):
+        """Tests that invites from federation to a particular user are
+        actually rate-limited.
+        """
+        other_server = "otherserver"
+        other_user = "@otheruser:" + other_server
+
+        # create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+
+        def create_invite():
+            room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+            room_version = self.get_success(self.store.get_room_version(room_id))
+            return event_from_pdu_json(
+                {
+                    "type": EventTypes.Member,
+                    "content": {"membership": "invite"},
+                    "room_id": room_id,
+                    "sender": other_user,
+                    "state_key": "@user:test",
+                    "depth": 32,
+                    "prev_events": [],
+                    "auth_events": [],
+                    "origin_server_ts": self.clock.time_msec(),
+                },
+                room_version,
+            )
+
+        for i in range(3):
+            event = create_invite()
+            self.get_success(
+                self.handler.on_invite_request(other_server, event, event.room_version,)
+            )
+
+        event = create_invite()
+        self.get_failure(
+            self.handler.on_invite_request(other_server, event, event.room_version,),
+            exc=LimitExceededError,
+        )
+
     def _build_and_send_join_event(self, other_server, other_user, room_id):
         join_event = self.get_success(
             self.handler.on_make_join_request(other_server, room_id, other_user)
@@ -198,7 +242,7 @@ class FederationTestCase(unittest.HomeserverTestCase):
         # the auth code requires that a signature exists, but doesn't check that
         # signature... go figure.
         join_event.signatures[other_server] = {"x": "y"}
-        with LoggingContext(request="send_join"):
+        with LoggingContext("send_join"):
             d = run_in_background(
                 self.handler.on_send_join_request, other_server, join_event
             )
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index af42775815..f955dfa490 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -206,7 +206,7 @@ class ServerAclValidationTestCase(unittest.HomeserverTestCase):
 
         # Redaction of event should fail.
         path = "/_matrix/client/r0/rooms/%s/redact/%s" % (self.room_id, event_id)
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", path, content={}, access_token=self.access_token
         )
         self.assertEqual(int(channel.result["code"]), 403)
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index a308c46da9..ad20400b1d 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -13,32 +13,26 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+from typing import Optional
 from urllib.parse import parse_qs, urlparse
 
-from mock import Mock, patch
+from mock import ANY, Mock, patch
 
-import attr
 import pymacaroons
 
-from twisted.python.failure import Failure
-from twisted.web._newclient import ResponseDone
-
-from synapse.handlers.oidc_handler import OidcError, OidcMappingProvider
 from synapse.handlers.sso import MappingException
+from synapse.server import HomeServer
 from synapse.types import UserID
 
+from tests.test_utils import FakeResponse, simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
+try:
+    import authlib  # noqa: F401
 
-@attr.s
-class FakeResponse:
-    code = attr.ib()
-    body = attr.ib()
-    phrase = attr.ib()
-
-    def deliverBody(self, protocol):
-        protocol.dataReceived(self.body)
-        protocol.connectionLost(Failure(ResponseDone()))
+    HAS_OIDC = True
+except ImportError:
+    HAS_OIDC = False
 
 
 # These are a few constants that are used as config parameters in the tests.
@@ -46,7 +40,7 @@ ISSUER = "https://issuer/"
 CLIENT_ID = "test-client-id"
 CLIENT_SECRET = "test-client-secret"
 BASE_URL = "https://synapse/"
-CALLBACK_URL = BASE_URL + "_synapse/oidc/callback"
+CALLBACK_URL = BASE_URL + "_synapse/client/oidc/callback"
 SCOPES = ["openid"]
 
 AUTHORIZATION_ENDPOINT = ISSUER + "authorize"
@@ -64,17 +58,14 @@ COMMON_CONFIG = {
 }
 
 
-# The cookie name and path don't really matter, just that it has to be coherent
-# between the callback & redirect handlers.
-COOKIE_NAME = b"oidc_session"
-COOKIE_PATH = "/_synapse/oidc"
-
-
-class TestMappingProvider(OidcMappingProvider):
+class TestMappingProvider:
     @staticmethod
     def parse_config(config):
         return
 
+    def __init__(self, config):
+        pass
+
     def get_remote_user_id(self, userinfo):
         return userinfo["sub"]
 
@@ -97,16 +88,6 @@ class TestMappingProviderFailures(TestMappingProvider):
         }
 
 
-def simple_async_mock(return_value=None, raises=None):
-    # AsyncMock is not available in python3.5, this mimics part of its behaviour
-    async def cb(*args, **kwargs):
-        if raises:
-            raise raises
-        return return_value
-
-    return Mock(side_effect=cb)
-
-
 async def get_json(url):
     # Mock get_json calls to handle jwks & oidc discovery endpoints
     if url == WELL_KNOWN:
@@ -127,6 +108,9 @@ async def get_json(url):
 
 
 class OidcHandlerTestCase(HomeserverTestCase):
+    if not HAS_OIDC:
+        skip = "requires OIDC"
+
     def default_config(self):
         config = super().default_config()
         config["public_baseurl"] = BASE_URL
@@ -155,6 +139,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         hs = self.setup_test_homeserver(proxied_http_client=self.http_client)
 
         self.handler = hs.get_oidc_handler()
+        self.provider = self.handler._providers["oidc"]
         sso_handler = hs.get_sso_handler()
         # Mock the render error method.
         self.render_error = Mock(return_value=None)
@@ -166,27 +151,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return hs
 
     def metadata_edit(self, values):
-        return patch.dict(self.handler._provider_metadata, values)
+        return patch.dict(self.provider._provider_metadata, values)
 
     def assertRenderedError(self, error, error_description=None):
+        self.render_error.assert_called_once()
         args = self.render_error.call_args[0]
         self.assertEqual(args[1], error)
         if error_description is not None:
             self.assertEqual(args[2], error_description)
         # Reset the render_error mock
         self.render_error.reset_mock()
+        return args
 
     def test_config(self):
         """Basic config correctly sets up the callback URL and client auth correctly."""
-        self.assertEqual(self.handler._callback_url, CALLBACK_URL)
-        self.assertEqual(self.handler._client_auth.client_id, CLIENT_ID)
-        self.assertEqual(self.handler._client_auth.client_secret, CLIENT_SECRET)
+        self.assertEqual(self.provider._callback_url, CALLBACK_URL)
+        self.assertEqual(self.provider._client_auth.client_id, CLIENT_ID)
+        self.assertEqual(self.provider._client_auth.client_secret, CLIENT_SECRET)
 
     @override_config({"oidc_config": {"discover": True}})
     def test_discovery(self):
         """The handler should discover the endpoints from OIDC discovery document."""
         # This would throw if some metadata were invalid
-        metadata = self.get_success(self.handler.load_metadata())
+        metadata = self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_called_once_with(WELL_KNOWN)
 
         self.assertEqual(metadata.issuer, ISSUER)
@@ -198,47 +185,47 @@ class OidcHandlerTestCase(HomeserverTestCase):
 
         # subsequent calls should be cached
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_no_discovery(self):
         """When discovery is disabled, it should not try to load from discovery document."""
-        self.get_success(self.handler.load_metadata())
+        self.get_success(self.provider.load_metadata())
         self.http_client.get_json.assert_not_called()
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_load_jwks(self):
         """JWKS loading is done once (then cached) if used."""
-        jwks = self.get_success(self.handler.load_jwks())
+        jwks = self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
         self.assertEqual(jwks, {"keys": []})
 
         # subsequent calls should be cached…
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks())
+        self.get_success(self.provider.load_jwks())
         self.http_client.get_json.assert_not_called()
 
         # …unless forced
         self.http_client.reset_mock()
-        self.get_success(self.handler.load_jwks(force=True))
+        self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_called_once_with(JWKS_URI)
 
         # Throw if the JWKS uri is missing
         with self.metadata_edit({"jwks_uri": None}):
-            self.get_failure(self.handler.load_jwks(force=True), RuntimeError)
+            self.get_failure(self.provider.load_jwks(force=True), RuntimeError)
 
         # Return empty key set if JWKS are not used
-        self.handler._scopes = []  # not asking the openid scope
+        self.provider._scopes = []  # not asking the openid scope
         self.http_client.get_json.reset_mock()
-        jwks = self.get_success(self.handler.load_jwks(force=True))
+        jwks = self.get_success(self.provider.load_jwks(force=True))
         self.http_client.get_json.assert_not_called()
         self.assertEqual(jwks, {"keys": []})
 
     @override_config({"oidc_config": COMMON_CONFIG})
     def test_validate_config(self):
         """Provider metadatas are extensively validated."""
-        h = self.handler
+        h = self.provider
 
         # Default test config does not throw
         h._validate_metadata()
@@ -317,13 +304,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
         """Provider metadata validation can be disabled by config."""
         with self.metadata_edit({"issuer": "http://insecure"}):
             # This should not throw
-            self.handler._validate_metadata()
+            self.provider._validate_metadata()
 
     def test_redirect_request(self):
         """The redirect request has the right arguments & generates a valid session cookie."""
         req = Mock(spec=["addCookie"])
         url = self.get_success(
-            self.handler.handle_redirect_request(req, b"http://client/redirect")
+            self.provider.handle_redirect_request(req, b"http://client/redirect")
         )
         url = urlparse(url)
         auth_endpoint = urlparse(AUTHORIZATION_ENDPOINT)
@@ -347,14 +334,21 @@ class OidcHandlerTestCase(HomeserverTestCase):
         # For some reason, call.args does not work with python3.5
         args = calls[0][0]
         kwargs = calls[0][1]
-        self.assertEqual(args[0], COOKIE_NAME)
-        self.assertEqual(kwargs["path"], COOKIE_PATH)
+
+        # The cookie name and path don't really matter, just that it has to be coherent
+        # between the callback & redirect handlers.
+        self.assertEqual(args[0], b"oidc_session")
+        self.assertEqual(kwargs["path"], "/_synapse/client/oidc")
         cookie = args[1]
 
         macaroon = pymacaroons.Macaroon.deserialize(cookie)
-        state = self.handler._get_value_from_macaroon(macaroon, "state")
-        nonce = self.handler._get_value_from_macaroon(macaroon, "nonce")
-        redirect = self.handler._get_value_from_macaroon(
+        state = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "state"
+        )
+        nonce = self.handler._token_generator._get_value_from_macaroon(
+            macaroon, "nonce"
+        )
+        redirect = self.handler._token_generator._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
 
@@ -384,31 +378,29 @@ class OidcHandlerTestCase(HomeserverTestCase):
          - when the userinfo fetching fails
          - when the code exchange fails
         """
+
+        # ensure that we are correctly testing the fallback when "get_extra_attributes"
+        # is not implemented.
+        mapping_provider = self.provider._user_mapping_provider
+        with self.assertRaises(AttributeError):
+            _ = mapping_provider.get_extra_attributes
+
         token = {
             "type": "bearer",
             "id_token": "id_token",
             "access_token": "access_token",
         }
+        username = "bar"
         userinfo = {
             "sub": "foo",
-            "preferred_username": "bar",
+            "username": username,
         }
-        user_id = "@foo:domain.org"
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._fetch_userinfo = simple_async_mock(return_value=userinfo)
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
-        self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+        expected_user_id = "@%s:%s" % (username, self.hs.hostname)
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        self.provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
 
         code = "code"
         state = "state"
@@ -416,74 +408,61 @@ class OidcHandlerTestCase(HomeserverTestCase):
         client_redirect_url = "http://client/redirect"
         user_agent = "Browser"
         ip_address = "10.0.0.1"
-        request.getCookie.return_value = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce=nonce,
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(state, nonce, client_redirect_url)
+        request = _build_callback_request(
+            code, state, session, user_agent=user_agent, ip_address=ip_address
         )
 
-        request.args = {}
-        request.args[b"code"] = [code.encode("utf-8")]
-        request.args[b"state"] = [state.encode("utf-8")]
-
-        request.getClientIP.return_value = ip_address
-        request.get_user_agent.return_value = user_agent
-
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id, request, client_redirect_url, None, new_user=True
         )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_called_once_with(token, nonce=nonce)
-        self.handler._map_userinfo_to_user.assert_called_once_with(
-            userinfo, token, user_agent, ip_address
-        )
-        self.handler._fetch_userinfo.assert_not_called()
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_called_once_with(token, nonce=nonce)
+        self.provider._fetch_userinfo.assert_not_called()
         self.render_error.assert_not_called()
 
         # Handle mapping errors
-        self.handler._map_userinfo_to_user = simple_async_mock(
-            raises=MappingException()
-        )
-        self.get_success(self.handler.handle_oidc_callback(request))
-        self.assertRenderedError("mapping_error")
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
+        with patch.object(
+            self.provider,
+            "_remote_id_from_userinfo",
+            new=Mock(side_effect=MappingException()),
+        ):
+            self.get_success(self.handler.handle_oidc_callback(request))
+            self.assertRenderedError("mapping_error")
 
         # Handle ID token errors
-        self.handler._parse_id_token = simple_async_mock(raises=Exception())
+        self.provider._parse_id_token = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("invalid_token")
 
-        self.handler._auth_handler.complete_sso_login.reset_mock()
-        self.handler._exchange_code.reset_mock()
-        self.handler._parse_id_token.reset_mock()
-        self.handler._map_userinfo_to_user.reset_mock()
-        self.handler._fetch_userinfo.reset_mock()
+        auth_handler.complete_sso_login.reset_mock()
+        self.provider._exchange_code.reset_mock()
+        self.provider._parse_id_token.reset_mock()
+        self.provider._fetch_userinfo.reset_mock()
 
         # With userinfo fetching
-        self.handler._scopes = []  # do not ask the "openid" scope
+        self.provider._scopes = []  # do not ask the "openid" scope
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {},
-        )
-        self.handler._exchange_code.assert_called_once_with(code)
-        self.handler._parse_id_token.assert_not_called()
-        self.handler._map_userinfo_to_user.assert_called_once_with(
-            userinfo, token, user_agent, ip_address
+        auth_handler.complete_sso_login.assert_called_once_with(
+            expected_user_id, request, client_redirect_url, None, new_user=False
         )
-        self.handler._fetch_userinfo.assert_called_once_with(token)
+        self.provider._exchange_code.assert_called_once_with(code)
+        self.provider._parse_id_token.assert_not_called()
+        self.provider._fetch_userinfo.assert_called_once_with(token)
         self.render_error.assert_not_called()
 
         # Handle userinfo fetching error
-        self.handler._fetch_userinfo = simple_async_mock(raises=Exception())
+        self.provider._fetch_userinfo = simple_async_mock(raises=Exception())
         self.get_success(self.handler.handle_oidc_callback(request))
         self.assertRenderedError("fetch_error")
 
         # Handle code exchange failure
-        self.handler._exchange_code = simple_async_mock(
+        from synapse.handlers.oidc_handler import OidcError
+
+        self.provider._exchange_code = simple_async_mock(
             raises=OidcError("invalid_request")
         )
         self.get_success(self.handler.handle_oidc_callback(request))
@@ -513,11 +492,8 @@ class OidcHandlerTestCase(HomeserverTestCase):
         self.assertRenderedError("invalid_session")
 
         # Mismatching session
-        session = self.handler._generate_oidc_session_token(
-            state="state",
-            nonce="nonce",
-            client_redirect_url="http://client/redirect",
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state="state", nonce="nonce", client_redirect_url="http://client/redirect",
         )
         request.args = {}
         request.args[b"state"] = [b"mismatching state"]
@@ -541,7 +517,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
             return_value=FakeResponse(code=200, phrase=b"OK", body=token_json)
         )
         code = "code"
-        ret = self.get_success(self.handler._exchange_code(code))
+        ret = self.get_success(self.provider._exchange_code(code))
         kwargs = self.http_client.request.call_args[1]
 
         self.assertEqual(ret, token)
@@ -563,7 +539,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 body=b'{"error": "foo", "error_description": "bar"}',
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        from synapse.handlers.oidc_handler import OidcError
+
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "foo")
         self.assertEqual(exc.value.error_description, "bar")
 
@@ -573,7 +551,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=500, phrase=b"Internal Server Error", body=b"Not JSON",
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # Internal server error with JSON body
@@ -585,14 +563,14 @@ class OidcHandlerTestCase(HomeserverTestCase):
             )
         )
 
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "internal_server_error")
 
         # 4xx error without "error" field
         self.http_client.request = simple_async_mock(
             return_value=FakeResponse(code=400, phrase=b"Bad request", body=b"{}",)
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "server_error")
 
         # 2xx error with "error" field
@@ -601,7 +579,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
                 code=200, phrase=b"OK", body=b'{"error": "some_error"}',
             )
         )
-        exc = self.get_failure(self.handler._exchange_code(code), OidcError)
+        exc = self.get_failure(self.provider._exchange_code(code), OidcError)
         self.assertEqual(exc.value.error, "some_error")
 
     @override_config(
@@ -624,72 +602,56 @@ class OidcHandlerTestCase(HomeserverTestCase):
         }
         userinfo = {
             "sub": "foo",
+            "username": "foo",
             "phone": "1234567",
         }
-        user_id = "@foo:domain.org"
-        self.handler._exchange_code = simple_async_mock(return_value=token)
-        self.handler._parse_id_token = simple_async_mock(return_value=userinfo)
-        self.handler._map_userinfo_to_user = simple_async_mock(return_value=user_id)
-        self.handler._auth_handler.complete_sso_login = simple_async_mock()
-        request = Mock(
-            spec=[
-                "args",
-                "getCookie",
-                "addCookie",
-                "requestHeaders",
-                "getClientIP",
-                "get_user_agent",
-            ]
-        )
+        self.provider._exchange_code = simple_async_mock(return_value=token)
+        self.provider._parse_id_token = simple_async_mock(return_value=userinfo)
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
 
         state = "state"
         client_redirect_url = "http://client/redirect"
-        request.getCookie.return_value = self.handler._generate_oidc_session_token(
-            state=state,
-            nonce="nonce",
-            client_redirect_url=client_redirect_url,
-            ui_auth_session_id=None,
+        session = self._generate_oidc_session_token(
+            state=state, nonce="nonce", client_redirect_url=client_redirect_url,
         )
-
-        request.args = {}
-        request.args[b"code"] = [b"code"]
-        request.args[b"state"] = [state.encode("utf-8")]
-
-        request.getClientIP.return_value = "10.0.0.1"
-        request.get_user_agent.return_value = "Browser"
+        request = _build_callback_request("code", state, session)
 
         self.get_success(self.handler.handle_oidc_callback(request))
 
-        self.handler._auth_handler.complete_sso_login.assert_called_once_with(
-            user_id, request, client_redirect_url, {"phone": "1234567"},
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@foo:test",
+            request,
+            client_redirect_url,
+            {"phone": "1234567"},
+            new_user=True,
         )
 
     def test_map_userinfo_to_user(self):
         """Ensure that mapping the userinfo returned from a provider to an MXID works properly."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         userinfo = {
             "sub": "test_user",
             "username": "test_user",
         }
-        # The token doesn't matter with the default user mapping provider.
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", ANY, ANY, None, new_user=True
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Some providers return an integer ID.
         userinfo = {
             "sub": 1234,
             "username": "test_user_2",
         }
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user_2:test", ANY, ANY, None, new_user=True
         )
-        self.assertEqual(mxid, "@test_user_2:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Test if the mxid is already taken
         store = self.hs.get_datastore()
@@ -698,14 +660,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user3.to_string(), password_hash=None)
         )
         userinfo = {"sub": "test3", "username": "test_user_3"}
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
-        self.assertEqual(
-            str(e.value), "Mapping provider does not support de-duplicating Matrix IDs",
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        self.assertRenderedError(
+            "mapping_error",
+            "Mapping provider does not support de-duplicating Matrix IDs",
         )
 
     @override_config({"oidc_config": {"allow_existing_users": True}})
@@ -717,26 +676,26 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user.to_string(), password_hash=None)
         )
 
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         # Map a user via SSO.
         userinfo = {
             "sub": "test",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Subsequent calls should map to the same mxid.
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Note that a second SSO user can be mapped to the same Matrix ID. (This
         # requires a unique sub, but something that maps to the same matrix ID,
@@ -747,13 +706,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test1",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            user.to_string(), ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register some non-exact matching cases.
         user2 = UserID.from_string("@TEST_user_2:test")
@@ -770,14 +727,11 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test2",
             "username": "TEST_USER_2",
         }
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
-        )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        args = self.assertRenderedError("mapping_error")
         self.assertTrue(
-            str(e.value).startswith(
+            args[2].startswith(
                 "Attempted to login as '@TEST_USER_2:test' but it matches more than one user inexactly:"
             )
         )
@@ -788,28 +742,17 @@ class OidcHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id=user2.to_string(), password_hash=None)
         )
 
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@TEST_USER_2:test", ANY, ANY, None, new_user=False
         )
-        self.assertEqual(mxid, "@TEST_USER_2:test")
 
     def test_map_userinfo_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
-        userinfo = {
-            "sub": "test2",
-            "username": "föö",
-        }
-        token = {}
-
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(
+            _make_callback_with_userinfo(self.hs, {"sub": "test2", "username": "föö"})
         )
-        self.assertEqual(str(e.value), "localpart is invalid: föö")
+        self.assertRenderedError("mapping_error", "localpart is invalid: föö")
 
     @override_config(
         {
@@ -822,6 +765,9 @@ class OidcHandlerTestCase(HomeserverTestCase):
     )
     def test_map_userinfo_to_user_retries(self):
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         store = self.hs.get_datastore()
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
@@ -830,14 +776,13 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "test",
             "username": "test_user",
         }
-        token = {}
-        mxid = self.get_success(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            )
-        )
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+
         # test_user is already taken, so test_user1 gets registered instead.
-        self.assertEqual(mxid, "@test_user1:test")
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user1:test", ANY, ANY, None, new_user=True
+        )
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register all of the potential mxids for a particular OIDC username.
         self.get_success(
@@ -853,12 +798,128 @@ class OidcHandlerTestCase(HomeserverTestCase):
             "sub": "tester",
             "username": "tester",
         }
-        e = self.get_failure(
-            self.handler._map_userinfo_to_user(
-                userinfo, token, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        auth_handler.complete_sso_login.assert_not_called()
+        self.assertRenderedError(
+            "mapping_error", "Unable to generate a Matrix ID from the SSO response"
         )
-        self.assertEqual(
-            str(e.value), "Unable to generate a Matrix ID from the SSO response"
+
+    def test_empty_localpart(self):
+        """Attempts to map onto an empty localpart should be rejected."""
+        userinfo = {
+            "sub": "tester",
+            "username": "",
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+    @override_config(
+        {
+            "oidc_config": {
+                "user_mapping_provider": {
+                    "config": {"localpart_template": "{{ user.username }}"}
+                }
+            }
+        }
+    )
+    def test_null_localpart(self):
+        """Mapping onto a null localpart via an empty OIDC attribute should be rejected"""
+        userinfo = {
+            "sub": "tester",
+            "username": None,
+        }
+        self.get_success(_make_callback_with_userinfo(self.hs, userinfo))
+        self.assertRenderedError("mapping_error", "localpart is invalid: ")
+
+    def _generate_oidc_session_token(
+        self,
+        state: str,
+        nonce: str,
+        client_redirect_url: str,
+        ui_auth_session_id: Optional[str] = None,
+    ) -> str:
+        from synapse.handlers.oidc_handler import OidcSessionData
+
+        return self.handler._token_generator.generate_oidc_session_token(
+            state=state,
+            session_data=OidcSessionData(
+                idp_id="oidc",
+                nonce=nonce,
+                client_redirect_url=client_redirect_url,
+                ui_auth_session_id=ui_auth_session_id,
+            ),
         )
+
+
+async def _make_callback_with_userinfo(
+    hs: HomeServer, userinfo: dict, client_redirect_url: str = "http://client/redirect"
+) -> None:
+    """Mock up an OIDC callback with the given userinfo dict
+
+    We'll pull out the OIDC handler from the homeserver, stub out a couple of methods,
+    and poke in the userinfo dict as if it were the response to an OIDC userinfo call.
+
+    Args:
+        hs: the HomeServer impl to send the callback to.
+        userinfo: the OIDC userinfo dict
+        client_redirect_url: the URL to redirect to on success.
+    """
+    from synapse.handlers.oidc_handler import OidcSessionData
+
+    handler = hs.get_oidc_handler()
+    provider = handler._providers["oidc"]
+    provider._exchange_code = simple_async_mock(return_value={})
+    provider._parse_id_token = simple_async_mock(return_value=userinfo)
+    provider._fetch_userinfo = simple_async_mock(return_value=userinfo)
+
+    state = "state"
+    session = handler._token_generator.generate_oidc_session_token(
+        state=state,
+        session_data=OidcSessionData(
+            idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
+        ),
+    )
+    request = _build_callback_request("code", state, session)
+
+    await handler.handle_oidc_callback(request)
+
+
+def _build_callback_request(
+    code: str,
+    state: str,
+    session: str,
+    user_agent: str = "Browser",
+    ip_address: str = "10.0.0.1",
+):
+    """Builds a fake SynapseRequest to mock the browser callback
+
+    Returns a Mock object which looks like the SynapseRequest we get from a browser
+    after SSO (before we return to the client)
+
+    Args:
+        code: the authorization code which would have been returned by the OIDC
+           provider
+        state: the "state" param which would have been passed around in the
+           query param. Should be the same as was embedded in the session in
+           _build_oidc_session.
+        session: the "session" which would have been passed around in the cookie.
+        user_agent: the user-agent to present
+        ip_address: the IP address to pretend the request came from
+    """
+    request = Mock(
+        spec=[
+            "args",
+            "getCookie",
+            "addCookie",
+            "requestHeaders",
+            "getClientIP",
+            "getHeader",
+        ]
+    )
+
+    request.getCookie.return_value = session
+    request.args = {}
+    request.args[b"code"] = [code.encode("utf-8")]
+    request.args[b"state"] = [state.encode("utf-8")]
+    request.getClientIP.return_value = ip_address
+    return request
diff --git a/tests/handlers/test_password_providers.py b/tests/handlers/test_password_providers.py
index ceaf0902d2..f816594ee4 100644
--- a/tests/handlers/test_password_providers.py
+++ b/tests/handlers/test_password_providers.py
@@ -432,6 +432,29 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     @override_config(
         {
+            **providers_config(CustomAuthProvider),
+            "password_config": {"enabled": False, "localdb_enabled": False},
+        }
+    )
+    def test_custom_auth_password_disabled_localdb_enabled(self):
+        """Check the localdb_enabled == enabled == False
+
+        Regression test for https://github.com/matrix-org/synapse/issues/8914: check
+        that setting *both* `localdb_enabled` *and* `password: enabled` to False doesn't
+        cause an exception.
+        """
+        self.register_user("localuser", "localpass")
+
+        flows = self._get_login_flows()
+        self.assertEqual(flows, [{"type": "test.login_type"}] + ADDITIONAL_LOGIN_FLOWS)
+
+        # login shouldn't work and should be rejected with a 400 ("unknown login type")
+        channel = self._send_password_login("localuser", "localpass")
+        self.assertEqual(channel.code, 400, channel.result)
+        mock_password_provider.check_auth.assert_not_called()
+
+    @override_config(
+        {
             **providers_config(PasswordCustomAuthProvider),
             "password_config": {"enabled": False},
         }
@@ -528,7 +551,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 400, channel.result)
 
     def _get_login_flows(self) -> JsonDict:
-        _, channel = self.make_request("GET", "/_matrix/client/r0/login")
+        channel = self.make_request("GET", "/_matrix/client/r0/login")
         self.assertEqual(channel.code, 200, channel.result)
         return channel.json_body["flows"]
 
@@ -537,7 +560,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
 
     def _send_login(self, type, user, **params) -> FakeChannel:
         params.update({"identifier": {"type": "m.id.user", "user": user}, "type": type})
-        _, channel = self.make_request("POST", "/_matrix/client/r0/login", params)
+        channel = self.make_request("POST", "/_matrix/client/r0/login", params)
         return channel
 
     def _start_delete_device_session(self, access_token, device_id) -> str:
@@ -574,7 +597,7 @@ class PasswordAuthProviderTests(unittest.HomeserverTestCase):
         self, access_token: str, device: str, body: Union[JsonDict, bytes] = b"",
     ) -> FakeChannel:
         """Delete an individual device."""
-        _, channel = self.make_request(
+        channel = self.make_request(
             "DELETE", "devices/" + device, body, access_token=access_token
         )
         return channel
diff --git a/tests/handlers/test_presence.py b/tests/handlers/test_presence.py
index 8ed67640f8..0794b32c9c 100644
--- a/tests/handlers/test_presence.py
+++ b/tests/handlers/test_presence.py
@@ -463,7 +463,7 @@ class PresenceJoinTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor, clock):
         hs = self.setup_test_homeserver(
-            "server", http_client=None, federation_sender=Mock()
+            "server", federation_http_client=None, federation_sender=Mock()
         )
         return hs
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index a69fa28b41..022943a10a 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -44,8 +44,6 @@ class ProfileTestCase(unittest.TestCase):
 
         hs = yield setup_test_homeserver(
             self.addCleanup,
-            http_client=None,
-            resource_for_federation=Mock(),
             federation_client=self.mock_federation,
             federation_server=Mock(),
             federation_registry=self.mock_registry,
@@ -107,6 +105,21 @@ class ProfileTestCase(unittest.TestCase):
             "Frank",
         )
 
+        # Set displayname to an empty string
+        yield defer.ensureDeferred(
+            self.handler.set_displayname(
+                self.frank, synapse.types.create_requester(self.frank), ""
+            )
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_displayname(self.frank.localpart)
+                )
+            )
+        )
+
     @defer.inlineCallbacks
     def test_set_my_name_if_disabled(self):
         self.hs.config.enable_set_displayname = False
@@ -225,6 +238,21 @@ class ProfileTestCase(unittest.TestCase):
             "http://my.server/me.png",
         )
 
+        # Set avatar to an empty string
+        yield defer.ensureDeferred(
+            self.handler.set_avatar_url(
+                self.frank, synapse.types.create_requester(self.frank), "",
+            )
+        )
+
+        self.assertIsNone(
+            (
+                yield defer.ensureDeferred(
+                    self.store.get_profile_avatar_url(self.frank.localpart)
+                )
+            ),
+        )
+
     @defer.inlineCallbacks
     def test_set_my_avatar_if_disabled(self):
         self.hs.config.enable_set_avatar_url = False
diff --git a/tests/handlers/test_saml.py b/tests/handlers/test_saml.py
index 45dc17aba5..a8d6c0f617 100644
--- a/tests/handlers/test_saml.py
+++ b/tests/handlers/test_saml.py
@@ -12,13 +12,35 @@
 #  See the License for the specific language governing permissions and
 #  limitations under the License.
 
+from typing import Optional
+
+from mock import Mock
+
 import attr
 
 from synapse.api.errors import RedirectException
-from synapse.handlers.sso import MappingException
 
+from tests.test_utils import simple_async_mock
 from tests.unittest import HomeserverTestCase, override_config
 
+# Check if we have the dependencies to run the tests.
+try:
+    import saml2.config
+    from saml2.sigver import SigverError
+
+    has_saml2 = True
+
+    # pysaml2 can be installed and imported, but might not be able to find xmlsec1.
+    config = saml2.config.SPConfig()
+    try:
+        config.load({"metadata": {}})
+        has_xmlsec1 = True
+    except SigverError:
+        has_xmlsec1 = False
+except ImportError:
+    has_saml2 = False
+    has_xmlsec1 = False
+
 # These are a few constants that are used as config parameters in the tests.
 BASE_URL = "https://synapse/"
 
@@ -26,6 +48,8 @@ BASE_URL = "https://synapse/"
 @attr.s
 class FakeAuthnResponse:
     ava = attr.ib(type=dict)
+    assertions = attr.ib(type=list, factory=list)
+    in_response_to = attr.ib(type=Optional[str], default=None)
 
 
 class TestMappingProvider:
@@ -86,17 +110,29 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         return hs
 
+    if not has_saml2:
+        skip = "Requires pysaml2"
+    elif not has_xmlsec1:
+        skip = "Requires xmlsec1"
+
     def test_map_saml_response_to_user(self):
         """Ensure that mapping the SAML response returned from a provider to an MXID works properly."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # send a mocked-up SAML response to the callback
         saml_response = FakeAuthnResponse({"uid": "test_user", "username": "test_user"})
-        # The redirect_url doesn't matter with the default user mapping provider.
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "redirect_uri")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "redirect_uri", None, new_user=True
         )
-        self.assertEqual(mxid, "@test_user:test")
 
     @override_config({"saml2_config": {"grandfathered_mxid_source_attribute": "mxid"}})
     def test_map_saml_response_to_existing_user(self):
@@ -106,53 +142,81 @@ class SamlHandlerTestCase(HomeserverTestCase):
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
 
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
         # Map a user via SSO.
         saml_response = FakeAuthnResponse(
             {"uid": "tester", "mxid": ["test_user"], "username": "test_user"}
         )
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "")
+        )
+
+        # check that the auth handler got called as expected
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "", None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
 
         # Subsequent calls should map to the same mxid.
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        auth_handler.complete_sso_login.reset_mock()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, "")
+        )
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user:test", request, "", None, new_user=False
         )
-        self.assertEqual(mxid, "@test_user:test")
 
     def test_map_saml_response_to_invalid_localpart(self):
         """If the mapping provider generates an invalid localpart it should be rejected."""
+
+        # stub out the auth handler
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+
+        # mock out the error renderer too
+        sso_handler = self.hs.get_sso_handler()
+        sso_handler.render_error = Mock(return_value=None)
+
         saml_response = FakeAuthnResponse({"uid": "test", "username": "föö"})
-        redirect_url = ""
-        e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
+        )
+        sso_handler.render_error.assert_called_once_with(
+            request, "mapping_error", "localpart is invalid: föö"
         )
-        self.assertEqual(str(e.value), "localpart is invalid: föö")
+        auth_handler.complete_sso_login.assert_not_called()
 
     def test_map_saml_response_to_user_retries(self):
         """The mapping provider can retry generating an MXID if the MXID is already in use."""
+
+        # stub out the auth handler and error renderer
+        auth_handler = self.hs.get_auth_handler()
+        auth_handler.complete_sso_login = simple_async_mock()
+        sso_handler = self.hs.get_sso_handler()
+        sso_handler.render_error = Mock(return_value=None)
+
+        # register a user to occupy the first-choice MXID
         store = self.hs.get_datastore()
         self.get_success(
             store.register_user(user_id="@test_user:test", password_hash=None)
         )
+
+        # send the fake SAML response
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
-        redirect_url = ""
-        mxid = self.get_success(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            )
+        request = _mock_request()
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
         )
+
         # test_user is already taken, so test_user1 gets registered instead.
-        self.assertEqual(mxid, "@test_user1:test")
+        auth_handler.complete_sso_login.assert_called_once_with(
+            "@test_user1:test", request, "", None, new_user=True
+        )
+        auth_handler.complete_sso_login.reset_mock()
 
         # Register all of the potential mxids for a particular SAML username.
         self.get_success(
@@ -165,15 +229,15 @@ class SamlHandlerTestCase(HomeserverTestCase):
 
         # Now attempt to map to a username, this will fail since all potential usernames are taken.
         saml_response = FakeAuthnResponse({"uid": "tester", "username": "tester"})
-        e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
-            MappingException,
+        self.get_success(
+            self.handler._handle_authn_response(request, saml_response, ""),
         )
-        self.assertEqual(
-            str(e.value), "Unable to generate a Matrix ID from the SSO response"
+        sso_handler.render_error.assert_called_once_with(
+            request,
+            "mapping_error",
+            "Unable to generate a Matrix ID from the SSO response",
         )
+        auth_handler.complete_sso_login.assert_not_called()
 
     @override_config(
         {
@@ -185,12 +249,17 @@ class SamlHandlerTestCase(HomeserverTestCase):
         }
     )
     def test_map_saml_response_redirect(self):
+        """Test a mapping provider that raises a RedirectException"""
+
         saml_response = FakeAuthnResponse({"uid": "test", "username": "test_user"})
-        redirect_url = ""
+        request = _mock_request()
         e = self.get_failure(
-            self.handler._map_saml_response_to_user(
-                saml_response, redirect_url, "user-agent", "10.10.10.10"
-            ),
+            self.handler._handle_authn_response(request, saml_response, ""),
             RedirectException,
         )
         self.assertEqual(e.value.location, b"https://custom-saml-redirect/")
+
+
+def _mock_request():
+    """Returns a mock which will stand in as a SynapseRequest"""
+    return Mock(spec=["getClientIP", "getHeader"])
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index abbdf2d524..96e5bdac4a 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
 
 
 import json
+from typing import Dict
 
 from mock import ANY, Mock, call
 
 from twisted.internet import defer
+from twisted.web.resource import Resource
 
 from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
 from synapse.types import UserID, create_requester
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
-from tests.utils import register_federation_servlets
 
 # Some local users to test with
 U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
 
 
 class TypingNotificationsTestCase(unittest.HomeserverTestCase):
-    servlets = [register_federation_servlets]
-
     def make_homeserver(self, reactor, clock):
         # we mock out the keyring so as to skip the authentication check on the
         # federation API call.
@@ -70,13 +70,18 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         hs = self.setup_test_homeserver(
             notifier=Mock(),
-            http_client=mock_federation_client,
+            federation_http_client=mock_federation_client,
             keyring=mock_keyring,
             replication_streams={},
         )
 
         return hs
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TransportLayerServer(self.hs)
+        return d
+
     def prepare(self, reactor, clock, hs):
         mock_notifier = hs.get_notifier()
         self.on_new_event = mock_notifier.on_new_event
@@ -192,7 +197,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        put_json = self.hs.get_http_client().put_json
+        put_json = self.hs.get_federation_http_client().put_json
         put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
@@ -215,7 +220,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.assertEquals(self.event_source.get_current_key(), 0)
 
-        (request, channel) = self.make_request(
+        channel = self.make_request(
             "PUT",
             "/_matrix/federation/v1/send/1000000",
             _make_edu_transaction_json(
@@ -270,7 +275,7 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         self.on_new_event.assert_has_calls([call("typing_key", 1, rooms=[ROOM_ID])])
 
-        put_json = self.hs.get_http_client().put_json
+        put_json = self.hs.get_federation_http_client().put_json
         put_json.assert_called_once_with(
             "farm",
             path="/_matrix/federation/v1/send/1000000",
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 98e5af2072..9c886d671a 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -54,6 +54,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
                 user_id=support_user_id, password_hash=None, user_type=UserTypes.SUPPORT
             )
         )
+        regular_user_id = "@regular:test"
+        self.get_success(
+            self.store.register_user(user_id=regular_user_id, password_hash=None)
+        )
 
         self.get_success(
             self.handler.handle_local_profile_change(support_user_id, None)
@@ -63,13 +67,47 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         display_name = "display_name"
 
         profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
-        regular_user_id = "@regular:test"
         self.get_success(
             self.handler.handle_local_profile_change(regular_user_id, profile_info)
         )
         profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
         self.assertTrue(profile["display_name"] == display_name)
 
+    def test_handle_local_profile_change_with_deactivated_user(self):
+        # create user
+        r_user_id = "@regular:test"
+        self.get_success(
+            self.store.register_user(user_id=r_user_id, password_hash=None)
+        )
+
+        # update profile
+        display_name = "Regular User"
+        profile_info = ProfileInfo(avatar_url="avatar_url", display_name=display_name)
+        self.get_success(
+            self.handler.handle_local_profile_change(r_user_id, profile_info)
+        )
+
+        # profile is in directory
+        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        self.assertTrue(profile["display_name"] == display_name)
+
+        # deactivate user
+        self.get_success(self.store.set_user_deactivated_status(r_user_id, True))
+        self.get_success(self.handler.handle_user_deactivated(r_user_id))
+
+        # profile is not in directory
+        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        self.assertTrue(profile is None)
+
+        # update profile after deactivation
+        self.get_success(
+            self.handler.handle_local_profile_change(r_user_id, profile_info)
+        )
+
+        # profile is furthermore not in directory
+        profile = self.get_success(self.store.get_user_in_directory(r_user_id))
+        self.assertTrue(profile is None)
+
     def test_handle_user_deactivated_support_user(self):
         s_user_id = "@support:test"
         self.get_success(
@@ -270,7 +308,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         spam_checker = self.hs.get_spam_checker()
 
         class AllowAll:
-            def check_username_for_spam(self, user_profile):
+            async def check_username_for_spam(self, user_profile):
                 # Allow all users.
                 return False
 
@@ -283,7 +321,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         # Configure a spam checker that filters all users.
         class BlockAll:
-            def check_username_for_spam(self, user_profile):
+            async def check_username_for_spam(self, user_profile):
                 # All users are spammy.
                 return True
 
@@ -534,7 +572,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2)
 
         # Assert user directory is not empty
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", b"user_directory/search", b'{"search_term":"user2"}'
         )
         self.assertEquals(200, channel.code, channel.result)
@@ -542,7 +580,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
 
         # Disable user directory and check search returns nothing
         self.config.user_directory_search_enabled = False
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", b"user_directory/search", b'{"search_term":"user2"}'
         )
         self.assertEquals(200, channel.code, channel.result)