summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/_base.py4
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/auth.py412
-rw-r--r--synapse/handlers/cas_handler.py71
-rw-r--r--synapse/handlers/deactivate_account.py5
-rw-r--r--synapse/handlers/federation.py10
-rw-r--r--synapse/handlers/identity.py3
-rw-r--r--synapse/handlers/message.py21
-rw-r--r--synapse/handlers/oidc_handler.py203
-rw-r--r--synapse/handlers/pagination.py17
-rw-r--r--synapse/handlers/presence.py5
-rw-r--r--synapse/handlers/profile.py8
-rw-r--r--synapse/handlers/receipts.py3
-rw-r--r--synapse/handlers/register.py238
-rw-r--r--synapse/handlers/room.py10
-rw-r--r--synapse/handlers/room_member.py67
-rw-r--r--synapse/handlers/saml_handler.py157
-rw-r--r--synapse/handlers/sso.py244
-rw-r--r--synapse/handlers/sync.py4
19 files changed, 1006 insertions, 478 deletions
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index bd8e71ae56..bb81c0e81d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -169,7 +169,9 @@ class BaseHandler:
                 # and having homeservers have their own users leave keeps more
                 # of that decision-making and control local to the guest-having
                 # homeserver.
-                requester = synapse.types.create_requester(target_user, is_guest=True)
+                requester = synapse.types.create_requester(
+                    target_user, is_guest=True, authenticated_entity=self.server_name
+                )
                 handler = self.hs.get_room_member_handler()
                 await handler.update_membership(
                     requester,
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9fc8444228..5c6458eb52 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -226,7 +226,7 @@ class ApplicationServicesHandler:
         new_token: Optional[int],
         users: Collection[Union[str, UserID]],
     ):
-        logger.info("Checking interested services for %s" % (stream_key))
+        logger.debug("Checking interested services for %s" % (stream_key))
         with Measure(self.clock, "notify_interested_services_ephemeral"):
             for service in services:
                 # Only handle typing if we have the latest token
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 213baea2e3..c7dc07008a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
 # Copyright 2017 Vector Creations Ltd
+# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -25,6 +26,7 @@ from typing import (
     Dict,
     Iterable,
     List,
+    Mapping,
     Optional,
     Tuple,
     Union,
@@ -181,17 +183,12 @@ class AuthHandler(BaseHandler):
         #   better way to break the loop
         account_handler = ModuleApi(hs, self)
 
-        self.password_providers = []
-        for module, config in hs.config.password_providers:
-            try:
-                self.password_providers.append(
-                    module(config=config, account_handler=account_handler)
-                )
-            except Exception as e:
-                logger.error("Error while initializing %r: %s", module, e)
-                raise
+        self.password_providers = [
+            PasswordProvider.load(module, config, account_handler)
+            for module, config in hs.config.password_providers
+        ]
 
-        logger.info("Extra password_providers: %r", self.password_providers)
+        logger.info("Extra password_providers: %s", self.password_providers)
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
@@ -205,15 +202,23 @@ class AuthHandler(BaseHandler):
         # type in the list. (NB that the spec doesn't require us to do so and
         # clients which favour types that they don't understand over those that
         # they do are technically broken)
+
+        # start out by assuming PASSWORD is enabled; we will remove it later if not.
         login_types = []
-        if self._password_enabled:
+        if hs.config.password_localdb_enabled:
             login_types.append(LoginType.PASSWORD)
+
         for provider in self.password_providers:
             if hasattr(provider, "get_supported_login_types"):
                 for t in provider.get_supported_login_types().keys():
                     if t not in login_types:
                         login_types.append(t)
+
+        if not self._password_enabled:
+            login_types.remove(LoginType.PASSWORD)
+
         self._supported_login_types = login_types
+
         # Login types and UI Auth types have a heavy overlap, but are not
         # necessarily identical. Login types have SSO (and other login types)
         # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
@@ -230,6 +235,13 @@ class AuthHandler(BaseHandler):
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
         )
 
+        # Ratelimitier for failed /login attempts
+        self._failed_login_attempts_ratelimiter = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+            burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+        )
+
         self._clock = self.hs.get_clock()
 
         # Expire old UI auth sessions after a period of time.
@@ -642,14 +654,8 @@ class AuthHandler(BaseHandler):
             res = await checker.check_auth(authdict, clientip=clientip)
             return res
 
-        # build a v1-login-style dict out of the authdict and fall back to the
-        # v1 code
-        user_id = authdict.get("user")
-
-        if user_id is None:
-            raise SynapseError(400, "", Codes.MISSING_PARAM)
-
-        (canonical_id, callback) = await self.validate_login(user_id, authdict)
+        # fall back to the v1 login flow
+        canonical_id, _ = await self.validate_login(authdict)
         return canonical_id
 
     def _get_params_recaptcha(self) -> dict:
@@ -698,8 +704,12 @@ class AuthHandler(BaseHandler):
         }
 
     async def get_access_token_for_user_id(
-        self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
-    ):
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        valid_until_ms: Optional[int],
+        puppets_user_id: Optional[str] = None,
+    ) -> str:
         """
         Creates a new access token for the user with the given user ID.
 
@@ -725,13 +735,25 @@ class AuthHandler(BaseHandler):
             fmt_expiry = time.strftime(
                 " until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
             )
-        logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
+
+        if puppets_user_id:
+            logger.info(
+                "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
+            )
+        else:
+            logger.info(
+                "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
+            )
 
         await self.auth.check_auth_blocking(user_id)
 
         access_token = self.macaroon_gen.generate_access_token(user_id)
         await self.store.add_access_token_to_user(
-            user_id, access_token, device_id, valid_until_ms
+            user_id=user_id,
+            token=access_token,
+            device_id=device_id,
+            valid_until_ms=valid_until_ms,
+            puppets_user_id=puppets_user_id,
         )
 
         # the device *should* have been registered before we got here; however,
@@ -808,17 +830,17 @@ class AuthHandler(BaseHandler):
         return self._supported_login_types
 
     async def validate_login(
-        self, username: str, login_submission: Dict[str, Any]
+        self, login_submission: Dict[str, Any], ratelimit: bool = False,
     ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
         """Authenticates the user for the /login API
 
-        Also used by the user-interactive auth flow to validate
-        m.login.password auth types.
+        Also used by the user-interactive auth flow to validate auth types which don't
+        have an explicit UIA handler, including m.password.auth.
 
         Args:
-            username: username supplied by the user
             login_submission: the whole of the login submission
                 (including 'type' and other relevant fields)
+            ratelimit: whether to apply the failed_login_attempt ratelimiter
         Returns:
             A tuple of the canonical user id, and optional callback
                 to be called once the access token and device id are issued
@@ -827,38 +849,160 @@ class AuthHandler(BaseHandler):
             SynapseError if there was a problem with the request
             LoginError if there was an authentication problem.
         """
-
-        if username.startswith("@"):
-            qualified_user_id = username
-        else:
-            qualified_user_id = UserID(username, self.hs.hostname).to_string()
-
         login_type = login_submission.get("type")
-        known_login_type = False
+        if not isinstance(login_type, str):
+            raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
+
+        # ideally, we wouldn't be checking the identifier unless we know we have a login
+        # method which uses it (https://github.com/matrix-org/synapse/issues/8836)
+        #
+        # But the auth providers' check_auth interface requires a username, so in
+        # practice we can only support login methods which we can map to a username
+        # anyway.
 
         # special case to check for "password" for the check_password interface
         # for the auth providers
         password = login_submission.get("password")
-
         if login_type == LoginType.PASSWORD:
             if not self._password_enabled:
                 raise SynapseError(400, "Password login has been disabled.")
-            if not password:
-                raise SynapseError(400, "Missing parameter: password")
+            if not isinstance(password, str):
+                raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
 
-        for provider in self.password_providers:
-            if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
-                known_login_type = True
-                is_valid = await provider.check_password(qualified_user_id, password)
-                if is_valid:
-                    return qualified_user_id, None
+        # map old-school login fields into new-school "identifier" fields.
+        identifier_dict = convert_client_dict_legacy_fields_to_identifier(
+            login_submission
+        )
 
-            if not hasattr(provider, "get_supported_login_types") or not hasattr(
-                provider, "check_auth"
-            ):
-                # this password provider doesn't understand custom login types
-                continue
+        # convert phone type identifiers to generic threepids
+        if identifier_dict["type"] == "m.id.phone":
+            identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
+
+        # convert threepid identifiers to user IDs
+        if identifier_dict["type"] == "m.id.thirdparty":
+            address = identifier_dict.get("address")
+            medium = identifier_dict.get("medium")
+
+            if medium is None or address is None:
+                raise SynapseError(400, "Invalid thirdparty identifier")
+
+            # For emails, canonicalise the address.
+            # We store all email addresses canonicalised in the DB.
+            # (See add_threepid in synapse/handlers/auth.py)
+            if medium == "email":
+                try:
+                    address = canonicalise_email(address)
+                except ValueError as e:
+                    raise SynapseError(400, str(e))
+
+            # We also apply account rate limiting using the 3PID as a key, as
+            # otherwise using 3PID bypasses the ratelimiting based on user ID.
+            if ratelimit:
+                self._failed_login_attempts_ratelimiter.ratelimit(
+                    (medium, address), update=False
+                )
+
+            # Check for login providers that support 3pid login types
+            if login_type == LoginType.PASSWORD:
+                # we've already checked that there is a (valid) password field
+                assert isinstance(password, str)
+                (
+                    canonical_user_id,
+                    callback_3pid,
+                ) = await self.check_password_provider_3pid(medium, address, password)
+                if canonical_user_id:
+                    # Authentication through password provider and 3pid succeeded
+                    return canonical_user_id, callback_3pid
+
+            # No password providers were able to handle this 3pid
+            # Check local store
+            user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+                medium, address
+            )
+            if not user_id:
+                logger.warning(
+                    "unknown 3pid identifier medium %s, address %r", medium, address
+                )
+                # We mark that we've failed to log in here, as
+                # `check_password_provider_3pid` might have returned `None` due
+                # to an incorrect password, rather than the account not
+                # existing.
+                #
+                # If it returned None but the 3PID was bound then we won't hit
+                # this code path, which is fine as then the per-user ratelimit
+                # will kick in below.
+                if ratelimit:
+                    self._failed_login_attempts_ratelimiter.can_do_action(
+                        (medium, address)
+                    )
+                raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+            identifier_dict = {"type": "m.id.user", "user": user_id}
 
+        # by this point, the identifier should be an m.id.user: if it's anything
+        # else, we haven't understood it.
+        if identifier_dict["type"] != "m.id.user":
+            raise SynapseError(400, "Unknown login identifier type")
+
+        username = identifier_dict.get("user")
+        if not username:
+            raise SynapseError(400, "User identifier is missing 'user' key")
+
+        if username.startswith("@"):
+            qualified_user_id = username
+        else:
+            qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+        # Check if we've hit the failed ratelimit (but don't update it)
+        if ratelimit:
+            self._failed_login_attempts_ratelimiter.ratelimit(
+                qualified_user_id.lower(), update=False
+            )
+
+        try:
+            return await self._validate_userid_login(username, login_submission)
+        except LoginError:
+            # The user has failed to log in, so we need to update the rate
+            # limiter. Using `can_do_action` avoids us raising a ratelimit
+            # exception and masking the LoginError. The actual ratelimiting
+            # should have happened above.
+            if ratelimit:
+                self._failed_login_attempts_ratelimiter.can_do_action(
+                    qualified_user_id.lower()
+                )
+            raise
+
+    async def _validate_userid_login(
+        self, username: str, login_submission: Dict[str, Any],
+    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+        """Helper for validate_login
+
+        Handles login, once we've mapped 3pids onto userids
+
+        Args:
+            username: the username, from the identifier dict
+            login_submission: the whole of the login submission
+                (including 'type' and other relevant fields)
+        Returns:
+            A tuple of the canonical user id, and optional callback
+                to be called once the access token and device id are issued
+        Raises:
+            StoreError if there was a problem accessing the database
+            SynapseError if there was a problem with the request
+            LoginError if there was an authentication problem.
+        """
+        if username.startswith("@"):
+            qualified_user_id = username
+        else:
+            qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+        login_type = login_submission.get("type")
+        # we already checked that we have a valid login type
+        assert isinstance(login_type, str)
+
+        known_login_type = False
+
+        for provider in self.password_providers:
             supported_login_types = provider.get_supported_login_types()
             if login_type not in supported_login_types:
                 # this password provider doesn't understand this login type
@@ -883,15 +1027,17 @@ class AuthHandler(BaseHandler):
 
             result = await provider.check_auth(username, login_type, login_dict)
             if result:
-                if isinstance(result, str):
-                    result = (result, None)
                 return result
 
         if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
             known_login_type = True
 
+            # we've already checked that there is a (valid) password field
+            password = login_submission["password"]
+            assert isinstance(password, str)
+
             canonical_user_id = await self._check_local_password(
-                qualified_user_id, password  # type: ignore
+                qualified_user_id, password
             )
 
             if canonical_user_id:
@@ -922,19 +1068,9 @@ class AuthHandler(BaseHandler):
             unsuccessful, `user_id` and `callback` are both `None`.
         """
         for provider in self.password_providers:
-            if hasattr(provider, "check_3pid_auth"):
-                # This function is able to return a deferred that either
-                # resolves None, meaning authentication failure, or upon
-                # success, to a str (which is the user_id) or a tuple of
-                # (user_id, callback_func), where callback_func should be run
-                # after we've finished everything else
-                result = await provider.check_3pid_auth(medium, address, password)
-                if result:
-                    # Check if the return value is a str or a tuple
-                    if isinstance(result, str):
-                        # If it's a str, set callback function to None
-                        result = (result, None)
-                    return result
+            result = await provider.check_3pid_auth(medium, address, password)
+            if result:
+                return result
 
         return None, None
 
@@ -992,16 +1128,11 @@ class AuthHandler(BaseHandler):
 
         # see if any of our auth providers want to know about this
         for provider in self.password_providers:
-            if hasattr(provider, "on_logged_out"):
-                # This might return an awaitable, if it does block the log out
-                # until it completes.
-                result = provider.on_logged_out(
-                    user_id=user_info.user_id,
-                    device_id=user_info.device_id,
-                    access_token=access_token,
-                )
-                if inspect.isawaitable(result):
-                    await result
+            await provider.on_logged_out(
+                user_id=user_info.user_id,
+                device_id=user_info.device_id,
+                access_token=access_token,
+            )
 
         # delete pushers associated with this access token
         if user_info.token_id is not None:
@@ -1030,11 +1161,10 @@ class AuthHandler(BaseHandler):
 
         # see if any of our auth providers want to know about this
         for provider in self.password_providers:
-            if hasattr(provider, "on_logged_out"):
-                for token, token_id, device_id in tokens_and_devices:
-                    await provider.on_logged_out(
-                        user_id=user_id, device_id=device_id, access_token=token
-                    )
+            for token, token_id, device_id in tokens_and_devices:
+                await provider.on_logged_out(
+                    user_id=user_id, device_id=device_id, access_token=token
+                )
 
         # delete pushers associated with the access tokens
         await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1358,3 +1488,127 @@ class MacaroonGenerator:
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
         return macaroon
+
+
+class PasswordProvider:
+    """Wrapper for a password auth provider module
+
+    This class abstracts out all of the backwards-compatibility hacks for
+    password providers, to provide a consistent interface.
+    """
+
+    @classmethod
+    def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+        try:
+            pp = module(config=config, account_handler=module_api)
+        except Exception as e:
+            logger.error("Error while initializing %r: %s", module, e)
+            raise
+        return cls(pp, module_api)
+
+    def __init__(self, pp, module_api: ModuleApi):
+        self._pp = pp
+        self._module_api = module_api
+
+        self._supported_login_types = {}
+
+        # grandfather in check_password support
+        if hasattr(self._pp, "check_password"):
+            self._supported_login_types[LoginType.PASSWORD] = ("password",)
+
+        g = getattr(self._pp, "get_supported_login_types", None)
+        if g:
+            self._supported_login_types.update(g())
+
+    def __str__(self):
+        return str(self._pp)
+
+    def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
+        """Get the login types supported by this password provider
+
+        Returns a map from a login type identifier (such as m.login.password) to an
+        iterable giving the fields which must be provided by the user in the submission
+        to the /login API.
+
+        This wrapper adds m.login.password to the list if the underlying password
+        provider supports the check_password() api.
+        """
+        return self._supported_login_types
+
+    async def check_auth(
+        self, username: str, login_type: str, login_dict: JsonDict
+    ) -> Optional[Tuple[str, Optional[Callable]]]:
+        """Check if the user has presented valid login credentials
+
+        This wrapper also calls check_password() if the underlying password provider
+        supports the check_password() api and the login type is m.login.password.
+
+        Args:
+            username: user id presented by the client. Either an MXID or an unqualified
+                username.
+
+            login_type: the login type being attempted - one of the types returned by
+                get_supported_login_types()
+
+            login_dict: the dictionary of login secrets passed by the client.
+
+        Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
+            user, and `callback` is an optional callback which will be called with the
+            result from the /login call (including access_token, device_id, etc.)
+        """
+        # first grandfather in a call to check_password
+        if login_type == LoginType.PASSWORD:
+            g = getattr(self._pp, "check_password", None)
+            if g:
+                qualified_user_id = self._module_api.get_qualified_user_id(username)
+                is_valid = await self._pp.check_password(
+                    qualified_user_id, login_dict["password"]
+                )
+                if is_valid:
+                    return qualified_user_id, None
+
+        g = getattr(self._pp, "check_auth", None)
+        if not g:
+            return None
+        result = await g(username, login_type, login_dict)
+
+        # Check if the return value is a str or a tuple
+        if isinstance(result, str):
+            # If it's a str, set callback function to None
+            return result, None
+
+        return result
+
+    async def check_3pid_auth(
+        self, medium: str, address: str, password: str
+    ) -> Optional[Tuple[str, Optional[Callable]]]:
+        g = getattr(self._pp, "check_3pid_auth", None)
+        if not g:
+            return None
+
+        # This function is able to return a deferred that either
+        # resolves None, meaning authentication failure, or upon
+        # success, to a str (which is the user_id) or a tuple of
+        # (user_id, callback_func), where callback_func should be run
+        # after we've finished everything else
+        result = await g(medium, address, password)
+
+        # Check if the return value is a str or a tuple
+        if isinstance(result, str):
+            # If it's a str, set callback function to None
+            return result, None
+
+        return result
+
+    async def on_logged_out(
+        self, user_id: str, device_id: Optional[str], access_token: str
+    ) -> None:
+        g = getattr(self._pp, "on_logged_out", None)
+        if not g:
+            return
+
+        # This might return an awaitable, if it does block the log out
+        # until it completes.
+        result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
+        if inspect.isawaitable(result):
+            await result
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index 048a3b3c0b..f4ea0a9767 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import urllib
-from typing import Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
 from xml.etree import ElementTree as ET
 
 from twisted.web.client import PartialDownloadError
@@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
 from synapse.http.site import SynapseRequest
 from synapse.types import UserID, map_username_to_mxid_localpart
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -31,10 +34,10 @@ class CasHandler:
     Utility class for to handle the response from a CAS SSO service.
 
     Args:
-        hs (synapse.server.HomeServer)
+        hs
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self._hostname = hs.hostname
         self._auth_handler = hs.get_auth_handler()
@@ -200,27 +203,57 @@ class CasHandler:
             args["session"] = session
         username, user_display_name = await self._validate_ticket(ticket, args)
 
-        localpart = map_username_to_mxid_localpart(username)
-        user_id = UserID(localpart, self._hostname).to_string()
-        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.get_user_agent("")
+        ip_address = self.hs.get_ip_from_request(request)
+
+        # Get the matrix ID from the CAS username.
+        user_id = await self._map_cas_user_to_matrix_user(
+            username, user_display_name, user_agent, ip_address
+        )
 
         if session:
             await self._auth_handler.complete_sso_ui_auth(
-                registered_user_id, session, request,
+                user_id, session, request,
             )
-
         else:
-            if not registered_user_id:
-                # Pull out the user-agent and IP from the request.
-                user_agent = request.get_user_agent("")
-                ip_address = self.hs.get_ip_from_request(request)
-
-                registered_user_id = await self._registration_handler.register_user(
-                    localpart=localpart,
-                    default_display_name=user_display_name,
-                    user_agent_ips=(user_agent, ip_address),
-                )
+            # If this not a UI auth request than there must be a redirect URL.
+            assert client_redirect_url
 
             await self._auth_handler.complete_sso_login(
-                registered_user_id, request, client_redirect_url
+                user_id, request, client_redirect_url
             )
+
+    async def _map_cas_user_to_matrix_user(
+        self,
+        remote_user_id: str,
+        display_name: Optional[str],
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """
+        Given a CAS username, retrieve the user ID for it and possibly register the user.
+
+        Args:
+            remote_user_id: The username from the CAS response.
+            display_name: The display name from the CAS response.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+
+        Returns:
+             The user ID associated with this response.
+        """
+
+        localpart = map_username_to_mxid_localpart(remote_user_id)
+        user_id = UserID(localpart, self._hostname).to_string()
+        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+
+        # If the user does not exist, register it.
+        if not registered_user_id:
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart,
+                default_display_name=display_name,
+                user_agent_ips=[(user_agent, ip_address)],
+            )
+
+        return registered_user_id
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 4efe6c530a..e808142365 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -39,6 +39,7 @@ class DeactivateAccountHandler(BaseHandler):
         self._room_member_handler = hs.get_room_member_handler()
         self._identity_handler = hs.get_identity_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
+        self._server_name = hs.hostname
 
         # Flag that indicates whether the process to part users from rooms is running
         self._user_parter_running = False
@@ -152,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler):
         for room in pending_invites:
             try:
                 await self._room_member_handler.update_membership(
-                    create_requester(user),
+                    create_requester(user, authenticated_entity=self._server_name),
                     user,
                     room.room_id,
                     "leave",
@@ -208,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler):
             logger.info("User parter parting %r from %r", user_id, room_id)
             try:
                 await self._room_member_handler.update_membership(
-                    create_requester(user),
+                    create_requester(user, authenticated_entity=self._server_name),
                     user,
                     room_id,
                     "leave",
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e03caea251..b9799090f7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -68,7 +68,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer
 from synapse.replication.http.federation import (
     ReplicationCleanRoomRestServlet,
     ReplicationFederationSendEventsRestServlet,
-    ReplicationStoreRoomOnInviteRestServlet,
+    ReplicationStoreRoomOnOutlierMembershipRestServlet,
 )
 from synapse.state import StateResolutionStore
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -153,12 +153,14 @@ class FederationHandler(BaseHandler):
             self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
                 hs
             )
-            self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client(
+            self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
                 hs
             )
         else:
             self._device_list_updater = hs.get_device_handler().device_list_updater
-            self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
+            self._maybe_store_room_on_outlier_membership = (
+                self.store.maybe_store_room_on_outlier_membership
+            )
 
         # When joining a room we need to queue any events for that room up.
         # For each room, a list of (pdu, origin) tuples.
@@ -1618,7 +1620,7 @@ class FederationHandler(BaseHandler):
         # keep a record of the room version, if we don't yet know it.
         # (this may get overwritten if we later get a different room version in a
         # join dance).
-        await self._maybe_store_room_on_invite(
+        await self._maybe_store_room_on_outlier_membership(
             room_id=event.room_id, room_version=room_version
         )
 
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index bc3e9607ca..9b3c6b4551 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler):
             raise SynapseError(500, "An error was encountered when sending the email")
 
         token_expires = (
-            self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
+            self.hs.get_clock().time_msec()
+            + self.hs.config.email_validation_token_lifetime
         )
 
         await self.store.start_or_continue_validation_session(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index c6791fb912..96843338ae 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -472,7 +472,7 @@ class EventCreationHandler:
         Returns:
             Tuple of created event, Context
         """
-        await self.auth.check_auth_blocking(requester.user.to_string())
+        await self.auth.check_auth_blocking(requester=requester)
 
         if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
             room_version = event_dict["content"]["room_version"]
@@ -619,7 +619,13 @@ class EventCreationHandler:
         if requester.app_service is not None:
             return
 
-        user_id = requester.user.to_string()
+        user_id = requester.authenticated_entity
+        if not user_id.startswith("@"):
+            # The authenticated entity might not be a user, e.g. if it's the
+            # server puppetting the user.
+            return
+
+        user = UserID.from_string(user_id)
 
         # exempt the system notices user
         if (
@@ -639,9 +645,7 @@ class EventCreationHandler:
         if u["consent_version"] == self.config.user_consent_version:
             return
 
-        consent_uri = self._consent_uri_builder.build_user_consent_uri(
-            requester.user.localpart
-        )
+        consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
         msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
         raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
 
@@ -1252,7 +1256,7 @@ class EventCreationHandler:
         for user_id in members:
             if not self.hs.is_mine_id(user_id):
                 continue
-            requester = create_requester(user_id)
+            requester = create_requester(user_id, authenticated_entity=self.server_name)
             try:
                 event, context = await self.create_event(
                     requester,
@@ -1273,11 +1277,6 @@ class EventCreationHandler:
                     requester, event, context, ratelimit=False, ignore_shadow_ban=True,
                 )
                 return True
-            except ConsentNotGivenError:
-                logger.info(
-                    "Failed to send dummy event into room %s for user %s due to "
-                    "lack of consent. Will try another user" % (room_id, user_id)
-                )
             except AuthError:
                 logger.info(
                     "Failed to send dummy event into room %s for user %s due to "
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 331d4e7e96..c605f7082a 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import inspect
 import logging
 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
 from urllib.parse import urlencode
@@ -34,7 +35,8 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -83,17 +85,12 @@ class OidcError(Exception):
         return self.error
 
 
-class MappingException(Exception):
-    """Used to catch errors when mapping the UserInfo object
-    """
-
-
-class OidcHandler:
+class OidcHandler(BaseHandler):
     """Handles requests related to the OpenID Connect login flow.
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
+        super().__init__(hs)
         self._callback_url = hs.config.oidc_callback_url  # type: str
         self._scopes = hs.config.oidc_scopes  # type: List[str]
         self._user_profile_method = hs.config.oidc_user_profile_method  # type: str
@@ -120,36 +117,13 @@ class OidcHandler:
         self._http_client = hs.get_proxied_http_client()
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
-        self._datastore = hs.get_datastore()
-        self._clock = hs.get_clock()
-        self._hostname = hs.hostname  # type: str
         self._server_name = hs.config.server_name  # type: str
         self._macaroon_secret_key = hs.config.macaroon_secret_key
-        self._error_template = hs.config.sso_error_template
 
         # identifier for the external_ids table
         self._auth_provider_id = "oidc"
 
-    def _render_error(
-        self, request, error: str, error_description: Optional[str] = None
-    ) -> None:
-        """Render the error template and respond to the request with it.
-
-        This is used to show errors to the user. The template of this page can
-        be found under `synapse/res/templates/sso_error.html`.
-
-        Args:
-            request: The incoming request from the browser.
-                We'll respond with an HTML page describing the error.
-            error: A technical identifier for this error. Those include
-                well-known OAuth2/OIDC error types like invalid_request or
-                access_denied.
-            error_description: A human-readable description of the error.
-        """
-        html = self._error_template.render(
-            error=error, error_description=error_description
-        )
-        respond_with_html(request, 400, html)
+        self._sso_handler = hs.get_sso_handler()
 
     def _validate_metadata(self):
         """Verifies the provider metadata.
@@ -571,7 +545,7 @@ class OidcHandler:
 
         Since we might want to display OIDC-related errors in a user-friendly
         way, we don't raise SynapseError from here. Instead, we call
-        ``self._render_error`` which displays an HTML page for the error.
+        ``self._sso_handler.render_error`` which displays an HTML page for the error.
 
         Most of the OpenID Connect logic happens here:
 
@@ -609,7 +583,7 @@ class OidcHandler:
             if error != "access_denied":
                 logger.error("Error from the OIDC provider: %s %s", error, description)
 
-            self._render_error(request, error, description)
+            self._sso_handler.render_error(request, error, description)
             return
 
         # otherwise, it is presumably a successful response. see:
@@ -619,7 +593,9 @@ class OidcHandler:
         session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
         if session is None:
             logger.info("No session cookie found")
-            self._render_error(request, "missing_session", "No session cookie found")
+            self._sso_handler.render_error(
+                request, "missing_session", "No session cookie found"
+            )
             return
 
         # Remove the cookie. There is a good chance that if the callback failed
@@ -637,7 +613,9 @@ class OidcHandler:
         # Check for the state query parameter
         if b"state" not in request.args:
             logger.info("State parameter is missing")
-            self._render_error(request, "invalid_request", "State parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "State parameter is missing"
+            )
             return
 
         state = request.args[b"state"][0].decode()
@@ -651,17 +629,19 @@ class OidcHandler:
             ) = self._verify_oidc_session_token(session, state)
         except MacaroonDeserializationException as e:
             logger.exception("Invalid session")
-            self._render_error(request, "invalid_session", str(e))
+            self._sso_handler.render_error(request, "invalid_session", str(e))
             return
         except MacaroonInvalidSignatureException as e:
             logger.exception("Could not verify session")
-            self._render_error(request, "mismatching_session", str(e))
+            self._sso_handler.render_error(request, "mismatching_session", str(e))
             return
 
         # Exchange the code with the provider
         if b"code" not in request.args:
             logger.info("Code parameter is missing")
-            self._render_error(request, "invalid_request", "Code parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "Code parameter is missing"
+            )
             return
 
         logger.debug("Exchanging code")
@@ -670,7 +650,7 @@ class OidcHandler:
             token = await self._exchange_code(code)
         except OidcError as e:
             logger.exception("Could not exchange code")
-            self._render_error(request, e.error, e.error_description)
+            self._sso_handler.render_error(request, e.error, e.error_description)
             return
 
         logger.debug("Successfully obtained OAuth2 access token")
@@ -683,7 +663,7 @@ class OidcHandler:
                 userinfo = await self._fetch_userinfo(token)
             except Exception as e:
                 logger.exception("Could not fetch userinfo")
-                self._render_error(request, "fetch_error", str(e))
+                self._sso_handler.render_error(request, "fetch_error", str(e))
                 return
         else:
             logger.debug("Extracting userinfo from id_token")
@@ -691,7 +671,7 @@ class OidcHandler:
                 userinfo = await self._parse_id_token(token, nonce=nonce)
             except Exception as e:
                 logger.exception("Invalid id_token")
-                self._render_error(request, "invalid_token", str(e))
+                self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
         # Pull out the user-agent and IP from the request.
@@ -705,7 +685,7 @@ class OidcHandler:
             )
         except MappingException as e:
             logger.exception("Could not map user")
-            self._render_error(request, "mapping_error", str(e))
+            self._sso_handler.render_error(request, "mapping_error", str(e))
             return
 
         # Mapping providers might not have get_extra_attributes: only call this
@@ -770,7 +750,7 @@ class OidcHandler:
             macaroon.add_first_party_caveat(
                 "ui_auth_session_id = %s" % (ui_auth_session_id,)
             )
-        now = self._clock.time_msec()
+        now = self.clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
 
@@ -845,7 +825,7 @@ class OidcHandler:
         if not caveat.startswith(prefix):
             return False
         expiry = int(caveat[len(prefix) :])
-        now = self._clock.time_msec()
+        now = self.clock.time_msec()
         return now < expiry
 
     async def _map_userinfo_to_user(
@@ -885,71 +865,77 @@ class OidcHandler:
         # to be strings.
         remote_user_id = str(remote_user_id)
 
-        logger.info(
-            "Looking for existing mapping for user %s:%s",
-            self._auth_provider_id,
-            remote_user_id,
-        )
-
-        registered_user_id = await self._datastore.get_user_by_external_id(
-            self._auth_provider_id, remote_user_id,
+        # Older mapping providers don't accept the `failures` argument, so we
+        # try and detect support.
+        mapper_signature = inspect.signature(
+            self._user_mapping_provider.map_user_attributes
         )
+        supports_failures = "failures" in mapper_signature.parameters
 
-        if registered_user_id is not None:
-            logger.info("Found existing mapping %s", registered_user_id)
-            return registered_user_id
-
-        try:
-            attributes = await self._user_mapping_provider.map_user_attributes(
-                userinfo, token
-            )
-        except Exception as e:
-            raise MappingException(
-                "Could not extract user attributes from OIDC response: " + str(e)
-            )
+        async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+            """
+            Call the mapping provider to map the OIDC userinfo and token to user attributes.
 
-        logger.debug(
-            "Retrieved user attributes from user mapping provider: %r", attributes
-        )
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            if supports_failures:
+                attributes = await self._user_mapping_provider.map_user_attributes(
+                    userinfo, token, failures
+                )
+            else:
+                # If the mapping provider does not support processing failures,
+                # do not continually generate the same Matrix ID since it will
+                # continue to already be in use. Note that the error raised is
+                # arbitrary and will get turned into a MappingException.
+                if failures:
+                    raise MappingException(
+                        "Mapping provider does not support de-duplicating Matrix IDs"
+                    )
 
-        if not attributes["localpart"]:
-            raise MappingException("localpart is empty")
+                attributes = await self._user_mapping_provider.map_user_attributes(  # type: ignore
+                    userinfo, token
+                )
 
-        localpart = map_username_to_mxid_localpart(attributes["localpart"])
+            return UserAttributes(**attributes)
 
-        user_id = UserID(localpart, self._hostname).to_string()
-        users = await self._datastore.get_users_by_id_case_insensitive(user_id)
-        if users:
+        async def grandfather_existing_users() -> Optional[str]:
             if self._allow_existing_users:
-                if len(users) == 1:
-                    registered_user_id = next(iter(users))
-                elif user_id in users:
-                    registered_user_id = user_id
-                else:
-                    raise MappingException(
-                        "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
-                            user_id, list(users.keys())
+                # If allowing existing users we want to generate a single localpart
+                # and attempt to match it.
+                attributes = await oidc_response_to_user_attributes(failures=0)
+
+                user_id = UserID(attributes.localpart, self.server_name).to_string()
+                users = await self.store.get_users_by_id_case_insensitive(user_id)
+                if users:
+                    # If an existing matrix ID is returned, then use it.
+                    if len(users) == 1:
+                        previously_registered_user_id = next(iter(users))
+                    elif user_id in users:
+                        previously_registered_user_id = user_id
+                    else:
+                        # Do not attempt to continue generating Matrix IDs.
+                        raise MappingException(
+                            "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+                                user_id, users
+                            )
                         )
-                    )
-            else:
-                # This mxid is taken
-                raise MappingException("mxid '{}' is already taken".format(user_id))
-        else:
-            # It's the first time this user is logging in and the mapped mxid was
-            # not taken, register the user
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=attributes["display_name"],
-                user_agent_ips=(user_agent, ip_address),
-            )
-        await self._datastore.record_user_external_id(
-            self._auth_provider_id, remote_user_id, registered_user_id,
+
+                    return previously_registered_user_id
+
+            return None
+
+        return await self._sso_handler.get_mxid_from_sso(
+            self._auth_provider_id,
+            remote_user_id,
+            user_agent,
+            ip_address,
+            oidc_response_to_user_attributes,
+            grandfather_existing_users,
         )
-        return registered_user_id
 
 
-UserAttribute = TypedDict(
-    "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+    "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
 )
 C = TypeVar("C")
 
@@ -992,13 +978,15 @@ class OidcMappingProvider(Generic[C]):
         raise NotImplementedError()
 
     async def map_user_attributes(
-        self, userinfo: UserInfo, token: Token
-    ) -> UserAttribute:
+        self, userinfo: UserInfo, token: Token, failures: int
+    ) -> UserAttributeDict:
         """Map a `UserInfo` object into user attributes.
 
         Args:
             userinfo: An object representing the user given by the OIDC provider
             token: A dict with the tokens returned by the provider
+            failures: How many times a call to this function with this
+                UserInfo has resulted in a failure.
 
         Returns:
             A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1098,10 +1086,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         return userinfo[self._config.subject_claim]
 
     async def map_user_attributes(
-        self, userinfo: UserInfo, token: Token
-    ) -> UserAttribute:
+        self, userinfo: UserInfo, token: Token, failures: int
+    ) -> UserAttributeDict:
         localpart = self._config.localpart_template.render(user=userinfo).strip()
 
+        # Ensure only valid characters are included in the MXID.
+        localpart = map_username_to_mxid_localpart(localpart)
+
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid.
+        localpart += str(failures) if failures else ""
+
         display_name = None  # type: Optional[str]
         if self._config.display_name_template is not None:
             display_name = self._config.display_name_template.render(
@@ -1111,7 +1106,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
             if display_name == "":
                 display_name = None
 
-        return UserAttribute(localpart=localpart, display_name=display_name)
+        return UserAttributeDict(localpart=localpart, display_name=display_name)
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
         extras = {}  # type: Dict[str, str]
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 426b58da9e..5372753707 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -299,17 +299,22 @@ class PaginationHandler:
         """
         return self._purges_by_id.get(purge_id)
 
-    async def purge_room(self, room_id: str) -> None:
-        """Purge the given room from the database"""
+    async def purge_room(self, room_id: str, force: bool = False) -> None:
+        """Purge the given room from the database.
+
+        Args:
+            room_id: room to be purged
+            force: set true to skip checking for joined users.
+        """
         with await self.pagination_lock.write(room_id):
             # check we know about the room
             await self.store.get_room_version_id(room_id)
 
             # first check that we have no users in this room
-            joined = await self.store.is_host_joined(room_id, self._server_name)
-
-            if joined:
-                raise SynapseError(400, "Users are still joined to this room")
+            if not force:
+                joined = await self.store.is_host_joined(room_id, self._server_name)
+                if joined:
+                    raise SynapseError(400, "Users are still joined to this room")
 
             await self.storage.purge_events.purge_room(room_id)
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 8e014c9bb5..22d1e9d35c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,7 @@ The methods that define policy are:
 import abc
 import logging
 from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
 
 from prometheus_client import Counter
 from typing_extensions import ContextManager
@@ -46,8 +46,7 @@ from synapse.util.caches.descriptors import cached
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 
-MYPY = False
-if MYPY:
+if TYPE_CHECKING:
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 74a1ddd780..dee0ef45e7 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -206,7 +206,9 @@ class ProfileHandler(BaseHandler):
         # the join event to update the displayname in the rooms.
         # This must be done by the target user himself.
         if by_admin:
-            requester = create_requester(target_user)
+            requester = create_requester(
+                target_user, authenticated_entity=requester.authenticated_entity,
+            )
 
         await self.store.set_profile_displayname(
             target_user.localpart, displayname_to_set
@@ -286,7 +288,9 @@ class ProfileHandler(BaseHandler):
 
         # Same like set_displayname
         if by_admin:
-            requester = create_requester(target_user)
+            requester = create_requester(
+                target_user, authenticated_entity=requester.authenticated_entity
+            )
 
         await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index c242c409cf..153cbae7b9 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -158,7 +158,8 @@ class ReceiptEventSource:
         if from_key == to_key:
             return [], to_key
 
-        # We first need to fetch all new receipts
+        # Fetch all read receipts for all rooms, up to a limit of 100. This is ordered
+        # by most recent.
         rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
             from_key=from_key, to_key=to_key
         )
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ed1ff62599..0d85fd0868 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,10 +15,12 @@
 
 """Contains functions for registering clients."""
 import logging
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
 from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.appservice import ApplicationService
 from synapse.config.server import is_threepid_reserved
 from synapse.http.servlet import assert_params_in_dict
 from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -32,16 +34,14 @@ from synapse.types import RoomAlias, UserID, create_requester
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class RegistrationHandler(BaseHandler):
-    def __init__(self, hs):
-        """
-
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
@@ -52,6 +52,7 @@ class RegistrationHandler(BaseHandler):
         self.ratelimiter = hs.get_registration_ratelimiter()
         self.macaroon_gen = hs.get_macaroon_generator()
         self._server_notices_mxid = hs.config.server_notices_mxid
+        self._server_name = hs.hostname
 
         self.spam_checker = hs.get_spam_checker()
 
@@ -70,7 +71,10 @@ class RegistrationHandler(BaseHandler):
         self.session_lifetime = hs.config.session_lifetime
 
     async def check_username(
-        self, localpart, guest_access_token=None, assigned_user_id=None
+        self,
+        localpart: str,
+        guest_access_token: Optional[str] = None,
+        assigned_user_id: Optional[str] = None,
     ):
         if types.contains_invalid_mxid_characters(localpart):
             raise SynapseError(
@@ -139,39 +143,45 @@ class RegistrationHandler(BaseHandler):
 
     async def register_user(
         self,
-        localpart=None,
-        password_hash=None,
-        guest_access_token=None,
-        make_guest=False,
-        admin=False,
-        threepid=None,
-        user_type=None,
-        default_display_name=None,
-        address=None,
-        bind_emails=[],
-        by_admin=False,
-        user_agent_ips=None,
-    ):
+        localpart: Optional[str] = None,
+        password_hash: Optional[str] = None,
+        guest_access_token: Optional[str] = None,
+        make_guest: bool = False,
+        admin: bool = False,
+        threepid: Optional[dict] = None,
+        user_type: Optional[str] = None,
+        default_display_name: Optional[str] = None,
+        address: Optional[str] = None,
+        bind_emails: List[str] = [],
+        by_admin: bool = False,
+        user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+    ) -> str:
         """Registers a new client on the server.
 
         Args:
             localpart: The local part of the user ID to register. If None,
               one will be generated.
-            password_hash (str|None): The hashed password to assign to this user so they can
+            password_hash: The hashed password to assign to this user so they can
               login again. This can be None which means they cannot login again
               via a password (e.g. the user is an application service user).
-            user_type (str|None): type of user. One of the values from
+            guest_access_token: The access token used when this was a guest
+                account.
+            make_guest: True if the the new user should be guest,
+                false to add a regular user account.
+            admin: True if the user should be registered as a server admin.
+            threepid: The threepid used for registering, if any.
+            user_type: type of user. One of the values from
               api.constants.UserTypes, or None for a normal user.
-            default_display_name (unicode|None): if set, the new user's displayname
+            default_display_name: if set, the new user's displayname
               will be set to this. Defaults to 'localpart'.
-            address (str|None): the IP address used to perform the registration.
-            bind_emails (List[str]): list of emails to bind to this account.
-            by_admin (bool): True if this registration is being made via the
+            address: the IP address used to perform the registration.
+            bind_emails: list of emails to bind to this account.
+            by_admin: True if this registration is being made via the
               admin api, otherwise False.
-            user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+            user_agent_ips: Tuples of IP addresses and user-agents used
                 during the registration process.
         Returns:
-            str: user_id
+            The registere user_id.
         Raises:
             SynapseError if there was a problem registering.
         """
@@ -235,8 +245,10 @@ class RegistrationHandler(BaseHandler):
         else:
             # autogen a sequential user ID
             fail_count = 0
-            user = None
-            while not user:
+            # If a default display name is not given, generate one.
+            generate_display_name = default_display_name is None
+            # This breaks on successful registration *or* errors after 10 failures.
+            while True:
                 # Fail after being unable to find a suitable ID a few times
                 if fail_count > 10:
                     raise SynapseError(500, "Unable to find a suitable guest user ID")
@@ -245,7 +257,7 @@ class RegistrationHandler(BaseHandler):
                 user = UserID(localpart, self.hs.hostname)
                 user_id = user.to_string()
                 self.check_user_id_not_appservice_exclusive(user_id)
-                if default_display_name is None:
+                if generate_display_name:
                     default_display_name = localpart
                 try:
                     await self.register_with_store(
@@ -261,8 +273,6 @@ class RegistrationHandler(BaseHandler):
                     break
                 except SynapseError:
                     # if user id is taken, just generate another
-                    user = None
-                    user_id = None
                     fail_count += 1
 
         if not self.hs.config.user_consent_at_registration:
@@ -294,7 +304,7 @@ class RegistrationHandler(BaseHandler):
 
         return user_id
 
-    async def _create_and_join_rooms(self, user_id: str):
+    async def _create_and_join_rooms(self, user_id: str) -> None:
         """
         Create the auto-join rooms and join or invite the user to them.
 
@@ -317,7 +327,8 @@ class RegistrationHandler(BaseHandler):
         requires_join = False
         if self.hs.config.registration.auto_join_user_id:
             fake_requester = create_requester(
-                self.hs.config.registration.auto_join_user_id
+                self.hs.config.registration.auto_join_user_id,
+                authenticated_entity=self._server_name,
             )
 
             # If the room requires an invite, add the user to the list of invites.
@@ -329,7 +340,9 @@ class RegistrationHandler(BaseHandler):
             # being necessary this will occur after the invite was sent.
             requires_join = True
         else:
-            fake_requester = create_requester(user_id)
+            fake_requester = create_requester(
+                user_id, authenticated_entity=self._server_name
+            )
 
         # Choose whether to federate the new room.
         if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
@@ -362,7 +375,9 @@ class RegistrationHandler(BaseHandler):
                     # created it, then ensure the first user joins it.
                     if requires_join:
                         await room_member_handler.update_membership(
-                            requester=create_requester(user_id),
+                            requester=create_requester(
+                                user_id, authenticated_entity=self._server_name
+                            ),
                             target=UserID.from_string(user_id),
                             room_id=info["room_id"],
                             # Since it was just created, there are no remote hosts.
@@ -370,15 +385,10 @@ class RegistrationHandler(BaseHandler):
                             action="join",
                             ratelimit=False,
                         )
-
-            except ConsentNotGivenError as e:
-                # Technically not necessary to pull out this error though
-                # moving away from bare excepts is a good thing to do.
-                logger.error("Failed to join new user to %r: %r", r, e)
             except Exception as e:
                 logger.error("Failed to join new user to %r: %r", r, e)
 
-    async def _join_rooms(self, user_id: str):
+    async def _join_rooms(self, user_id: str) -> None:
         """
         Join or invite the user to the auto-join rooms.
 
@@ -424,9 +434,13 @@ class RegistrationHandler(BaseHandler):
 
                 # Send the invite, if necessary.
                 if requires_invite:
+                    # If an invite is required, there must be a auto-join user ID.
+                    assert self.hs.config.registration.auto_join_user_id
+
                     await room_member_handler.update_membership(
                         requester=create_requester(
-                            self.hs.config.registration.auto_join_user_id
+                            self.hs.config.registration.auto_join_user_id,
+                            authenticated_entity=self._server_name,
                         ),
                         target=UserID.from_string(user_id),
                         room_id=room_id,
@@ -437,7 +451,9 @@ class RegistrationHandler(BaseHandler):
 
                 # Send the join.
                 await room_member_handler.update_membership(
-                    requester=create_requester(user_id),
+                    requester=create_requester(
+                        user_id, authenticated_entity=self._server_name
+                    ),
                     target=UserID.from_string(user_id),
                     room_id=room_id,
                     remote_room_hosts=remote_room_hosts,
@@ -452,7 +468,7 @@ class RegistrationHandler(BaseHandler):
             except Exception as e:
                 logger.error("Failed to join new user to %r: %r", r, e)
 
-    async def _auto_join_rooms(self, user_id: str):
+    async def _auto_join_rooms(self, user_id: str) -> None:
         """Automatically joins users to auto join rooms - creating the room in the first place
         if the user is the first to be created.
 
@@ -475,16 +491,16 @@ class RegistrationHandler(BaseHandler):
         else:
             await self._join_rooms(user_id)
 
-    async def post_consent_actions(self, user_id):
+    async def post_consent_actions(self, user_id: str) -> None:
         """A series of registration actions that can only be carried out once consent
         has been granted
 
         Args:
-            user_id (str): The user to join
+            user_id: The user to join
         """
         await self._auto_join_rooms(user_id)
 
-    async def appservice_register(self, user_localpart, as_token):
+    async def appservice_register(self, user_localpart: str, as_token: str) -> str:
         user = UserID(user_localpart, self.hs.hostname)
         user_id = user.to_string()
         service = self.store.get_app_service_by_token(as_token)
@@ -509,7 +525,9 @@ class RegistrationHandler(BaseHandler):
         )
         return user_id
 
-    def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+    def check_user_id_not_appservice_exclusive(
+        self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
+    ) -> None:
         # don't allow people to register the server notices mxid
         if self._server_notices_mxid is not None:
             if user_id == self._server_notices_mxid:
@@ -533,12 +551,12 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE,
                 )
 
-    def check_registration_ratelimit(self, address):
+    def check_registration_ratelimit(self, address: Optional[str]) -> None:
         """A simple helper method to check whether the registration rate limit has been hit
         for a given IP address
 
         Args:
-            address (str|None): the IP address used to perform the registration. If this is
+            address: the IP address used to perform the registration. If this is
                 None, no ratelimiting will be performed.
 
         Raises:
@@ -549,42 +567,39 @@ class RegistrationHandler(BaseHandler):
 
         self.ratelimiter.ratelimit(address)
 
-    def register_with_store(
+    async def register_with_store(
         self,
-        user_id,
-        password_hash=None,
-        was_guest=False,
-        make_guest=False,
-        appservice_id=None,
-        create_profile_with_displayname=None,
-        admin=False,
-        user_type=None,
-        address=None,
-        shadow_banned=False,
-    ):
+        user_id: str,
+        password_hash: Optional[str] = None,
+        was_guest: bool = False,
+        make_guest: bool = False,
+        appservice_id: Optional[str] = None,
+        create_profile_with_displayname: Optional[str] = None,
+        admin: bool = False,
+        user_type: Optional[str] = None,
+        address: Optional[str] = None,
+        shadow_banned: bool = False,
+    ) -> None:
         """Register user in the datastore.
 
         Args:
-            user_id (str): The desired user ID to register.
-            password_hash (str|None): Optional. The password hash for this user.
-            was_guest (bool): Optional. Whether this is a guest account being
+            user_id: The desired user ID to register.
+            password_hash: Optional. The password hash for this user.
+            was_guest: Optional. Whether this is a guest account being
                 upgraded to a non-guest account.
-            make_guest (boolean): True if the the new user should be guest,
+            make_guest: True if the the new user should be guest,
                 false to add a regular user account.
-            appservice_id (str|None): The ID of the appservice registering the user.
-            create_profile_with_displayname (unicode|None): Optionally create a
+            appservice_id: The ID of the appservice registering the user.
+            create_profile_with_displayname: Optionally create a
                 profile for the user, setting their displayname to the given value
-            admin (boolean): is an admin user?
-            user_type (str|None): type of user. One of the values from
+            admin: is an admin user?
+            user_type: type of user. One of the values from
                 api.constants.UserTypes, or None for a normal user.
-            address (str|None): the IP address used to perform the registration.
-            shadow_banned (bool): Whether to shadow-ban the user
-
-        Returns:
-            Awaitable
+            address: the IP address used to perform the registration.
+            shadow_banned: Whether to shadow-ban the user
         """
         if self.hs.config.worker_app:
-            return self._register_client(
+            await self._register_client(
                 user_id=user_id,
                 password_hash=password_hash,
                 was_guest=was_guest,
@@ -597,7 +612,7 @@ class RegistrationHandler(BaseHandler):
                 shadow_banned=shadow_banned,
             )
         else:
-            return self.store.register_user(
+            await self.store.register_user(
                 user_id=user_id,
                 password_hash=password_hash,
                 was_guest=was_guest,
@@ -610,22 +625,24 @@ class RegistrationHandler(BaseHandler):
             )
 
     async def register_device(
-        self, user_id, device_id, initial_display_name, is_guest=False
-    ):
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_display_name: Optional[str],
+        is_guest: bool = False,
+    ) -> Tuple[str, str]:
         """Register a device for a user and generate an access token.
 
         The access token will be limited by the homeserver's session_lifetime config.
 
         Args:
-            user_id (str): full canonical @user:id
-            device_id (str|None): The device ID to check, or None to generate
-                a new one.
-            initial_display_name (str|None): An optional display name for the
-                device.
-            is_guest (bool): Whether this is a guest account
+            user_id: full canonical @user:id
+            device_id: The device ID to check, or None to generate a new one.
+            initial_display_name: An optional display name for the device.
+            is_guest: Whether this is a guest account
 
         Returns:
-            tuple[str, str]: Tuple of device ID and access token
+            Tuple of device ID and access token
         """
 
         if self.hs.config.worker_app:
@@ -645,7 +662,7 @@ class RegistrationHandler(BaseHandler):
                 )
             valid_until_ms = self.clock.time_msec() + self.session_lifetime
 
-        device_id = await self.device_handler.check_device_registered(
+        registered_device_id = await self.device_handler.check_device_registered(
             user_id, device_id, initial_display_name
         )
         if is_guest:
@@ -655,20 +672,21 @@ class RegistrationHandler(BaseHandler):
             )
         else:
             access_token = await self._auth_handler.get_access_token_for_user_id(
-                user_id, device_id=device_id, valid_until_ms=valid_until_ms
+                user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
             )
 
-        return (device_id, access_token)
+        return (registered_device_id, access_token)
 
-    async def post_registration_actions(self, user_id, auth_result, access_token):
+    async def post_registration_actions(
+        self, user_id: str, auth_result: dict, access_token: Optional[str]
+    ) -> None:
         """A user has completed registration
 
         Args:
-            user_id (str): The user ID that consented
-            auth_result (dict): The authenticated credentials of the newly
-                registered user.
-            access_token (str|None): The access token of the newly logged in
-                device, or None if `inhibit_login` enabled.
+            user_id: The user ID that consented
+            auth_result: The authenticated credentials of the newly registered user.
+            access_token: The access token of the newly logged in device, or
+                None if `inhibit_login` enabled.
         """
         if self.hs.config.worker_app:
             await self._post_registration_client(
@@ -694,19 +712,20 @@ class RegistrationHandler(BaseHandler):
         if auth_result and LoginType.TERMS in auth_result:
             await self._on_user_consented(user_id, self.hs.config.user_consent_version)
 
-    async def _on_user_consented(self, user_id, consent_version):
+    async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
         """A user consented to the terms on registration
 
         Args:
-            user_id (str): The user ID that consented.
-            consent_version (str): version of the policy the user has
-                consented to.
+            user_id: The user ID that consented.
+            consent_version: version of the policy the user has consented to.
         """
         logger.info("%s has consented to the privacy policy", user_id)
         await self.store.user_set_consent_version(user_id, consent_version)
         await self.post_consent_actions(user_id)
 
-    async def _register_email_threepid(self, user_id, threepid, token):
+    async def _register_email_threepid(
+        self, user_id: str, threepid: dict, token: Optional[str]
+    ) -> None:
         """Add an email address as a 3pid identifier
 
         Also adds an email pusher for the email address, if configured in the
@@ -715,10 +734,9 @@ class RegistrationHandler(BaseHandler):
         Must be called on master.
 
         Args:
-            user_id (str): id of user
-            threepid (object): m.login.email.identity auth response
-            token (str|None): access_token for the user, or None if not logged
-                in.
+            user_id: id of user
+            threepid: m.login.email.identity auth response
+            token: access_token for the user, or None if not logged in.
         """
         reqd = ("medium", "address", "validated_at")
         if any(x not in threepid for x in reqd):
@@ -744,6 +762,8 @@ class RegistrationHandler(BaseHandler):
             # up when the access token is saved, but that's quite an
             # invasive change I'd rather do separately.
             user_tuple = await self.store.get_user_by_access_token(token)
+            # The token better still exist.
+            assert user_tuple
             token_id = user_tuple.token_id
 
             await self.pusher_pool.add_pusher(
@@ -758,14 +778,14 @@ class RegistrationHandler(BaseHandler):
                 data={},
             )
 
-    async def _register_msisdn_threepid(self, user_id, threepid):
+    async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
         """Add a phone number as a 3pid identifier
 
         Must be called on master.
 
         Args:
-            user_id (str): id of user
-            threepid (object): m.login.msisdn auth response
+            user_id: id of user
+            threepid: m.login.msisdn auth response
         """
         try:
             assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index e73031475f..930047e730 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -587,7 +587,7 @@ class RoomCreationHandler(BaseHandler):
         """
         user_id = requester.user.to_string()
 
-        await self.auth.check_auth_blocking(user_id)
+        await self.auth.check_auth_blocking(requester=requester)
 
         if (
             self._server_notices_mxid is not None
@@ -1257,7 +1257,9 @@ class RoomShutdownHandler:
                     400, "User must be our own: %s" % (new_room_user_id,)
                 )
 
-            room_creator_requester = create_requester(new_room_user_id)
+            room_creator_requester = create_requester(
+                new_room_user_id, authenticated_entity=requester_user_id
+            )
 
             info, stream_id = await self._room_creation_handler.create_room(
                 room_creator_requester,
@@ -1297,7 +1299,9 @@ class RoomShutdownHandler:
 
             try:
                 # Kick users from room
-                target_requester = create_requester(user_id)
+                target_requester = create_requester(
+                    user_id, authenticated_entity=requester_user_id
+                )
                 _, stream_id = await self.room_member_handler.update_membership(
                     requester=target_requester,
                     target=target_requester.user,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7cd858b7db..c002886324 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -31,7 +31,6 @@ from synapse.api.errors import (
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.storage.roommember import RoomsForUser
 from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_left_room
@@ -347,7 +346,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # later on.
             content = dict(content)
 
-        if not self.allow_per_room_profiles or requester.shadow_banned:
+        # allow the server notices mxid to set room-level profile
+        is_requester_server_notices_user = (
+            self._server_notices_mxid is not None
+            and requester.user.to_string() == self._server_notices_mxid
+        )
+
+        if (
+            not self.allow_per_room_profiles and not is_requester_server_notices_user
+        ) or requester.shadow_banned:
             # Strip profile data, knowing that new profile data will be added to the
             # event's content in event_creation_handler.create_event() using the target's
             # global profile.
@@ -515,10 +522,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         elif effective_membership_state == Membership.LEAVE:
             if not is_host_in_room:
                 # perhaps we've been invited
-                invite = await self.store.get_invite_for_local_user_in_room(
-                    user_id=target.to_string(), room_id=room_id
-                )  # type: Optional[RoomsForUser]
-                if not invite:
+                (
+                    current_membership_type,
+                    current_membership_event_id,
+                ) = await self.store.get_local_current_membership_for_user_in_room(
+                    target.to_string(), room_id
+                )
+                if (
+                    current_membership_type != Membership.INVITE
+                    or not current_membership_event_id
+                ):
                     logger.info(
                         "%s sent a leave request to %s, but that is not an active room "
                         "on this server, and there is no pending invite",
@@ -528,6 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
                     raise SynapseError(404, "Not a known room")
 
+                invite = await self.store.get_event(current_membership_event_id)
                 logger.info(
                     "%s rejects invite to %s from %s", target, room_id, invite.sender
                 )
@@ -965,6 +979,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
         self.distributor = hs.get_distributor()
         self.distributor.declare("user_left_room")
+        self._server_name = hs.hostname
 
     async def _is_remote_room_too_complex(
         self, room_id: str, remote_room_hosts: List[str]
@@ -1059,7 +1074,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
                 return event_id, stream_id
 
             # The room is too large. Leave.
-            requester = types.create_requester(user, None, False, False, None)
+            requester = types.create_requester(
+                user, authenticated_entity=self._server_name
+            )
             await self.update_membership(
                 requester=requester, target=user, room_id=room_id, action="leave"
             )
@@ -1104,32 +1121,34 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             #
             logger.warning("Failed to reject invite: %s", e)
 
-            return await self._locally_reject_invite(
+            return await self._generate_local_out_of_band_leave(
                 invite_event, txn_id, requester, content
             )
 
-    async def _locally_reject_invite(
+    async def _generate_local_out_of_band_leave(
         self,
-        invite_event: EventBase,
+        previous_membership_event: EventBase,
         txn_id: Optional[str],
         requester: Requester,
         content: JsonDict,
     ) -> Tuple[str, int]:
-        """Generate a local invite rejection
+        """Generate a local leave event for a room
 
-        This is called after we fail to reject an invite via a remote server. It
-        generates an out-of-band membership event locally.
+        This can be called after we e.g fail to reject an invite via a remote server.
+        It generates an out-of-band membership event locally.
 
         Args:
-            invite_event: the invite to be rejected
+            previous_membership_event: the previous membership event for this user
             txn_id: optional transaction ID supplied by the client
-            requester:  user making the rejection request, according to the access token
-            content: additional content to include in the rejection event.
+            requester: user making the request, according to the access token
+            content: additional content to include in the leave event.
                Normally an empty dict.
-        """
 
-        room_id = invite_event.room_id
-        target_user = invite_event.state_key
+        Returns:
+            A tuple containing (event_id, stream_id of the leave event)
+        """
+        room_id = previous_membership_event.room_id
+        target_user = previous_membership_event.state_key
 
         content["membership"] = Membership.LEAVE
 
@@ -1141,12 +1160,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             "state_key": target_user,
         }
 
-        # the auth events for the new event are the same as that of the invite, plus
-        # the invite itself.
+        # the auth events for the new event are the same as that of the previous event, plus
+        # the event itself.
         #
-        # the prev_events are just the invite.
-        prev_event_ids = [invite_event.event_id]
-        auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
+        # the prev_events consist solely of the previous membership event.
+        prev_event_ids = [previous_membership_event.event_id]
+        auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
 
         event, context = await self.event_creation_handler.create_event(
             requester,
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index fd6c5e9ea8..76d4169fe2 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,7 +24,8 @@ from saml2.client import Saml2Client
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
@@ -37,15 +38,11 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
-    import synapse.server
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
-class MappingException(Exception):
-    """Used to catch errors when mapping the SAML2 response to a user."""
-
-
 @attr.s(slots=True)
 class Saml2SessionData:
     """Data we track about SAML2 sessions"""
@@ -57,17 +54,14 @@ class Saml2SessionData:
     ui_auth_session_id = attr.ib(type=Optional[str], default=None)
 
 
-class SamlHandler:
-    def __init__(self, hs: "synapse.server.HomeServer"):
-        self.hs = hs
+class SamlHandler(BaseHandler):
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
-        self._auth = hs.get_auth()
+        self._saml_idp_entityid = hs.config.saml2_idp_entityid
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
-        self._clock = hs.get_clock()
-        self._datastore = hs.get_datastore()
-        self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
@@ -88,26 +82,9 @@ class SamlHandler:
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
         # a lock on the mappings
-        self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
-
-    def _render_error(
-        self, request, error: str, error_description: Optional[str] = None
-    ) -> None:
-        """Render the error template and respond to the request with it.
-
-        This is used to show errors to the user. The template of this page can
-        be found under `synapse/res/templates/sso_error.html`.
+        self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
 
-        Args:
-            request: The incoming request from the browser.
-                We'll respond with an HTML page describing the error.
-            error: A technical identifier for this error.
-            error_description: A human-readable description of the error.
-        """
-        html = self._error_template.render(
-            error=error, error_description=error_description
-        )
-        respond_with_html(request, 400, html)
+        self._sso_handler = hs.get_sso_handler()
 
     def handle_redirect_request(
         self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@@ -124,13 +101,13 @@ class SamlHandler:
             URL to redirect to
         """
         reqid, info = self._saml_client.prepare_for_authenticate(
-            relay_state=client_redirect_url
+            entityid=self._saml_idp_entityid, relay_state=client_redirect_url
         )
 
         # Since SAML sessions timeout it is useful to log when they were created.
         logger.info("Initiating a new SAML session: %s" % (reqid,))
 
-        now = self._clock.time_msec()
+        now = self.clock.time_msec()
         self._outstanding_requests_dict[reqid] = Saml2SessionData(
             creation_time=now, ui_auth_session_id=ui_auth_session_id,
         )
@@ -171,12 +148,12 @@ class SamlHandler:
             # in the (user-visible) exception message, so let's log the exception here
             # so we can track down the session IDs later.
             logger.warning(str(e))
-            self._render_error(
+            self._sso_handler.render_error(
                 request, "unsolicited_response", "Unexpected SAML2 login."
             )
             return
         except Exception as e:
-            self._render_error(
+            self._sso_handler.render_error(
                 request,
                 "invalid_response",
                 "Unable to parse SAML2 response: %s." % (e,),
@@ -184,7 +161,7 @@ class SamlHandler:
             return
 
         if saml2_auth.not_signed:
-            self._render_error(
+            self._sso_handler.render_error(
                 request, "unsigned_respond", "SAML2 response was not signed."
             )
             return
@@ -210,7 +187,7 @@ class SamlHandler:
         # attributes.
         for requirement in self._saml2_attribute_requirements:
             if not _check_attribute_requirement(saml2_auth.ava, requirement):
-                self._render_error(
+                self._sso_handler.render_error(
                     request, "unauthorised", "You are not authorised to log in here."
                 )
                 return
@@ -226,7 +203,7 @@ class SamlHandler:
             )
         except MappingException as e:
             logger.exception("Could not map user")
-            self._render_error(request, "mapping_error", str(e))
+            self._sso_handler.render_error(request, "mapping_error", str(e))
             return
 
         # Complete the interactive auth session or the login.
@@ -272,20 +249,26 @@ class SamlHandler:
                 "Failed to extract remote user id from SAML response"
             )
 
-        with (await self._mapping_lock.queue(self._auth_provider_id)):
-            # first of all, check if we already have a mapping for this user
-            logger.info(
-                "Looking for existing mapping for user %s:%s",
-                self._auth_provider_id,
-                remote_user_id,
+        async def saml_response_to_remapped_user_attributes(
+            failures: int,
+        ) -> UserAttributes:
+            """
+            Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            # Call the mapping provider.
+            result = self._user_mapping_provider.saml_response_to_user_attributes(
+                saml2_auth, failures, client_redirect_url
             )
-            registered_user_id = await self._datastore.get_user_by_external_id(
-                self._auth_provider_id, remote_user_id
+            # Remap some of the results.
+            return UserAttributes(
+                localpart=result.get("mxid_localpart"),
+                display_name=result.get("displayname"),
+                emails=result.get("emails", []),
             )
-            if registered_user_id is not None:
-                logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id
 
+        async def grandfather_existing_users() -> Optional[str]:
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
             if (
@@ -294,75 +277,35 @@ class SamlHandler:
             ):
                 attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
                 user_id = UserID(
-                    map_username_to_mxid_localpart(attrval), self._hostname
+                    map_username_to_mxid_localpart(attrval), self.server_name
                 ).to_string()
-                logger.info(
+
+                logger.debug(
                     "Looking for existing account based on mapped %s %s",
                     self._grandfathered_mxid_source_attribute,
                     user_id,
                 )
 
-                users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+                users = await self.store.get_users_by_id_case_insensitive(user_id)
                 if users:
                     registered_user_id = list(users.keys())[0]
                     logger.info("Grandfathering mapping to %s", registered_user_id)
-                    await self._datastore.record_user_external_id(
-                        self._auth_provider_id, remote_user_id, registered_user_id
-                    )
                     return registered_user_id
 
-            # Map saml response to user attributes using the configured mapping provider
-            for i in range(1000):
-                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
-                    saml2_auth, i, client_redirect_url=client_redirect_url,
-                )
-
-                logger.debug(
-                    "Retrieved SAML attributes from user mapping provider: %s "
-                    "(attempt %d)",
-                    attribute_dict,
-                    i,
-                )
-
-                localpart = attribute_dict.get("mxid_localpart")
-                if not localpart:
-                    raise MappingException(
-                        "Error parsing SAML2 response: SAML mapping provider plugin "
-                        "did not return a mxid_localpart value"
-                    )
-
-                displayname = attribute_dict.get("displayname")
-                emails = attribute_dict.get("emails", [])
-
-                # Check if this mxid already exists
-                if not await self._datastore.get_users_by_id_case_insensitive(
-                    UserID(localpart, self._hostname).to_string()
-                ):
-                    # This mxid is free
-                    break
-            else:
-                # Unable to generate a username in 1000 iterations
-                # Break and return error to the user
-                raise MappingException(
-                    "Unable to generate a Matrix ID from the SAML response"
-                )
+            return None
 
-            logger.info("Mapped SAML user to local part %s", localpart)
-
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=displayname,
-                bind_emails=emails,
-                user_agent_ips=(user_agent, ip_address),
-            )
-
-            await self._datastore.record_user_external_id(
-                self._auth_provider_id, remote_user_id, registered_user_id
+        with (await self._mapping_lock.queue(self._auth_provider_id)):
+            return await self._sso_handler.get_mxid_from_sso(
+                self._auth_provider_id,
+                remote_user_id,
+                user_agent,
+                ip_address,
+                saml_response_to_remapped_user_attributes,
+                grandfather_existing_users,
             )
-            return registered_user_id
 
     def expire_sessions(self):
-        expire_before = self._clock.time_msec() - self._saml2_session_lifetime
+        expire_before = self.clock.time_msec() - self._saml2_session_lifetime
         to_expire = set()
         for reqid, data in self._outstanding_requests_dict.items():
             if data.creation_time < expire_before:
@@ -474,11 +417,11 @@ class DefaultSamlMappingProvider:
             )
 
         # Use the configured mapper for this mxid_source
-        base_mxid_localpart = self._mxid_mapper(mxid_source)
+        localpart = self._mxid_mapper(mxid_source)
 
         # Append suffix integer if last call to this function failed to produce
-        # a usable mxid
-        localpart = base_mxid_localpart + (str(failures) if failures else "")
+        # a usable mxid.
+        localpart += str(failures) if failures else ""
 
         # Retrieve the display name from the saml response
         # If displayname is None, the mxid_localpart will be used instead
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
new file mode 100644
index 0000000000..47ad96f97e
--- /dev/null
+++ b/synapse/handlers/sso.py
@@ -0,0 +1,244 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+
+import attr
+
+from synapse.api.errors import RedirectException
+from synapse.handlers._base import BaseHandler
+from synapse.http.server import respond_with_html
+from synapse.types import UserID, contains_invalid_mxid_characters
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class MappingException(Exception):
+    """Used to catch errors when mapping an SSO response to user attributes.
+
+    Note that the msg that is raised is shown to end-users.
+    """
+
+
+@attr.s
+class UserAttributes:
+    localpart = attr.ib(type=str)
+    display_name = attr.ib(type=Optional[str], default=None)
+    emails = attr.ib(type=List[str], default=attr.Factory(list))
+
+
+class SsoHandler(BaseHandler):
+    # The number of attempts to ask the mapping provider for when generating an MXID.
+    _MAP_USERNAME_RETRIES = 1000
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+        self._registration_handler = hs.get_registration_handler()
+        self._error_template = hs.config.sso_error_template
+
+    def render_error(
+        self, request, error: str, error_description: Optional[str] = None
+    ) -> None:
+        """Renders the error template and responds with it.
+
+        This is used to show errors to the user. The template of this page can
+        be found under `synapse/res/templates/sso_error.html`.
+
+        Args:
+            request: The incoming request from the browser.
+                We'll respond with an HTML page describing the error.
+            error: A technical identifier for this error.
+            error_description: A human-readable description of the error.
+        """
+        html = self._error_template.render(
+            error=error, error_description=error_description
+        )
+        respond_with_html(request, 400, html)
+
+    async def get_sso_user_by_remote_user_id(
+        self, auth_provider_id: str, remote_user_id: str
+    ) -> Optional[str]:
+        """
+        Maps the user ID of a remote IdP to a mxid for a previously seen user.
+
+        If the user has not been seen yet, this will return None.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            remote_user_id: The user ID according to the remote IdP. This might
+                be an e-mail address, a GUID, or some other form. It must be
+                unique and immutable.
+
+        Returns:
+            The mxid of a previously seen user.
+        """
+        logger.debug(
+            "Looking for existing mapping for user %s:%s",
+            auth_provider_id,
+            remote_user_id,
+        )
+
+        # Check if we already have a mapping for this user.
+        previously_registered_user_id = await self.store.get_user_by_external_id(
+            auth_provider_id, remote_user_id,
+        )
+
+        # A match was found, return the user ID.
+        if previously_registered_user_id is not None:
+            logger.info(
+                "Found existing mapping for IdP '%s' and remote_user_id '%s': %s",
+                auth_provider_id,
+                remote_user_id,
+                previously_registered_user_id,
+            )
+            return previously_registered_user_id
+
+        # No match.
+        return None
+
+    async def get_mxid_from_sso(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        user_agent: str,
+        ip_address: str,
+        sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+        grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
+    ) -> str:
+        """
+        Given an SSO ID, retrieve the user ID for it and possibly register the user.
+
+        This first checks if the SSO ID has previously been linked to a matrix ID,
+        if it has that matrix ID is returned regardless of the current mapping
+        logic.
+
+        If a callable is provided for grandfathering users, it is called and can
+        potentially return a matrix ID to use. If it does, the SSO ID is linked to
+        this matrix ID for subsequent calls.
+
+        The mapping function is called (potentially multiple times) to generate
+        a localpart for the user.
+
+        If an unused localpart is generated, the user is registered from the
+        given user-agent and IP address and the SSO ID is linked to this matrix
+        ID for subsequent calls.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            remote_user_id: The unique identifier from the SSO provider.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+            sso_to_matrix_id_mapper: A callable to generate the user attributes.
+                The only parameter is an integer which represents the amount of
+                times the returned mxid localpart mapping has failed.
+
+                It is expected that the mapper can raise two exceptions, which
+                will get passed through to the caller:
+
+                    MappingException if there was a problem mapping the response
+                        to the user.
+                    RedirectException to redirect to an additional page (e.g.
+                        to prompt the user for more information).
+            grandfather_existing_users: A callable which can return an previously
+                existing matrix ID. The SSO ID is then linked to the returned
+                matrix ID.
+
+        Returns:
+             The user ID associated with the SSO response.
+
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: if the mapping provider needs to redirect the user
+                to an additional page. (e.g. to prompt for more information)
+
+        """
+        # first of all, check if we already have a mapping for this user
+        previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+            auth_provider_id, remote_user_id,
+        )
+        if previously_registered_user_id:
+            return previously_registered_user_id
+
+        # Check for grandfathering of users.
+        if grandfather_existing_users:
+            previously_registered_user_id = await grandfather_existing_users()
+            if previously_registered_user_id:
+                # Future logins should also match this user ID.
+                await self.store.record_user_external_id(
+                    auth_provider_id, remote_user_id, previously_registered_user_id
+                )
+                return previously_registered_user_id
+
+        # Otherwise, generate a new user.
+        for i in range(self._MAP_USERNAME_RETRIES):
+            try:
+                attributes = await sso_to_matrix_id_mapper(i)
+            except (RedirectException, MappingException):
+                # Mapping providers are allowed to issue a redirect (e.g. to ask
+                # the user for more information) and can issue a mapping exception
+                # if a name cannot be generated.
+                raise
+            except Exception as e:
+                # Any other exception is unexpected.
+                raise MappingException(
+                    "Could not extract user attributes from SSO response."
+                ) from e
+
+            logger.debug(
+                "Retrieved user attributes from user mapping provider: %r (attempt %d)",
+                attributes,
+                i,
+            )
+
+            if not attributes.localpart:
+                raise MappingException(
+                    "Error parsing SSO response: SSO mapping provider plugin "
+                    "did not return a localpart value"
+                )
+
+            # Check if this mxid already exists
+            user_id = UserID(attributes.localpart, self.server_name).to_string()
+            if not await self.store.get_users_by_id_case_insensitive(user_id):
+                # This mxid is free
+                break
+        else:
+            # Unable to generate a username in 1000 iterations
+            # Break and return error to the user
+            raise MappingException(
+                "Unable to generate a Matrix ID from the SSO response"
+            )
+
+        # Since the localpart is provided via a potentially untrusted module,
+        # ensure the MXID is valid before registering.
+        if contains_invalid_mxid_characters(attributes.localpart):
+            raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
+
+        logger.debug("Mapped SSO user to local part %s", attributes.localpart)
+        registered_user_id = await self._registration_handler.register_user(
+            localpart=attributes.localpart,
+            default_display_name=attributes.display_name,
+            bind_emails=attributes.emails,
+            user_agent_ips=[(user_agent, ip_address)],
+        )
+
+        await self.store.record_user_external_id(
+            auth_provider_id, remote_user_id, registered_user_id
+        )
+        return registered_user_id
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 32e53c2d25..9827c7eb8d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -31,6 +31,7 @@ from synapse.types import (
     Collection,
     JsonDict,
     MutableStateMap,
+    Requester,
     RoomStreamToken,
     StateMap,
     StreamToken,
@@ -260,6 +261,7 @@ class SyncHandler:
 
     async def wait_for_sync_for_user(
         self,
+        requester: Requester,
         sync_config: SyncConfig,
         since_token: Optional[StreamToken] = None,
         timeout: int = 0,
@@ -273,7 +275,7 @@ class SyncHandler:
         # not been exceeded (if not part of the group by this point, almost certain
         # auth_blocking will occur)
         user_id = sync_config.user.to_string()
-        await self.auth.check_auth_blocking(user_id)
+        await self.auth.check_auth_blocking(requester=requester)
 
         res = await self.response_cache.wrap(
             sync_config.request_key,