summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/__init__.py33
-rw-r--r--synapse/handlers/_base.py24
-rw-r--r--synapse/handlers/account_data.py14
-rw-r--r--synapse/handlers/account_validity.py42
-rw-r--r--synapse/handlers/admin.py4
-rw-r--r--synapse/handlers/appservice.py222
-rw-r--r--synapse/handlers/auth.py453
-rw-r--r--synapse/handlers/cas_handler.py73
-rw-r--r--synapse/handlers/deactivate_account.py18
-rw-r--r--synapse/handlers/device.py171
-rw-r--r--synapse/handlers/devicemessage.py25
-rw-r--r--synapse/handlers/directory.py12
-rw-r--r--synapse/handlers/e2e_keys.py43
-rw-r--r--synapse/handlers/federation.py112
-rw-r--r--synapse/handlers/groups_local.py9
-rw-r--r--synapse/handlers/identity.py3
-rw-r--r--synapse/handlers/initial_sync.py18
-rw-r--r--synapse/handlers/message.py381
-rw-r--r--synapse/handlers/oidc_handler.py224
-rw-r--r--synapse/handlers/pagination.py21
-rw-r--r--synapse/handlers/password_policy.py10
-rw-r--r--synapse/handlers/presence.py21
-rw-r--r--synapse/handlers/profile.py136
-rw-r--r--synapse/handlers/read_marker.py10
-rw-r--r--synapse/handlers/receipts.py36
-rw-r--r--synapse/handlers/register.py247
-rw-r--r--synapse/handlers/room.py62
-rw-r--r--synapse/handlers/room_member.py185
-rw-r--r--synapse/handlers/saml_handler.py161
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/sso.py244
-rw-r--r--synapse/handlers/state_deltas.py2
-rw-r--r--synapse/handlers/stats.py2
-rw-r--r--synapse/handlers/sync.py49
-rw-r--r--synapse/handlers/typing.py54
-rw-r--r--synapse/handlers/ui_auth/checkers.py2
-rw-r--r--synapse/handlers/user_directory.py2
37 files changed, 2063 insertions, 1064 deletions
diff --git a/synapse/handlers/__init__.py b/synapse/handlers/__init__.py
index 286f0054be..bfebb0f644 100644
--- a/synapse/handlers/__init__.py
+++ b/synapse/handlers/__init__.py
@@ -12,36 +12,3 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-from .admin import AdminHandler
-from .directory import DirectoryHandler
-from .federation import FederationHandler
-from .identity import IdentityHandler
-from .search import SearchHandler
-
-
-class Handlers:
-
-    """ Deprecated. A collection of handlers.
-
-    At some point most of the classes whose name ended "Handler" were
-    accessed through this class.
-
-    However this makes it painful to unit test the handlers and to run cut
-    down versions of synapse that only use specific handlers because using a
-    single handler required creating all of the handlers. So some of the
-    handlers have been lifted out of the Handlers object and are now accessed
-    directly through the homeserver object itself.
-
-    Any new handlers should follow the new pattern of being accessed through
-    the homeserver object and should not be added to the Handlers object.
-
-    The remaining handlers should be moved out of the handlers object.
-    """
-
-    def __init__(self, hs):
-        self.federation_handler = FederationHandler(hs)
-        self.directory_handler = DirectoryHandler(hs)
-        self.admin_handler = AdminHandler(hs)
-        self.identity_handler = IdentityHandler(hs)
-        self.search_handler = SearchHandler(hs)
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 0206320e96..bb81c0e81d 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING, Optional
 
 import synapse.state
 import synapse.storage
@@ -22,6 +23,9 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.types import UserID
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -30,11 +34,7 @@ class BaseHandler:
     Common base class for the event handlers.
     """
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()  # type: synapse.storage.DataStore
         self.auth = hs.get_auth()
         self.notifier = hs.get_notifier()
@@ -56,7 +56,7 @@ class BaseHandler:
                 clock=self.clock,
                 rate_hz=self.hs.config.rc_admin_redaction.per_second,
                 burst_count=self.hs.config.rc_admin_redaction.burst_count,
-            )
+            )  # type: Optional[Ratelimiter]
         else:
             self.admin_redaction_ratelimiter = None
 
@@ -127,15 +127,15 @@ class BaseHandler:
             if guest_access != "can_join":
                 if context:
                     current_state_ids = await context.get_current_state_ids()
-                    current_state = await self.store.get_events(
+                    current_state_dict = await self.store.get_events(
                         list(current_state_ids.values())
                     )
+                    current_state = list(current_state_dict.values())
                 else:
-                    current_state = await self.state_handler.get_current_state(
+                    current_state_map = await self.state_handler.get_current_state(
                         event.room_id
                     )
-
-                current_state = list(current_state.values())
+                    current_state = list(current_state_map.values())
 
                 logger.info("maybe_kick_guest_users %r", current_state)
                 await self.kick_guest_users(current_state)
@@ -169,7 +169,9 @@ class BaseHandler:
                 # and having homeservers have their own users leave keeps more
                 # of that decision-making and control local to the guest-having
                 # homeserver.
-                requester = synapse.types.create_requester(target_user, is_guest=True)
+                requester = synapse.types.create_requester(
+                    target_user, is_guest=True, authenticated_entity=self.server_name
+                )
                 handler = self.hs.get_room_member_handler()
                 await handler.update_membership(
                     requester,
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 9112a0ab86..341135822e 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -12,16 +12,24 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import TYPE_CHECKING, List, Tuple
+
+from synapse.types import JsonDict, UserID
+
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
 
 
 class AccountDataEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
 
-    def get_current_key(self, direction="f"):
+    def get_current_key(self, direction: str = "f") -> int:
         return self.store.get_max_account_data_stream_id()
 
-    async def get_new_events(self, user, from_key, **kwargs):
+    async def get_new_events(
+        self, user: UserID, from_key: int, **kwargs
+    ) -> Tuple[List[JsonDict], int]:
         user_id = user.to_string()
         last_stream_id = from_key
 
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 4caf6d591a..664d09da1c 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -18,19 +18,22 @@ import email.utils
 import logging
 from email.mime.multipart import MIMEMultipart
 from email.mime.text import MIMEText
-from typing import List
+from typing import TYPE_CHECKING, List
 
-from synapse.api.errors import StoreError
+from synapse.api.errors import StoreError, SynapseError
 from synapse.logging.context import make_deferred_yieldable
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import wrap_as_background_process
 from synapse.types import UserID
 from synapse.util import stringutils
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class AccountValidityHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self.config = hs.config
         self.store = self.hs.get_datastore()
@@ -63,16 +66,11 @@ class AccountValidityHandler:
             self._raw_from = email.utils.parseaddr(self._from_string)[1]
 
             # Check the renewal emails to send and send them every 30min.
-            def send_emails():
-                # run as a background process to make sure that the database transactions
-                # have a logcontext to report to
-                return run_as_background_process(
-                    "send_renewals", self._send_renewal_emails
-                )
-
-            self.clock.looping_call(send_emails, 30 * 60 * 1000)
+            if hs.config.run_background_tasks:
+                self.clock.looping_call(self._send_renewal_emails, 30 * 60 * 1000)
 
-    async def _send_renewal_emails(self):
+    @wrap_as_background_process("send_renewals")
+    async def _send_renewal_emails(self) -> None:
         """Gets the list of users whose account is expiring in the amount of time
         configured in the ``renew_at`` parameter from the ``account_validity``
         configuration, and sends renewal emails to all of these users as long as they
@@ -86,11 +84,25 @@ class AccountValidityHandler:
                     user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
                 )
 
-    async def send_renewal_email_to_user(self, user_id: str):
+    async def send_renewal_email_to_user(self, user_id: str) -> None:
+        """
+        Send a renewal email for a specific user.
+
+        Args:
+            user_id: The user ID to send a renewal email for.
+
+        Raises:
+            SynapseError if the user is not set to renew.
+        """
         expiration_ts = await self.store.get_expiration_ts_for_user(user_id)
+
+        # If this user isn't set to be expired, raise an error.
+        if expiration_ts is None:
+            raise SynapseError(400, "User has no expiration time: %s" % (user_id,))
+
         await self._send_renewal_email(user_id, expiration_ts)
 
-    async def _send_renewal_email(self, user_id: str, expiration_ts: int):
+    async def _send_renewal_email(self, user_id: str, expiration_ts: int) -> None:
         """Sends out a renewal email to every email address attached to the given user
         with a unique link allowing them to renew their account.
 
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 1ce2091b46..a703944543 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -88,7 +88,7 @@ class AdminHandler(BaseHandler):
 
         # We only try and fetch events for rooms the user has been in. If
         # they've been e.g. invited to a room without joining then we handle
-        # those seperately.
+        # those separately.
         rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
 
         for index, room in enumerate(rooms):
@@ -226,7 +226,7 @@ class ExfiltrationWriter:
         """
 
     def finished(self):
-        """Called when all data has succesfully been exported and written.
+        """Called when all data has successfully been exported and written.
 
         This functions return value is passed to the caller of
         `export_user_data`.
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 9d4e87dad6..5c6458eb52 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,8 +12,8 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
+from typing import TYPE_CHECKING, Dict, List, Optional, Union
 
 from prometheus_client import Counter
 
@@ -21,21 +21,32 @@ from twisted.internet import defer
 
 import synapse
 from synapse.api.constants import EventTypes
+from synapse.appservice import ApplicationService
+from synapse.events import EventBase
+from synapse.handlers.presence import format_user_presence_state
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics import (
     event_processing_loop_counter,
     event_processing_loop_room_count,
 )
-from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.metrics.background_process_metrics import (
+    run_as_background_process,
+    wrap_as_background_process,
+)
+from synapse.storage.databases.main.directory import RoomAliasMapping
+from synapse.types import Collection, JsonDict, RoomAlias, RoomStreamToken, UserID
 from synapse.util.metrics import Measure
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 events_processed_counter = Counter("synapse_handlers_appservice_events_processed", "")
 
 
 class ApplicationServicesHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
         self.is_mine_id = hs.is_mine_id
         self.appservice_api = hs.get_application_service_api()
@@ -43,19 +54,22 @@ class ApplicationServicesHandler:
         self.started_scheduler = False
         self.clock = hs.get_clock()
         self.notify_appservices = hs.config.notify_appservices
+        self.event_sources = hs.get_event_sources()
 
         self.current_max = 0
         self.is_processing = False
 
-    async def notify_interested_services(self, current_id):
+    def notify_interested_services(self, max_token: RoomStreamToken):
         """Notifies (pushes) all application services interested in this event.
 
         Pushing is done asynchronously, so this method won't block for any
         prolonged length of time.
-
-        Args:
-            current_id(int): The current maximum ID.
         """
+        # We just use the minimum stream ordering and ignore the vector clock
+        # component. This is safe to do as long as we *always* ignore the vector
+        # clock components.
+        current_id = max_token.stream
+
         services = self.store.get_app_services()
         if not services or not self.notify_appservices:
             return
@@ -64,6 +78,12 @@ class ApplicationServicesHandler:
         if self.is_processing:
             return
 
+        # We only start a new background process if necessary rather than
+        # optimistically (to cut down on overhead).
+        self._notify_interested_services(max_token)
+
+    @wrap_as_background_process("notify_interested_services")
+    async def _notify_interested_services(self, max_token: RoomStreamToken):
         with Measure(self.clock, "notify_interested_services"):
             self.is_processing = True
             try:
@@ -79,7 +99,7 @@ class ApplicationServicesHandler:
                     if not events:
                         break
 
-                    events_by_room = {}
+                    events_by_room = {}  # type: Dict[str, List[EventBase]]
                     for event in events:
                         events_by_room.setdefault(event.room_id, []).append(event)
 
@@ -158,11 +178,139 @@ class ApplicationServicesHandler:
             finally:
                 self.is_processing = False
 
-    async def query_user_exists(self, user_id):
+    def notify_interested_services_ephemeral(
+        self,
+        stream_key: str,
+        new_token: Optional[int],
+        users: Collection[Union[str, UserID]] = [],
+    ):
+        """This is called by the notifier in the background
+        when a ephemeral event handled by the homeserver.
+
+        This will determine which appservices
+        are interested in the event, and submit them.
+
+        Events will only be pushed to appservices
+        that have opted into ephemeral events
+
+        Args:
+            stream_key: The stream the event came from.
+            new_token: The latest stream token
+            users: The user(s) involved with the event.
+        """
+        if not self.notify_appservices:
+            return
+
+        if stream_key not in ("typing_key", "receipt_key", "presence_key"):
+            return
+
+        services = [
+            service
+            for service in self.store.get_app_services()
+            if service.supports_ephemeral
+        ]
+        if not services:
+            return
+
+        # We only start a new background process if necessary rather than
+        # optimistically (to cut down on overhead).
+        self._notify_interested_services_ephemeral(
+            services, stream_key, new_token, users
+        )
+
+    @wrap_as_background_process("notify_interested_services_ephemeral")
+    async def _notify_interested_services_ephemeral(
+        self,
+        services: List[ApplicationService],
+        stream_key: str,
+        new_token: Optional[int],
+        users: Collection[Union[str, UserID]],
+    ):
+        logger.debug("Checking interested services for %s" % (stream_key))
+        with Measure(self.clock, "notify_interested_services_ephemeral"):
+            for service in services:
+                # Only handle typing if we have the latest token
+                if stream_key == "typing_key" and new_token is not None:
+                    events = await self._handle_typing(service, new_token)
+                    if events:
+                        self.scheduler.submit_ephemeral_events_for_as(service, events)
+                    # We don't persist the token for typing_key for performance reasons
+                elif stream_key == "receipt_key":
+                    events = await self._handle_receipts(service)
+                    if events:
+                        self.scheduler.submit_ephemeral_events_for_as(service, events)
+                    await self.store.set_type_stream_id_for_appservice(
+                        service, "read_receipt", new_token
+                    )
+                elif stream_key == "presence_key":
+                    events = await self._handle_presence(service, users)
+                    if events:
+                        self.scheduler.submit_ephemeral_events_for_as(service, events)
+                    await self.store.set_type_stream_id_for_appservice(
+                        service, "presence", new_token
+                    )
+
+    async def _handle_typing(
+        self, service: ApplicationService, new_token: int
+    ) -> List[JsonDict]:
+        typing_source = self.event_sources.sources["typing"]
+        # Get the typing events from just before current
+        typing, _ = await typing_source.get_new_events_as(
+            service=service,
+            # For performance reasons, we don't persist the previous
+            # token in the DB and instead fetch the latest typing information
+            # for appservices.
+            from_key=new_token - 1,
+        )
+        return typing
+
+    async def _handle_receipts(self, service: ApplicationService) -> List[JsonDict]:
+        from_key = await self.store.get_type_stream_id_for_appservice(
+            service, "read_receipt"
+        )
+        receipts_source = self.event_sources.sources["receipt"]
+        receipts, _ = await receipts_source.get_new_events_as(
+            service=service, from_key=from_key
+        )
+        return receipts
+
+    async def _handle_presence(
+        self, service: ApplicationService, users: Collection[Union[str, UserID]]
+    ) -> List[JsonDict]:
+        events = []  # type: List[JsonDict]
+        presence_source = self.event_sources.sources["presence"]
+        from_key = await self.store.get_type_stream_id_for_appservice(
+            service, "presence"
+        )
+        for user in users:
+            if isinstance(user, str):
+                user = UserID.from_string(user)
+
+            interested = await service.is_interested_in_presence(user, self.store)
+            if not interested:
+                continue
+            presence_events, _ = await presence_source.get_new_events(
+                user=user, service=service, from_key=from_key,
+            )
+            time_now = self.clock.time_msec()
+            events.extend(
+                {
+                    "type": "m.presence",
+                    "sender": event.user_id,
+                    "content": format_user_presence_state(
+                        event, time_now, include_user_id=False
+                    ),
+                }
+                for event in presence_events
+            )
+
+        return events
+
+    async def query_user_exists(self, user_id: str) -> bool:
         """Check if any application service knows this user_id exists.
 
         Args:
-            user_id(str): The user to query if they exist on any AS.
+            user_id: The user to query if they exist on any AS.
         Returns:
             True if this user exists on at least one application service.
         """
@@ -173,11 +321,13 @@ class ApplicationServicesHandler:
                 return True
         return False
 
-    async def query_room_alias_exists(self, room_alias):
+    async def query_room_alias_exists(
+        self, room_alias: RoomAlias
+    ) -> Optional[RoomAliasMapping]:
         """Check if an application service knows this room alias exists.
 
         Args:
-            room_alias(RoomAlias): The room alias to query.
+            room_alias: The room alias to query.
         Returns:
             namedtuple: with keys "room_id" and "servers" or None if no
             association can be found.
@@ -193,10 +343,13 @@ class ApplicationServicesHandler:
             )
             if is_known_alias:
                 # the alias exists now so don't query more ASes.
-                result = await self.store.get_association_from_room_alias(room_alias)
-                return result
+                return await self.store.get_association_from_room_alias(room_alias)
 
-    async def query_3pe(self, kind, protocol, fields):
+        return None
+
+    async def query_3pe(
+        self, kind: str, protocol: str, fields: Dict[bytes, List[bytes]]
+    ) -> List[JsonDict]:
         services = self._get_services_for_3pn(protocol)
 
         results = await make_deferred_yieldable(
@@ -218,9 +371,11 @@ class ApplicationServicesHandler:
 
         return ret
 
-    async def get_3pe_protocols(self, only_protocol=None):
+    async def get_3pe_protocols(
+        self, only_protocol: Optional[str] = None
+    ) -> Dict[str, JsonDict]:
         services = self.store.get_app_services()
-        protocols = {}
+        protocols = {}  # type: Dict[str, List[JsonDict]]
 
         # Collect up all the individual protocol responses out of the ASes
         for s in services:
@@ -236,7 +391,7 @@ class ApplicationServicesHandler:
                 if info is not None:
                     protocols[p].append(info)
 
-        def _merge_instances(infos):
+        def _merge_instances(infos: List[JsonDict]) -> JsonDict:
             if not infos:
                 return {}
 
@@ -251,19 +406,17 @@ class ApplicationServicesHandler:
 
             return combined
 
-        for p in protocols.keys():
-            protocols[p] = _merge_instances(protocols[p])
-
-        return protocols
+        return {p: _merge_instances(protocols[p]) for p in protocols.keys()}
 
-    async def _get_services_for_event(self, event):
+    async def _get_services_for_event(
+        self, event: EventBase
+    ) -> List[ApplicationService]:
         """Retrieve a list of application services interested in this event.
 
         Args:
-            event(Event): The event to check. Can be None if alias_list is not.
+            event: The event to check. Can be None if alias_list is not.
         Returns:
-            list<ApplicationService>: A list of services interested in this
-            event based on the service regex.
+            A list of services interested in this event based on the service regex.
         """
         services = self.store.get_app_services()
 
@@ -277,17 +430,15 @@ class ApplicationServicesHandler:
 
         return interested_list
 
-    def _get_services_for_user(self, user_id):
+    def _get_services_for_user(self, user_id: str) -> List[ApplicationService]:
         services = self.store.get_app_services()
-        interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
-        return interested_list
+        return [s for s in services if (s.is_interested_in_user(user_id))]
 
-    def _get_services_for_3pn(self, protocol):
+    def _get_services_for_3pn(self, protocol: str) -> List[ApplicationService]:
         services = self.store.get_app_services()
-        interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
-        return interested_list
+        return [s for s in services if s.is_interested_in_protocol(protocol)]
 
-    async def _is_unknown_user(self, user_id):
+    async def _is_unknown_user(self, user_id: str) -> bool:
         if not self.is_mine_id(user_id):
             # we don't know if they are unknown or not since it isn't one of our
             # users. We can't poke ASes.
@@ -302,9 +453,8 @@ class ApplicationServicesHandler:
         service_list = [s for s in services if s.sender == user_id]
         return len(service_list) == 0
 
-    async def _check_user_exists(self, user_id):
+    async def _check_user_exists(self, user_id: str) -> bool:
         unknown_user = await self._is_unknown_user(user_id)
         if unknown_user:
-            exists = await self.query_user_exists(user_id)
-            return exists
+            return await self.query_user_exists(user_id)
         return True
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 00eae92052..c7dc07008a 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014 - 2016 OpenMarket Ltd
 # Copyright 2017 Vector Creations Ltd
+# Copyright 2019 - 2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -18,10 +19,21 @@ import logging
 import time
 import unicodedata
 import urllib.parse
-from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Tuple,
+    Union,
+)
 
 import attr
-import bcrypt  # type: ignore[import]
+import bcrypt
 import pymacaroons
 
 from synapse.api.constants import LoginType
@@ -49,6 +61,9 @@ from synapse.util.threepids import canonicalise_email
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -149,11 +164,7 @@ class SsoLoginExtraAttributes:
 class AuthHandler(BaseHandler):
     SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000
 
-    def __init__(self, hs):
-        """
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
@@ -164,13 +175,20 @@ class AuthHandler(BaseHandler):
 
         self.bcrypt_rounds = hs.config.bcrypt_rounds
 
+        # we can't use hs.get_module_api() here, because to do so will create an
+        # import loop.
+        #
+        # TODO: refactor this class to separate the lower-level stuff that
+        #   ModuleApi can use from the higher-level stuff that uses ModuleApi, as
+        #   better way to break the loop
         account_handler = ModuleApi(hs, self)
+
         self.password_providers = [
-            module(config=config, account_handler=account_handler)
+            PasswordProvider.load(module, config, account_handler)
             for module, config in hs.config.password_providers
         ]
 
-        logger.info("Extra password_providers: %r", self.password_providers)
+        logger.info("Extra password_providers: %s", self.password_providers)
 
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
@@ -184,15 +202,23 @@ class AuthHandler(BaseHandler):
         # type in the list. (NB that the spec doesn't require us to do so and
         # clients which favour types that they don't understand over those that
         # they do are technically broken)
+
+        # start out by assuming PASSWORD is enabled; we will remove it later if not.
         login_types = []
-        if self._password_enabled:
+        if hs.config.password_localdb_enabled:
             login_types.append(LoginType.PASSWORD)
+
         for provider in self.password_providers:
             if hasattr(provider, "get_supported_login_types"):
                 for t in provider.get_supported_login_types().keys():
                     if t not in login_types:
                         login_types.append(t)
+
+        if not self._password_enabled:
+            login_types.remove(LoginType.PASSWORD)
+
         self._supported_login_types = login_types
+
         # Login types and UI Auth types have a heavy overlap, but are not
         # necessarily identical. Login types have SSO (and other login types)
         # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
@@ -209,10 +235,17 @@ class AuthHandler(BaseHandler):
             burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
         )
 
+        # Ratelimitier for failed /login attempts
+        self._failed_login_attempts_ratelimiter = Ratelimiter(
+            clock=hs.get_clock(),
+            rate_hz=self.hs.config.rc_login_failed_attempts.per_second,
+            burst_count=self.hs.config.rc_login_failed_attempts.burst_count,
+        )
+
         self._clock = self.hs.get_clock()
 
         # Expire old UI auth sessions after a period of time.
-        if hs.config.worker_app is None:
+        if hs.config.run_background_tasks:
             self._clock.looping_call(
                 run_as_background_process,
                 5 * 60 * 1000,
@@ -463,9 +496,7 @@ class AuthHandler(BaseHandler):
             # authentication flow.
             await self.store.set_ui_auth_clientdict(sid, clientdict)
 
-        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
-            0
-        ].decode("ascii", "surrogateescape")
+        user_agent = request.get_user_agent("")
 
         await self.store.add_user_agent_ip_to_ui_auth_session(
             session.session_id, user_agent, clientip
@@ -623,14 +654,8 @@ class AuthHandler(BaseHandler):
             res = await checker.check_auth(authdict, clientip=clientip)
             return res
 
-        # build a v1-login-style dict out of the authdict and fall back to the
-        # v1 code
-        user_id = authdict.get("user")
-
-        if user_id is None:
-            raise SynapseError(400, "", Codes.MISSING_PARAM)
-
-        (canonical_id, callback) = await self.validate_login(user_id, authdict)
+        # fall back to the v1 login flow
+        canonical_id, _ = await self.validate_login(authdict)
         return canonical_id
 
     def _get_params_recaptcha(self) -> dict:
@@ -679,13 +704,17 @@ class AuthHandler(BaseHandler):
         }
 
     async def get_access_token_for_user_id(
-        self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
-    ):
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        valid_until_ms: Optional[int],
+        puppets_user_id: Optional[str] = None,
+    ) -> str:
         """
         Creates a new access token for the user with the given user ID.
 
         The user is assumed to have been authenticated by some other
-        machanism (e.g. CAS), and the user_id converted to the canonical case.
+        mechanism (e.g. CAS), and the user_id converted to the canonical case.
 
         The device will be recorded in the table if it is not there already.
 
@@ -706,13 +735,25 @@ class AuthHandler(BaseHandler):
             fmt_expiry = time.strftime(
                 " until %Y-%m-%d %H:%M:%S", time.localtime(valid_until_ms / 1000.0)
             )
-        logger.info("Logging in user %s on device %s%s", user_id, device_id, fmt_expiry)
+
+        if puppets_user_id:
+            logger.info(
+                "Logging in user %s as %s%s", user_id, puppets_user_id, fmt_expiry
+            )
+        else:
+            logger.info(
+                "Logging in user %s on device %s%s", user_id, device_id, fmt_expiry
+            )
 
         await self.auth.check_auth_blocking(user_id)
 
         access_token = self.macaroon_gen.generate_access_token(user_id)
         await self.store.add_access_token_to_user(
-            user_id, access_token, device_id, valid_until_ms
+            user_id=user_id,
+            token=access_token,
+            device_id=device_id,
+            valid_until_ms=valid_until_ms,
+            puppets_user_id=puppets_user_id,
         )
 
         # the device *should* have been registered before we got here; however,
@@ -789,17 +830,17 @@ class AuthHandler(BaseHandler):
         return self._supported_login_types
 
     async def validate_login(
-        self, username: str, login_submission: Dict[str, Any]
+        self, login_submission: Dict[str, Any], ratelimit: bool = False,
     ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
         """Authenticates the user for the /login API
 
-        Also used by the user-interactive auth flow to validate
-        m.login.password auth types.
+        Also used by the user-interactive auth flow to validate auth types which don't
+        have an explicit UIA handler, including m.password.auth.
 
         Args:
-            username: username supplied by the user
             login_submission: the whole of the login submission
                 (including 'type' and other relevant fields)
+            ratelimit: whether to apply the failed_login_attempt ratelimiter
         Returns:
             A tuple of the canonical user id, and optional callback
                 to be called once the access token and device id are issued
@@ -808,38 +849,160 @@ class AuthHandler(BaseHandler):
             SynapseError if there was a problem with the request
             LoginError if there was an authentication problem.
         """
-
-        if username.startswith("@"):
-            qualified_user_id = username
-        else:
-            qualified_user_id = UserID(username, self.hs.hostname).to_string()
-
         login_type = login_submission.get("type")
-        known_login_type = False
+        if not isinstance(login_type, str):
+            raise SynapseError(400, "Bad parameter: type", Codes.INVALID_PARAM)
+
+        # ideally, we wouldn't be checking the identifier unless we know we have a login
+        # method which uses it (https://github.com/matrix-org/synapse/issues/8836)
+        #
+        # But the auth providers' check_auth interface requires a username, so in
+        # practice we can only support login methods which we can map to a username
+        # anyway.
 
         # special case to check for "password" for the check_password interface
         # for the auth providers
         password = login_submission.get("password")
-
         if login_type == LoginType.PASSWORD:
             if not self._password_enabled:
                 raise SynapseError(400, "Password login has been disabled.")
-            if not password:
-                raise SynapseError(400, "Missing parameter: password")
+            if not isinstance(password, str):
+                raise SynapseError(400, "Bad parameter: password", Codes.INVALID_PARAM)
 
-        for provider in self.password_providers:
-            if hasattr(provider, "check_password") and login_type == LoginType.PASSWORD:
-                known_login_type = True
-                is_valid = await provider.check_password(qualified_user_id, password)
-                if is_valid:
-                    return qualified_user_id, None
+        # map old-school login fields into new-school "identifier" fields.
+        identifier_dict = convert_client_dict_legacy_fields_to_identifier(
+            login_submission
+        )
 
-            if not hasattr(provider, "get_supported_login_types") or not hasattr(
-                provider, "check_auth"
-            ):
-                # this password provider doesn't understand custom login types
-                continue
+        # convert phone type identifiers to generic threepids
+        if identifier_dict["type"] == "m.id.phone":
+            identifier_dict = login_id_phone_to_thirdparty(identifier_dict)
+
+        # convert threepid identifiers to user IDs
+        if identifier_dict["type"] == "m.id.thirdparty":
+            address = identifier_dict.get("address")
+            medium = identifier_dict.get("medium")
+
+            if medium is None or address is None:
+                raise SynapseError(400, "Invalid thirdparty identifier")
+
+            # For emails, canonicalise the address.
+            # We store all email addresses canonicalised in the DB.
+            # (See add_threepid in synapse/handlers/auth.py)
+            if medium == "email":
+                try:
+                    address = canonicalise_email(address)
+                except ValueError as e:
+                    raise SynapseError(400, str(e))
+
+            # We also apply account rate limiting using the 3PID as a key, as
+            # otherwise using 3PID bypasses the ratelimiting based on user ID.
+            if ratelimit:
+                self._failed_login_attempts_ratelimiter.ratelimit(
+                    (medium, address), update=False
+                )
+
+            # Check for login providers that support 3pid login types
+            if login_type == LoginType.PASSWORD:
+                # we've already checked that there is a (valid) password field
+                assert isinstance(password, str)
+                (
+                    canonical_user_id,
+                    callback_3pid,
+                ) = await self.check_password_provider_3pid(medium, address, password)
+                if canonical_user_id:
+                    # Authentication through password provider and 3pid succeeded
+                    return canonical_user_id, callback_3pid
+
+            # No password providers were able to handle this 3pid
+            # Check local store
+            user_id = await self.hs.get_datastore().get_user_id_by_threepid(
+                medium, address
+            )
+            if not user_id:
+                logger.warning(
+                    "unknown 3pid identifier medium %s, address %r", medium, address
+                )
+                # We mark that we've failed to log in here, as
+                # `check_password_provider_3pid` might have returned `None` due
+                # to an incorrect password, rather than the account not
+                # existing.
+                #
+                # If it returned None but the 3PID was bound then we won't hit
+                # this code path, which is fine as then the per-user ratelimit
+                # will kick in below.
+                if ratelimit:
+                    self._failed_login_attempts_ratelimiter.can_do_action(
+                        (medium, address)
+                    )
+                raise LoginError(403, "", errcode=Codes.FORBIDDEN)
+
+            identifier_dict = {"type": "m.id.user", "user": user_id}
+
+        # by this point, the identifier should be an m.id.user: if it's anything
+        # else, we haven't understood it.
+        if identifier_dict["type"] != "m.id.user":
+            raise SynapseError(400, "Unknown login identifier type")
+
+        username = identifier_dict.get("user")
+        if not username:
+            raise SynapseError(400, "User identifier is missing 'user' key")
+
+        if username.startswith("@"):
+            qualified_user_id = username
+        else:
+            qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+        # Check if we've hit the failed ratelimit (but don't update it)
+        if ratelimit:
+            self._failed_login_attempts_ratelimiter.ratelimit(
+                qualified_user_id.lower(), update=False
+            )
 
+        try:
+            return await self._validate_userid_login(username, login_submission)
+        except LoginError:
+            # The user has failed to log in, so we need to update the rate
+            # limiter. Using `can_do_action` avoids us raising a ratelimit
+            # exception and masking the LoginError. The actual ratelimiting
+            # should have happened above.
+            if ratelimit:
+                self._failed_login_attempts_ratelimiter.can_do_action(
+                    qualified_user_id.lower()
+                )
+            raise
+
+    async def _validate_userid_login(
+        self, username: str, login_submission: Dict[str, Any],
+    ) -> Tuple[str, Optional[Callable[[Dict[str, str]], None]]]:
+        """Helper for validate_login
+
+        Handles login, once we've mapped 3pids onto userids
+
+        Args:
+            username: the username, from the identifier dict
+            login_submission: the whole of the login submission
+                (including 'type' and other relevant fields)
+        Returns:
+            A tuple of the canonical user id, and optional callback
+                to be called once the access token and device id are issued
+        Raises:
+            StoreError if there was a problem accessing the database
+            SynapseError if there was a problem with the request
+            LoginError if there was an authentication problem.
+        """
+        if username.startswith("@"):
+            qualified_user_id = username
+        else:
+            qualified_user_id = UserID(username, self.hs.hostname).to_string()
+
+        login_type = login_submission.get("type")
+        # we already checked that we have a valid login type
+        assert isinstance(login_type, str)
+
+        known_login_type = False
+
+        for provider in self.password_providers:
             supported_login_types = provider.get_supported_login_types()
             if login_type not in supported_login_types:
                 # this password provider doesn't understand this login type
@@ -864,15 +1027,17 @@ class AuthHandler(BaseHandler):
 
             result = await provider.check_auth(username, login_type, login_dict)
             if result:
-                if isinstance(result, str):
-                    result = (result, None)
                 return result
 
         if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
             known_login_type = True
 
+            # we've already checked that there is a (valid) password field
+            password = login_submission["password"]
+            assert isinstance(password, str)
+
             canonical_user_id = await self._check_local_password(
-                qualified_user_id, password  # type: ignore
+                qualified_user_id, password
             )
 
             if canonical_user_id:
@@ -903,19 +1068,9 @@ class AuthHandler(BaseHandler):
             unsuccessful, `user_id` and `callback` are both `None`.
         """
         for provider in self.password_providers:
-            if hasattr(provider, "check_3pid_auth"):
-                # This function is able to return a deferred that either
-                # resolves None, meaning authentication failure, or upon
-                # success, to a str (which is the user_id) or a tuple of
-                # (user_id, callback_func), where callback_func should be run
-                # after we've finished everything else
-                result = await provider.check_3pid_auth(medium, address, password)
-                if result:
-                    # Check if the return value is a str or a tuple
-                    if isinstance(result, str):
-                        # If it's a str, set callback function to None
-                        result = (result, None)
-                    return result
+            result = await provider.check_3pid_auth(medium, address, password)
+            if result:
+                return result
 
         return None, None
 
@@ -973,21 +1128,16 @@ class AuthHandler(BaseHandler):
 
         # see if any of our auth providers want to know about this
         for provider in self.password_providers:
-            if hasattr(provider, "on_logged_out"):
-                # This might return an awaitable, if it does block the log out
-                # until it completes.
-                result = provider.on_logged_out(
-                    user_id=str(user_info["user"]),
-                    device_id=user_info["device_id"],
-                    access_token=access_token,
-                )
-                if inspect.isawaitable(result):
-                    await result
+            await provider.on_logged_out(
+                user_id=user_info.user_id,
+                device_id=user_info.device_id,
+                access_token=access_token,
+            )
 
         # delete pushers associated with this access token
-        if user_info["token_id"] is not None:
+        if user_info.token_id is not None:
             await self.hs.get_pusherpool().remove_pushers_by_access_token(
-                str(user_info["user"]), (user_info["token_id"],)
+                user_info.user_id, (user_info.token_id,)
             )
 
     async def delete_access_tokens_for_user(
@@ -1011,11 +1161,10 @@ class AuthHandler(BaseHandler):
 
         # see if any of our auth providers want to know about this
         for provider in self.password_providers:
-            if hasattr(provider, "on_logged_out"):
-                for token, token_id, device_id in tokens_and_devices:
-                    await provider.on_logged_out(
-                        user_id=user_id, device_id=device_id, access_token=token
-                    )
+            for token, token_id, device_id in tokens_and_devices:
+                await provider.on_logged_out(
+                    user_id=user_id, device_id=device_id, access_token=token
+                )
 
         # delete pushers associated with the access tokens
         await self.hs.get_pusherpool().remove_pushers_by_access_token(
@@ -1073,7 +1222,7 @@ class AuthHandler(BaseHandler):
         if medium == "email":
             address = canonicalise_email(address)
 
-        identity_handler = self.hs.get_handlers().identity_handler
+        identity_handler = self.hs.get_identity_handler()
         result = await identity_handler.try_unbind_threepid(
             user_id, {"medium": medium, "address": address, "id_server": id_server}
         )
@@ -1115,20 +1264,22 @@ class AuthHandler(BaseHandler):
             Whether self.hash(password) == stored_hash.
         """
 
-        def _do_validate_hash():
+        def _do_validate_hash(checked_hash: bytes):
             # Normalise the Unicode in the password
             pw = unicodedata.normalize("NFKC", password)
 
             return bcrypt.checkpw(
                 pw.encode("utf8") + self.hs.config.password_pepper.encode("utf8"),
-                stored_hash,
+                checked_hash,
             )
 
         if stored_hash:
             if not isinstance(stored_hash, bytes):
                 stored_hash = stored_hash.encode("ascii")
 
-            return await defer_to_thread(self.hs.get_reactor(), _do_validate_hash)
+            return await defer_to_thread(
+                self.hs.get_reactor(), _do_validate_hash, stored_hash
+            )
         else:
             return False
 
@@ -1337,3 +1488,127 @@ class MacaroonGenerator:
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("user_id = %s" % (user_id,))
         return macaroon
+
+
+class PasswordProvider:
+    """Wrapper for a password auth provider module
+
+    This class abstracts out all of the backwards-compatibility hacks for
+    password providers, to provide a consistent interface.
+    """
+
+    @classmethod
+    def load(cls, module, config, module_api: ModuleApi) -> "PasswordProvider":
+        try:
+            pp = module(config=config, account_handler=module_api)
+        except Exception as e:
+            logger.error("Error while initializing %r: %s", module, e)
+            raise
+        return cls(pp, module_api)
+
+    def __init__(self, pp, module_api: ModuleApi):
+        self._pp = pp
+        self._module_api = module_api
+
+        self._supported_login_types = {}
+
+        # grandfather in check_password support
+        if hasattr(self._pp, "check_password"):
+            self._supported_login_types[LoginType.PASSWORD] = ("password",)
+
+        g = getattr(self._pp, "get_supported_login_types", None)
+        if g:
+            self._supported_login_types.update(g())
+
+    def __str__(self):
+        return str(self._pp)
+
+    def get_supported_login_types(self) -> Mapping[str, Iterable[str]]:
+        """Get the login types supported by this password provider
+
+        Returns a map from a login type identifier (such as m.login.password) to an
+        iterable giving the fields which must be provided by the user in the submission
+        to the /login API.
+
+        This wrapper adds m.login.password to the list if the underlying password
+        provider supports the check_password() api.
+        """
+        return self._supported_login_types
+
+    async def check_auth(
+        self, username: str, login_type: str, login_dict: JsonDict
+    ) -> Optional[Tuple[str, Optional[Callable]]]:
+        """Check if the user has presented valid login credentials
+
+        This wrapper also calls check_password() if the underlying password provider
+        supports the check_password() api and the login type is m.login.password.
+
+        Args:
+            username: user id presented by the client. Either an MXID or an unqualified
+                username.
+
+            login_type: the login type being attempted - one of the types returned by
+                get_supported_login_types()
+
+            login_dict: the dictionary of login secrets passed by the client.
+
+        Returns: (user_id, callback) where `user_id` is the fully-qualified mxid of the
+            user, and `callback` is an optional callback which will be called with the
+            result from the /login call (including access_token, device_id, etc.)
+        """
+        # first grandfather in a call to check_password
+        if login_type == LoginType.PASSWORD:
+            g = getattr(self._pp, "check_password", None)
+            if g:
+                qualified_user_id = self._module_api.get_qualified_user_id(username)
+                is_valid = await self._pp.check_password(
+                    qualified_user_id, login_dict["password"]
+                )
+                if is_valid:
+                    return qualified_user_id, None
+
+        g = getattr(self._pp, "check_auth", None)
+        if not g:
+            return None
+        result = await g(username, login_type, login_dict)
+
+        # Check if the return value is a str or a tuple
+        if isinstance(result, str):
+            # If it's a str, set callback function to None
+            return result, None
+
+        return result
+
+    async def check_3pid_auth(
+        self, medium: str, address: str, password: str
+    ) -> Optional[Tuple[str, Optional[Callable]]]:
+        g = getattr(self._pp, "check_3pid_auth", None)
+        if not g:
+            return None
+
+        # This function is able to return a deferred that either
+        # resolves None, meaning authentication failure, or upon
+        # success, to a str (which is the user_id) or a tuple of
+        # (user_id, callback_func), where callback_func should be run
+        # after we've finished everything else
+        result = await g(medium, address, password)
+
+        # Check if the return value is a str or a tuple
+        if isinstance(result, str):
+            # If it's a str, set callback function to None
+            return result, None
+
+        return result
+
+    async def on_logged_out(
+        self, user_id: str, device_id: Optional[str], access_token: str
+    ) -> None:
+        g = getattr(self._pp, "on_logged_out", None)
+        if not g:
+            return
+
+        # This might return an awaitable, if it does block the log out
+        # until it completes.
+        result = g(user_id=user_id, device_id=device_id, access_token=access_token,)
+        if inspect.isawaitable(result):
+            await result
diff --git a/synapse/handlers/cas_handler.py b/synapse/handlers/cas_handler.py
index a4cc4b9a5a..f4ea0a9767 100644
--- a/synapse/handlers/cas_handler.py
+++ b/synapse/handlers/cas_handler.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import urllib
-from typing import Dict, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, Optional, Tuple
 from xml.etree import ElementTree as ET
 
 from twisted.web.client import PartialDownloadError
@@ -23,6 +23,9 @@ from synapse.api.errors import Codes, LoginError
 from synapse.http.site import SynapseRequest
 from synapse.types import UserID, map_username_to_mxid_localpart
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -31,10 +34,10 @@ class CasHandler:
     Utility class for to handle the response from a CAS SSO service.
 
     Args:
-        hs (synapse.server.HomeServer)
+        hs
     """
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.hs = hs
         self._hostname = hs.hostname
         self._auth_handler = hs.get_auth_handler()
@@ -200,29 +203,57 @@ class CasHandler:
             args["session"] = session
         username, user_display_name = await self._validate_ticket(ticket, args)
 
-        localpart = map_username_to_mxid_localpart(username)
-        user_id = UserID(localpart, self._hostname).to_string()
-        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+        # Pull out the user-agent and IP from the request.
+        user_agent = request.get_user_agent("")
+        ip_address = self.hs.get_ip_from_request(request)
+
+        # Get the matrix ID from the CAS username.
+        user_id = await self._map_cas_user_to_matrix_user(
+            username, user_display_name, user_agent, ip_address
+        )
 
         if session:
             await self._auth_handler.complete_sso_ui_auth(
-                registered_user_id, session, request,
+                user_id, session, request,
             )
-
         else:
-            if not registered_user_id:
-                # Pull out the user-agent and IP from the request.
-                user_agent = request.requestHeaders.getRawHeaders(
-                    b"User-Agent", default=[b""]
-                )[0].decode("ascii", "surrogateescape")
-                ip_address = self.hs.get_ip_from_request(request)
-
-                registered_user_id = await self._registration_handler.register_user(
-                    localpart=localpart,
-                    default_display_name=user_display_name,
-                    user_agent_ips=(user_agent, ip_address),
-                )
+            # If this not a UI auth request than there must be a redirect URL.
+            assert client_redirect_url
 
             await self._auth_handler.complete_sso_login(
-                registered_user_id, request, client_redirect_url
+                user_id, request, client_redirect_url
+            )
+
+    async def _map_cas_user_to_matrix_user(
+        self,
+        remote_user_id: str,
+        display_name: Optional[str],
+        user_agent: str,
+        ip_address: str,
+    ) -> str:
+        """
+        Given a CAS username, retrieve the user ID for it and possibly register the user.
+
+        Args:
+            remote_user_id: The username from the CAS response.
+            display_name: The display name from the CAS response.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+
+        Returns:
+             The user ID associated with this response.
+        """
+
+        localpart = map_username_to_mxid_localpart(remote_user_id)
+        user_id = UserID(localpart, self._hostname).to_string()
+        registered_user_id = await self._auth_handler.check_user_exists(user_id)
+
+        # If the user does not exist, register it.
+        if not registered_user_id:
+            registered_user_id = await self._registration_handler.register_user(
+                localpart=localpart,
+                default_display_name=display_name,
+                user_agent_ips=[(user_agent, ip_address)],
             )
+
+        return registered_user_id
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 0635ad5708..e808142365 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -14,7 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Optional
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import SynapseError
 from synapse.metrics.background_process_metrics import run_as_background_process
@@ -22,27 +22,31 @@ from synapse.types import UserID, create_requester
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class DeactivateAccountHandler(BaseHandler):
     """Handler which deals with deactivating user accounts."""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.hs = hs
         self._auth_handler = hs.get_auth_handler()
         self._device_handler = hs.get_device_handler()
         self._room_member_handler = hs.get_room_member_handler()
-        self._identity_handler = hs.get_handlers().identity_handler
+        self._identity_handler = hs.get_identity_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
+        self._server_name = hs.hostname
 
         # Flag that indicates whether the process to part users from rooms is running
         self._user_parter_running = False
 
         # Start the user parter loop so it can resume parting users from rooms where
         # it left off (if it has work left to do).
-        if hs.config.worker_app is None:
+        if hs.config.run_background_tasks:
             hs.get_reactor().callWhenRunning(self._start_user_parting)
 
         self._account_validity_enabled = hs.config.account_validity.enabled
@@ -137,7 +141,7 @@ class DeactivateAccountHandler(BaseHandler):
 
         return identity_server_supports_unbinding
 
-    async def _reject_pending_invites_for_user(self, user_id: str):
+    async def _reject_pending_invites_for_user(self, user_id: str) -> None:
         """Reject pending invites addressed to a given user ID.
 
         Args:
@@ -149,7 +153,7 @@ class DeactivateAccountHandler(BaseHandler):
         for room in pending_invites:
             try:
                 await self._room_member_handler.update_membership(
-                    create_requester(user),
+                    create_requester(user, authenticated_entity=self._server_name),
                     user,
                     room.room_id,
                     "leave",
@@ -205,7 +209,7 @@ class DeactivateAccountHandler(BaseHandler):
             logger.info("User parter parting %r from %r", user_id, room_id)
             try:
                 await self._room_member_handler.update_membership(
-                    create_requester(user),
+                    create_requester(user, authenticated_entity=self._server_name),
                     user,
                     room_id,
                     "leave",
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b9d9098104..debb1b4f29 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -1,7 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2016 OpenMarket Ltd
 # Copyright 2019 New Vector Ltd
-# Copyright 2019 The Matrix.org Foundation C.I.C.
+# Copyright 2019,2020 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Tuple
 
 from synapse.api import errors
 from synapse.api.constants import EventTypes
@@ -29,7 +29,10 @@ from synapse.api.errors import (
 from synapse.logging.opentracing import log_kv, set_tag, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import (
+    Collection,
+    JsonDict,
     StreamToken,
+    UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
@@ -41,13 +44,16 @@ from synapse.util.retryutils import NotRetryingDestination
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 MAX_DEVICE_DISPLAY_NAME_LEN = 100
 
 
 class DeviceWorkerHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.hs = hs
@@ -105,7 +111,9 @@ class DeviceWorkerHandler(BaseHandler):
 
     @trace
     @measure_func("device.get_user_ids_changed")
-    async def get_user_ids_changed(self, user_id: str, from_token: StreamToken):
+    async def get_user_ids_changed(
+        self, user_id: str, from_token: StreamToken
+    ) -> JsonDict:
         """Get list of users that have had the devices updated, or have newly
         joined a room, that `user_id` may be interested in.
         """
@@ -221,8 +229,8 @@ class DeviceWorkerHandler(BaseHandler):
             possibly_joined = possibly_changed & users_who_share_room
             possibly_left = (possibly_changed | possibly_left) - users_who_share_room
         else:
-            possibly_joined = []
-            possibly_left = []
+            possibly_joined = set()
+            possibly_left = set()
 
         result = {"changed": list(possibly_joined), "left": list(possibly_left)}
 
@@ -230,7 +238,7 @@ class DeviceWorkerHandler(BaseHandler):
 
         return result
 
-    async def on_federation_query_user_devices(self, user_id):
+    async def on_federation_query_user_devices(self, user_id: str) -> JsonDict:
         stream_id, devices = await self.store.get_e2e_device_keys_for_federation_query(
             user_id
         )
@@ -249,7 +257,7 @@ class DeviceWorkerHandler(BaseHandler):
 
 
 class DeviceHandler(DeviceWorkerHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
@@ -264,7 +272,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
         hs.get_distributor().observe("user_left_room", self.user_left_room)
 
-    def _check_device_name_length(self, name: str):
+    def _check_device_name_length(self, name: Optional[str]):
         """
         Checks whether a device name is longer than the maximum allowed length.
 
@@ -283,8 +291,11 @@ class DeviceHandler(DeviceWorkerHandler):
             )
 
     async def check_device_registered(
-        self, user_id, device_id, initial_device_display_name=None
-    ):
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_device_display_name: Optional[str] = None,
+    ) -> str:
         """
         If the given device has not been registered, register it with the
         supplied display name.
@@ -292,12 +303,11 @@ class DeviceHandler(DeviceWorkerHandler):
         If no device_id is supplied, we make one up.
 
         Args:
-            user_id (str):  @user:id
-            device_id (str | None): device id supplied by client
-            initial_device_display_name (str | None): device display name from
-                 client
+            user_id:  @user:id
+            device_id: device id supplied by client
+            initial_device_display_name: device display name from client
         Returns:
-            str: device id (generated if none was supplied)
+            device id (generated if none was supplied)
         """
 
         self._check_device_name_length(initial_device_display_name)
@@ -316,15 +326,15 @@ class DeviceHandler(DeviceWorkerHandler):
         # times in case of a clash.
         attempts = 0
         while attempts < 5:
-            device_id = stringutils.random_string(10).upper()
+            new_device_id = stringutils.random_string(10).upper()
             new_device = await self.store.store_device(
                 user_id=user_id,
-                device_id=device_id,
+                device_id=new_device_id,
                 initial_device_display_name=initial_device_display_name,
             )
             if new_device:
-                await self.notify_device_update(user_id, [device_id])
-                return device_id
+                await self.notify_device_update(user_id, [new_device_id])
+                return new_device_id
             attempts += 1
 
         raise errors.StoreError(500, "Couldn't generate a device ID.")
@@ -433,7 +443,9 @@ class DeviceHandler(DeviceWorkerHandler):
 
     @trace
     @measure_func("notify_device_update")
-    async def notify_device_update(self, user_id, device_ids):
+    async def notify_device_update(
+        self, user_id: str, device_ids: Collection[str]
+    ) -> None:
         """Notify that a user's device(s) has changed. Pokes the notifier, and
         remote servers if the user is local.
         """
@@ -445,7 +457,7 @@ class DeviceHandler(DeviceWorkerHandler):
             user_id
         )
 
-        hosts = set()
+        hosts = set()  # type: Set[str]
         if self.hs.is_mine_id(user_id):
             hosts.update(get_domain_from_id(u) for u in users_who_share_room)
             hosts.discard(self.server_name)
@@ -497,7 +509,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
         self.notifier.on_new_event("device_list_key", position, users=[from_user_id])
 
-    async def user_left_room(self, user, room_id):
+    async def user_left_room(self, user: UserID, room_id: str) -> None:
         user_id = user.to_string()
         room_ids = await self.store.get_rooms_for_user(user_id)
         if not room_ids:
@@ -505,8 +517,89 @@ class DeviceHandler(DeviceWorkerHandler):
             # receive device updates. Mark this in DB.
             await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
 
+    async def store_dehydrated_device(
+        self,
+        user_id: str,
+        device_data: JsonDict,
+        initial_device_display_name: Optional[str] = None,
+    ) -> str:
+        """Store a dehydrated device for a user.  If the user had a previous
+        dehydrated device, it is removed.
+
+        Args:
+            user_id: the user that we are storing the device for
+            device_data: the dehydrated device information
+            initial_device_display_name: The display name to use for the device
+        Returns:
+            device id of the dehydrated device
+        """
+        device_id = await self.check_device_registered(
+            user_id, None, initial_device_display_name,
+        )
+        old_device_id = await self.store.store_dehydrated_device(
+            user_id, device_id, device_data
+        )
+        if old_device_id is not None:
+            await self.delete_device(user_id, old_device_id)
+        return device_id
+
+    async def get_dehydrated_device(
+        self, user_id: str
+    ) -> Optional[Tuple[str, JsonDict]]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        return await self.store.get_dehydrated_device(user_id)
+
+    async def rehydrate_device(
+        self, user_id: str, access_token: str, device_id: str
+    ) -> dict:
+        """Process a rehydration request from the user.
+
+        Args:
+            user_id: the user who is rehydrating the device
+            access_token: the access token used for the request
+            device_id: the ID of the device that will be rehydrated
+        Returns:
+            a dict containing {"success": True}
+        """
+        success = await self.store.remove_dehydrated_device(user_id, device_id)
+
+        if not success:
+            raise errors.NotFoundError()
+
+        # If the dehydrated device was successfully deleted (the device ID
+        # matched the stored dehydrated device), then modify the access
+        # token to use the dehydrated device's ID and copy the old device
+        # display name to the dehydrated device, and destroy the old device
+        # ID
+        old_device_id = await self.store.set_device_for_access_token(
+            access_token, device_id
+        )
+        old_device = await self.store.get_device(user_id, old_device_id)
+        await self.store.update_device(user_id, device_id, old_device["display_name"])
+        # can't call self.delete_device because that will clobber the
+        # access token so call the storage layer directly
+        await self.store.delete_device(user_id, old_device_id)
+        await self.store.delete_e2e_keys_by_device(
+            user_id=user_id, device_id=old_device_id
+        )
 
-def _update_device_from_client_ips(device, client_ips):
+        # tell everyone that the old device is gone and that the dehydrated
+        # device has a new display name
+        await self.notify_device_update(user_id, [old_device_id, device_id])
+
+        return {"success": True}
+
+
+def _update_device_from_client_ips(
+    device: Dict[str, Any], client_ips: Dict[Tuple[str, str], Dict[str, Any]]
+) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
     device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
 
@@ -514,7 +607,7 @@ def _update_device_from_client_ips(device, client_ips):
 class DeviceListUpdater:
     "Handles incoming device list updates from federation and updates the DB"
 
-    def __init__(self, hs, device_handler):
+    def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
         self.store = hs.get_datastore()
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
@@ -523,7 +616,9 @@ class DeviceListUpdater:
         self._remote_edu_linearizer = Linearizer(name="remote_device_list")
 
         # user_id -> list of updates waiting to be handled.
-        self._pending_updates = {}
+        self._pending_updates = (
+            {}
+        )  # type: Dict[str, List[Tuple[str, str, Iterable[str], JsonDict]]]
 
         # Recently seen stream ids. We don't bother keeping these in the DB,
         # but they're useful to have them about to reduce the number of spurious
@@ -546,7 +641,9 @@ class DeviceListUpdater:
         )
 
     @trace
-    async def incoming_device_list_update(self, origin, edu_content):
+    async def incoming_device_list_update(
+        self, origin: str, edu_content: JsonDict
+    ) -> None:
         """Called on incoming device list update from federation. Responsible
         for parsing the EDU and adding to pending updates list.
         """
@@ -607,7 +704,7 @@ class DeviceListUpdater:
         await self._handle_device_updates(user_id)
 
     @measure_func("_incoming_device_list_update")
-    async def _handle_device_updates(self, user_id):
+    async def _handle_device_updates(self, user_id: str) -> None:
         "Actually handle pending updates."
 
         with (await self._remote_edu_linearizer.queue(user_id)):
@@ -655,7 +752,9 @@ class DeviceListUpdater:
                     stream_id for _, stream_id, _, _ in pending_updates
                 )
 
-    async def _need_to_do_resync(self, user_id, updates):
+    async def _need_to_do_resync(
+        self, user_id: str, updates: Iterable[Tuple[str, str, Iterable[str], JsonDict]]
+    ) -> bool:
         """Given a list of updates for a user figure out if we need to do a full
         resync, or whether we have enough data that we can just apply the delta.
         """
@@ -686,7 +785,7 @@ class DeviceListUpdater:
         return False
 
     @trace
-    async def _maybe_retry_device_resync(self):
+    async def _maybe_retry_device_resync(self) -> None:
         """Retry to resync device lists that are out of sync, except if another retry is
         in progress.
         """
@@ -729,7 +828,7 @@ class DeviceListUpdater:
 
     async def user_device_resync(
         self, user_id: str, mark_failed_as_stale: bool = True
-    ) -> Optional[dict]:
+    ) -> Optional[JsonDict]:
         """Fetches all devices for a user and updates the device cache with them.
 
         Args:
@@ -753,7 +852,7 @@ class DeviceListUpdater:
                 # it later.
                 await self.store.mark_remote_user_device_cache_as_stale(user_id)
 
-            return
+            return None
         except (RequestSendFailed, HttpResponseException) as e:
             logger.warning(
                 "Failed to handle device list update for %s: %s", user_id, e,
@@ -770,12 +869,12 @@ class DeviceListUpdater:
             # next time we get a device list update for this user_id.
             # This makes it more likely that the device lists will
             # eventually become consistent.
-            return
+            return None
         except FederationDeniedError as e:
             set_tag("error", True)
             log_kv({"reason": "FederationDeniedError"})
             logger.info(e)
-            return
+            return None
         except Exception as e:
             set_tag("error", True)
             log_kv(
@@ -788,7 +887,7 @@ class DeviceListUpdater:
                 # it later.
                 await self.store.mark_remote_user_device_cache_as_stale(user_id)
 
-            return
+            return None
         log_kv({"result": result})
         stream_id = result["stream_id"]
         devices = result["devices"]
@@ -849,7 +948,7 @@ class DeviceListUpdater:
         user_id: str,
         master_key: Optional[Dict[str, Any]],
         self_signing_key: Optional[Dict[str, Any]],
-    ) -> list:
+    ) -> List[str]:
         """Process the given new master and self-signing key for the given remote user.
 
         Args:
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 64ef7f63ab..9cac5a8463 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict
+from typing import TYPE_CHECKING, Any, Dict
 
 from synapse.api.errors import SynapseError
 from synapse.logging.context import run_in_background
@@ -24,18 +24,22 @@ from synapse.logging.opentracing import (
     set_tag,
     start_active_span,
 )
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util import json_encoder
 from synapse.util.stringutils import random_string
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
+
 logger = logging.getLogger(__name__)
 
 
 class DeviceMessageHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         """
         Args:
-            hs (synapse.server.HomeServer): server
+            hs: server
         """
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
@@ -48,7 +52,7 @@ class DeviceMessageHandler:
 
         self._device_list_updater = hs.get_device_handler().device_list_updater
 
-    async def on_direct_to_device_edu(self, origin, content):
+    async def on_direct_to_device_edu(self, origin: str, content: JsonDict) -> None:
         local_messages = {}
         sender_user_id = content["sender"]
         if origin != get_domain_from_id(sender_user_id):
@@ -95,7 +99,7 @@ class DeviceMessageHandler:
         message_type: str,
         sender_user_id: str,
         by_device: Dict[str, Dict[str, Any]],
-    ):
+    ) -> None:
         """Checks inbound device messages for unknown remote devices, and if
         found marks the remote cache for the user as stale.
         """
@@ -138,11 +142,16 @@ class DeviceMessageHandler:
                 self._device_list_updater.user_device_resync, sender_user_id
             )
 
-    async def send_device_message(self, sender_user_id, message_type, messages):
+    async def send_device_message(
+        self,
+        sender_user_id: str,
+        message_type: str,
+        messages: Dict[str, Dict[str, JsonDict]],
+    ) -> None:
         set_tag("number_of_messages", len(messages))
         set_tag("sender", sender_user_id)
         local_messages = {}
-        remote_messages = {}
+        remote_messages = {}  # type: Dict[str, Dict[str, Dict[str, JsonDict]]]
         for user_id, by_device in messages.items():
             # we use UserID.from_string to catch invalid user ids
             if self.is_mine(UserID.from_string(user_id)):
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 62aa9a2da8..ad5683d251 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -46,6 +46,7 @@ class DirectoryHandler(BaseHandler):
         self.config = hs.config
         self.enable_room_list_search = hs.config.enable_room_list_search
         self.require_membership = hs.config.require_membership_for_aliases
+        self.third_party_event_rules = hs.get_third_party_event_rules()
 
         self.federation = hs.get_federation_client()
         hs.get_federation_registry().register_query_handler(
@@ -383,7 +384,7 @@ class DirectoryHandler(BaseHandler):
         """
         creator = await self.store.get_room_alias_creator(alias.to_string())
 
-        if creator is not None and creator == user_id:
+        if creator == user_id:
             return True
 
         # Resolve the alias to the corresponding room.
@@ -454,6 +455,15 @@ class DirectoryHandler(BaseHandler):
                 # per alias creation rule?
                 raise SynapseError(403, "Not allowed to publish room")
 
+            # Check if publishing is blocked by a third party module
+            allowed_by_third_party_rules = await (
+                self.third_party_event_rules.check_visibility_can_be_modified(
+                    room_id, visibility
+                )
+            )
+            if not allowed_by_third_party_rules:
+                raise SynapseError(403, "Not allowed to publish room")
+
         await self.store.set_room_is_public(room_id, making_public)
 
     async def edit_published_appservice_room_list(
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index dd40fd1299..929752150d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -129,6 +129,11 @@ class E2eKeysHandler:
                 if user_id in local_query:
                     results[user_id] = keys
 
+        # Get cached cross-signing keys
+        cross_signing_keys = await self.get_cross_signing_keys_from_cache(
+            device_keys_query, from_user_id
+        )
+
         # Now attempt to get any remote devices from our local cache.
         remote_queries_not_in_cache = {}
         if remote_queries:
@@ -155,16 +160,28 @@ class E2eKeysHandler:
                             unsigned["device_display_name"] = device_display_name
                         user_devices[device_id] = result
 
+            # check for missing cross-signing keys.
+            for user_id in remote_queries.keys():
+                cached_cross_master = user_id in cross_signing_keys["master_keys"]
+                cached_cross_selfsigning = (
+                    user_id in cross_signing_keys["self_signing_keys"]
+                )
+
+                # check if we are missing only one of cross-signing master or
+                # self-signing key, but the other one is cached.
+                # as we need both, this will issue a federation request.
+                # if we don't have any of the keys, either the user doesn't have
+                # cross-signing set up, or the cached device list
+                # is not (yet) updated.
+                if cached_cross_master ^ cached_cross_selfsigning:
+                    user_ids_not_in_cache.add(user_id)
+
+            # add those users to the list to fetch over federation.
             for user_id in user_ids_not_in_cache:
                 domain = get_domain_from_id(user_id)
                 r = remote_queries_not_in_cache.setdefault(domain, {})
                 r[user_id] = remote_queries[user_id]
 
-        # Get cached cross-signing keys
-        cross_signing_keys = await self.get_cross_signing_keys_from_cache(
-            device_keys_query, from_user_id
-        )
-
         # Now fetch any devices that we don't have in our cache
         @trace
         async def do_remote_query(destination):
@@ -496,6 +513,22 @@ class E2eKeysHandler:
             log_kv(
                 {"message": "Did not update one_time_keys", "reason": "no keys given"}
             )
+        fallback_keys = keys.get("org.matrix.msc2732.fallback_keys", None)
+        if fallback_keys and isinstance(fallback_keys, dict):
+            log_kv(
+                {
+                    "message": "Updating fallback_keys for device.",
+                    "user_id": user_id,
+                    "device_id": device_id,
+                }
+            )
+            await self.store.set_e2e_fallback_keys(user_id, device_id, fallback_keys)
+        elif fallback_keys:
+            log_kv({"message": "Did not update fallback_keys", "reason": "not a dict"})
+        else:
+            log_kv(
+                {"message": "Did not update fallback_keys", "reason": "no keys given"}
+            )
 
         # the device should have been registered already, but it may have been
         # deleted due to a race with a DELETE request. Or we may be using an
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1a8144405a..b9799090f7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -55,6 +55,7 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.handlers._base import BaseHandler
+from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import (
     make_deferred_yieldable,
     nested_logging_context,
@@ -67,7 +68,7 @@ from synapse.replication.http.devices import ReplicationUserDevicesResyncRestSer
 from synapse.replication.http.federation import (
     ReplicationCleanRoomRestServlet,
     ReplicationFederationSendEventsRestServlet,
-    ReplicationStoreRoomOnInviteRestServlet,
+    ReplicationStoreRoomOnOutlierMembershipRestServlet,
 )
 from synapse.state import StateResolutionStore
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
@@ -112,7 +113,7 @@ class FederationHandler(BaseHandler):
     """Handles events that originated from federation.
         Responsible for:
         a) handling received Pdus before handing them on as Events to the rest
-        of the homeserver (including auth and state conflict resoultion)
+        of the homeserver (including auth and state conflict resolutions)
         b) converting events that were produced by local clients that may need
         to be sent to remote homeservers.
         c) doing the necessary dances to invite remote users and join remote
@@ -152,12 +153,14 @@ class FederationHandler(BaseHandler):
             self._user_device_resync = ReplicationUserDevicesResyncRestServlet.make_client(
                 hs
             )
-            self._maybe_store_room_on_invite = ReplicationStoreRoomOnInviteRestServlet.make_client(
+            self._maybe_store_room_on_outlier_membership = ReplicationStoreRoomOnOutlierMembershipRestServlet.make_client(
                 hs
             )
         else:
             self._device_list_updater = hs.get_device_handler().device_list_updater
-            self._maybe_store_room_on_invite = self.store.maybe_store_room_on_invite
+            self._maybe_store_room_on_outlier_membership = (
+                self.store.maybe_store_room_on_outlier_membership
+            )
 
         # When joining a room we need to queue any events for that room up.
         # For each room, a list of (pdu, origin) tuples.
@@ -477,7 +480,7 @@ class FederationHandler(BaseHandler):
         # ----
         #
         # Update richvdh 2018/09/18: There are a number of problems with timing this
-        # request out agressively on the client side:
+        # request out aggressively on the client side:
         #
         # - it plays badly with the server-side rate-limiter, which starts tarpitting you
         #   if you send too many requests at once, so you end up with the server carefully
@@ -495,13 +498,13 @@ class FederationHandler(BaseHandler):
         #   we'll end up back here for the *next* PDU in the list, which exacerbates the
         #   problem.
         #
-        # - the agressive 10s timeout was introduced to deal with incoming federation
+        # - the aggressive 10s timeout was introduced to deal with incoming federation
         #   requests taking 8 hours to process. It's not entirely clear why that was going
         #   on; certainly there were other issues causing traffic storms which are now
         #   resolved, and I think in any case we may be more sensible about our locking
         #   now. We're *certainly* more sensible about our logging.
         #
-        # All that said: Let's try increasing the timout to 60s and see what happens.
+        # All that said: Let's try increasing the timeout to 60s and see what happens.
 
         try:
             missing_events = await self.federation_client.get_missing_events(
@@ -1120,7 +1123,7 @@ class FederationHandler(BaseHandler):
                     logger.info(str(e))
                     continue
                 except RequestSendFailed as e:
-                    logger.info("Falied to get backfill from %s because %s", dom, e)
+                    logger.info("Failed to get backfill from %s because %s", dom, e)
                     continue
                 except FederationDeniedError as e:
                     logger.info(e)
@@ -1507,18 +1510,9 @@ class FederationHandler(BaseHandler):
             event, context = await self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
-        except AuthError as e:
+        except SynapseError as e:
             logger.warning("Failed to create join to %s because %s", room_id, e)
-            raise e
-
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.info("Creation of join %s forbidden by third-party rules", event)
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
+            raise
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
@@ -1554,7 +1548,7 @@ class FederationHandler(BaseHandler):
         #
         # The reasons we have the destination server rather than the origin
         # server send it are slightly mysterious: the origin server should have
-        # all the neccessary state once it gets the response to the send_join,
+        # all the necessary state once it gets the response to the send_join,
         # so it could send the event itself if it wanted to. It may be that
         # doing it this way reduces failure modes, or avoids certain attacks
         # where a new server selectively tells a subset of the federation that
@@ -1567,15 +1561,6 @@ class FederationHandler(BaseHandler):
 
         context = await self._handle_new_event(origin, event)
 
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.info("Sending of join %s forbidden by third-party rules", event)
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
-
         logger.debug(
             "on_send_join_request: After _handle_new_event: %s, sigs: %s",
             event.event_id,
@@ -1635,7 +1620,7 @@ class FederationHandler(BaseHandler):
         # keep a record of the room version, if we don't yet know it.
         # (this may get overwritten if we later get a different room version in a
         # join dance).
-        await self._maybe_store_room_on_invite(
+        await self._maybe_store_room_on_outlier_membership(
             room_id=event.room_id, room_version=room_version
         )
 
@@ -1667,7 +1652,7 @@ class FederationHandler(BaseHandler):
         event.internal_metadata.outlier = True
         event.internal_metadata.out_of_band_membership = True
 
-        # Try the host that we succesfully called /make_leave/ on first for
+        # Try the host that we successfully called /make_leave/ on first for
         # the /send_leave/ request.
         host_list = list(target_hosts)
         try:
@@ -1748,15 +1733,6 @@ class FederationHandler(BaseHandler):
             builder=builder
         )
 
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.warning("Creation of leave %s forbidden by third-party rules", event)
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
-
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
@@ -1789,16 +1765,7 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context = await self._handle_new_event(origin, event)
-
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.info("Sending of leave %s forbidden by third-party rules", event)
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
+        await self._handle_new_event(origin, event)
 
         logger.debug(
             "on_send_leave_request: After _handle_new_event: %s, sigs: %s",
@@ -2694,18 +2661,6 @@ class FederationHandler(BaseHandler):
                 builder=builder
             )
 
-            event_allowed = await self.third_party_event_rules.check_event_allowed(
-                event, context
-            )
-            if not event_allowed:
-                logger.info(
-                    "Creation of threepid invite %s forbidden by third-party rules",
-                    event,
-                )
-                raise SynapseError(
-                    403, "This event is not allowed in this context", Codes.FORBIDDEN
-                )
-
             event, context = await self.add_display_name_to_third_party_invite(
                 room_version, event_dict, event, context
             )
@@ -2734,7 +2689,7 @@ class FederationHandler(BaseHandler):
             )
 
     async def on_exchange_third_party_invite_request(
-        self, room_id: str, event_dict: JsonDict
+        self, event_dict: JsonDict
     ) -> None:
         """Handle an exchange_third_party_invite request from a remote server
 
@@ -2742,12 +2697,11 @@ class FederationHandler(BaseHandler):
         into a normal m.room.member invite.
 
         Args:
-            room_id: The ID of the room.
-
-            event_dict (dict[str, Any]): Dictionary containing the event body.
+            event_dict: Dictionary containing the event body.
 
         """
-        room_version = await self.store.get_room_version_id(room_id)
+        assert_params_in_dict(event_dict, ["room_id"])
+        room_version = await self.store.get_room_version_id(event_dict["room_id"])
 
         # NB: event_dict has a particular specced format we might need to fudge
         # if we change event formats too much.
@@ -2756,18 +2710,6 @@ class FederationHandler(BaseHandler):
         event, context = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
-
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            logger.warning(
-                "Exchange of threepid invite %s forbidden by third-party rules", event
-            )
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
-
         event, context = await self.add_display_name_to_third_party_invite(
             room_version, event_dict, event, context
         )
@@ -2966,17 +2908,20 @@ class FederationHandler(BaseHandler):
             return result["max_stream_id"]
         else:
             assert self.storage.persistence
-            max_stream_token = await self.storage.persistence.persist_events(
+
+            # Note that this returns the events that were persisted, which may not be
+            # the same as were passed in if some were deduplicated due to transaction IDs.
+            events, max_stream_token = await self.storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
             if self._ephemeral_messages_enabled:
-                for (event, context) in event_and_contexts:
+                for event in events:
                     # If there's an expiry timestamp on the event, schedule its expiry.
                     self._message_handler.maybe_schedule_expiry(event)
 
             if not backfilled:  # Never notify for backfilled events
-                for event, _ in event_and_contexts:
+                for event in events:
                     await self._notify_persisted_event(event, max_stream_token)
 
             return max_stream_token.stream
@@ -3008,6 +2953,9 @@ class FederationHandler(BaseHandler):
         elif event.internal_metadata.is_outlier():
             return
 
+        # the event has been persisted so it should have a stream ordering.
+        assert event.internal_metadata.stream_ordering
+
         event_pos = PersistedEventPosition(
             self._instance_name, event.internal_metadata.stream_ordering
         )
diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py
index 489a7b885d..7c06cc529e 100644
--- a/synapse/handlers/groups_local.py
+++ b/synapse/handlers/groups_local.py
@@ -17,7 +17,7 @@
 import logging
 
 from synapse.api.errors import HttpResponseException, RequestSendFailed, SynapseError
-from synapse.types import get_domain_from_id
+from synapse.types import GroupID, get_domain_from_id
 
 logger = logging.getLogger(__name__)
 
@@ -28,6 +28,9 @@ def _create_rerouter(func_name):
     """
 
     async def f(self, group_id, *args, **kwargs):
+        if not GroupID.is_valid(group_id):
+            raise SynapseError(400, "%s was not legal group ID" % (group_id,))
+
         if self.is_mine_id(group_id):
             return await getattr(self.groups_server_handler, func_name)(
                 group_id, *args, **kwargs
@@ -346,7 +349,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
                 server_name=get_domain_from_id(group_id),
             )
 
-        # TODO: Check that the group is public and we're being added publically
+        # TODO: Check that the group is public and we're being added publicly
         is_publicised = content.get("publicise", False)
 
         token = await self.store.register_user_group_membership(
@@ -391,7 +394,7 @@ class GroupsLocalHandler(GroupsLocalWorkerHandler):
                 server_name=get_domain_from_id(group_id),
             )
 
-        # TODO: Check that the group is public and we're being added publically
+        # TODO: Check that the group is public and we're being added publicly
         is_publicised = content.get("publicise", False)
 
         token = await self.store.register_user_group_membership(
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index bc3e9607ca..9b3c6b4551 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -354,7 +354,8 @@ class IdentityHandler(BaseHandler):
             raise SynapseError(500, "An error was encountered when sending the email")
 
         token_expires = (
-            self.hs.clock.time_msec() + self.hs.config.email_validation_token_lifetime
+            self.hs.get_clock().time_msec()
+            + self.hs.config.email_validation_token_lifetime
         )
 
         await self.store.start_or_continue_validation_session(
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 39a85801c1..cb11754bf8 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Optional, Tuple
 
 from twisted.internet import defer
 
@@ -47,12 +47,14 @@ class InitialSyncHandler(BaseHandler):
         self.state = hs.get_state_handler()
         self.clock = hs.get_clock()
         self.validator = EventValidator()
-        self.snapshot_cache = ResponseCache(hs, "initial_sync_cache")
+        self.snapshot_cache = ResponseCache(
+            hs, "initial_sync_cache"
+        )  # type: ResponseCache[Tuple[str, Optional[StreamToken], Optional[StreamToken], str, Optional[int], bool, bool]]
         self._event_serializer = hs.get_event_client_serializer()
         self.storage = hs.get_storage()
         self.state_store = self.storage.state
 
-    def snapshot_all_rooms(
+    async def snapshot_all_rooms(
         self,
         user_id: str,
         pagin_config: PaginationConfig,
@@ -84,7 +86,7 @@ class InitialSyncHandler(BaseHandler):
             include_archived,
         )
 
-        return self.snapshot_cache.wrap(
+        return await self.snapshot_cache.wrap(
             key,
             self._snapshot_all_rooms,
             user_id,
@@ -291,6 +293,10 @@ class InitialSyncHandler(BaseHandler):
                 user_id, room_id, pagin_config, membership, is_peeking
             )
         elif membership == Membership.LEAVE:
+            # The member_event_id will always be available if membership is set
+            # to leave.
+            assert member_event_id
+
             result = await self._room_initial_sync_parted(
                 user_id, room_id, pagin_config, membership, member_event_id, is_peeking
             )
@@ -313,7 +319,7 @@ class InitialSyncHandler(BaseHandler):
         user_id: str,
         room_id: str,
         pagin_config: PaginationConfig,
-        membership: Membership,
+        membership: str,
         member_event_id: str,
         is_peeking: bool,
     ) -> JsonDict:
@@ -365,7 +371,7 @@ class InitialSyncHandler(BaseHandler):
         user_id: str,
         room_id: str,
         pagin_config: PaginationConfig,
-        membership: Membership,
+        membership: str,
         is_peeking: bool,
     ) -> JsonDict:
         current_state = await self.state.get_current_state(room_id=room_id)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index ee271e85e5..96843338ae 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -50,15 +50,15 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
-from synapse.util import json_decoder
+from synapse.util import json_decoder, json_encoder
 from synapse.util.async_helpers import Linearizer
-from synapse.util.frozenutils import frozendict_json_encoder
 from synapse.util.metrics import measure_func
 from synapse.visibility import filter_events_for_client
 
 from ._base import BaseHandler
 
 if TYPE_CHECKING:
+    from synapse.events.third_party_rules import ThirdPartyEventRules
     from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
@@ -393,27 +393,31 @@ class EventCreationHandler:
         self.action_generator = hs.get_action_generator()
 
         self.spam_checker = hs.get_spam_checker()
-        self.third_party_event_rules = hs.get_third_party_event_rules()
+        self.third_party_event_rules = (
+            self.hs.get_third_party_event_rules()
+        )  # type: ThirdPartyEventRules
 
         self._block_events_without_consent_error = (
             self.config.block_events_without_consent_error
         )
 
+        # we need to construct a ConsentURIBuilder here, as it checks that the necessary
+        # config options, but *only* if we have a configuration for which we are
+        # going to need it.
+        if self._block_events_without_consent_error:
+            self._consent_uri_builder = ConsentURIBuilder(self.config)
+
         # Rooms which should be excluded from dummy insertion. (For instance,
         # those without local users who can send events into the room).
         #
         # map from room id to time-of-last-attempt.
         #
         self._rooms_to_exclude_from_dummy_event_insertion = {}  # type: Dict[str, int]
-
-        # we need to construct a ConsentURIBuilder here, as it checks that the necessary
-        # config options, but *only* if we have a configuration for which we are
-        # going to need it.
-        if self._block_events_without_consent_error:
-            self._consent_uri_builder = ConsentURIBuilder(self.config)
+        # The number of forward extremeities before a dummy event is sent.
+        self._dummy_events_threshold = hs.config.dummy_events_threshold
 
         if (
-            not self.config.worker_app
+            self.config.run_background_tasks
             and self.config.cleanup_extremities_with_dummy_events
         ):
             self.clock.looping_call(
@@ -428,15 +432,13 @@ class EventCreationHandler:
 
         self._ephemeral_events_enabled = hs.config.enable_ephemeral_messages
 
-        self._dummy_events_threshold = hs.config.dummy_events_threshold
-
     async def create_event(
         self,
         requester: Requester,
         event_dict: dict,
-        token_id: Optional[str] = None,
         txn_id: Optional[str] = None,
         prev_event_ids: Optional[List[str]] = None,
+        auth_event_ids: Optional[List[str]] = None,
         require_consent: bool = True,
     ) -> Tuple[EventBase, EventContext]:
         """
@@ -450,13 +452,18 @@ class EventCreationHandler:
         Args:
             requester
             event_dict: An entire event
-            token_id
             txn_id
             prev_event_ids:
                 the forward extremities to use as the prev_events for the
                 new event.
 
                 If None, they will be requested from the database.
+
+            auth_event_ids:
+                The event ids to use as the auth_events for the new event.
+                Should normally be left as None, which will cause them to be calculated
+                based on the room state at the prev_events.
+
             require_consent: Whether to check if the requester has
                 consented to the privacy policy.
         Raises:
@@ -465,7 +472,7 @@ class EventCreationHandler:
         Returns:
             Tuple of created event, Context
         """
-        await self.auth.check_auth_blocking(requester.user.to_string())
+        await self.auth.check_auth_blocking(requester=requester)
 
         if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
             room_version = event_dict["content"]["room_version"]
@@ -508,14 +515,17 @@ class EventCreationHandler:
         if require_consent and not is_exempt:
             await self.assert_accepted_privacy_policy(requester)
 
-        if token_id is not None:
-            builder.internal_metadata.token_id = token_id
+        if requester.access_token_id is not None:
+            builder.internal_metadata.token_id = requester.access_token_id
 
         if txn_id is not None:
             builder.internal_metadata.txn_id = txn_id
 
         event, context = await self.create_new_client_event(
-            builder=builder, requester=requester, prev_event_ids=prev_event_ids,
+            builder=builder,
+            requester=requester,
+            prev_event_ids=prev_event_ids,
+            auth_event_ids=auth_event_ids,
         )
 
         # In an ideal world we wouldn't need the second part of this condition. However,
@@ -609,7 +619,13 @@ class EventCreationHandler:
         if requester.app_service is not None:
             return
 
-        user_id = requester.user.to_string()
+        user_id = requester.authenticated_entity
+        if not user_id.startswith("@"):
+            # The authenticated entity might not be a user, e.g. if it's the
+            # server puppetting the user.
+            return
+
+        user = UserID.from_string(user_id)
 
         # exempt the system notices user
         if (
@@ -629,65 +645,10 @@ class EventCreationHandler:
         if u["consent_version"] == self.config.user_consent_version:
             return
 
-        consent_uri = self._consent_uri_builder.build_user_consent_uri(
-            requester.user.localpart
-        )
+        consent_uri = self._consent_uri_builder.build_user_consent_uri(user.localpart)
         msg = self._block_events_without_consent_error % {"consent_uri": consent_uri}
         raise ConsentNotGivenError(msg=msg, consent_uri=consent_uri)
 
-    async def send_nonmember_event(
-        self,
-        requester: Requester,
-        event: EventBase,
-        context: EventContext,
-        ratelimit: bool = True,
-        ignore_shadow_ban: bool = False,
-    ) -> int:
-        """
-        Persists and notifies local clients and federation of an event.
-
-        Args:
-            requester: The requester sending the event.
-            event: The event to send.
-            context: The context of the event.
-            ratelimit: Whether to rate limit this send.
-            ignore_shadow_ban: True if shadow-banned users should be allowed to
-                send this event.
-
-        Return:
-            The stream_id of the persisted event.
-
-        Raises:
-            ShadowBanError if the requester has been shadow-banned.
-        """
-        if event.type == EventTypes.Member:
-            raise SynapseError(
-                500, "Tried to send member event through non-member codepath"
-            )
-
-        if not ignore_shadow_ban and requester.shadow_banned:
-            # We randomly sleep a bit just to annoy the requester.
-            await self.clock.sleep(random.randint(1, 10))
-            raise ShadowBanError()
-
-        user = UserID.from_string(event.sender)
-
-        assert self.hs.is_mine(user), "User must be our own: %s" % (user,)
-
-        if event.is_state():
-            prev_event = await self.deduplicate_state_event(event, context)
-            if prev_event is not None:
-                logger.info(
-                    "Not bothering to persist state event %s duplicated by %s",
-                    event.event_id,
-                    prev_event.event_id,
-                )
-                return await self.store.get_stream_id_for_event(prev_event.event_id)
-
-        return await self.handle_new_client_event(
-            requester=requester, event=event, context=context, ratelimit=ratelimit
-        )
-
     async def deduplicate_state_event(
         self, event: EventBase, context: EventContext
     ) -> Optional[EventBase]:
@@ -699,7 +660,7 @@ class EventCreationHandler:
             context: The event context.
 
         Returns:
-            The previous verion of the event is returned, if it is found in the
+            The previous version of the event is returned, if it is found in the
             event context. Otherwise, None is returned.
         """
         prev_state_ids = await context.get_prev_state_ids()
@@ -728,7 +689,7 @@ class EventCreationHandler:
         """
         Creates an event, then sends it.
 
-        See self.create_event and self.send_nonmember_event.
+        See self.create_event and self.handle_new_client_event.
 
         Args:
             requester: The requester sending the event.
@@ -738,9 +699,19 @@ class EventCreationHandler:
             ignore_shadow_ban: True if shadow-banned users should be allowed to
                 send this event.
 
+        Returns:
+            The event, and its stream ordering (if deduplication happened,
+            the previous, duplicate event).
+
         Raises:
             ShadowBanError if the requester has been shadow-banned.
         """
+
+        if event_dict["type"] == EventTypes.Member:
+            raise SynapseError(
+                500, "Tried to send member event through non-member codepath"
+            )
+
         if not ignore_shadow_ban and requester.shadow_banned:
             # We randomly sleep a bit just to annoy the requester.
             await self.clock.sleep(random.randint(1, 10))
@@ -752,8 +723,25 @@ class EventCreationHandler:
         # extremities to pile up, which in turn leads to state resolution
         # taking longer.
         with (await self.limiter.queue(event_dict["room_id"])):
+            if txn_id and requester.access_token_id:
+                existing_event_id = await self.store.get_event_id_from_transaction_id(
+                    event_dict["room_id"],
+                    requester.user.to_string(),
+                    requester.access_token_id,
+                    txn_id,
+                )
+                if existing_event_id:
+                    event = await self.store.get_event(existing_event_id)
+                    # we know it was persisted, so must have a stream ordering
+                    assert event.internal_metadata.stream_ordering
+                    return event, event.internal_metadata.stream_ordering
+
             event, context = await self.create_event(
-                requester, event_dict, token_id=requester.access_token_id, txn_id=txn_id
+                requester, event_dict, txn_id=txn_id
+            )
+
+            assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
+                event.sender,
             )
 
             spam_error = self.spam_checker.check_event_for_spam(event)
@@ -762,14 +750,17 @@ class EventCreationHandler:
                     spam_error = "Spam is not permitted here"
                 raise SynapseError(403, spam_error, Codes.FORBIDDEN)
 
-            stream_id = await self.send_nonmember_event(
-                requester,
-                event,
-                context,
+            ev = await self.handle_new_client_event(
+                requester=requester,
+                event=event,
+                context=context,
                 ratelimit=ratelimit,
                 ignore_shadow_ban=ignore_shadow_ban,
             )
-        return event, stream_id
+
+        # we know it was persisted, so must have a stream ordering
+        assert ev.internal_metadata.stream_ordering
+        return ev, ev.internal_metadata.stream_ordering
 
     @measure_func("create_new_client_event")
     async def create_new_client_event(
@@ -777,6 +768,7 @@ class EventCreationHandler:
         builder: EventBuilder,
         requester: Optional[Requester] = None,
         prev_event_ids: Optional[List[str]] = None,
+        auth_event_ids: Optional[List[str]] = None,
     ) -> Tuple[EventBase, EventContext]:
         """Create a new event for a local client
 
@@ -789,6 +781,11 @@ class EventCreationHandler:
 
                 If None, they will be requested from the database.
 
+            auth_event_ids:
+                The event ids to use as the auth_events for the new event.
+                Should normally be left as None, which will cause them to be calculated
+                based on the room state at the prev_events.
+
         Returns:
             Tuple of created event, context
         """
@@ -810,11 +807,30 @@ class EventCreationHandler:
             builder.type == EventTypes.Create or len(prev_event_ids) > 0
         ), "Attempting to create an event with no prev_events"
 
-        event = await builder.build(prev_event_ids=prev_event_ids)
+        event = await builder.build(
+            prev_event_ids=prev_event_ids, auth_event_ids=auth_event_ids
+        )
         context = await self.state.compute_event_context(event)
         if requester:
             context.app_service = requester.app_service
 
+        third_party_result = await self.third_party_event_rules.check_event_allowed(
+            event, context
+        )
+        if not third_party_result:
+            logger.info(
+                "Event %s forbidden by third-party rules", event,
+            )
+            raise SynapseError(
+                403, "This event is not allowed in this context", Codes.FORBIDDEN
+            )
+        elif isinstance(third_party_result, dict):
+            # the third-party rules want to replace the event. We'll need to build a new
+            # event.
+            event, context = await self._rebuild_event_after_third_party_rules(
+                third_party_result, event
+            )
+
         self.validator.validate_new(event, self.config)
 
         # If this event is an annotation then we check that that the sender
@@ -843,8 +859,11 @@ class EventCreationHandler:
         context: EventContext,
         ratelimit: bool = True,
         extra_users: List[UserID] = [],
-    ) -> int:
-        """Processes a new event. This includes checking auth, persisting it,
+        ignore_shadow_ban: bool = False,
+    ) -> EventBase:
+        """Processes a new event.
+
+        This includes deduplicating, checking auth, persisting,
         notifying users, sending to remote servers, etc.
 
         If called from a worker will hit out to the master process for final
@@ -857,10 +876,39 @@ class EventCreationHandler:
             ratelimit
             extra_users: Any extra users to notify about event
 
+            ignore_shadow_ban: True if shadow-banned users should be allowed to
+                send this event.
+
         Return:
-            The stream_id of the persisted event.
+            If the event was deduplicated, the previous, duplicate, event. Otherwise,
+            `event`.
+
+        Raises:
+            ShadowBanError if the requester has been shadow-banned.
         """
 
+        # we don't apply shadow-banning to membership events here. Invites are blocked
+        # higher up the stack, and we allow shadow-banned users to send join and leave
+        # events as normal.
+        if (
+            event.type != EventTypes.Member
+            and not ignore_shadow_ban
+            and requester.shadow_banned
+        ):
+            # We randomly sleep a bit just to annoy the requester.
+            await self.clock.sleep(random.randint(1, 10))
+            raise ShadowBanError()
+
+        if event.is_state():
+            prev_event = await self.deduplicate_state_event(event, context)
+            if prev_event is not None:
+                logger.info(
+                    "Not bothering to persist state event %s duplicated by %s",
+                    event.event_id,
+                    prev_event.event_id,
+                )
+                return prev_event
+
         if event.is_state() and (event.type, event.state_key) == (
             EventTypes.Create,
             "",
@@ -869,14 +917,6 @@ class EventCreationHandler:
         else:
             room_version = await self.store.get_room_version_id(event.room_id)
 
-        event_allowed = await self.third_party_event_rules.check_event_allowed(
-            event, context
-        )
-        if not event_allowed:
-            raise SynapseError(
-                403, "This event is not allowed in this context", Codes.FORBIDDEN
-            )
-
         if event.internal_metadata.is_out_of_band_membership():
             # the only sort of out-of-band-membership events we expect to see here
             # are invite rejections we have generated ourselves.
@@ -891,7 +931,7 @@ class EventCreationHandler:
 
         # Ensure that we can round trip before trying to persist in db
         try:
-            dump = frozendict_json_encoder.encode(event.content)
+            dump = json_encoder.encode(event.content)
             json_decoder.decode(dump)
         except Exception:
             logger.exception("Failed to encode content: %r", event.content)
@@ -914,14 +954,24 @@ class EventCreationHandler:
                     extra_users=extra_users,
                 )
                 stream_id = result["stream_id"]
-                event.internal_metadata.stream_ordering = stream_id
-                return stream_id
-
-            stream_id = await self.persist_and_notify_client_event(
+                event_id = result["event_id"]
+                if event_id != event.event_id:
+                    # If we get a different event back then it means that its
+                    # been de-duplicated, so we replace the given event with the
+                    # one already persisted.
+                    event = await self.store.get_event(event_id)
+                else:
+                    # If we newly persisted the event then we need to update its
+                    # stream_ordering entry manually (as it was persisted on
+                    # another worker).
+                    event.internal_metadata.stream_ordering = stream_id
+                return event
+
+            event = await self.persist_and_notify_client_event(
                 requester, event, context, ratelimit=ratelimit, extra_users=extra_users
             )
 
-            return stream_id
+            return event
         except Exception:
             # Ensure that we actually remove the entries in the push actions
             # staging area, if we calculated them.
@@ -966,11 +1016,16 @@ class EventCreationHandler:
         context: EventContext,
         ratelimit: bool = True,
         extra_users: List[UserID] = [],
-    ) -> int:
+    ) -> EventBase:
         """Called when we have fully built the event, have already
         calculated the push actions for the event, and checked auth.
 
         This should only be run on the instance in charge of persisting events.
+
+        Returns:
+            The persisted event. This may be different than the given event if
+            it was de-duplicated (e.g. because we had already persisted an
+            event with the same transaction ID.)
         """
         assert self.storage.persistence is not None
         assert self._events_shard_config.should_handle(
@@ -1018,7 +1073,7 @@ class EventCreationHandler:
 
             # Check the alias is currently valid (if it has changed).
             room_alias_str = event.content.get("alias", None)
-            directory_handler = self.hs.get_handlers().directory_handler
+            directory_handler = self.hs.get_directory_handler()
             if room_alias_str and room_alias_str != original_alias:
                 await self._validate_canonical_alias(
                     directory_handler, room_alias_str, event.room_id
@@ -1044,38 +1099,17 @@ class EventCreationHandler:
                         directory_handler, alias_str, event.room_id
                     )
 
-        federation_handler = self.hs.get_handlers().federation_handler
+        federation_handler = self.hs.get_federation_handler()
 
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.INVITE:
-
-                def is_inviter_member_event(e):
-                    return e.type == EventTypes.Member and e.sender == event.sender
-
-                current_state_ids = await context.get_current_state_ids()
-
-                # We know this event is not an outlier, so this must be
-                # non-None.
-                assert current_state_ids is not None
-
-                state_to_include_ids = [
-                    e_id
-                    for k, e_id in current_state_ids.items()
-                    if k[0] in self.room_invite_state_types
-                    or k == (EventTypes.Member, event.sender)
-                ]
-
-                state_to_include = await self.store.get_events(state_to_include_ids)
-
-                event.unsigned["invite_room_state"] = [
-                    {
-                        "type": e.type,
-                        "state_key": e.state_key,
-                        "content": e.content,
-                        "sender": e.sender,
-                    }
-                    for e in state_to_include.values()
-                ]
+                event.unsigned[
+                    "invite_room_state"
+                ] = await self.store.get_stripped_room_state_from_event_context(
+                    context,
+                    self.room_invite_state_types,
+                    membership_user_id=event.sender,
+                )
 
                 invitee = UserID.from_string(event.state_key)
                 if not self.hs.is_mine(invitee):
@@ -1108,6 +1142,9 @@ class EventCreationHandler:
                 if original_event.room_id != event.room_id:
                     raise SynapseError(400, "Cannot redact event from a different room")
 
+                if original_event.type == EventTypes.ServerACL:
+                    raise AuthError(403, "Redacting server ACL events is not permitted")
+
             prev_state_ids = await context.get_prev_state_ids()
             auth_events_ids = self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
@@ -1138,9 +1175,13 @@ class EventCreationHandler:
             if prev_state_ids:
                 raise AuthError(403, "Changing the room create event is forbidden")
 
-        event_pos, max_stream_token = await self.storage.persistence.persist_event(
-            event, context=context
-        )
+        # Note that this returns the event that was persisted, which may not be
+        # the same as we passed in if it was deduplicated due transaction IDs.
+        (
+            event,
+            event_pos,
+            max_stream_token,
+        ) = await self.storage.persistence.persist_event(event, context=context)
 
         if self._ephemeral_events_enabled:
             # If there's an expiry timestamp on the event, schedule its expiry.
@@ -1161,7 +1202,7 @@ class EventCreationHandler:
             # matters as sometimes presence code can take a while.
             run_in_background(self._bump_active_time, requester.user)
 
-        return event_pos.stream
+        return event
 
     async def _bump_active_time(self, user: UserID) -> None:
         try:
@@ -1215,7 +1256,7 @@ class EventCreationHandler:
         for user_id in members:
             if not self.hs.is_mine_id(user_id):
                 continue
-            requester = create_requester(user_id)
+            requester = create_requester(user_id, authenticated_entity=self.server_name)
             try:
                 event, context = await self.create_event(
                     requester,
@@ -1232,15 +1273,10 @@ class EventCreationHandler:
 
                 # Since this is a dummy-event it is OK if it is sent by a
                 # shadow-banned user.
-                await self.send_nonmember_event(
+                await self.handle_new_client_event(
                     requester, event, context, ratelimit=False, ignore_shadow_ban=True,
                 )
                 return True
-            except ConsentNotGivenError:
-                logger.info(
-                    "Failed to send dummy event into room %s for user %s due to "
-                    "lack of consent. Will try another user" % (room_id, user_id)
-                )
             except AuthError:
                 logger.info(
                     "Failed to send dummy event into room %s for user %s due to "
@@ -1260,3 +1296,62 @@ class EventCreationHandler:
                 room_id,
             )
             del self._rooms_to_exclude_from_dummy_event_insertion[room_id]
+
+    async def _rebuild_event_after_third_party_rules(
+        self, third_party_result: dict, original_event: EventBase
+    ) -> Tuple[EventBase, EventContext]:
+        # the third_party_event_rules want to replace the event.
+        # we do some basic checks, and then return the replacement event and context.
+
+        # Construct a new EventBuilder and validate it, which helps with the
+        # rest of these checks.
+        try:
+            builder = self.event_builder_factory.for_room_version(
+                original_event.room_version, third_party_result
+            )
+            self.validator.validate_builder(builder)
+        except SynapseError as e:
+            raise Exception(
+                "Third party rules module created an invalid event: " + e.msg,
+            )
+
+        immutable_fields = [
+            # changing the room is going to break things: we've already checked that the
+            # room exists, and are holding a concurrency limiter token for that room.
+            # Also, we might need to use a different room version.
+            "room_id",
+            # changing the type or state key might work, but we'd need to check that the
+            # calling functions aren't making assumptions about them.
+            "type",
+            "state_key",
+        ]
+
+        for k in immutable_fields:
+            if getattr(builder, k, None) != original_event.get(k):
+                raise Exception(
+                    "Third party rules module created an invalid event: "
+                    "cannot change field " + k
+                )
+
+        # check that the new sender belongs to this HS
+        if not self.hs.is_mine_id(builder.sender):
+            raise Exception(
+                "Third party rules module created an invalid event: "
+                "invalid sender " + builder.sender
+            )
+
+        # copy over the original internal metadata
+        for k, v in original_event.internal_metadata.get_dict().items():
+            setattr(builder.internal_metadata, k, v)
+
+        # the event type hasn't changed, so there's no point in re-calculating the
+        # auth events.
+        event = await builder.build(
+            prev_event_ids=original_event.prev_event_ids(),
+            auth_event_ids=original_event.auth_event_ids(),
+        )
+
+        # we rebuild the event context, to be on the safe side. If nothing else,
+        # delta_ids might need an update.
+        context = await self.state.compute_event_context(event)
+        return event, context
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index 19cd652675..c605f7082a 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import inspect
 import logging
 from typing import TYPE_CHECKING, Dict, Generic, List, Optional, Tuple, TypeVar
 from urllib.parse import urlencode
@@ -34,7 +35,8 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 
 from synapse.config import ConfigError
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -83,19 +85,15 @@ class OidcError(Exception):
         return self.error
 
 
-class MappingException(Exception):
-    """Used to catch errors when mapping the UserInfo object
-    """
-
-
-class OidcHandler:
+class OidcHandler(BaseHandler):
     """Handles requests related to the OpenID Connect login flow.
     """
 
     def __init__(self, hs: "HomeServer"):
-        self.hs = hs
+        super().__init__(hs)
         self._callback_url = hs.config.oidc_callback_url  # type: str
         self._scopes = hs.config.oidc_scopes  # type: List[str]
+        self._user_profile_method = hs.config.oidc_user_profile_method  # type: str
         self._client_auth = ClientAuth(
             hs.config.oidc_client_id,
             hs.config.oidc_client_secret,
@@ -119,36 +117,13 @@ class OidcHandler:
         self._http_client = hs.get_proxied_http_client()
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
-        self._datastore = hs.get_datastore()
-        self._clock = hs.get_clock()
-        self._hostname = hs.hostname  # type: str
         self._server_name = hs.config.server_name  # type: str
         self._macaroon_secret_key = hs.config.macaroon_secret_key
-        self._error_template = hs.config.sso_error_template
 
         # identifier for the external_ids table
         self._auth_provider_id = "oidc"
 
-    def _render_error(
-        self, request, error: str, error_description: Optional[str] = None
-    ) -> None:
-        """Render the error template and respond to the request with it.
-
-        This is used to show errors to the user. The template of this page can
-        be found under `synapse/res/templates/sso_error.html`.
-
-        Args:
-            request: The incoming request from the browser.
-                We'll respond with an HTML page describing the error.
-            error: A technical identifier for this error. Those include
-                well-known OAuth2/OIDC error types like invalid_request or
-                access_denied.
-            error_description: A human-readable description of the error.
-        """
-        html = self._error_template.render(
-            error=error, error_description=error_description
-        )
-        respond_with_html(request, 400, html)
+        self._sso_handler = hs.get_sso_handler()
 
     def _validate_metadata(self):
         """Verifies the provider metadata.
@@ -196,11 +171,11 @@ class OidcHandler:
                     % (m["response_types_supported"],)
                 )
 
-        # If the openid scope was not requested, we need a userinfo endpoint to fetch user infos
+        # Ensure there's a userinfo endpoint to fetch from if it is required.
         if self._uses_userinfo:
             if m.get("userinfo_endpoint") is None:
                 raise ValueError(
-                    'provider has no "userinfo_endpoint", even though it is required because the "openid" scope is not requested'
+                    'provider has no "userinfo_endpoint", even though it is required'
                 )
         else:
             # If we're not using userinfo, we need a valid jwks to validate the ID token
@@ -216,12 +191,14 @@ class OidcHandler:
 
         This is based on the requested scopes: if the scopes include
         ``openid``, the provider should give use an ID token containing the
-        user informations. If not, we should fetch them using the
+        user information. If not, we should fetch them using the
         ``access_token`` with the ``userinfo_endpoint``.
         """
 
-        # Maybe that should be user-configurable and not inferred?
-        return "openid" not in self._scopes
+        return (
+            "openid" not in self._scopes
+            or self._user_profile_method == "userinfo_endpoint"
+        )
 
     async def load_metadata(self) -> OpenIDProviderMetadata:
         """Load and validate the provider metadata.
@@ -423,7 +400,7 @@ class OidcHandler:
         return resp
 
     async def _fetch_userinfo(self, token: Token) -> UserInfo:
-        """Fetch user informations from the ``userinfo_endpoint``.
+        """Fetch user information from the ``userinfo_endpoint``.
 
         Args:
             token: the token given by the ``token_endpoint``.
@@ -568,7 +545,7 @@ class OidcHandler:
 
         Since we might want to display OIDC-related errors in a user-friendly
         way, we don't raise SynapseError from here. Instead, we call
-        ``self._render_error`` which displays an HTML page for the error.
+        ``self._sso_handler.render_error`` which displays an HTML page for the error.
 
         Most of the OpenID Connect logic happens here:
 
@@ -606,7 +583,7 @@ class OidcHandler:
             if error != "access_denied":
                 logger.error("Error from the OIDC provider: %s %s", error, description)
 
-            self._render_error(request, error, description)
+            self._sso_handler.render_error(request, error, description)
             return
 
         # otherwise, it is presumably a successful response. see:
@@ -616,7 +593,9 @@ class OidcHandler:
         session = request.getCookie(SESSION_COOKIE_NAME)  # type: Optional[bytes]
         if session is None:
             logger.info("No session cookie found")
-            self._render_error(request, "missing_session", "No session cookie found")
+            self._sso_handler.render_error(
+                request, "missing_session", "No session cookie found"
+            )
             return
 
         # Remove the cookie. There is a good chance that if the callback failed
@@ -634,7 +613,9 @@ class OidcHandler:
         # Check for the state query parameter
         if b"state" not in request.args:
             logger.info("State parameter is missing")
-            self._render_error(request, "invalid_request", "State parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "State parameter is missing"
+            )
             return
 
         state = request.args[b"state"][0].decode()
@@ -648,17 +629,19 @@ class OidcHandler:
             ) = self._verify_oidc_session_token(session, state)
         except MacaroonDeserializationException as e:
             logger.exception("Invalid session")
-            self._render_error(request, "invalid_session", str(e))
+            self._sso_handler.render_error(request, "invalid_session", str(e))
             return
         except MacaroonInvalidSignatureException as e:
             logger.exception("Could not verify session")
-            self._render_error(request, "mismatching_session", str(e))
+            self._sso_handler.render_error(request, "mismatching_session", str(e))
             return
 
         # Exchange the code with the provider
         if b"code" not in request.args:
             logger.info("Code parameter is missing")
-            self._render_error(request, "invalid_request", "Code parameter is missing")
+            self._sso_handler.render_error(
+                request, "invalid_request", "Code parameter is missing"
+            )
             return
 
         logger.debug("Exchanging code")
@@ -667,7 +650,7 @@ class OidcHandler:
             token = await self._exchange_code(code)
         except OidcError as e:
             logger.exception("Could not exchange code")
-            self._render_error(request, e.error, e.error_description)
+            self._sso_handler.render_error(request, e.error, e.error_description)
             return
 
         logger.debug("Successfully obtained OAuth2 access token")
@@ -680,7 +663,7 @@ class OidcHandler:
                 userinfo = await self._fetch_userinfo(token)
             except Exception as e:
                 logger.exception("Could not fetch userinfo")
-                self._render_error(request, "fetch_error", str(e))
+                self._sso_handler.render_error(request, "fetch_error", str(e))
                 return
         else:
             logger.debug("Extracting userinfo from id_token")
@@ -688,13 +671,11 @@ class OidcHandler:
                 userinfo = await self._parse_id_token(token, nonce=nonce)
             except Exception as e:
                 logger.exception("Invalid id_token")
-                self._render_error(request, "invalid_token", str(e))
+                self._sso_handler.render_error(request, "invalid_token", str(e))
                 return
 
         # Pull out the user-agent and IP from the request.
-        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
-            0
-        ].decode("ascii", "surrogateescape")
+        user_agent = request.get_user_agent("")
         ip_address = self.hs.get_ip_from_request(request)
 
         # Call the mapper to register/login the user
@@ -704,7 +685,7 @@ class OidcHandler:
             )
         except MappingException as e:
             logger.exception("Could not map user")
-            self._render_error(request, "mapping_error", str(e))
+            self._sso_handler.render_error(request, "mapping_error", str(e))
             return
 
         # Mapping providers might not have get_extra_attributes: only call this
@@ -753,7 +734,7 @@ class OidcHandler:
                 Defaults to an hour.
 
         Returns:
-            A signed macaroon token with the session informations.
+            A signed macaroon token with the session information.
         """
         macaroon = pymacaroons.Macaroon(
             location=self._server_name, identifier="key", key=self._macaroon_secret_key,
@@ -769,7 +750,7 @@ class OidcHandler:
             macaroon.add_first_party_caveat(
                 "ui_auth_session_id = %s" % (ui_auth_session_id,)
             )
-        now = self._clock.time_msec()
+        now = self.clock.time_msec()
         expiry = now + duration_in_ms
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
 
@@ -844,7 +825,7 @@ class OidcHandler:
         if not caveat.startswith(prefix):
             return False
         expiry = int(caveat[len(prefix) :])
-        now = self._clock.time_msec()
+        now = self.clock.time_msec()
         return now < expiry
 
     async def _map_userinfo_to_user(
@@ -884,71 +865,77 @@ class OidcHandler:
         # to be strings.
         remote_user_id = str(remote_user_id)
 
-        logger.info(
-            "Looking for existing mapping for user %s:%s",
-            self._auth_provider_id,
-            remote_user_id,
-        )
-
-        registered_user_id = await self._datastore.get_user_by_external_id(
-            self._auth_provider_id, remote_user_id,
+        # Older mapping providers don't accept the `failures` argument, so we
+        # try and detect support.
+        mapper_signature = inspect.signature(
+            self._user_mapping_provider.map_user_attributes
         )
+        supports_failures = "failures" in mapper_signature.parameters
 
-        if registered_user_id is not None:
-            logger.info("Found existing mapping %s", registered_user_id)
-            return registered_user_id
+        async def oidc_response_to_user_attributes(failures: int) -> UserAttributes:
+            """
+            Call the mapping provider to map the OIDC userinfo and token to user attributes.
 
-        try:
-            attributes = await self._user_mapping_provider.map_user_attributes(
-                userinfo, token
-            )
-        except Exception as e:
-            raise MappingException(
-                "Could not extract user attributes from OIDC response: " + str(e)
-            )
-
-        logger.debug(
-            "Retrieved user attributes from user mapping provider: %r", attributes
-        )
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            if supports_failures:
+                attributes = await self._user_mapping_provider.map_user_attributes(
+                    userinfo, token, failures
+                )
+            else:
+                # If the mapping provider does not support processing failures,
+                # do not continually generate the same Matrix ID since it will
+                # continue to already be in use. Note that the error raised is
+                # arbitrary and will get turned into a MappingException.
+                if failures:
+                    raise MappingException(
+                        "Mapping provider does not support de-duplicating Matrix IDs"
+                    )
 
-        if not attributes["localpart"]:
-            raise MappingException("localpart is empty")
+                attributes = await self._user_mapping_provider.map_user_attributes(  # type: ignore
+                    userinfo, token
+                )
 
-        localpart = map_username_to_mxid_localpart(attributes["localpart"])
+            return UserAttributes(**attributes)
 
-        user_id = UserID(localpart, self._hostname).to_string()
-        users = await self._datastore.get_users_by_id_case_insensitive(user_id)
-        if users:
+        async def grandfather_existing_users() -> Optional[str]:
             if self._allow_existing_users:
-                if len(users) == 1:
-                    registered_user_id = next(iter(users))
-                elif user_id in users:
-                    registered_user_id = user_id
-                else:
-                    raise MappingException(
-                        "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
-                            user_id, list(users.keys())
+                # If allowing existing users we want to generate a single localpart
+                # and attempt to match it.
+                attributes = await oidc_response_to_user_attributes(failures=0)
+
+                user_id = UserID(attributes.localpart, self.server_name).to_string()
+                users = await self.store.get_users_by_id_case_insensitive(user_id)
+                if users:
+                    # If an existing matrix ID is returned, then use it.
+                    if len(users) == 1:
+                        previously_registered_user_id = next(iter(users))
+                    elif user_id in users:
+                        previously_registered_user_id = user_id
+                    else:
+                        # Do not attempt to continue generating Matrix IDs.
+                        raise MappingException(
+                            "Attempted to login as '{}' but it matches more than one user inexactly: {}".format(
+                                user_id, users
+                            )
                         )
-                    )
-            else:
-                # This mxid is taken
-                raise MappingException("mxid '{}' is already taken".format(user_id))
-        else:
-            # It's the first time this user is logging in and the mapped mxid was
-            # not taken, register the user
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=attributes["display_name"],
-                user_agent_ips=(user_agent, ip_address),
-            )
-        await self._datastore.record_user_external_id(
-            self._auth_provider_id, remote_user_id, registered_user_id,
+
+                    return previously_registered_user_id
+
+            return None
+
+        return await self._sso_handler.get_mxid_from_sso(
+            self._auth_provider_id,
+            remote_user_id,
+            user_agent,
+            ip_address,
+            oidc_response_to_user_attributes,
+            grandfather_existing_users,
         )
-        return registered_user_id
 
 
-UserAttribute = TypedDict(
-    "UserAttribute", {"localpart": str, "display_name": Optional[str]}
+UserAttributeDict = TypedDict(
+    "UserAttributeDict", {"localpart": str, "display_name": Optional[str]}
 )
 C = TypeVar("C")
 
@@ -991,13 +978,15 @@ class OidcMappingProvider(Generic[C]):
         raise NotImplementedError()
 
     async def map_user_attributes(
-        self, userinfo: UserInfo, token: Token
-    ) -> UserAttribute:
+        self, userinfo: UserInfo, token: Token, failures: int
+    ) -> UserAttributeDict:
         """Map a `UserInfo` object into user attributes.
 
         Args:
             userinfo: An object representing the user given by the OIDC provider
             token: A dict with the tokens returned by the provider
+            failures: How many times a call to this function with this
+                UserInfo has resulted in a failure.
 
         Returns:
             A dict containing the ``localpart`` and (optionally) the ``display_name``
@@ -1097,10 +1086,17 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         return userinfo[self._config.subject_claim]
 
     async def map_user_attributes(
-        self, userinfo: UserInfo, token: Token
-    ) -> UserAttribute:
+        self, userinfo: UserInfo, token: Token, failures: int
+    ) -> UserAttributeDict:
         localpart = self._config.localpart_template.render(user=userinfo).strip()
 
+        # Ensure only valid characters are included in the MXID.
+        localpart = map_username_to_mxid_localpart(localpart)
+
+        # Append suffix integer if last call to this function failed to produce
+        # a usable mxid.
+        localpart += str(failures) if failures else ""
+
         display_name = None  # type: Optional[str]
         if self._config.display_name_template is not None:
             display_name = self._config.display_name_template.render(
@@ -1110,7 +1106,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
             if display_name == "":
                 display_name = None
 
-        return UserAttribute(localpart=localpart, display_name=display_name)
+        return UserAttributeDict(localpart=localpart, display_name=display_name)
 
     async def get_extra_attributes(self, userinfo: UserInfo, token: Token) -> JsonDict:
         extras = {}  # type: Dict[str, str]
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 2c2a633938..5372753707 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -92,7 +92,7 @@ class PaginationHandler:
         self._retention_allowed_lifetime_min = hs.config.retention_allowed_lifetime_min
         self._retention_allowed_lifetime_max = hs.config.retention_allowed_lifetime_max
 
-        if hs.config.retention_enabled:
+        if hs.config.run_background_tasks and hs.config.retention_enabled:
             # Run the purge jobs described in the configuration file.
             for job in hs.config.retention_purge_jobs:
                 logger.info("Setting up purge job with config: %s", job)
@@ -299,17 +299,22 @@ class PaginationHandler:
         """
         return self._purges_by_id.get(purge_id)
 
-    async def purge_room(self, room_id: str) -> None:
-        """Purge the given room from the database"""
+    async def purge_room(self, room_id: str, force: bool = False) -> None:
+        """Purge the given room from the database.
+
+        Args:
+            room_id: room to be purged
+            force: set true to skip checking for joined users.
+        """
         with await self.pagination_lock.write(room_id):
             # check we know about the room
             await self.store.get_room_version_id(room_id)
 
             # first check that we have no users in this room
-            joined = await self.store.is_host_joined(room_id, self._server_name)
-
-            if joined:
-                raise SynapseError(400, "Users are still joined to this room")
+            if not force:
+                joined = await self.store.is_host_joined(room_id, self._server_name)
+                if joined:
+                    raise SynapseError(400, "Users are still joined to this room")
 
             await self.storage.purge_events.purge_room(room_id)
 
@@ -383,7 +388,7 @@ class PaginationHandler:
                             "room_key", leave_token
                         )
 
-                await self.hs.get_handlers().federation_handler.maybe_backfill(
+                await self.hs.get_federation_handler().maybe_backfill(
                     room_id, curr_topo, limit=pagin_config.limit,
                 )
 
diff --git a/synapse/handlers/password_policy.py b/synapse/handlers/password_policy.py
index 88e2f87200..6c635cc31b 100644
--- a/synapse/handlers/password_policy.py
+++ b/synapse/handlers/password_policy.py
@@ -16,14 +16,18 @@
 
 import logging
 import re
+from typing import TYPE_CHECKING
 
 from synapse.api.errors import Codes, PasswordRefusedError
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class PasswordPolicyHandler:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.policy = hs.config.password_policy
         self.enabled = hs.config.password_policy_enabled
 
@@ -33,11 +37,11 @@ class PasswordPolicyHandler:
         self.regexp_uppercase = re.compile("[A-Z]")
         self.regexp_lowercase = re.compile("[a-z]")
 
-    def validate_password(self, password):
+    def validate_password(self, password: str) -> None:
         """Checks whether a given password complies with the server's policy.
 
         Args:
-            password (str): The password to check against the server's policy.
+            password: The password to check against the server's policy.
 
         Raises:
             PasswordRefusedError: The password doesn't comply with the server's policy.
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 1000ac95ff..22d1e9d35c 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -25,7 +25,7 @@ The methods that define policy are:
 import abc
 import logging
 from contextlib import contextmanager
-from typing import Dict, Iterable, List, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Set, Tuple
 
 from prometheus_client import Counter
 from typing_extensions import ContextManager
@@ -46,9 +46,8 @@ from synapse.util.caches.descriptors import cached
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 
-MYPY = False
-if MYPY:
-    import synapse.server
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
@@ -101,7 +100,7 @@ assert LAST_ACTIVE_GRANULARITY < IDLE_TIMER
 class BasePresenceHandler(abc.ABC):
     """Parts of the PresenceHandler that are shared between workers and master"""
 
-    def __init__(self, hs: "synapse.server.HomeServer"):
+    def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
 
@@ -199,7 +198,7 @@ class BasePresenceHandler(abc.ABC):
 
 
 class PresenceHandler(BasePresenceHandler):
-    def __init__(self, hs: "synapse.server.HomeServer"):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.hs = hs
         self.is_mine_id = hs.is_mine_id
@@ -802,7 +801,7 @@ class PresenceHandler(BasePresenceHandler):
             between the requested tokens due to the limit.
 
             The token returned can be used in a subsequent call to this
-            function to get further updatees.
+            function to get further updates.
 
             The updates are a list of 2-tuples of stream ID and the row data
         """
@@ -977,7 +976,7 @@ def should_notify(old_state, new_state):
             new_state.last_active_ts - old_state.last_active_ts
             > LAST_ACTIVE_GRANULARITY
         ):
-            # Only notify about last active bumps if we're not currently acive
+            # Only notify about last active bumps if we're not currently active
             if not new_state.currently_active:
                 notify_reason_counter.labels("last_active_change_online").inc()
                 return True
@@ -1011,7 +1010,7 @@ def format_user_presence_state(state, now, include_user_id=True):
 
 
 class PresenceEventSource:
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         # We can't call get_presence_handler here because there's a cycle:
         #
         #   Presence -> Notifier -> PresenceEventSource -> Presence
@@ -1071,12 +1070,14 @@ class PresenceEventSource:
 
             users_interested_in = await self._get_interested_in(user, explicit_room_id)
 
-            user_ids_changed = set()
+            user_ids_changed = set()  # type: Collection[str]
             changed = None
             if from_key:
                 changed = stream_change_cache.get_all_entities_changed(from_key)
 
             if changed is not None and len(changed) < 500:
+                assert isinstance(user_ids_changed, set)
+
                 # For small deltas, its quicker to get all changes and then
                 # work out if we share a room or they're in our presence list
                 get_updates_counter.labels("stream").inc()
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 5453e6dfc8..dee0ef45e7 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -12,9 +12,9 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import random
+from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import (
     AuthError,
@@ -24,26 +24,37 @@ from synapse.api.errors import (
     StoreError,
     SynapseError,
 )
-from synapse.metrics.background_process_metrics import run_as_background_process
-from synapse.types import UserID, create_requester, get_domain_from_id
+from synapse.metrics.background_process_metrics import wrap_as_background_process
+from synapse.types import (
+    JsonDict,
+    Requester,
+    UserID,
+    create_requester,
+    get_domain_from_id,
+)
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 MAX_DISPLAYNAME_LEN = 256
 MAX_AVATAR_URL_LEN = 1000
 
 
-class BaseProfileHandler(BaseHandler):
+class ProfileHandler(BaseHandler):
     """Handles fetching and updating user profile information.
 
-    BaseProfileHandler can be instantiated directly on workers and will
-    delegate to master when necessary. The master process should use the
-    subclass MasterProfileHandler
+    ProfileHandler can be instantiated directly on workers and will
+    delegate to master when necessary.
     """
 
-    def __init__(self, hs):
+    PROFILE_UPDATE_MS = 60 * 1000
+    PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
+
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
         self.federation = hs.get_federation_client()
@@ -53,7 +64,12 @@ class BaseProfileHandler(BaseHandler):
 
         self.user_directory_handler = hs.get_user_directory_handler()
 
-    async def get_profile(self, user_id):
+        if hs.config.run_background_tasks:
+            self.clock.looping_call(
+                self._update_remote_profile_cache, self.PROFILE_UPDATE_MS
+            )
+
+    async def get_profile(self, user_id: str) -> JsonDict:
         target_user = UserID.from_string(user_id)
 
         if self.hs.is_mine(target_user):
@@ -82,11 +98,18 @@ class BaseProfileHandler(BaseHandler):
             except RequestSendFailed as e:
                 raise SynapseError(502, "Failed to fetch profile") from e
             except HttpResponseException as e:
+                if e.code < 500 and e.code != 404:
+                    # Other codes are not allowed in c2s API
+                    logger.info(
+                        "Server replied with wrong response: %s %s", e.code, e.msg
+                    )
+
+                    raise SynapseError(502, "Failed to fetch profile")
                 raise e.to_synapse_error()
 
-    async def get_profile_from_cache(self, user_id):
+    async def get_profile_from_cache(self, user_id: str) -> JsonDict:
         """Get the profile information from our local cache. If the user is
-        ours then the profile information will always be corect. Otherwise,
+        ours then the profile information will always be correct. Otherwise,
         it may be out of date/missing.
         """
         target_user = UserID.from_string(user_id)
@@ -108,7 +131,7 @@ class BaseProfileHandler(BaseHandler):
             profile = await self.store.get_from_remote_profile_cache(user_id)
             return profile or {}
 
-    async def get_displayname(self, target_user):
+    async def get_displayname(self, target_user: UserID) -> Optional[str]:
         if self.hs.is_mine(target_user):
             try:
                 displayname = await self.store.get_profile_displayname(
@@ -136,15 +159,19 @@ class BaseProfileHandler(BaseHandler):
             return result["displayname"]
 
     async def set_displayname(
-        self, target_user, requester, new_displayname, by_admin=False
-    ):
+        self,
+        target_user: UserID,
+        requester: Requester,
+        new_displayname: str,
+        by_admin: bool = False,
+    ) -> None:
         """Set the displayname of a user
 
         Args:
-            target_user (UserID): the user whose displayname is to be changed.
-            requester (Requester): The user attempting to make this change.
-            new_displayname (str): The displayname to give this user.
-            by_admin (bool): Whether this change was made by an administrator.
+            target_user: the user whose displayname is to be changed.
+            requester: The user attempting to make this change.
+            new_displayname: The displayname to give this user.
+            by_admin: Whether this change was made by an administrator.
         """
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -162,23 +189,30 @@ class BaseProfileHandler(BaseHandler):
                 )
 
         if not isinstance(new_displayname, str):
-            raise SynapseError(400, "Invalid displayname")
+            raise SynapseError(
+                400, "'displayname' must be a string", errcode=Codes.INVALID_PARAM
+            )
 
         if len(new_displayname) > MAX_DISPLAYNAME_LEN:
             raise SynapseError(
                 400, "Displayname is too long (max %i)" % (MAX_DISPLAYNAME_LEN,)
             )
 
+        displayname_to_set = new_displayname  # type: Optional[str]
         if new_displayname == "":
-            new_displayname = None
+            displayname_to_set = None
 
         # If the admin changes the display name of a user, the requesting user cannot send
         # the join event to update the displayname in the rooms.
         # This must be done by the target user himself.
         if by_admin:
-            requester = create_requester(target_user)
+            requester = create_requester(
+                target_user, authenticated_entity=requester.authenticated_entity,
+            )
 
-        await self.store.set_profile_displayname(target_user.localpart, new_displayname)
+        await self.store.set_profile_displayname(
+            target_user.localpart, displayname_to_set
+        )
 
         if self.hs.config.user_directory_search_all_users:
             profile = await self.store.get_profileinfo(target_user.localpart)
@@ -188,7 +222,7 @@ class BaseProfileHandler(BaseHandler):
 
         await self._update_join_states(requester, target_user)
 
-    async def get_avatar_url(self, target_user):
+    async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
         if self.hs.is_mine(target_user):
             try:
                 avatar_url = await self.store.get_profile_avatar_url(
@@ -215,15 +249,19 @@ class BaseProfileHandler(BaseHandler):
             return result["avatar_url"]
 
     async def set_avatar_url(
-        self, target_user, requester, new_avatar_url, by_admin=False
+        self,
+        target_user: UserID,
+        requester: Requester,
+        new_avatar_url: str,
+        by_admin: bool = False,
     ):
         """Set a new avatar URL for a user.
 
         Args:
-            target_user (UserID): the user whose avatar URL is to be changed.
-            requester (Requester): The user attempting to make this change.
-            new_avatar_url (str): The avatar URL to give this user.
-            by_admin (bool): Whether this change was made by an administrator.
+            target_user: the user whose avatar URL is to be changed.
+            requester: The user attempting to make this change.
+            new_avatar_url: The avatar URL to give this user.
+            by_admin: Whether this change was made by an administrator.
         """
         if not self.hs.is_mine(target_user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -239,7 +277,9 @@ class BaseProfileHandler(BaseHandler):
                 )
 
         if not isinstance(new_avatar_url, str):
-            raise SynapseError(400, "Invalid displayname")
+            raise SynapseError(
+                400, "'avatar_url' must be a string", errcode=Codes.INVALID_PARAM
+            )
 
         if len(new_avatar_url) > MAX_AVATAR_URL_LEN:
             raise SynapseError(
@@ -248,7 +288,9 @@ class BaseProfileHandler(BaseHandler):
 
         # Same like set_displayname
         if by_admin:
-            requester = create_requester(target_user)
+            requester = create_requester(
+                target_user, authenticated_entity=requester.authenticated_entity
+            )
 
         await self.store.set_profile_avatar_url(target_user.localpart, new_avatar_url)
 
@@ -260,7 +302,7 @@ class BaseProfileHandler(BaseHandler):
 
         await self._update_join_states(requester, target_user)
 
-    async def on_profile_query(self, args):
+    async def on_profile_query(self, args: JsonDict) -> JsonDict:
         user = UserID.from_string(args["user_id"])
         if not self.hs.is_mine(user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -285,7 +327,9 @@ class BaseProfileHandler(BaseHandler):
 
         return response
 
-    async def _update_join_states(self, requester, target_user):
+    async def _update_join_states(
+        self, requester: Requester, target_user: UserID
+    ) -> None:
         if not self.hs.is_mine(target_user):
             return
 
@@ -316,15 +360,17 @@ class BaseProfileHandler(BaseHandler):
                     "Failed to update join event for room %s - %s", room_id, str(e)
                 )
 
-    async def check_profile_query_allowed(self, target_user, requester=None):
+    async def check_profile_query_allowed(
+        self, target_user: UserID, requester: Optional[UserID] = None
+    ) -> None:
         """Checks whether a profile query is allowed. If the
         'require_auth_for_profile_requests' config flag is set to True and a
         'requester' is provided, the query is only allowed if the two users
         share a room.
 
         Args:
-            target_user (UserID): The owner of the queried profile.
-            requester (None|UserID): The user querying for the profile.
+            target_user: The owner of the queried profile.
+            requester: The user querying for the profile.
 
         Raises:
             SynapseError(403): The two users share no room, or ne user couldn't
@@ -363,25 +409,7 @@ class BaseProfileHandler(BaseHandler):
                 raise SynapseError(403, "Profile isn't available", Codes.FORBIDDEN)
             raise
 
-
-class MasterProfileHandler(BaseProfileHandler):
-    PROFILE_UPDATE_MS = 60 * 1000
-    PROFILE_UPDATE_EVERY_MS = 24 * 60 * 60 * 1000
-
-    def __init__(self, hs):
-        super().__init__(hs)
-
-        assert hs.config.worker_app is None
-
-        self.clock.looping_call(
-            self._start_update_remote_profile_cache, self.PROFILE_UPDATE_MS
-        )
-
-    def _start_update_remote_profile_cache(self):
-        return run_as_background_process(
-            "Update remote profile", self._update_remote_profile_cache
-        )
-
+    @wrap_as_background_process("Update remote profile")
     async def _update_remote_profile_cache(self):
         """Called periodically to check profiles of remote users we haven't
         checked in a while.
diff --git a/synapse/handlers/read_marker.py b/synapse/handlers/read_marker.py
index c32f314a1c..a7550806e6 100644
--- a/synapse/handlers/read_marker.py
+++ b/synapse/handlers/read_marker.py
@@ -14,23 +14,29 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.util.async_helpers import Linearizer
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class ReadMarkerHandler(BaseHandler):
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.server_name = hs.config.server_name
         self.store = hs.get_datastore()
         self.read_marker_linearizer = Linearizer(name="read_marker")
         self.notifier = hs.get_notifier()
 
-    async def received_client_read_marker(self, room_id, user_id, event_id):
+    async def received_client_read_marker(
+        self, room_id: str, user_id: str, event_id: str
+    ) -> None:
         """Updates the read marker for a given user in a given room if the event ID given
         is ahead in the stream relative to the current read marker.
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 7225923757..153cbae7b9 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -13,9 +13,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
+from typing import List, Tuple
 
+from synapse.appservice import ApplicationService
 from synapse.handlers._base import BaseHandler
-from synapse.types import ReadReceipt, get_domain_from_id
+from synapse.types import JsonDict, ReadReceipt, get_domain_from_id
 from synapse.util.async_helpers import maybe_awaitable
 
 logger = logging.getLogger(__name__)
@@ -140,5 +142,37 @@ class ReceiptEventSource:
 
         return (events, to_key)
 
+    async def get_new_events_as(
+        self, from_key: int, service: ApplicationService
+    ) -> Tuple[List[JsonDict], int]:
+        """Returns a set of new receipt events that an appservice
+        may be interested in.
+
+        Args:
+            from_key: the stream position at which events should be fetched from
+            service: The appservice which may be interested
+        """
+        from_key = int(from_key)
+        to_key = self.get_current_key()
+
+        if from_key == to_key:
+            return [], to_key
+
+        # Fetch all read receipts for all rooms, up to a limit of 100. This is ordered
+        # by most recent.
+        rooms_to_events = await self.store.get_linearized_receipts_for_all_rooms(
+            from_key=from_key, to_key=to_key
+        )
+
+        # Then filter down to rooms that the AS can read
+        events = []
+        for room_id, event in rooms_to_events.items():
+            if not await service.matches_user_in_member_list(room_id, self.store):
+                continue
+
+            events.append(event)
+
+        return (events, to_key)
+
     def get_current_key(self, direction="f"):
         return self.store.get_max_receipt_stream_id()
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 538f4b2a61..0d85fd0868 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -15,10 +15,12 @@
 
 """Contains functions for registering clients."""
 import logging
+from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from synapse import types
 from synapse.api.constants import MAX_USERID_LENGTH, EventTypes, JoinRules, LoginType
 from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.appservice import ApplicationService
 from synapse.config.server import is_threepid_reserved
 from synapse.http.servlet import assert_params_in_dict
 from synapse.replication.http.login import RegisterDeviceReplicationServlet
@@ -32,26 +34,25 @@ from synapse.types import RoomAlias, UserID, create_requester
 
 from ._base import BaseHandler
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class RegistrationHandler(BaseHandler):
-    def __init__(self, hs):
-        """
-
-        Args:
-            hs (synapse.server.HomeServer):
-        """
+    def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
         self.hs = hs
         self.auth = hs.get_auth()
         self._auth_handler = hs.get_auth_handler()
         self.profile_handler = hs.get_profile_handler()
         self.user_directory_handler = hs.get_user_directory_handler()
-        self.identity_handler = self.hs.get_handlers().identity_handler
+        self.identity_handler = self.hs.get_identity_handler()
         self.ratelimiter = hs.get_registration_ratelimiter()
         self.macaroon_gen = hs.get_macaroon_generator()
         self._server_notices_mxid = hs.config.server_notices_mxid
+        self._server_name = hs.hostname
 
         self.spam_checker = hs.get_spam_checker()
 
@@ -70,7 +71,10 @@ class RegistrationHandler(BaseHandler):
         self.session_lifetime = hs.config.session_lifetime
 
     async def check_username(
-        self, localpart, guest_access_token=None, assigned_user_id=None
+        self,
+        localpart: str,
+        guest_access_token: Optional[str] = None,
+        assigned_user_id: Optional[str] = None,
     ):
         if types.contains_invalid_mxid_characters(localpart):
             raise SynapseError(
@@ -115,7 +119,10 @@ class RegistrationHandler(BaseHandler):
                     400, "User ID already taken.", errcode=Codes.USER_IN_USE
                 )
             user_data = await self.auth.get_user_by_access_token(guest_access_token)
-            if not user_data["is_guest"] or user_data["user"].localpart != localpart:
+            if (
+                not user_data.is_guest
+                or UserID.from_string(user_data.user_id).localpart != localpart
+            ):
                 raise AuthError(
                     403,
                     "Cannot register taken user ID without valid guest "
@@ -136,39 +143,45 @@ class RegistrationHandler(BaseHandler):
 
     async def register_user(
         self,
-        localpart=None,
-        password_hash=None,
-        guest_access_token=None,
-        make_guest=False,
-        admin=False,
-        threepid=None,
-        user_type=None,
-        default_display_name=None,
-        address=None,
-        bind_emails=[],
-        by_admin=False,
-        user_agent_ips=None,
-    ):
+        localpart: Optional[str] = None,
+        password_hash: Optional[str] = None,
+        guest_access_token: Optional[str] = None,
+        make_guest: bool = False,
+        admin: bool = False,
+        threepid: Optional[dict] = None,
+        user_type: Optional[str] = None,
+        default_display_name: Optional[str] = None,
+        address: Optional[str] = None,
+        bind_emails: List[str] = [],
+        by_admin: bool = False,
+        user_agent_ips: Optional[List[Tuple[str, str]]] = None,
+    ) -> str:
         """Registers a new client on the server.
 
         Args:
             localpart: The local part of the user ID to register. If None,
               one will be generated.
-            password_hash (str|None): The hashed password to assign to this user so they can
+            password_hash: The hashed password to assign to this user so they can
               login again. This can be None which means they cannot login again
               via a password (e.g. the user is an application service user).
-            user_type (str|None): type of user. One of the values from
+            guest_access_token: The access token used when this was a guest
+                account.
+            make_guest: True if the the new user should be guest,
+                false to add a regular user account.
+            admin: True if the user should be registered as a server admin.
+            threepid: The threepid used for registering, if any.
+            user_type: type of user. One of the values from
               api.constants.UserTypes, or None for a normal user.
-            default_display_name (unicode|None): if set, the new user's displayname
+            default_display_name: if set, the new user's displayname
               will be set to this. Defaults to 'localpart'.
-            address (str|None): the IP address used to perform the registration.
-            bind_emails (List[str]): list of emails to bind to this account.
-            by_admin (bool): True if this registration is being made via the
+            address: the IP address used to perform the registration.
+            bind_emails: list of emails to bind to this account.
+            by_admin: True if this registration is being made via the
               admin api, otherwise False.
-            user_agent_ips (List[(str, str)]): Tuples of IP addresses and user-agents used
+            user_agent_ips: Tuples of IP addresses and user-agents used
                 during the registration process.
         Returns:
-            str: user_id
+            The registere user_id.
         Raises:
             SynapseError if there was a problem registering.
         """
@@ -232,8 +245,10 @@ class RegistrationHandler(BaseHandler):
         else:
             # autogen a sequential user ID
             fail_count = 0
-            user = None
-            while not user:
+            # If a default display name is not given, generate one.
+            generate_display_name = default_display_name is None
+            # This breaks on successful registration *or* errors after 10 failures.
+            while True:
                 # Fail after being unable to find a suitable ID a few times
                 if fail_count > 10:
                     raise SynapseError(500, "Unable to find a suitable guest user ID")
@@ -242,7 +257,7 @@ class RegistrationHandler(BaseHandler):
                 user = UserID(localpart, self.hs.hostname)
                 user_id = user.to_string()
                 self.check_user_id_not_appservice_exclusive(user_id)
-                if default_display_name is None:
+                if generate_display_name:
                     default_display_name = localpart
                 try:
                     await self.register_with_store(
@@ -258,8 +273,6 @@ class RegistrationHandler(BaseHandler):
                     break
                 except SynapseError:
                     # if user id is taken, just generate another
-                    user = None
-                    user_id = None
                     fail_count += 1
 
         if not self.hs.config.user_consent_at_registration:
@@ -291,7 +304,7 @@ class RegistrationHandler(BaseHandler):
 
         return user_id
 
-    async def _create_and_join_rooms(self, user_id: str):
+    async def _create_and_join_rooms(self, user_id: str) -> None:
         """
         Create the auto-join rooms and join or invite the user to them.
 
@@ -314,7 +327,8 @@ class RegistrationHandler(BaseHandler):
         requires_join = False
         if self.hs.config.registration.auto_join_user_id:
             fake_requester = create_requester(
-                self.hs.config.registration.auto_join_user_id
+                self.hs.config.registration.auto_join_user_id,
+                authenticated_entity=self._server_name,
             )
 
             # If the room requires an invite, add the user to the list of invites.
@@ -326,7 +340,9 @@ class RegistrationHandler(BaseHandler):
             # being necessary this will occur after the invite was sent.
             requires_join = True
         else:
-            fake_requester = create_requester(user_id)
+            fake_requester = create_requester(
+                user_id, authenticated_entity=self._server_name
+            )
 
         # Choose whether to federate the new room.
         if not self.hs.config.registration.autocreate_auto_join_rooms_federated:
@@ -359,7 +375,9 @@ class RegistrationHandler(BaseHandler):
                     # created it, then ensure the first user joins it.
                     if requires_join:
                         await room_member_handler.update_membership(
-                            requester=create_requester(user_id),
+                            requester=create_requester(
+                                user_id, authenticated_entity=self._server_name
+                            ),
                             target=UserID.from_string(user_id),
                             room_id=info["room_id"],
                             # Since it was just created, there are no remote hosts.
@@ -367,15 +385,10 @@ class RegistrationHandler(BaseHandler):
                             action="join",
                             ratelimit=False,
                         )
-
-            except ConsentNotGivenError as e:
-                # Technically not necessary to pull out this error though
-                # moving away from bare excepts is a good thing to do.
-                logger.error("Failed to join new user to %r: %r", r, e)
             except Exception as e:
                 logger.error("Failed to join new user to %r: %r", r, e)
 
-    async def _join_rooms(self, user_id: str):
+    async def _join_rooms(self, user_id: str) -> None:
         """
         Join or invite the user to the auto-join rooms.
 
@@ -421,9 +434,13 @@ class RegistrationHandler(BaseHandler):
 
                 # Send the invite, if necessary.
                 if requires_invite:
+                    # If an invite is required, there must be a auto-join user ID.
+                    assert self.hs.config.registration.auto_join_user_id
+
                     await room_member_handler.update_membership(
                         requester=create_requester(
-                            self.hs.config.registration.auto_join_user_id
+                            self.hs.config.registration.auto_join_user_id,
+                            authenticated_entity=self._server_name,
                         ),
                         target=UserID.from_string(user_id),
                         room_id=room_id,
@@ -434,7 +451,9 @@ class RegistrationHandler(BaseHandler):
 
                 # Send the join.
                 await room_member_handler.update_membership(
-                    requester=create_requester(user_id),
+                    requester=create_requester(
+                        user_id, authenticated_entity=self._server_name
+                    ),
                     target=UserID.from_string(user_id),
                     room_id=room_id,
                     remote_room_hosts=remote_room_hosts,
@@ -449,7 +468,7 @@ class RegistrationHandler(BaseHandler):
             except Exception as e:
                 logger.error("Failed to join new user to %r: %r", r, e)
 
-    async def _auto_join_rooms(self, user_id: str):
+    async def _auto_join_rooms(self, user_id: str) -> None:
         """Automatically joins users to auto join rooms - creating the room in the first place
         if the user is the first to be created.
 
@@ -472,16 +491,16 @@ class RegistrationHandler(BaseHandler):
         else:
             await self._join_rooms(user_id)
 
-    async def post_consent_actions(self, user_id):
+    async def post_consent_actions(self, user_id: str) -> None:
         """A series of registration actions that can only be carried out once consent
         has been granted
 
         Args:
-            user_id (str): The user to join
+            user_id: The user to join
         """
         await self._auto_join_rooms(user_id)
 
-    async def appservice_register(self, user_localpart, as_token):
+    async def appservice_register(self, user_localpart: str, as_token: str) -> str:
         user = UserID(user_localpart, self.hs.hostname)
         user_id = user.to_string()
         service = self.store.get_app_service_by_token(as_token)
@@ -506,7 +525,9 @@ class RegistrationHandler(BaseHandler):
         )
         return user_id
 
-    def check_user_id_not_appservice_exclusive(self, user_id, allowed_appservice=None):
+    def check_user_id_not_appservice_exclusive(
+        self, user_id: str, allowed_appservice: Optional[ApplicationService] = None
+    ) -> None:
         # don't allow people to register the server notices mxid
         if self._server_notices_mxid is not None:
             if user_id == self._server_notices_mxid:
@@ -530,12 +551,12 @@ class RegistrationHandler(BaseHandler):
                     errcode=Codes.EXCLUSIVE,
                 )
 
-    def check_registration_ratelimit(self, address):
+    def check_registration_ratelimit(self, address: Optional[str]) -> None:
         """A simple helper method to check whether the registration rate limit has been hit
         for a given IP address
 
         Args:
-            address (str|None): the IP address used to perform the registration. If this is
+            address: the IP address used to perform the registration. If this is
                 None, no ratelimiting will be performed.
 
         Raises:
@@ -546,42 +567,39 @@ class RegistrationHandler(BaseHandler):
 
         self.ratelimiter.ratelimit(address)
 
-    def register_with_store(
+    async def register_with_store(
         self,
-        user_id,
-        password_hash=None,
-        was_guest=False,
-        make_guest=False,
-        appservice_id=None,
-        create_profile_with_displayname=None,
-        admin=False,
-        user_type=None,
-        address=None,
-        shadow_banned=False,
-    ):
+        user_id: str,
+        password_hash: Optional[str] = None,
+        was_guest: bool = False,
+        make_guest: bool = False,
+        appservice_id: Optional[str] = None,
+        create_profile_with_displayname: Optional[str] = None,
+        admin: bool = False,
+        user_type: Optional[str] = None,
+        address: Optional[str] = None,
+        shadow_banned: bool = False,
+    ) -> None:
         """Register user in the datastore.
 
         Args:
-            user_id (str): The desired user ID to register.
-            password_hash (str|None): Optional. The password hash for this user.
-            was_guest (bool): Optional. Whether this is a guest account being
+            user_id: The desired user ID to register.
+            password_hash: Optional. The password hash for this user.
+            was_guest: Optional. Whether this is a guest account being
                 upgraded to a non-guest account.
-            make_guest (boolean): True if the the new user should be guest,
+            make_guest: True if the the new user should be guest,
                 false to add a regular user account.
-            appservice_id (str|None): The ID of the appservice registering the user.
-            create_profile_with_displayname (unicode|None): Optionally create a
+            appservice_id: The ID of the appservice registering the user.
+            create_profile_with_displayname: Optionally create a
                 profile for the user, setting their displayname to the given value
-            admin (boolean): is an admin user?
-            user_type (str|None): type of user. One of the values from
+            admin: is an admin user?
+            user_type: type of user. One of the values from
                 api.constants.UserTypes, or None for a normal user.
-            address (str|None): the IP address used to perform the registration.
-            shadow_banned (bool): Whether to shadow-ban the user
-
-        Returns:
-            Awaitable
+            address: the IP address used to perform the registration.
+            shadow_banned: Whether to shadow-ban the user
         """
         if self.hs.config.worker_app:
-            return self._register_client(
+            await self._register_client(
                 user_id=user_id,
                 password_hash=password_hash,
                 was_guest=was_guest,
@@ -594,7 +612,7 @@ class RegistrationHandler(BaseHandler):
                 shadow_banned=shadow_banned,
             )
         else:
-            return self.store.register_user(
+            await self.store.register_user(
                 user_id=user_id,
                 password_hash=password_hash,
                 was_guest=was_guest,
@@ -607,22 +625,24 @@ class RegistrationHandler(BaseHandler):
             )
 
     async def register_device(
-        self, user_id, device_id, initial_display_name, is_guest=False
-    ):
+        self,
+        user_id: str,
+        device_id: Optional[str],
+        initial_display_name: Optional[str],
+        is_guest: bool = False,
+    ) -> Tuple[str, str]:
         """Register a device for a user and generate an access token.
 
         The access token will be limited by the homeserver's session_lifetime config.
 
         Args:
-            user_id (str): full canonical @user:id
-            device_id (str|None): The device ID to check, or None to generate
-                a new one.
-            initial_display_name (str|None): An optional display name for the
-                device.
-            is_guest (bool): Whether this is a guest account
+            user_id: full canonical @user:id
+            device_id: The device ID to check, or None to generate a new one.
+            initial_display_name: An optional display name for the device.
+            is_guest: Whether this is a guest account
 
         Returns:
-            tuple[str, str]: Tuple of device ID and access token
+            Tuple of device ID and access token
         """
 
         if self.hs.config.worker_app:
@@ -642,7 +662,7 @@ class RegistrationHandler(BaseHandler):
                 )
             valid_until_ms = self.clock.time_msec() + self.session_lifetime
 
-        device_id = await self.device_handler.check_device_registered(
+        registered_device_id = await self.device_handler.check_device_registered(
             user_id, device_id, initial_display_name
         )
         if is_guest:
@@ -652,20 +672,21 @@ class RegistrationHandler(BaseHandler):
             )
         else:
             access_token = await self._auth_handler.get_access_token_for_user_id(
-                user_id, device_id=device_id, valid_until_ms=valid_until_ms
+                user_id, device_id=registered_device_id, valid_until_ms=valid_until_ms
             )
 
-        return (device_id, access_token)
+        return (registered_device_id, access_token)
 
-    async def post_registration_actions(self, user_id, auth_result, access_token):
+    async def post_registration_actions(
+        self, user_id: str, auth_result: dict, access_token: Optional[str]
+    ) -> None:
         """A user has completed registration
 
         Args:
-            user_id (str): The user ID that consented
-            auth_result (dict): The authenticated credentials of the newly
-                registered user.
-            access_token (str|None): The access token of the newly logged in
-                device, or None if `inhibit_login` enabled.
+            user_id: The user ID that consented
+            auth_result: The authenticated credentials of the newly registered user.
+            access_token: The access token of the newly logged in device, or
+                None if `inhibit_login` enabled.
         """
         if self.hs.config.worker_app:
             await self._post_registration_client(
@@ -691,19 +712,20 @@ class RegistrationHandler(BaseHandler):
         if auth_result and LoginType.TERMS in auth_result:
             await self._on_user_consented(user_id, self.hs.config.user_consent_version)
 
-    async def _on_user_consented(self, user_id, consent_version):
+    async def _on_user_consented(self, user_id: str, consent_version: str) -> None:
         """A user consented to the terms on registration
 
         Args:
-            user_id (str): The user ID that consented.
-            consent_version (str): version of the policy the user has
-                consented to.
+            user_id: The user ID that consented.
+            consent_version: version of the policy the user has consented to.
         """
         logger.info("%s has consented to the privacy policy", user_id)
         await self.store.user_set_consent_version(user_id, consent_version)
         await self.post_consent_actions(user_id)
 
-    async def _register_email_threepid(self, user_id, threepid, token):
+    async def _register_email_threepid(
+        self, user_id: str, threepid: dict, token: Optional[str]
+    ) -> None:
         """Add an email address as a 3pid identifier
 
         Also adds an email pusher for the email address, if configured in the
@@ -712,10 +734,9 @@ class RegistrationHandler(BaseHandler):
         Must be called on master.
 
         Args:
-            user_id (str): id of user
-            threepid (object): m.login.email.identity auth response
-            token (str|None): access_token for the user, or None if not logged
-                in.
+            user_id: id of user
+            threepid: m.login.email.identity auth response
+            token: access_token for the user, or None if not logged in.
         """
         reqd = ("medium", "address", "validated_at")
         if any(x not in threepid for x in reqd):
@@ -741,7 +762,9 @@ class RegistrationHandler(BaseHandler):
             # up when the access token is saved, but that's quite an
             # invasive change I'd rather do separately.
             user_tuple = await self.store.get_user_by_access_token(token)
-            token_id = user_tuple["token_id"]
+            # The token better still exist.
+            assert user_tuple
+            token_id = user_tuple.token_id
 
             await self.pusher_pool.add_pusher(
                 user_id=user_id,
@@ -755,14 +778,14 @@ class RegistrationHandler(BaseHandler):
                 data={},
             )
 
-    async def _register_msisdn_threepid(self, user_id, threepid):
+    async def _register_msisdn_threepid(self, user_id: str, threepid: dict) -> None:
         """Add a phone number as a 3pid identifier
 
         Must be called on master.
 
         Args:
-            user_id (str): id of user
-            threepid (object): m.login.msisdn auth response
+            user_id: id of user
+            threepid: m.login.msisdn auth response
         """
         try:
             assert_params_in_dict(threepid, ["medium", "address", "validated_at"])
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index d5f7c78edf..930047e730 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -120,7 +120,7 @@ class RoomCreationHandler(BaseHandler):
         # subsequent requests
         self._upgrade_response_cache = ResponseCache(
             hs, "room_upgrade", timeout_ms=FIVE_MINUTES_IN_MS
-        )
+        )  # type: ResponseCache[Tuple[str, str]]
         self._server_notices_mxid = hs.config.server_notices_mxid
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
@@ -185,6 +185,7 @@ class RoomCreationHandler(BaseHandler):
             ShadowBanError if the requester is shadow-banned.
         """
         user_id = requester.user.to_string()
+        assert self.hs.is_mine_id(user_id), "User must be our own: %s" % (user_id,)
 
         # start by allocating a new room id
         r = await self.store.get_room(old_room_id)
@@ -213,7 +214,6 @@ class RoomCreationHandler(BaseHandler):
                     "replacement_room": new_room_id,
                 },
             },
-            token_id=requester.access_token_id,
         )
         old_room_version = await self.store.get_room_version_id(old_room_id)
         await self.auth.check_from_context(
@@ -229,8 +229,8 @@ class RoomCreationHandler(BaseHandler):
         )
 
         # now send the tombstone
-        await self.event_creation_handler.send_nonmember_event(
-            requester, tombstone_event, tombstone_context
+        await self.event_creation_handler.handle_new_client_event(
+            requester=requester, event=tombstone_event, context=tombstone_context,
         )
 
         old_room_state = await tombstone_context.get_current_state_ids()
@@ -587,7 +587,7 @@ class RoomCreationHandler(BaseHandler):
         """
         user_id = requester.user.to_string()
 
-        await self.auth.check_auth_blocking(user_id)
+        await self.auth.check_auth_blocking(requester=requester)
 
         if (
             self._server_notices_mxid is not None
@@ -681,7 +681,16 @@ class RoomCreationHandler(BaseHandler):
             creator_id=user_id, is_public=is_public, room_version=room_version,
         )
 
-        directory_handler = self.hs.get_handlers().directory_handler
+        # Check whether this visibility value is blocked by a third party module
+        allowed_by_third_party_rules = await (
+            self.third_party_event_rules.check_visibility_can_be_modified(
+                room_id, visibility
+            )
+        )
+        if not allowed_by_third_party_rules:
+            raise SynapseError(403, "Room visibility value not allowed.")
+
+        directory_handler = self.hs.get_directory_handler()
         if room_alias:
             await directory_handler.create_association(
                 requester=requester,
@@ -762,22 +771,29 @@ class RoomCreationHandler(BaseHandler):
                 ratelimit=False,
             )
 
-        for invitee in invite_list:
+        # we avoid dropping the lock between invites, as otherwise joins can
+        # start coming in and making the createRoom slow.
+        #
+        # we also don't need to check the requester's shadow-ban here, as we
+        # have already done so above (and potentially emptied invite_list).
+        with (await self.room_member_handler.member_linearizer.queue((room_id,))):
             content = {}
             is_direct = config.get("is_direct", None)
             if is_direct:
                 content["is_direct"] = is_direct
 
-            # Note that update_membership with an action of "invite" can raise a
-            # ShadowBanError, but this was handled above by emptying invite_list.
-            _, last_stream_id = await self.room_member_handler.update_membership(
-                requester,
-                UserID.from_string(invitee),
-                room_id,
-                "invite",
-                ratelimit=False,
-                content=content,
-            )
+            for invitee in invite_list:
+                (
+                    _,
+                    last_stream_id,
+                ) = await self.room_member_handler.update_membership_locked(
+                    requester,
+                    UserID.from_string(invitee),
+                    room_id,
+                    "invite",
+                    ratelimit=False,
+                    content=content,
+                )
 
         for invite_3pid in invite_3pid_list:
             id_server = invite_3pid["id_server"]
@@ -962,8 +978,6 @@ class RoomCreationHandler(BaseHandler):
             try:
                 random_string = stringutils.random_string(18)
                 gen_room_id = RoomID(random_string, self.hs.hostname).to_string()
-                if isinstance(gen_room_id, bytes):
-                    gen_room_id = gen_room_id.decode("utf-8")
                 await self.store.store_room(
                     room_id=gen_room_id,
                     room_creator_user_id=creator_id,
@@ -1243,7 +1257,9 @@ class RoomShutdownHandler:
                     400, "User must be our own: %s" % (new_room_user_id,)
                 )
 
-            room_creator_requester = create_requester(new_room_user_id)
+            room_creator_requester = create_requester(
+                new_room_user_id, authenticated_entity=requester_user_id
+            )
 
             info, stream_id = await self._room_creation_handler.create_room(
                 room_creator_requester,
@@ -1261,7 +1277,7 @@ class RoomShutdownHandler:
             )
 
             # We now wait for the create room to come back in via replication so
-            # that we can assume that all the joins/invites have propogated before
+            # that we can assume that all the joins/invites have propagated before
             # we try and auto join below.
             await self._replication.wait_for_stream_position(
                 self.hs.config.worker.events_shard_config.get_instance(new_room_id),
@@ -1283,7 +1299,9 @@ class RoomShutdownHandler:
 
             try:
                 # Kick users from room
-                target_requester = create_requester(user_id)
+                target_requester = create_requester(
+                    user_id, authenticated_entity=requester_user_id
+                )
                 _, stream_id = await self.room_member_handler.update_membership(
                     requester=target_requester,
                     target=target_requester.user,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 8feba8c90a..c002886324 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -17,12 +17,10 @@ import abc
 import logging
 import random
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union
-
-from unpaddedbase64 import encode_base64
+from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
 
 from synapse import types
-from synapse.api.constants import MAX_DEPTH, EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -31,13 +29,8 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.api.room_versions import EventFormatVersions
-from synapse.crypto.event_signing import compute_event_reference_hash
 from synapse.events import EventBase
-from synapse.events.builder import create_local_event_from_event_dict
 from synapse.events.snapshot import EventContext
-from synapse.events.validator import EventValidator
-from synapse.storage.roommember import RoomsForUser
 from synapse.types import JsonDict, Requester, RoomAlias, RoomID, StateMap, UserID
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_left_room
@@ -64,9 +57,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.state_handler = hs.get_state_handler()
         self.config = hs.config
 
-        self.federation_handler = hs.get_handlers().federation_handler
-        self.directory_handler = hs.get_handlers().directory_handler
-        self.identity_handler = hs.get_handlers().identity_handler
+        self.federation_handler = hs.get_federation_handler()
+        self.directory_handler = hs.get_directory_handler()
+        self.identity_handler = hs.get_identity_handler()
         self.registration_handler = hs.get_registration_handler()
         self.profile_handler = hs.get_profile_handler()
         self.event_creation_handler = hs.get_event_creation_handler()
@@ -171,6 +164,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         if requester.is_guest:
             content["kind"] = "guest"
 
+        # Check if we already have an event with a matching transaction ID. (We
+        # do this check just before we persist an event as well, but may as well
+        # do it up front for efficiency.)
+        if txn_id and requester.access_token_id:
+            existing_event_id = await self.store.get_event_id_from_transaction_id(
+                room_id, requester.user.to_string(), requester.access_token_id, txn_id,
+            )
+            if existing_event_id:
+                event_pos = await self.store.get_position_for_event(existing_event_id)
+                return existing_event_id, event_pos.stream
+
         event, context = await self.event_creation_handler.create_event(
             requester,
             {
@@ -182,21 +186,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 # For backwards compatibility:
                 "membership": membership,
             },
-            token_id=requester.access_token_id,
             txn_id=txn_id,
             prev_event_ids=prev_event_ids,
             require_consent=require_consent,
         )
 
-        # Check if this event matches the previous membership event for the user.
-        duplicate = await self.event_creation_handler.deduplicate_state_event(
-            event, context
-        )
-        if duplicate is not None:
-            # Discard the new event since this membership change is a no-op.
-            _, stream_id = await self.store.get_event_ordering(duplicate.event_id)
-            return duplicate.event_id, stream_id
-
         prev_state_ids = await context.get_prev_state_ids()
 
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
@@ -221,7 +215,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                         retry_after_ms=int(1000 * (time_allowed - time_now_s))
                     )
 
-        stream_id = await self.event_creation_handler.handle_new_client_event(
+        result_event = await self.event_creation_handler.handle_new_client_event(
             requester, event, context, extra_users=[target], ratelimit=ratelimit,
         )
 
@@ -231,7 +225,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 if prev_member_event.membership == Membership.JOIN:
                     await self._user_left_room(target, room_id)
 
-        return event.event_id, stream_id
+        # we know it was persisted, so should have a stream ordering
+        assert result_event.internal_metadata.stream_ordering
+        return result_event.event_id, result_event.internal_metadata.stream_ordering
 
     async def copy_room_tags_and_direct_to_room(
         self, old_room_id, new_room_id, user_id
@@ -247,7 +243,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         user_account_data, _ = await self.store.get_account_data_for_user(user_id)
 
         # Copy direct message state if applicable
-        direct_rooms = user_account_data.get("m.direct", {})
+        direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})
 
         # Check which key this room is under
         if isinstance(direct_rooms, dict):
@@ -258,7 +254,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
                     # Save back to user's m.direct account data
                     await self.store.add_account_data_for_user(
-                        user_id, "m.direct", direct_rooms
+                        user_id, AccountDataTypes.DIRECT, direct_rooms
                     )
                     break
 
@@ -310,7 +306,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         key = (room_id,)
 
         with (await self.member_linearizer.queue(key)):
-            result = await self._update_membership(
+            result = await self.update_membership_locked(
                 requester,
                 target,
                 room_id,
@@ -325,7 +321,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         return result
 
-    async def _update_membership(
+    async def update_membership_locked(
         self,
         requester: Requester,
         target: UserID,
@@ -338,6 +334,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         content: Optional[dict] = None,
         require_consent: bool = True,
     ) -> Tuple[str, int]:
+        """Helper for update_membership.
+
+        Assumes that the membership linearizer is already held for the room.
+        """
         content_specified = bool(content)
         if content is None:
             content = {}
@@ -346,7 +346,15 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # later on.
             content = dict(content)
 
-        if not self.allow_per_room_profiles or requester.shadow_banned:
+        # allow the server notices mxid to set room-level profile
+        is_requester_server_notices_user = (
+            self._server_notices_mxid is not None
+            and requester.user.to_string() == self._server_notices_mxid
+        )
+
+        if (
+            not self.allow_per_room_profiles and not is_requester_server_notices_user
+        ) or requester.shadow_banned:
             # Strip profile data, knowing that new profile data will be added to the
             # event's content in event_creation_handler.create_event() using the target's
             # global profile.
@@ -441,12 +449,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 same_membership = old_membership == effective_membership_state
                 same_sender = requester.user.to_string() == old_state.sender
                 if same_sender and same_membership and same_content:
-                    _, stream_id = await self.store.get_event_ordering(
-                        old_state.event_id
-                    )
+                    # duplicate event.
+                    # we know it was persisted, so must have a stream ordering.
+                    assert old_state.internal_metadata.stream_ordering
                     return (
                         old_state.event_id,
-                        stream_id,
+                        old_state.internal_metadata.stream_ordering,
                     )
 
             if old_membership in ["ban", "leave"] and action == "kick":
@@ -514,10 +522,16 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         elif effective_membership_state == Membership.LEAVE:
             if not is_host_in_room:
                 # perhaps we've been invited
-                invite = await self.store.get_invite_for_local_user_in_room(
-                    user_id=target.to_string(), room_id=room_id
-                )  # type: Optional[RoomsForUser]
-                if not invite:
+                (
+                    current_membership_type,
+                    current_membership_event_id,
+                ) = await self.store.get_local_current_membership_for_user_in_room(
+                    target.to_string(), room_id
+                )
+                if (
+                    current_membership_type != Membership.INVITE
+                    or not current_membership_event_id
+                ):
                     logger.info(
                         "%s sent a leave request to %s, but that is not an active room "
                         "on this server, and there is no pending invite",
@@ -527,6 +541,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
                     raise SynapseError(404, "Not a known room")
 
+                invite = await self.store.get_event(current_membership_event_id)
                 logger.info(
                     "%s rejects invite to %s from %s", target, room_id, invite.sender
                 )
@@ -642,7 +657,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
     async def send_membership_event(
         self,
-        requester: Requester,
+        requester: Optional[Requester],
         event: EventBase,
         context: EventContext,
         ratelimit: bool = True,
@@ -672,12 +687,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         else:
             requester = types.create_requester(target_user)
 
-        prev_event = await self.event_creation_handler.deduplicate_state_event(
-            event, context
-        )
-        if prev_event is not None:
-            return
-
         prev_state_ids = await context.get_prev_state_ids()
         if event.membership == Membership.JOIN:
             if requester.is_guest:
@@ -692,7 +701,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if is_blocked:
                 raise SynapseError(403, "This room has been blocked on this server")
 
-        await self.event_creation_handler.handle_new_client_event(
+        event = await self.event_creation_handler.handle_new_client_event(
             requester, event, context, extra_users=[target_user], ratelimit=ratelimit
         )
 
@@ -970,6 +979,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
         self.distributor = hs.get_distributor()
         self.distributor.declare("user_left_room")
+        self._server_name = hs.hostname
 
     async def _is_remote_room_too_complex(
         self, room_id: str, remote_room_hosts: List[str]
@@ -1064,7 +1074,9 @@ class RoomMemberMasterHandler(RoomMemberHandler):
                 return event_id, stream_id
 
             # The room is too large. Leave.
-            requester = types.create_requester(user, None, False, False, None)
+            requester = types.create_requester(
+                user, authenticated_entity=self._server_name
+            )
             await self.update_membership(
                 requester=requester, target=user, room_id=room_id, action="leave"
             )
@@ -1109,57 +1121,38 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             #
             logger.warning("Failed to reject invite: %s", e)
 
-            return await self._locally_reject_invite(
+            return await self._generate_local_out_of_band_leave(
                 invite_event, txn_id, requester, content
             )
 
-    async def _locally_reject_invite(
+    async def _generate_local_out_of_band_leave(
         self,
-        invite_event: EventBase,
+        previous_membership_event: EventBase,
         txn_id: Optional[str],
         requester: Requester,
         content: JsonDict,
     ) -> Tuple[str, int]:
-        """Generate a local invite rejection
+        """Generate a local leave event for a room
 
-        This is called after we fail to reject an invite via a remote server. It
-        generates an out-of-band membership event locally.
+        This can be called after we e.g fail to reject an invite via a remote server.
+        It generates an out-of-band membership event locally.
 
         Args:
-            invite_event: the invite to be rejected
+            previous_membership_event: the previous membership event for this user
             txn_id: optional transaction ID supplied by the client
-            requester:  user making the rejection request, according to the access token
-            content: additional content to include in the rejection event.
+            requester: user making the request, according to the access token
+            content: additional content to include in the leave event.
                Normally an empty dict.
-        """
 
-        room_id = invite_event.room_id
-        target_user = invite_event.state_key
-        room_version = await self.store.get_room_version(room_id)
+        Returns:
+            A tuple containing (event_id, stream_id of the leave event)
+        """
+        room_id = previous_membership_event.room_id
+        target_user = previous_membership_event.state_key
 
         content["membership"] = Membership.LEAVE
 
-        # the auth events for the new event are the same as that of the invite, plus
-        # the invite itself.
-        #
-        # the prev_events are just the invite.
-        invite_hash = invite_event.event_id  # type: Union[str, Tuple]
-        if room_version.event_format == EventFormatVersions.V1:
-            alg, h = compute_event_reference_hash(invite_event)
-            invite_hash = (invite_event.event_id, {alg: encode_base64(h)})
-
-        auth_events = tuple(invite_event.auth_events) + (invite_hash,)
-        prev_events = (invite_hash,)
-
-        # we cap depth of generated events, to ensure that they are not
-        # rejected by other servers (and so that they can be persisted in
-        # the db)
-        depth = min(invite_event.depth + 1, MAX_DEPTH)
-
         event_dict = {
-            "depth": depth,
-            "auth_events": auth_events,
-            "prev_events": prev_events,
             "type": EventTypes.Member,
             "room_id": room_id,
             "sender": target_user,
@@ -1167,28 +1160,30 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             "state_key": target_user,
         }
 
-        event = create_local_event_from_event_dict(
-            clock=self.clock,
-            hostname=self.hs.hostname,
-            signing_key=self.hs.signing_key,
-            room_version=room_version,
-            event_dict=event_dict,
+        # the auth events for the new event are the same as that of the previous event, plus
+        # the event itself.
+        #
+        # the prev_events consist solely of the previous membership event.
+        prev_event_ids = [previous_membership_event.event_id]
+        auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
+
+        event, context = await self.event_creation_handler.create_event(
+            requester,
+            event_dict,
+            txn_id=txn_id,
+            prev_event_ids=prev_event_ids,
+            auth_event_ids=auth_event_ids,
         )
         event.internal_metadata.outlier = True
         event.internal_metadata.out_of_band_membership = True
-        if txn_id is not None:
-            event.internal_metadata.txn_id = txn_id
-        if requester.access_token_id is not None:
-            event.internal_metadata.token_id = requester.access_token_id
 
-        EventValidator().validate_new(event, self.config)
-
-        context = await self.state_handler.compute_event_context(event)
-        context.app_service = requester.app_service
-        stream_id = await self.event_creation_handler.handle_new_client_event(
+        result_event = await self.event_creation_handler.handle_new_client_event(
             requester, event, context, extra_users=[UserID.from_string(target_user)],
         )
-        return event.event_id, stream_id
+        # we know it was persisted, so must have a stream ordering
+        assert result_event.internal_metadata.stream_ordering
+
+        return result_event.event_id, result_event.internal_metadata.stream_ordering
 
     async def _user_left_room(self, target: UserID, room_id: str) -> None:
         """Implements RoomMemberHandler._user_left_room
diff --git a/synapse/handlers/saml_handler.py b/synapse/handlers/saml_handler.py
index 285c481a96..76d4169fe2 100644
--- a/synapse/handlers/saml_handler.py
+++ b/synapse/handlers/saml_handler.py
@@ -24,7 +24,8 @@ from saml2.client import Saml2Client
 from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.saml2_config import SamlAttributeRequirement
-from synapse.http.server import respond_with_html
+from synapse.handlers._base import BaseHandler
+from synapse.handlers.sso import MappingException, UserAttributes
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.module_api import ModuleApi
@@ -37,15 +38,11 @@ from synapse.util.async_helpers import Linearizer
 from synapse.util.iterutils import chunk_seq
 
 if TYPE_CHECKING:
-    import synapse.server
+    from synapse.server import HomeServer
 
 logger = logging.getLogger(__name__)
 
 
-class MappingException(Exception):
-    """Used to catch errors when mapping the SAML2 response to a user."""
-
-
 @attr.s(slots=True)
 class Saml2SessionData:
     """Data we track about SAML2 sessions"""
@@ -57,17 +54,14 @@ class Saml2SessionData:
     ui_auth_session_id = attr.ib(type=Optional[str], default=None)
 
 
-class SamlHandler:
-    def __init__(self, hs: "synapse.server.HomeServer"):
-        self.hs = hs
+class SamlHandler(BaseHandler):
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
         self._saml_client = Saml2Client(hs.config.saml2_sp_config)
-        self._auth = hs.get_auth()
+        self._saml_idp_entityid = hs.config.saml2_idp_entityid
         self._auth_handler = hs.get_auth_handler()
         self._registration_handler = hs.get_registration_handler()
 
-        self._clock = hs.get_clock()
-        self._datastore = hs.get_datastore()
-        self._hostname = hs.hostname
         self._saml2_session_lifetime = hs.config.saml2_session_lifetime
         self._grandfathered_mxid_source_attribute = (
             hs.config.saml2_grandfathered_mxid_source_attribute
@@ -88,26 +82,9 @@ class SamlHandler:
         self._outstanding_requests_dict = {}  # type: Dict[str, Saml2SessionData]
 
         # a lock on the mappings
-        self._mapping_lock = Linearizer(name="saml_mapping", clock=self._clock)
-
-    def _render_error(
-        self, request, error: str, error_description: Optional[str] = None
-    ) -> None:
-        """Render the error template and respond to the request with it.
-
-        This is used to show errors to the user. The template of this page can
-        be found under `synapse/res/templates/sso_error.html`.
+        self._mapping_lock = Linearizer(name="saml_mapping", clock=self.clock)
 
-        Args:
-            request: The incoming request from the browser.
-                We'll respond with an HTML page describing the error.
-            error: A technical identifier for this error.
-            error_description: A human-readable description of the error.
-        """
-        html = self._error_template.render(
-            error=error, error_description=error_description
-        )
-        respond_with_html(request, 400, html)
+        self._sso_handler = hs.get_sso_handler()
 
     def handle_redirect_request(
         self, client_redirect_url: bytes, ui_auth_session_id: Optional[str] = None
@@ -124,13 +101,13 @@ class SamlHandler:
             URL to redirect to
         """
         reqid, info = self._saml_client.prepare_for_authenticate(
-            relay_state=client_redirect_url
+            entityid=self._saml_idp_entityid, relay_state=client_redirect_url
         )
 
         # Since SAML sessions timeout it is useful to log when they were created.
         logger.info("Initiating a new SAML session: %s" % (reqid,))
 
-        now = self._clock.time_msec()
+        now = self.clock.time_msec()
         self._outstanding_requests_dict[reqid] = Saml2SessionData(
             creation_time=now, ui_auth_session_id=ui_auth_session_id,
         )
@@ -171,12 +148,12 @@ class SamlHandler:
             # in the (user-visible) exception message, so let's log the exception here
             # so we can track down the session IDs later.
             logger.warning(str(e))
-            self._render_error(
+            self._sso_handler.render_error(
                 request, "unsolicited_response", "Unexpected SAML2 login."
             )
             return
         except Exception as e:
-            self._render_error(
+            self._sso_handler.render_error(
                 request,
                 "invalid_response",
                 "Unable to parse SAML2 response: %s." % (e,),
@@ -184,7 +161,7 @@ class SamlHandler:
             return
 
         if saml2_auth.not_signed:
-            self._render_error(
+            self._sso_handler.render_error(
                 request, "unsigned_respond", "SAML2 response was not signed."
             )
             return
@@ -210,15 +187,13 @@ class SamlHandler:
         # attributes.
         for requirement in self._saml2_attribute_requirements:
             if not _check_attribute_requirement(saml2_auth.ava, requirement):
-                self._render_error(
+                self._sso_handler.render_error(
                     request, "unauthorised", "You are not authorised to log in here."
                 )
                 return
 
         # Pull out the user-agent and IP from the request.
-        user_agent = request.requestHeaders.getRawHeaders(b"User-Agent", default=[b""])[
-            0
-        ].decode("ascii", "surrogateescape")
+        user_agent = request.get_user_agent("")
         ip_address = self.hs.get_ip_from_request(request)
 
         # Call the mapper to register/login the user
@@ -228,7 +203,7 @@ class SamlHandler:
             )
         except MappingException as e:
             logger.exception("Could not map user")
-            self._render_error(request, "mapping_error", str(e))
+            self._sso_handler.render_error(request, "mapping_error", str(e))
             return
 
         # Complete the interactive auth session or the login.
@@ -274,20 +249,26 @@ class SamlHandler:
                 "Failed to extract remote user id from SAML response"
             )
 
-        with (await self._mapping_lock.queue(self._auth_provider_id)):
-            # first of all, check if we already have a mapping for this user
-            logger.info(
-                "Looking for existing mapping for user %s:%s",
-                self._auth_provider_id,
-                remote_user_id,
+        async def saml_response_to_remapped_user_attributes(
+            failures: int,
+        ) -> UserAttributes:
+            """
+            Call the mapping provider to map a SAML response to user attributes and coerce the result into the standard form.
+
+            This is backwards compatibility for abstraction for the SSO handler.
+            """
+            # Call the mapping provider.
+            result = self._user_mapping_provider.saml_response_to_user_attributes(
+                saml2_auth, failures, client_redirect_url
             )
-            registered_user_id = await self._datastore.get_user_by_external_id(
-                self._auth_provider_id, remote_user_id
+            # Remap some of the results.
+            return UserAttributes(
+                localpart=result.get("mxid_localpart"),
+                display_name=result.get("displayname"),
+                emails=result.get("emails", []),
             )
-            if registered_user_id is not None:
-                logger.info("Found existing mapping %s", registered_user_id)
-                return registered_user_id
 
+        async def grandfather_existing_users() -> Optional[str]:
             # backwards-compatibility hack: see if there is an existing user with a
             # suitable mapping from the uid
             if (
@@ -296,75 +277,35 @@ class SamlHandler:
             ):
                 attrval = saml2_auth.ava[self._grandfathered_mxid_source_attribute][0]
                 user_id = UserID(
-                    map_username_to_mxid_localpart(attrval), self._hostname
+                    map_username_to_mxid_localpart(attrval), self.server_name
                 ).to_string()
-                logger.info(
+
+                logger.debug(
                     "Looking for existing account based on mapped %s %s",
                     self._grandfathered_mxid_source_attribute,
                     user_id,
                 )
 
-                users = await self._datastore.get_users_by_id_case_insensitive(user_id)
+                users = await self.store.get_users_by_id_case_insensitive(user_id)
                 if users:
                     registered_user_id = list(users.keys())[0]
                     logger.info("Grandfathering mapping to %s", registered_user_id)
-                    await self._datastore.record_user_external_id(
-                        self._auth_provider_id, remote_user_id, registered_user_id
-                    )
                     return registered_user_id
 
-            # Map saml response to user attributes using the configured mapping provider
-            for i in range(1000):
-                attribute_dict = self._user_mapping_provider.saml_response_to_user_attributes(
-                    saml2_auth, i, client_redirect_url=client_redirect_url,
-                )
-
-                logger.debug(
-                    "Retrieved SAML attributes from user mapping provider: %s "
-                    "(attempt %d)",
-                    attribute_dict,
-                    i,
-                )
-
-                localpart = attribute_dict.get("mxid_localpart")
-                if not localpart:
-                    raise MappingException(
-                        "Error parsing SAML2 response: SAML mapping provider plugin "
-                        "did not return a mxid_localpart value"
-                    )
-
-                displayname = attribute_dict.get("displayname")
-                emails = attribute_dict.get("emails", [])
-
-                # Check if this mxid already exists
-                if not await self._datastore.get_users_by_id_case_insensitive(
-                    UserID(localpart, self._hostname).to_string()
-                ):
-                    # This mxid is free
-                    break
-            else:
-                # Unable to generate a username in 1000 iterations
-                # Break and return error to the user
-                raise MappingException(
-                    "Unable to generate a Matrix ID from the SAML response"
-                )
+            return None
 
-            logger.info("Mapped SAML user to local part %s", localpart)
-
-            registered_user_id = await self._registration_handler.register_user(
-                localpart=localpart,
-                default_display_name=displayname,
-                bind_emails=emails,
-                user_agent_ips=(user_agent, ip_address),
-            )
-
-            await self._datastore.record_user_external_id(
-                self._auth_provider_id, remote_user_id, registered_user_id
+        with (await self._mapping_lock.queue(self._auth_provider_id)):
+            return await self._sso_handler.get_mxid_from_sso(
+                self._auth_provider_id,
+                remote_user_id,
+                user_agent,
+                ip_address,
+                saml_response_to_remapped_user_attributes,
+                grandfather_existing_users,
             )
-            return registered_user_id
 
     def expire_sessions(self):
-        expire_before = self._clock.time_msec() - self._saml2_session_lifetime
+        expire_before = self.clock.time_msec() - self._saml2_session_lifetime
         to_expire = set()
         for reqid, data in self._outstanding_requests_dict.items():
             if data.creation_time < expire_before:
@@ -476,11 +417,11 @@ class DefaultSamlMappingProvider:
             )
 
         # Use the configured mapper for this mxid_source
-        base_mxid_localpart = self._mxid_mapper(mxid_source)
+        localpart = self._mxid_mapper(mxid_source)
 
         # Append suffix integer if last call to this function failed to produce
-        # a usable mxid
-        localpart = base_mxid_localpart + (str(failures) if failures else "")
+        # a usable mxid.
+        localpart += str(failures) if failures else ""
 
         # Retrieve the display name from the saml response
         # If displayname is None, the mxid_localpart will be used instead
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index e9402e6e2e..66f1bbcfc4 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -139,7 +139,7 @@ class SearchHandler(BaseHandler):
             # Filter to apply to results
             filter_dict = room_cat.get("filter", {})
 
-            # What to order results by (impacts whether pagination can be doen)
+            # What to order results by (impacts whether pagination can be done)
             order_by = room_cat.get("order_by", "rank")
 
             # Return the current state of the rooms?
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
new file mode 100644
index 0000000000..47ad96f97e
--- /dev/null
+++ b/synapse/handlers/sso.py
@@ -0,0 +1,244 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import TYPE_CHECKING, Awaitable, Callable, List, Optional
+
+import attr
+
+from synapse.api.errors import RedirectException
+from synapse.handlers._base import BaseHandler
+from synapse.http.server import respond_with_html
+from synapse.types import UserID, contains_invalid_mxid_characters
+
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
+logger = logging.getLogger(__name__)
+
+
+class MappingException(Exception):
+    """Used to catch errors when mapping an SSO response to user attributes.
+
+    Note that the msg that is raised is shown to end-users.
+    """
+
+
+@attr.s
+class UserAttributes:
+    localpart = attr.ib(type=str)
+    display_name = attr.ib(type=Optional[str], default=None)
+    emails = attr.ib(type=List[str], default=attr.Factory(list))
+
+
+class SsoHandler(BaseHandler):
+    # The number of attempts to ask the mapping provider for when generating an MXID.
+    _MAP_USERNAME_RETRIES = 1000
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+        self._registration_handler = hs.get_registration_handler()
+        self._error_template = hs.config.sso_error_template
+
+    def render_error(
+        self, request, error: str, error_description: Optional[str] = None
+    ) -> None:
+        """Renders the error template and responds with it.
+
+        This is used to show errors to the user. The template of this page can
+        be found under `synapse/res/templates/sso_error.html`.
+
+        Args:
+            request: The incoming request from the browser.
+                We'll respond with an HTML page describing the error.
+            error: A technical identifier for this error.
+            error_description: A human-readable description of the error.
+        """
+        html = self._error_template.render(
+            error=error, error_description=error_description
+        )
+        respond_with_html(request, 400, html)
+
+    async def get_sso_user_by_remote_user_id(
+        self, auth_provider_id: str, remote_user_id: str
+    ) -> Optional[str]:
+        """
+        Maps the user ID of a remote IdP to a mxid for a previously seen user.
+
+        If the user has not been seen yet, this will return None.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            remote_user_id: The user ID according to the remote IdP. This might
+                be an e-mail address, a GUID, or some other form. It must be
+                unique and immutable.
+
+        Returns:
+            The mxid of a previously seen user.
+        """
+        logger.debug(
+            "Looking for existing mapping for user %s:%s",
+            auth_provider_id,
+            remote_user_id,
+        )
+
+        # Check if we already have a mapping for this user.
+        previously_registered_user_id = await self.store.get_user_by_external_id(
+            auth_provider_id, remote_user_id,
+        )
+
+        # A match was found, return the user ID.
+        if previously_registered_user_id is not None:
+            logger.info(
+                "Found existing mapping for IdP '%s' and remote_user_id '%s': %s",
+                auth_provider_id,
+                remote_user_id,
+                previously_registered_user_id,
+            )
+            return previously_registered_user_id
+
+        # No match.
+        return None
+
+    async def get_mxid_from_sso(
+        self,
+        auth_provider_id: str,
+        remote_user_id: str,
+        user_agent: str,
+        ip_address: str,
+        sso_to_matrix_id_mapper: Callable[[int], Awaitable[UserAttributes]],
+        grandfather_existing_users: Optional[Callable[[], Awaitable[Optional[str]]]],
+    ) -> str:
+        """
+        Given an SSO ID, retrieve the user ID for it and possibly register the user.
+
+        This first checks if the SSO ID has previously been linked to a matrix ID,
+        if it has that matrix ID is returned regardless of the current mapping
+        logic.
+
+        If a callable is provided for grandfathering users, it is called and can
+        potentially return a matrix ID to use. If it does, the SSO ID is linked to
+        this matrix ID for subsequent calls.
+
+        The mapping function is called (potentially multiple times) to generate
+        a localpart for the user.
+
+        If an unused localpart is generated, the user is registered from the
+        given user-agent and IP address and the SSO ID is linked to this matrix
+        ID for subsequent calls.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            remote_user_id: The unique identifier from the SSO provider.
+            user_agent: The user agent of the client making the request.
+            ip_address: The IP address of the client making the request.
+            sso_to_matrix_id_mapper: A callable to generate the user attributes.
+                The only parameter is an integer which represents the amount of
+                times the returned mxid localpart mapping has failed.
+
+                It is expected that the mapper can raise two exceptions, which
+                will get passed through to the caller:
+
+                    MappingException if there was a problem mapping the response
+                        to the user.
+                    RedirectException to redirect to an additional page (e.g.
+                        to prompt the user for more information).
+            grandfather_existing_users: A callable which can return an previously
+                existing matrix ID. The SSO ID is then linked to the returned
+                matrix ID.
+
+        Returns:
+             The user ID associated with the SSO response.
+
+        Raises:
+            MappingException if there was a problem mapping the response to a user.
+            RedirectException: if the mapping provider needs to redirect the user
+                to an additional page. (e.g. to prompt for more information)
+
+        """
+        # first of all, check if we already have a mapping for this user
+        previously_registered_user_id = await self.get_sso_user_by_remote_user_id(
+            auth_provider_id, remote_user_id,
+        )
+        if previously_registered_user_id:
+            return previously_registered_user_id
+
+        # Check for grandfathering of users.
+        if grandfather_existing_users:
+            previously_registered_user_id = await grandfather_existing_users()
+            if previously_registered_user_id:
+                # Future logins should also match this user ID.
+                await self.store.record_user_external_id(
+                    auth_provider_id, remote_user_id, previously_registered_user_id
+                )
+                return previously_registered_user_id
+
+        # Otherwise, generate a new user.
+        for i in range(self._MAP_USERNAME_RETRIES):
+            try:
+                attributes = await sso_to_matrix_id_mapper(i)
+            except (RedirectException, MappingException):
+                # Mapping providers are allowed to issue a redirect (e.g. to ask
+                # the user for more information) and can issue a mapping exception
+                # if a name cannot be generated.
+                raise
+            except Exception as e:
+                # Any other exception is unexpected.
+                raise MappingException(
+                    "Could not extract user attributes from SSO response."
+                ) from e
+
+            logger.debug(
+                "Retrieved user attributes from user mapping provider: %r (attempt %d)",
+                attributes,
+                i,
+            )
+
+            if not attributes.localpart:
+                raise MappingException(
+                    "Error parsing SSO response: SSO mapping provider plugin "
+                    "did not return a localpart value"
+                )
+
+            # Check if this mxid already exists
+            user_id = UserID(attributes.localpart, self.server_name).to_string()
+            if not await self.store.get_users_by_id_case_insensitive(user_id):
+                # This mxid is free
+                break
+        else:
+            # Unable to generate a username in 1000 iterations
+            # Break and return error to the user
+            raise MappingException(
+                "Unable to generate a Matrix ID from the SSO response"
+            )
+
+        # Since the localpart is provided via a potentially untrusted module,
+        # ensure the MXID is valid before registering.
+        if contains_invalid_mxid_characters(attributes.localpart):
+            raise MappingException("localpart is invalid: %s" % (attributes.localpart,))
+
+        logger.debug("Mapped SSO user to local part %s", attributes.localpart)
+        registered_user_id = await self._registration_handler.register_user(
+            localpart=attributes.localpart,
+            default_display_name=attributes.display_name,
+            bind_emails=attributes.emails,
+            user_agent_ips=[(user_agent, ip_address)],
+        )
+
+        await self.store.record_user_external_id(
+            auth_provider_id, remote_user_id, registered_user_id
+        )
+        return registered_user_id
diff --git a/synapse/handlers/state_deltas.py b/synapse/handlers/state_deltas.py
index 7a4ae0727a..fb4f70e8e2 100644
--- a/synapse/handlers/state_deltas.py
+++ b/synapse/handlers/state_deltas.py
@@ -32,7 +32,7 @@ class StateDeltasHandler:
         Returns:
             None if the field in the events either both match `public_value`
             or if neither do, i.e. there has been no change.
-            True if it didnt match `public_value` but now does
+            True if it didn't match `public_value` but now does
             False if it did match `public_value` but now doesn't
         """
         prev_event = None
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index 249ffe2a55..dc62b21c06 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -49,7 +49,7 @@ class StatsHandler:
         # Guard to ensure we only process deltas one at a time
         self._is_processing = False
 
-        if hs.config.stats_enabled:
+        if self.stats_enabled and hs.config.run_background_tasks:
             self.notifier.add_replication_callback(self.notify_new_event)
 
             # We kick this off so that we don't have to wait for a change before
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index bfe2583002..9827c7eb8d 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import itertools
 import logging
 from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
@@ -21,7 +20,7 @@ from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tup
 import attr
 from prometheus_client import Counter
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
 from synapse.api.filtering import FilterCollection
 from synapse.events import EventBase
 from synapse.logging.context import current_context
@@ -32,6 +31,7 @@ from synapse.types import (
     Collection,
     JsonDict,
     MutableStateMap,
+    Requester,
     RoomStreamToken,
     StateMap,
     StreamToken,
@@ -87,7 +87,7 @@ class SyncConfig:
 class TimelineBatch:
     prev_batch = attr.ib(type=StreamToken)
     events = attr.ib(type=List[EventBase])
-    limited = attr.ib(bool)
+    limited = attr.ib(type=bool)
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
@@ -201,6 +201,8 @@ class SyncResult:
         device_lists: List of user_ids whose devices have changed
         device_one_time_keys_count: Dict of algorithm to count for one time keys
             for this device
+        device_unused_fallback_key_types: List of key types that have an unused fallback
+            key
         groups: Group updates, if any
     """
 
@@ -213,6 +215,7 @@ class SyncResult:
     to_device = attr.ib(type=List[JsonDict])
     device_lists = attr.ib(type=DeviceLists)
     device_one_time_keys_count = attr.ib(type=JsonDict)
+    device_unused_fallback_key_types = attr.ib(type=List[str])
     groups = attr.ib(type=Optional[GroupsSyncResult])
 
     def __bool__(self) -> bool:
@@ -240,7 +243,9 @@ class SyncHandler:
         self.presence_handler = hs.get_presence_handler()
         self.event_sources = hs.get_event_sources()
         self.clock = hs.get_clock()
-        self.response_cache = ResponseCache(hs, "sync")
+        self.response_cache = ResponseCache(
+            hs, "sync"
+        )  # type: ResponseCache[Tuple[Any, ...]]
         self.state = hs.get_state_handler()
         self.auth = hs.get_auth()
         self.storage = hs.get_storage()
@@ -256,6 +261,7 @@ class SyncHandler:
 
     async def wait_for_sync_for_user(
         self,
+        requester: Requester,
         sync_config: SyncConfig,
         since_token: Optional[StreamToken] = None,
         timeout: int = 0,
@@ -269,7 +275,7 @@ class SyncHandler:
         # not been exceeded (if not part of the group by this point, almost certain
         # auth_blocking will occur)
         user_id = sync_config.user.to_string()
-        await self.auth.check_auth_blocking(user_id)
+        await self.auth.check_auth_blocking(requester=requester)
 
         res = await self.response_cache.wrap(
             sync_config.request_key,
@@ -457,8 +463,13 @@ class SyncHandler:
                 recents = []
 
             if not limited or block_all_timeline:
+                prev_batch_token = now_token
+                if recents:
+                    room_key = recents[0].internal_metadata.before
+                    prev_batch_token = now_token.copy_and_replace("room_key", room_key)
+
                 return TimelineBatch(
-                    events=recents, prev_batch=now_token, limited=False
+                    events=recents, prev_batch=prev_batch_token, limited=False
                 )
 
             filtering_factor = 2
@@ -745,7 +756,7 @@ class SyncHandler:
         """
         # TODO(mjark) Check if the state events were received by the server
         # after the previous sync, since we need to include those state
-        # updates even if they occured logically before the previous event.
+        # updates even if they occurred logically before the previous event.
         # TODO(mjark) Check for new redactions in the state events.
 
         with Measure(self.clock, "compute_state_delta"):
@@ -1014,10 +1025,14 @@ class SyncHandler:
         logger.debug("Fetching OTK data")
         device_id = sync_config.device_id
         one_time_key_counts = {}  # type: JsonDict
+        unused_fallback_key_types = []  # type: List[str]
         if device_id:
             one_time_key_counts = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
+            unused_fallback_key_types = await self.store.get_e2e_unused_fallback_key_types(
+                user_id, device_id
+            )
 
         logger.debug("Fetching group data")
         await self._generate_sync_entry_for_groups(sync_result_builder)
@@ -1041,6 +1056,7 @@ class SyncHandler:
             device_lists=device_lists,
             groups=sync_result_builder.groups,
             device_one_time_keys_count=one_time_key_counts,
+            device_unused_fallback_key_types=unused_fallback_key_types,
             next_batch=sync_result_builder.now_token,
         )
 
@@ -1378,13 +1394,16 @@ class SyncHandler:
                         return set(), set(), set(), set()
 
         ignored_account_data = await self.store.get_global_account_data_by_type_for_user(
-            "m.ignored_user_list", user_id=user_id
+            AccountDataTypes.IGNORED_USER_LIST, user_id=user_id
         )
 
+        # If there is ignored users account data and it matches the proper type,
+        # then use it.
+        ignored_users = frozenset()  # type: FrozenSet[str]
         if ignored_account_data:
-            ignored_users = ignored_account_data.get("ignored_users", {}).keys()
-        else:
-            ignored_users = frozenset()
+            ignored_users_data = ignored_account_data.get("ignored_users", {})
+            if isinstance(ignored_users_data, dict):
+                ignored_users = frozenset(ignored_users_data.keys())
 
         if since_token:
             room_changes = await self._get_rooms_changed(
@@ -1478,7 +1497,7 @@ class SyncHandler:
         return False
 
     async def _get_rooms_changed(
-        self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str]
+        self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
     ) -> _RoomChanges:
         """Gets the the changes that have happened since the last sync.
         """
@@ -1690,7 +1709,7 @@ class SyncHandler:
         return _RoomChanges(room_entries, invited, newly_joined_rooms, newly_left_rooms)
 
     async def _get_all_rooms(
-        self, sync_result_builder: "SyncResultBuilder", ignored_users: Set[str]
+        self, sync_result_builder: "SyncResultBuilder", ignored_users: FrozenSet[str]
     ) -> _RoomChanges:
         """Returns entries for all rooms for the user.
 
@@ -1764,7 +1783,7 @@ class SyncHandler:
     async def _generate_room_entry(
         self,
         sync_result_builder: "SyncResultBuilder",
-        ignored_users: Set[str],
+        ignored_users: FrozenSet[str],
         room_builder: "RoomSyncResultBuilder",
         ephemeral: List[JsonDict],
         tags: Optional[Dict[str, Dict[str, Any]]],
@@ -1865,7 +1884,7 @@ class SyncHandler:
         # members (as the client otherwise doesn't have enough info to form
         # the name itself).
         if sync_config.filter_collection.lazy_load_members() and (
-            # we recalulate the summary:
+            # we recalculate the summary:
             #   if there are membership changes in the timeline, or
             #   if membership has changed during a gappy sync, or
             #   if this is an initial sync.
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index 3cbfc2d780..e919a8f9ed 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -12,16 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 import random
 from collections import namedtuple
 from typing import TYPE_CHECKING, List, Set, Tuple
 
 from synapse.api.errors import AuthError, ShadowBanError, SynapseError
+from synapse.appservice import ApplicationService
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.tcp.streams import TypingStream
-from synapse.types import UserID, get_domain_from_id
+from synapse.types import JsonDict, UserID, get_domain_from_id
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -167,20 +167,25 @@ class FollowerTypingHandler:
             now_typing = set(row.user_ids)
             self._room_typing[row.room_id] = row.user_ids
 
-            run_as_background_process(
-                "_handle_change_in_typing",
-                self._handle_change_in_typing,
-                row.room_id,
-                prev_typing,
-                now_typing,
-            )
+            if self.federation:
+                run_as_background_process(
+                    "_send_changes_in_typing_to_remotes",
+                    self._send_changes_in_typing_to_remotes,
+                    row.room_id,
+                    prev_typing,
+                    now_typing,
+                )
 
-    async def _handle_change_in_typing(
+    async def _send_changes_in_typing_to_remotes(
         self, room_id: str, prev_typing: Set[str], now_typing: Set[str]
     ):
         """Process a change in typing of a room from replication, sending EDUs
         for any local users.
         """
+
+        if not self.federation:
+            return
+
         for user_id in now_typing - prev_typing:
             if self.is_mine_id(user_id):
                 await self._push_remote(RoomMember(room_id, user_id), True)
@@ -371,7 +376,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             between the requested tokens due to the limit.
 
             The token returned can be used in a subsequent call to this
-            function to get further updatees.
+            function to get further updates.
 
             The updates are a list of 2-tuples of stream ID and the row data
         """
@@ -430,6 +435,33 @@ class TypingNotificationEventSource:
             "content": {"user_ids": list(typing)},
         }
 
+    async def get_new_events_as(
+        self, from_key: int, service: ApplicationService
+    ) -> Tuple[List[JsonDict], int]:
+        """Returns a set of new typing events that an appservice
+        may be interested in.
+
+        Args:
+            from_key: the stream position at which events should be fetched from
+            service: The appservice which may be interested
+        """
+        with Measure(self.clock, "typing.get_new_events_as"):
+            from_key = int(from_key)
+            handler = self.get_typing_handler()
+
+            events = []
+            for room_id in handler._room_serials.keys():
+                if handler._room_serials[room_id] <= from_key:
+                    continue
+                if not await service.matches_user_in_member_list(
+                    room_id, handler.store
+                ):
+                    continue
+
+                events.append(self._make_event_for(room_id))
+
+            return (events, handler._latest_room_serial)
+
     async def get_new_events(self, from_key, room_ids, **kwargs):
         with Measure(self.clock, "typing.get_new_events"):
             from_key = int(from_key)
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 9146dc1a3b..3d66bf305e 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -143,7 +143,7 @@ class _BaseThreepidAuthChecker:
 
         threepid_creds = authdict["threepid_creds"]
 
-        identity_handler = self.hs.get_handlers().identity_handler
+        identity_handler = self.hs.get_identity_handler()
 
         logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
 
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 79393c8829..afbebfc200 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -31,7 +31,7 @@ class UserDirectoryHandler(StateDeltasHandler):
     N.B.: ASSUMES IT IS THE ONLY THING THAT MODIFIES THE USER DIRECTORY
 
     The user directory is filled with users who this server can see are joined to a
-    world_readable or publically joinable room. We keep a database table up to date
+    world_readable or publicly joinable room. We keep a database table up to date
     by streaming changes of the current state and recalculating whether users should
     be in the directory or not when necessary.
     """