summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/account_data.py2
-rw-r--r--synapse/handlers/admin.py1
-rw-r--r--synapse/handlers/appservice.py16
-rw-r--r--synapse/handlers/auth.py64
-rw-r--r--synapse/handlers/deactivate_account.py4
-rw-r--r--synapse/handlers/device.py100
-rw-r--r--synapse/handlers/devicemessage.py36
-rw-r--r--synapse/handlers/directory.py23
-rw-r--r--synapse/handlers/e2e_keys.py100
-rw-r--r--synapse/handlers/e2e_room_keys.py5
-rw-r--r--synapse/handlers/event_auth.py46
-rw-r--r--synapse/handlers/federation.py149
-rw-r--r--synapse/handlers/federation_event.py144
-rw-r--r--synapse/handlers/identity.py2
-rw-r--r--synapse/handlers/initial_sync.py27
-rw-r--r--synapse/handlers/message.py199
-rw-r--r--synapse/handlers/oidc.py395
-rw-r--r--synapse/handlers/pagination.py13
-rw-r--r--synapse/handlers/presence.py30
-rw-r--r--synapse/handlers/profile.py6
-rw-r--r--synapse/handlers/receipts.py3
-rw-r--r--synapse/handlers/register.py6
-rw-r--r--synapse/handlers/relations.py494
-rw-r--r--synapse/handlers/room.py216
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/handlers/saml.py4
-rw-r--r--synapse/handlers/search.py2
-rw-r--r--synapse/handlers/set_password.py6
-rw-r--r--synapse/handlers/sso.py193
-rw-r--r--synapse/handlers/sync.py60
-rw-r--r--synapse/handlers/typing.py10
31 files changed, 1676 insertions, 682 deletions
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 0478448b47..fc21d58001 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
         self,
         user: UserID,
         from_key: int,
-        limit: Optional[int],
+        limit: int,
         room_ids: Collection[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index f2989cc4a2..5bf8e86387 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -100,6 +100,7 @@ class AdminHandler:
         user_info_dict["avatar_url"] = profile.avatar_url
         user_info_dict["threepids"] = threepids
         user_info_dict["external_ids"] = external_ids
+        user_info_dict["erased"] = await self.store.is_user_erased(user.to_string())
 
         return user_info_dict
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 203b62e015..5d1d21cdc8 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -109,10 +109,13 @@ class ApplicationServicesHandler:
                     last_token = await self.store.get_appservice_last_pos()
                     (
                         upper_bound,
-                        events,
                         event_to_received_ts,
-                    ) = await self.store.get_all_new_events_stream(
-                        last_token, self.current_max, limit=100, get_prev_content=True
+                    ) = await self.store.get_all_new_event_ids_stream(
+                        last_token, self.current_max, limit=100
+                    )
+
+                    events = await self.store.get_events_as_list(
+                        event_to_received_ts.keys(), get_prev_content=True
                     )
 
                     events_by_room: Dict[str, List[EventBase]] = {}
@@ -575,9 +578,6 @@ class ApplicationServicesHandler:
             device_id,
         ), messages in recipient_device_to_messages.items():
             for message_json in messages:
-                # Remove 'message_id' from the to-device message, as it's an internal ID
-                message_json.pop("message_id", None)
-
                 message_payload.append(
                     {
                         "to_user_id": user_id,
@@ -612,8 +612,8 @@ class ApplicationServicesHandler:
         )
 
         # Fetch the users who have modified their device list since then.
-        users_with_changed_device_lists = (
-            await self.store.get_users_whose_devices_changed(from_key, to_key=new_key)
+        users_with_changed_device_lists = await self.store.get_all_devices_changed(
+            from_key, to_key=new_key
         )
 
         # Filter out any users the application service is not interested in
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index f5f0e0e7a7..8b9ef25d29 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -38,6 +38,7 @@ from typing import (
 import attr
 import bcrypt
 import unpaddedbase64
+from prometheus_client import Counter
 
 from twisted.internet.defer import CancelledError
 from twisted.web.server import Request
@@ -48,6 +49,7 @@ from synapse.api.errors import (
     Codes,
     InteractiveAuthIncompleteError,
     LoginError,
+    NotFoundError,
     StoreError,
     SynapseError,
     UserDeactivatedError,
@@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.registration import (
+    LoginTokenExpired,
+    LoginTokenLookupResult,
+    LoginTokenReused,
+)
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
-from synapse.util.macaroons import LoginTokenAttributes
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import base62_encode
 from synapse.util.threepids import canonicalise_email
@@ -80,6 +86,12 @@ logger = logging.getLogger(__name__)
 
 INVALID_USERNAME_OR_PASSWORD = "Invalid username or password"
 
+invalid_login_token_counter = Counter(
+    "synapse_user_login_invalid_login_tokens",
+    "Counts the number of rejected m.login.token on /login",
+    ["reason"],
+)
+
 
 def convert_client_dict_legacy_fields_to_identifier(
     submission: JsonDict,
@@ -883,6 +895,25 @@ class AuthHandler:
 
         return True
 
+    async def create_login_token_for_user_id(
+        self,
+        user_id: str,
+        duration_ms: int = (2 * 60 * 1000),
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
+    ) -> str:
+        login_token = self.generate_login_token()
+        now = self._clock.time_msec()
+        expiry_ts = now + duration_ms
+        await self.store.add_login_token_to_user(
+            user_id=user_id,
+            token=login_token,
+            expiry_ts=expiry_ts,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+        return login_token
+
     async def create_refresh_token_for_user_id(
         self,
         user_id: str,
@@ -1401,6 +1432,18 @@ class AuthHandler:
             return None
         return user_id
 
+    def generate_login_token(self) -> str:
+        """Generates an opaque string, for use as an short-term login token"""
+
+        # we use the following format for access tokens:
+        #    syl_<random string>_<base62 crc check>
+
+        random_string = stringutils.random_string(20)
+        base = f"syl_{random_string}"
+
+        crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+        return f"{base}_{crc}"
+
     def generate_access_token(self, for_user: UserID) -> str:
         """Generates an opaque string, for use as an access token"""
 
@@ -1427,16 +1470,17 @@ class AuthHandler:
         crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
         return f"{base}_{crc}"
 
-    async def validate_short_term_login_token(
-        self, login_token: str
-    ) -> LoginTokenAttributes:
+    async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult:
         try:
-            res = self.macaroon_gen.verify_short_term_login_token(login_token)
-        except Exception:
-            raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
+            return await self.store.consume_login_token(login_token)
+        except LoginTokenExpired:
+            invalid_login_token_counter.labels("expired").inc()
+        except LoginTokenReused:
+            invalid_login_token_counter.labels("reused").inc()
+        except NotFoundError:
+            invalid_login_token_counter.labels("not found").inc()
 
-        await self.auth_blocking.check_auth_blocking(res.user_id)
-        return res
+        raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
 
     async def delete_access_token(self, access_token: str) -> None:
         """Invalidate a single access token
@@ -1711,7 +1755,7 @@ class AuthHandler:
             )
 
         # Create a login token
-        login_token = self.macaroon_gen.generate_short_term_login_token(
+        login_token = await self.create_login_token_for_user_id(
             registered_user_id,
             auth_provider_id=auth_provider_id,
             auth_provider_session_id=auth_provider_session_id,
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index 816e1a6d79..d74d135c0c 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -16,6 +16,7 @@ import logging
 from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.types import Codes, Requester, UserID, create_requester
 
@@ -76,6 +77,9 @@ class DeactivateAccountHandler:
             True if identity server supports removing threepids, otherwise False.
         """
 
+        # This can only be called on the main process.
+        assert isinstance(self._device_handler, DeviceHandler)
+
         # Check if this user can be deactivated
         if not await self._third_party_rules.check_can_deactivate_user(
             user_id, by_admin
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f9cc5bddbc..d4750a32e6 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -65,6 +65,8 @@ DELETE_STALE_DEVICES_INTERVAL_MS = 24 * 60 * 60 * 1000
 
 
 class DeviceWorkerHandler:
+    device_list_updater: "DeviceListWorkerUpdater"
+
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.hs = hs
@@ -76,6 +78,8 @@ class DeviceWorkerHandler:
         self.server_name = hs.hostname
         self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
+        self.device_list_updater = DeviceListWorkerUpdater(hs)
+
     @trace
     async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
         """
@@ -99,6 +103,19 @@ class DeviceWorkerHandler:
         log_kv(device_map)
         return devices
 
+    async def get_dehydrated_device(
+        self, user_id: str
+    ) -> Optional[Tuple[str, JsonDict]]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
+        return await self.store.get_dehydrated_device(user_id)
+
     @trace
     async def get_device(self, user_id: str, device_id: str) -> JsonDict:
         """Retrieve the given device
@@ -127,7 +144,7 @@ class DeviceWorkerHandler:
     @cancellable
     async def get_device_changes_in_shared_rooms(
         self, user_id: str, room_ids: Collection[str], from_token: StreamToken
-    ) -> Collection[str]:
+    ) -> Set[str]:
         """Get the set of users whose devices have changed who share a room with
         the given user.
         """
@@ -320,6 +337,8 @@ class DeviceWorkerHandler:
 
 
 class DeviceHandler(DeviceWorkerHandler):
+    device_list_updater: "DeviceListUpdater"
+
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
@@ -606,19 +625,6 @@ class DeviceHandler(DeviceWorkerHandler):
             await self.delete_devices(user_id, [old_device_id])
         return device_id
 
-    async def get_dehydrated_device(
-        self, user_id: str
-    ) -> Optional[Tuple[str, JsonDict]]:
-        """Retrieve the information for a dehydrated device.
-
-        Args:
-            user_id: the user whose dehydrated device we are looking for
-        Returns:
-            a tuple whose first item is the device ID, and the second item is
-            the dehydrated device information
-        """
-        return await self.store.get_dehydrated_device(user_id)
-
     async def rehydrate_device(
         self, user_id: str, access_token: str, device_id: str
     ) -> dict:
@@ -682,13 +688,33 @@ class DeviceHandler(DeviceWorkerHandler):
         hosts_already_sent_to: Set[str] = set()
 
         try:
+            stream_id, room_id = await self.store.get_device_change_last_converted_pos()
+
             while True:
                 self._handle_new_device_update_new_data = False
-                rows = await self.store.get_uncoverted_outbound_room_pokes()
+                max_stream_id = self.store.get_device_stream_token()
+                rows = await self.store.get_uncoverted_outbound_room_pokes(
+                    stream_id, room_id
+                )
                 if not rows:
                     # If the DB returned nothing then there is nothing left to
                     # do, *unless* a new device list update happened during the
                     # DB query.
+
+                    # Advance `(stream_id, room_id)`.
+                    # `max_stream_id` comes from *before* the query for unconverted
+                    # rows, which means that any unconverted rows must have a larger
+                    # stream ID.
+                    if max_stream_id > stream_id:
+                        stream_id, room_id = max_stream_id, ""
+                        await self.store.set_device_change_last_converted_pos(
+                            stream_id, room_id
+                        )
+                    else:
+                        assert max_stream_id == stream_id
+                        # Avoid moving `room_id` backwards.
+                        pass
+
                     if self._handle_new_device_update_new_data:
                         continue
                     else:
@@ -718,7 +744,6 @@ class DeviceHandler(DeviceWorkerHandler):
                         user_id=user_id,
                         device_id=device_id,
                         room_id=room_id,
-                        stream_id=stream_id,
                         hosts=hosts,
                         context=opentracing_context,
                     )
@@ -752,6 +777,12 @@ class DeviceHandler(DeviceWorkerHandler):
                     hosts_already_sent_to.update(hosts)
                     current_stream_id = stream_id
 
+                # Advance `(stream_id, room_id)`.
+                _, _, room_id, stream_id, _ = rows[-1]
+                await self.store.set_device_change_last_converted_pos(
+                    stream_id, room_id
+                )
+
         finally:
             self._handle_new_device_update_is_processing = False
 
@@ -834,7 +865,6 @@ class DeviceHandler(DeviceWorkerHandler):
                 user_id=user_id,
                 device_id=device_id,
                 room_id=room_id,
-                stream_id=None,
                 hosts=potentially_changed_hosts,
                 context=None,
             )
@@ -858,7 +888,36 @@ def _update_device_from_client_ips(
     )
 
 
-class DeviceListUpdater:
+class DeviceListWorkerUpdater:
+    "Handles incoming device list updates from federation and contacts the main process over replication"
+
+    def __init__(self, hs: "HomeServer"):
+        from synapse.replication.http.devices import (
+            ReplicationUserDevicesResyncRestServlet,
+        )
+
+        self._user_device_resync_client = (
+            ReplicationUserDevicesResyncRestServlet.make_client(hs)
+        )
+
+    async def user_device_resync(
+        self, user_id: str, mark_failed_as_stale: bool = True
+    ) -> Optional[JsonDict]:
+        """Fetches all devices for a user and updates the device cache with them.
+
+        Args:
+            user_id: The user's id whose device_list will be updated.
+            mark_failed_as_stale: Whether to mark the user's device list as stale
+                if the attempt to resync failed.
+        Returns:
+            A dict with device info as under the "devices" in the result of this
+            request:
+            https://matrix.org/docs/spec/server_server/r0.1.2#get-matrix-federation-v1-user-devices-userid
+        """
+        return await self._user_device_resync_client(user_id=user_id)
+
+
+class DeviceListUpdater(DeviceListWorkerUpdater):
     "Handles incoming device list updates from federation and updates the DB"
 
     def __init__(self, hs: "HomeServer", device_handler: DeviceHandler):
@@ -937,7 +996,10 @@ class DeviceListUpdater:
         # Check if we are partially joining any rooms. If so we need to store
         # all device list updates so that we can handle them correctly once we
         # know who is in the room.
-        partial_rooms = await self.store.get_partial_state_rooms_and_servers()
+        # TODO(faster_joins): this fetches and processes a bunch of data that we don't
+        # use. Could be replaced by a tighter query e.g.
+        #   SELECT EXISTS(SELECT 1 FROM partial_state_rooms)
+        partial_rooms = await self.store.get_partial_state_room_resync_info()
         if partial_rooms:
             await self.store.add_remote_device_list_to_pending(
                 user_id,
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 444c08bc2e..75e89850f5 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -15,7 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, Any, Dict
 
-from synapse.api.constants import EduTypes, ToDeviceEventTypes
+from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
 from synapse.api.errors import SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.logging.context import run_in_background
@@ -216,14 +216,24 @@ class DeviceMessageHandler:
         """
         sender_user_id = requester.user.to_string()
 
-        message_id = random_string(16)
-        set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
-
-        log_kv({"number_of_to_device_messages": len(messages)})
-        set_tag("sender", sender_user_id)
+        set_tag(SynapseTags.TO_DEVICE_TYPE, message_type)
+        set_tag(SynapseTags.TO_DEVICE_SENDER, sender_user_id)
         local_messages = {}
         remote_messages: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
         for user_id, by_device in messages.items():
+            # add an opentracing log entry for each message
+            for device_id, message_content in by_device.items():
+                log_kv(
+                    {
+                        "event": "send_to_device_message",
+                        "user_id": user_id,
+                        "device_id": device_id,
+                        EventContentFields.TO_DEVICE_MSGID: message_content.get(
+                            EventContentFields.TO_DEVICE_MSGID
+                        ),
+                    }
+                )
+
             # Ratelimit local cross-user key requests by the sending device.
             if (
                 message_type == ToDeviceEventTypes.RoomKeyRequest
@@ -233,6 +243,7 @@ class DeviceMessageHandler:
                     requester, (sender_user_id, requester.device_id)
                 )
                 if not allowed:
+                    log_kv({"message": f"dropping key requests to {user_id}"})
                     logger.info(
                         "Dropping room_key_request from %s to %s due to rate limit",
                         sender_user_id,
@@ -247,18 +258,11 @@ class DeviceMessageHandler:
                         "content": message_content,
                         "type": message_type,
                         "sender": sender_user_id,
-                        "message_id": message_id,
                     }
                     for device_id, message_content in by_device.items()
                 }
                 if messages_by_device:
                     local_messages[user_id] = messages_by_device
-                    log_kv(
-                        {
-                            "user_id": user_id,
-                            "device_id": list(messages_by_device),
-                        }
-                    )
             else:
                 destination = get_domain_from_id(user_id)
                 remote_messages.setdefault(destination, {})[user_id] = by_device
@@ -267,7 +271,11 @@ class DeviceMessageHandler:
 
         remote_edu_contents = {}
         for destination, messages in remote_messages.items():
-            log_kv({"destination": destination})
+            # The EDU contains a "message_id" property which is used for
+            # idempotence. Make up a random one.
+            message_id = random_string(16)
+            log_kv({"destination": destination, "message_id": message_id})
+
             remote_edu_contents[destination] = {
                 "messages": messages,
                 "sender": sender_user_id,
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 7127d5aefc..2ea52257cb 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -16,6 +16,8 @@ import logging
 import string
 from typing import TYPE_CHECKING, Iterable, List, Optional
 
+from typing_extensions import Literal
+
 from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
 from synapse.api.errors import (
     AuthError,
@@ -83,7 +85,7 @@ class DirectoryHandler:
         # TODO(erikj): Add transactions.
         # TODO(erikj): Check if there is a current association.
         if not servers:
-            servers = await self._storage_controllers.state.get_current_hosts_in_room(
+            servers = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
                 room_id
             )
 
@@ -288,7 +290,7 @@ class DirectoryHandler:
                 Codes.NOT_FOUND,
             )
 
-        extra_servers = await self._storage_controllers.state.get_current_hosts_in_room(
+        extra_servers = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
             room_id
         )
         servers_set = set(extra_servers) | set(servers)
@@ -429,7 +431,10 @@ class DirectoryHandler:
         return await self.auth.check_can_change_room_list(room_id, requester)
 
     async def edit_published_room_list(
-        self, requester: Requester, room_id: str, visibility: str
+        self,
+        requester: Requester,
+        room_id: str,
+        visibility: Literal["public", "private"],
     ) -> None:
         """Edit the entry of the room in the published room list.
 
@@ -451,9 +456,6 @@ class DirectoryHandler:
         if requester.is_guest:
             raise AuthError(403, "Guests cannot edit the published room list")
 
-        if visibility not in ["public", "private"]:
-            raise SynapseError(400, "Invalid visibility setting")
-
         if visibility == "public" and not self.enable_room_list_search:
             # The room list has been disabled.
             raise AuthError(
@@ -505,7 +507,11 @@ class DirectoryHandler:
         await self.store.set_room_is_public(room_id, making_public)
 
     async def edit_published_appservice_room_list(
-        self, appservice_id: str, network_id: str, room_id: str, visibility: str
+        self,
+        appservice_id: str,
+        network_id: str,
+        room_id: str,
+        visibility: Literal["public", "private"],
     ) -> None:
         """Add or remove a room from the appservice/network specific public
         room list.
@@ -516,9 +522,6 @@ class DirectoryHandler:
             room_id
             visibility: either "public" or "private"
         """
-        if visibility not in ["public", "private"]:
-            raise SynapseError(400, "Invalid visibility setting")
-
         await self.store.set_room_is_public_appservice(
             room_id, appservice_id, network_id, visibility == "public"
         )
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 09a2492afc..5fe102e2f2 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -27,9 +27,9 @@ from twisted.internet import defer
 
 from synapse.api.constants import EduTypes
 from synapse.api.errors import CodeMessageException, Codes, NotFoundError, SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import log_kv, set_tag, tag_args, trace
-from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.types import (
     JsonDict,
     UserID,
@@ -49,33 +49,30 @@ logger = logging.getLogger(__name__)
 
 class E2eKeysHandler:
     def __init__(self, hs: "HomeServer"):
+        self.config = hs.config
         self.store = hs.get_datastores().main
         self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
         self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
-        self._edu_updater = SigningKeyEduUpdater(hs, self)
-
         federation_registry = hs.get_federation_registry()
 
-        self._is_master = hs.config.worker.worker_app is None
-        if not self._is_master:
-            self._user_device_resync_client = (
-                ReplicationUserDevicesResyncRestServlet.make_client(hs)
-            )
-        else:
+        is_master = hs.config.worker.worker_app is None
+        if is_master:
+            edu_updater = SigningKeyEduUpdater(hs)
+
             # Only register this edu handler on master as it requires writing
             # device updates to the db
             federation_registry.register_edu_handler(
                 EduTypes.SIGNING_KEY_UPDATE,
-                self._edu_updater.incoming_signing_key_update,
+                edu_updater.incoming_signing_key_update,
             )
             # also handle the unstable version
             # FIXME: remove this when enough servers have upgraded
             federation_registry.register_edu_handler(
                 EduTypes.UNSTABLE_SIGNING_KEY_UPDATE,
-                self._edu_updater.incoming_signing_key_update,
+                edu_updater.incoming_signing_key_update,
             )
 
         # doesn't really work as part of the generic query API, because the
@@ -318,14 +315,13 @@ class E2eKeysHandler:
             # probably be tracking their device lists. However, we haven't
             # done an initial sync on the device list so we do it now.
             try:
-                if self._is_master:
-                    resync_results = await self.device_handler.device_list_updater.user_device_resync(
+                resync_results = (
+                    await self.device_handler.device_list_updater.user_device_resync(
                         user_id
                     )
-                else:
-                    resync_results = await self._user_device_resync_client(
-                        user_id=user_id
-                    )
+                )
+                if resync_results is None:
+                    raise ValueError("Device resync failed")
 
                 # Add the device keys to the results.
                 user_devices = resync_results["devices"]
@@ -431,13 +427,17 @@ class E2eKeysHandler:
     @trace
     @cancellable
     async def query_local_devices(
-        self, query: Mapping[str, Optional[List[str]]]
+        self,
+        query: Mapping[str, Optional[List[str]]],
+        include_displaynames: bool = True,
     ) -> Dict[str, Dict[str, dict]]:
         """Get E2E device keys for local users
 
         Args:
             query: map from user_id to a list
                  of devices to query (None for all devices)
+            include_displaynames: Whether to include device displaynames in the returned
+                device details.
 
         Returns:
             A map from user_id -> device_id -> device details
@@ -469,7 +469,9 @@ class E2eKeysHandler:
             # make sure that each queried user appears in the result dict
             result_dict[user_id] = {}
 
-        results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
+        results = await self.store.get_e2e_device_keys_for_cs_api(
+            local_query, include_displaynames
+        )
 
         # Build the result structure
         for user_id, device_keys in results.items():
@@ -482,11 +484,33 @@ class E2eKeysHandler:
     async def on_federation_query_client_keys(
         self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
     ) -> JsonDict:
-        """Handle a device key query from a federated server"""
+        """Handle a device key query from a federated server:
+
+        Handles the path: GET /_matrix/federation/v1/users/keys/query
+
+        Args:
+            query_body: The body of the query request. Should contain a key
+                "device_keys" that map to a dictionary of user ID's -> list of
+                device IDs. If the list of device IDs is empty, all devices of
+                that user will be queried.
+
+        Returns:
+            A json dictionary containing the following:
+                - device_keys: A dictionary containing the requested device information.
+                - master_keys: An optional dictionary of user ID -> master cross-signing
+                   key info.
+                - self_signing_key: An optional dictionary of user ID -> self-signing
+                    key info.
+        """
         device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
             "device_keys", {}
         )
-        res = await self.query_local_devices(device_keys_query)
+        res = await self.query_local_devices(
+            device_keys_query,
+            include_displaynames=(
+                self.config.federation.allow_device_name_lookup_over_federation
+            ),
+        )
         ret = {"device_keys": res}
 
         # add in the cross-signing keys
@@ -576,6 +600,8 @@ class E2eKeysHandler:
     async def upload_keys_for_user(
         self, user_id: str, device_id: str, keys: JsonDict
     ) -> JsonDict:
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
 
         time_now = self.clock.time_msec()
 
@@ -703,6 +729,8 @@ class E2eKeysHandler:
             user_id: the user uploading the keys
             keys: the signing keys
         """
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
 
         # if a master key is uploaded, then check it.  Otherwise, load the
         # stored master key, to check signatures on other keys
@@ -794,6 +822,9 @@ class E2eKeysHandler:
         Raises:
             SynapseError: if the signatures dict is not valid.
         """
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
+
         failures = {}
 
         # signatures to be stored.  Each item will be a SignatureListItem
@@ -841,7 +872,7 @@ class E2eKeysHandler:
         - signatures of the user's master key by the user's devices.
 
         Args:
-            user_id (string): the user uploading the keys
+            user_id: the user uploading the keys
             signatures (dict[string, dict]): map of devices to signed keys
 
         Returns:
@@ -1171,6 +1202,9 @@ class E2eKeysHandler:
             A tuple of the retrieved key content, the key's ID and the matching VerifyKey.
             If the key cannot be retrieved, all values in the tuple will instead be None.
         """
+        # This can only be called from the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
+
         try:
             remote_result = await self.federation.query_user_devices(
                 user.domain, user.to_string()
@@ -1367,11 +1401,14 @@ class SignatureListItem:
 class SigningKeyEduUpdater:
     """Handles incoming signing key updates from federation and updates the DB"""
 
-    def __init__(self, hs: "HomeServer", e2e_keys_handler: E2eKeysHandler):
+    def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self.federation = hs.get_federation_client()
         self.clock = hs.get_clock()
-        self.e2e_keys_handler = e2e_keys_handler
+
+        device_handler = hs.get_device_handler()
+        assert isinstance(device_handler, DeviceHandler)
+        self._device_handler = device_handler
 
         self._remote_edu_linearizer = Linearizer(name="remote_signing_key")
 
@@ -1416,9 +1453,6 @@ class SigningKeyEduUpdater:
             user_id: the user whose updates we are processing
         """
 
-        device_handler = self.e2e_keys_handler.device_handler
-        device_list_updater = device_handler.device_list_updater
-
         async with self._remote_edu_linearizer.queue(user_id):
             pending_updates = self._pending_updates.pop(user_id, [])
             if not pending_updates:
@@ -1430,13 +1464,11 @@ class SigningKeyEduUpdater:
             logger.info("pending updates: %r", pending_updates)
 
             for master_key, self_signing_key in pending_updates:
-                new_device_ids = (
-                    await device_list_updater.process_cross_signing_key_update(
-                        user_id,
-                        master_key,
-                        self_signing_key,
-                    )
+                new_device_ids = await self._device_handler.device_list_updater.process_cross_signing_key_update(
+                    user_id,
+                    master_key,
+                    self_signing_key,
                 )
                 device_ids = device_ids + new_device_ids
 
-            await device_handler.notify_device_update(user_id, device_ids)
+            await self._device_handler.notify_device_update(user_id, device_ids)
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 28dc08c22a..83f53ceb88 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -377,8 +377,9 @@ class E2eRoomKeysHandler:
         """Deletes a given version of the user's e2e_room_keys backup
 
         Args:
-            user_id(str): the user whose current backup version we're deleting
-            version(str): the version id of the backup being deleted
+            user_id: the user whose current backup version we're deleting
+            version: Optional. the version ID of the backup version we're deleting
+                If missing, we delete the current backup version info.
         Raises:
             NotFoundError: if this backup version doesn't exist
         """
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 8249ca1ed2..f91dbbecb7 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Collection, List, Optional, Union
+from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union
 
 from synapse import event_auth
 from synapse.api.constants import (
@@ -29,7 +29,6 @@ from synapse.event_auth import (
 )
 from synapse.events import EventBase
 from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
 from synapse.types import StateMap, get_domain_from_id
 
 if TYPE_CHECKING:
@@ -46,17 +45,27 @@ class EventAuthHandler:
     def __init__(self, hs: "HomeServer"):
         self._clock = hs.get_clock()
         self._store = hs.get_datastores().main
+        self._state_storage_controller = hs.get_storage_controllers().state
         self._server_name = hs.hostname
 
     async def check_auth_rules_from_context(
         self,
         event: EventBase,
-        context: EventContext,
+        batched_auth_events: Optional[Mapping[str, EventBase]] = None,
     ) -> None:
-        """Check an event passes the auth rules at its own auth events"""
-        await check_state_independent_auth_rules(self._store, event)
+        """Check an event passes the auth rules at its own auth events
+        Args:
+            event: event to be authed
+            batched_auth_events: if the event being authed is part of a batch, any events
+            from the same batch that may be necessary to auth the current event
+        """
+        await check_state_independent_auth_rules(
+            self._store, event, batched_auth_events
+        )
         auth_event_ids = event.auth_event_ids()
         auth_events_by_id = await self._store.get_events(auth_event_ids)
+        if batched_auth_events:
+            auth_events_by_id.update(batched_auth_events)
         check_state_dependent_auth_rules(event, auth_events_by_id.values())
 
     def compute_auth_events(
@@ -171,17 +180,22 @@ class EventAuthHandler:
         this function may return an incorrect result as we are not able to fully
         track server membership in a room without full state.
         """
-        if not allow_partial_state_rooms and await self._store.is_partial_state_room(
-            room_id
-        ):
-            raise AuthError(
-                403,
-                "Unable to authorise you right now; room is partial-stated here.",
-                errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE,
-            )
-
-        if not await self.is_host_in_room(room_id, host):
-            raise AuthError(403, "Host not in room.")
+        if await self._store.is_partial_state_room(room_id):
+            if allow_partial_state_rooms:
+                current_hosts = await self._state_storage_controller.get_current_hosts_in_room_or_partial_state_approximation(
+                    room_id
+                )
+                if host not in current_hosts:
+                    raise AuthError(403, "Host not in room (partial-state approx).")
+            else:
+                raise AuthError(
+                    403,
+                    "Unable to authorise you right now; room is partial-stated here.",
+                    errcode=Codes.UNABLE_DUE_TO_PARTIAL_STATE,
+                )
+        else:
+            if not await self.is_host_in_room(room_id, host):
+                raise AuthError(403, "Host not in room.")
 
     async def check_restricted_join_rules(
         self,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 986ffed3d5..b2784d7333 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -45,6 +45,7 @@ from synapse.api.errors import (
     Codes,
     FederationDeniedError,
     FederationError,
+    FederationPullAttemptBackoffError,
     HttpResponseException,
     LimitExceededError,
     NotFoundError,
@@ -69,8 +70,8 @@ from synapse.replication.http.federation import (
 )
 from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
 from synapse.types import JsonDict, get_domain_from_id
+from synapse.types.state import StateFilter
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.visibility import filter_events_for_server
@@ -151,6 +152,7 @@ class FederationHandler:
         self._federation_event_handler = hs.get_federation_event_handler()
         self._device_handler = hs.get_device_handler()
         self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
+        self._notifier = hs.get_notifier()
 
         self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
             hs
@@ -378,6 +380,7 @@ class FederationHandler:
             filtered_extremities = await filter_events_for_server(
                 self._storage_controllers,
                 self.server_name,
+                self.server_name,
                 events_to_check,
                 redact=False,
                 check_history_visibility_only=True,
@@ -441,6 +444,15 @@ class FederationHandler:
                     # appropriate stuff.
                     # TODO: We can probably do something more intelligent here.
                     return True
+                except NotRetryingDestination as e:
+                    logger.info("_maybe_backfill_inner: %s", e)
+                    continue
+                except FederationDeniedError:
+                    logger.info(
+                        "_maybe_backfill_inner: Not attempting to backfill from %s because the homeserver is not on our federation whitelist",
+                        dom,
+                    )
+                    continue
                 except (SynapseError, InvalidResponseError) as e:
                     logger.info("Failed to backfill from %s because %s", dom, e)
                     continue
@@ -476,15 +488,9 @@ class FederationHandler:
 
                     logger.info("Failed to backfill from %s because %s", dom, e)
                     continue
-                except NotRetryingDestination as e:
-                    logger.info(str(e))
-                    continue
                 except RequestSendFailed as e:
                     logger.info("Failed to get backfill from %s because %s", dom, e)
                     continue
-                except FederationDeniedError as e:
-                    logger.info(e)
-                    continue
                 except Exception as e:
                     logger.exception("Failed to backfill from %s because %s", dom, e)
                     continue
@@ -631,6 +637,7 @@ class FederationHandler:
                     room_id=room_id,
                     servers=ret.servers_in_room,
                     device_lists_stream_id=self.store.get_device_stream_token(),
+                    joined_via=origin,
                 )
 
             try:
@@ -781,15 +788,27 @@ class FederationHandler:
 
         # Send the signed event back to the room, and potentially receive some
         # further information about the room in the form of partial state events
-        stripped_room_state = await self.federation_client.send_knock(
-            target_hosts, event
-        )
+        knock_response = await self.federation_client.send_knock(target_hosts, event)
 
         # Store any stripped room state events in the "unsigned" key of the event.
         # This is a bit of a hack and is cribbing off of invites. Basically we
         # store the room state here and retrieve it again when this event appears
         # in the invitee's sync stream. It is stripped out for all other local users.
-        event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
+        stripped_room_state = (
+            knock_response.get("knock_room_state")
+            # Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
+            # Thus, we also check for a 'knock_state_events' to support old instances.
+            # See https://github.com/matrix-org/synapse/issues/14088.
+            or knock_response.get("knock_state_events")
+        )
+
+        if stripped_room_state is None:
+            raise KeyError(
+                "Missing 'knock_room_state' (or legacy 'knock_state_events') field in "
+                "send_knock response"
+            )
+
+        event.unsigned["knock_room_state"] = stripped_room_state
 
         context = EventContext.for_outlier(self._storage_controllers)
         stream_id = await self._federation_event_handler.persist_events_and_notify(
@@ -928,7 +947,7 @@ class FederationHandler:
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
-        await self._event_auth_handler.check_auth_rules_from_context(event, context)
+        await self._event_auth_handler.check_auth_rules_from_context(event)
         return event
 
     async def on_invite_request(
@@ -1003,7 +1022,9 @@ class FederationHandler:
 
         context = EventContext.for_outlier(self._storage_controllers)
 
-        await self._bulk_push_rule_evaluator.action_for_event_by_user(event, context)
+        await self._bulk_push_rule_evaluator.action_for_events_by_user(
+            [(event, context)]
+        )
         try:
             await self._federation_event_handler.persist_events_and_notify(
                 event.room_id, [(event, context)]
@@ -1109,7 +1130,7 @@ class FederationHandler:
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
-            await self._event_auth_handler.check_auth_rules_from_context(event, context)
+            await self._event_auth_handler.check_auth_rules_from_context(event)
         except AuthError as e:
             logger.warning("Failed to create new leave %r because %s", event, e)
             raise e
@@ -1168,7 +1189,7 @@ class FederationHandler:
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_knock_request`
-            await self._event_auth_handler.check_auth_rules_from_context(event, context)
+            await self._event_auth_handler.check_auth_rules_from_context(event)
         except AuthError as e:
             logger.warning("Failed to create new knock %r because %s", event, e)
             raise e
@@ -1212,7 +1233,9 @@ class FederationHandler:
     async def on_backfill_request(
         self, origin: str, room_id: str, pdu_list: List[str], limit: int
     ) -> List[EventBase]:
-        await self._event_auth_handler.assert_host_in_room(room_id, origin)
+        # We allow partially joined rooms since in this case we are filtering out
+        # non-local events in `filter_events_for_server`.
+        await self._event_auth_handler.assert_host_in_room(room_id, origin, True)
 
         # Synapse asks for 100 events per backfill request. Do not allow more.
         limit = min(limit, 100)
@@ -1233,7 +1256,7 @@ class FederationHandler:
         )
 
         events = await filter_events_for_server(
-            self._storage_controllers, origin, events
+            self._storage_controllers, origin, self.server_name, events
         )
 
         return events
@@ -1264,7 +1287,7 @@ class FederationHandler:
         await self._event_auth_handler.assert_host_in_room(event.room_id, origin)
 
         events = await filter_events_for_server(
-            self._storage_controllers, origin, [event]
+            self._storage_controllers, origin, self.server_name, [event]
         )
         event = events[0]
         return event
@@ -1277,7 +1300,9 @@ class FederationHandler:
         latest_events: List[str],
         limit: int,
     ) -> List[EventBase]:
-        await self._event_auth_handler.assert_host_in_room(room_id, origin)
+        # We allow partially joined rooms since in this case we are filtering out
+        # non-local events in `filter_events_for_server`.
+        await self._event_auth_handler.assert_host_in_room(room_id, origin, True)
 
         # Only allow up to 20 events to be retrieved per request.
         limit = min(limit, 20)
@@ -1290,7 +1315,7 @@ class FederationHandler:
         )
 
         missing_events = await filter_events_for_server(
-            self._storage_controllers, origin, missing_events
+            self._storage_controllers, origin, self.server_name, missing_events
         )
 
         return missing_events
@@ -1334,9 +1359,7 @@ class FederationHandler:
 
             try:
                 validate_event_for_room_version(event)
-                await self._event_auth_handler.check_auth_rules_from_context(
-                    event, context
-                )
+                await self._event_auth_handler.check_auth_rules_from_context(event)
             except AuthError as e:
                 logger.warning("Denying new third party invite %r because %s", event, e)
                 raise e
@@ -1386,7 +1409,7 @@ class FederationHandler:
 
         try:
             validate_event_for_room_version(event)
-            await self._event_auth_handler.check_auth_rules_from_context(event, context)
+            await self._event_auth_handler.check_auth_rules_from_context(event)
         except AuthError as e:
             logger.warning("Denying third party invite %r because %s", event, e)
             raise e
@@ -1579,8 +1602,8 @@ class FederationHandler:
         Fetch the complexity of a remote room over federation.
 
         Args:
-            remote_room_hosts (list[str]): The remote servers to ask.
-            room_id (str): The room ID to ask about.
+            remote_room_hosts: The remote servers to ask.
+            room_id: The room ID to ask about.
 
         Returns:
             Dict contains the complexity
@@ -1602,13 +1625,13 @@ class FederationHandler:
         """Resumes resyncing of all partial-state rooms after a restart."""
         assert not self.config.worker.worker_app
 
-        partial_state_rooms = await self.store.get_partial_state_rooms_and_servers()
-        for room_id, servers_in_room in partial_state_rooms.items():
+        partial_state_rooms = await self.store.get_partial_state_room_resync_info()
+        for room_id, resync_info in partial_state_rooms.items():
             run_as_background_process(
                 desc="sync_partial_state_room",
                 func=self._sync_partial_state_room,
-                initial_destination=None,
-                other_destinations=servers_in_room,
+                initial_destination=resync_info.joined_via,
+                other_destinations=resync_info.servers_in_room,
                 room_id=room_id,
             )
 
@@ -1637,28 +1660,12 @@ class FederationHandler:
         #   really leave, that might mean we have difficulty getting the room state over
         #   federation.
         #   https://github.com/matrix-org/synapse/issues/12802
-        #
-        # TODO(faster_joins): we need some way of prioritising which homeservers in
-        #   `other_destinations` to try first, otherwise we'll spend ages trying dead
-        #   homeservers for large rooms.
-        #   https://github.com/matrix-org/synapse/issues/12999
-
-        if initial_destination is None and len(other_destinations) == 0:
-            raise ValueError(
-                f"Cannot resync state of {room_id}: no destinations provided"
-            )
 
         # Make an infinite iterator of destinations to try. Once we find a working
         # destination, we'll stick with it until it flakes.
-        destinations: Collection[str]
-        if initial_destination is not None:
-            # Move `initial_destination` to the front of the list.
-            destinations = list(other_destinations)
-            if initial_destination in destinations:
-                destinations.remove(initial_destination)
-            destinations = [initial_destination] + destinations
-        else:
-            destinations = other_destinations
+        destinations = _prioritise_destinations_for_partial_state_resync(
+            initial_destination, other_destinations, room_id
+        )
         destination_iter = itertools.cycle(destinations)
 
         # `destination` is the current remote homeserver we're pulling from.
@@ -1686,6 +1693,9 @@ class FederationHandler:
                     self._storage_controllers.state.notify_room_un_partial_stated(
                         room_id
                     )
+                    # Poke the notifier so that other workers see the write to
+                    # the un-partial-stated rooms stream.
+                    self._notifier.notify_replication()
 
                     # TODO(faster_joins) update room stats and user directory?
                     #   https://github.com/matrix-org/synapse/issues/12814
@@ -1708,7 +1718,22 @@ class FederationHandler:
                             destination, event
                         )
                         break
+                    except FederationPullAttemptBackoffError as exc:
+                        # Log a warning about why we failed to process the event (the error message
+                        # for `FederationPullAttemptBackoffError` is pretty good)
+                        logger.warning("_sync_partial_state_room: %s", exc)
+                        # We do not record a failed pull attempt when we backoff fetching a missing
+                        # `prev_event` because not being able to fetch the `prev_events` just means
+                        # we won't be able to de-outlier the pulled event. But we can still use an
+                        # `outlier` in the state/auth chain for another event. So we shouldn't stop
+                        # a downstream event from trying to pull it.
+                        #
+                        # This avoids a cascade of backoff for all events in the DAG downstream from
+                        # one event backoff upstream.
                     except FederationError as e:
+                        # TODO: We should `record_event_failed_pull_attempt` here,
+                        #   see https://github.com/matrix-org/synapse/issues/13700
+
                         if attempt == len(destinations) - 1:
                             # We have tried every remote server for this event. Give up.
                             # TODO(faster_joins) giving up isn't the right thing to do
@@ -1741,3 +1766,29 @@ class FederationHandler:
                             room_id,
                             destination,
                         )
+
+
+def _prioritise_destinations_for_partial_state_resync(
+    initial_destination: Optional[str],
+    other_destinations: Collection[str],
+    room_id: str,
+) -> Collection[str]:
+    """Work out the order in which we should ask servers to resync events.
+
+    If an `initial_destination` is given, it takes top priority. Otherwise
+    all servers are treated equally.
+
+    :raises ValueError: if no destination is provided at all.
+    """
+    if initial_destination is None and len(other_destinations) == 0:
+        raise ValueError(f"Cannot resync state of {room_id}: no destinations provided")
+
+    if initial_destination is None:
+        return other_destinations
+
+    # Move `initial_destination` to the front of the list.
+    destinations = list(other_destinations)
+    if initial_destination in destinations:
+        destinations.remove(initial_destination)
+    destinations = [initial_destination] + destinations
+    return destinations
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index da319943cc..66aca2f864 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -43,7 +43,9 @@ from synapse.api.constants import (
 from synapse.api.errors import (
     AuthError,
     Codes,
+    EventSizeError,
     FederationError,
+    FederationPullAttemptBackoffError,
     HttpResponseException,
     RequestSendFailed,
     SynapseError,
@@ -57,7 +59,7 @@ from synapse.event_auth import (
 )
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.federation.federation_client import InvalidResponseError
+from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
 from synapse.logging.context import nested_logging_context
 from synapse.logging.opentracing import (
     SynapseTags,
@@ -74,7 +76,6 @@ from synapse.replication.http.federation import (
 from synapse.state import StateResolutionStore
 from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
 from synapse.types import (
     PersistedEventPosition,
     RoomStreamToken,
@@ -82,6 +83,7 @@ from synapse.types import (
     UserID,
     get_domain_from_id,
 )
+from synapse.types.state import StateFilter
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.iterutils import batch_iter
 from synapse.util.retryutils import NotRetryingDestination
@@ -414,7 +416,9 @@ class FederationEventHandler:
 
         # First, precalculate the joined hosts so that the federation sender doesn't
         # need to.
-        await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
+        await self._event_creation_handler.cache_joined_hosts_for_events(
+            [(event, context)]
+        )
 
         await self._check_for_soft_fail(event, context=context, origin=origin)
         await self._run_push_actions_and_persist_event(event, context)
@@ -565,6 +569,9 @@ class FederationEventHandler:
             event: partial-state event to be de-partial-stated
 
         Raises:
+            FederationPullAttemptBackoffError if we are are deliberately not attempting
+                to pull the given event over federation because we've already done so
+                recently and are backing off.
             FederationError if we fail to request state from the remote server.
         """
         logger.info("Updating state for %s", event.event_id)
@@ -792,9 +799,42 @@ class FederationEventHandler:
             ],
         )
 
+        # Check if we already any of these have these events.
+        # Note: we currently make a lookup in the database directly here rather than
+        # checking the event cache, due to:
+        # https://github.com/matrix-org/synapse/issues/13476
+        existing_events_map = await self._store._get_events_from_db(
+            [event.event_id for event in events]
+        )
+
+        new_events = []
+        for event in events:
+            event_id = event.event_id
+
+            # If we've already seen this event ID...
+            if event_id in existing_events_map:
+                existing_event = existing_events_map[event_id]
+
+                # ...and the event itself was not previously stored as an outlier...
+                if not existing_event.event.internal_metadata.is_outlier():
+                    # ...then there's no need to persist it. We have it already.
+                    logger.info(
+                        "_process_pulled_event: Ignoring received event %s which we "
+                        "have already seen",
+                        event.event_id,
+                    )
+                    continue
+
+                # While we have seen this event before, it was stored as an outlier.
+                # We'll now persist it as a non-outlier.
+                logger.info("De-outliering event %s", event_id)
+
+            # Continue on with the events that are new to us.
+            new_events.append(event)
+
         # We want to sort these by depth so we process them and
         # tell clients about them in order.
-        sorted_events = sorted(events, key=lambda x: x.depth)
+        sorted_events = sorted(new_events, key=lambda x: x.depth)
         for ev in sorted_events:
             with nested_logging_context(ev.event_id):
                 await self._process_pulled_event(origin, ev, backfilled=backfilled)
@@ -846,18 +886,6 @@ class FederationEventHandler:
 
         event_id = event.event_id
 
-        existing = await self._store.get_event(
-            event_id, allow_none=True, allow_rejected=True
-        )
-        if existing:
-            if not existing.internal_metadata.is_outlier():
-                logger.info(
-                    "_process_pulled_event: Ignoring received event %s which we have already seen",
-                    event_id,
-                )
-                return
-            logger.info("De-outliering event %s", event_id)
-
         try:
             self._sanity_check_event(event)
         except SynapseError as err:
@@ -899,6 +927,18 @@ class FederationEventHandler:
                     context,
                     backfilled=backfilled,
                 )
+        except FederationPullAttemptBackoffError as exc:
+            # Log a warning about why we failed to process the event (the error message
+            # for `FederationPullAttemptBackoffError` is pretty good)
+            logger.warning("_process_pulled_event: %s", exc)
+            # We do not record a failed pull attempt when we backoff fetching a missing
+            # `prev_event` because not being able to fetch the `prev_events` just means
+            # we won't be able to de-outlier the pulled event. But we can still use an
+            # `outlier` in the state/auth chain for another event. So we shouldn't stop
+            # a downstream event from trying to pull it.
+            #
+            # This avoids a cascade of backoff for all events in the DAG downstream from
+            # one event backoff upstream.
         except FederationError as e:
             await self._store.record_event_failed_pull_attempt(
                 event.room_id, event_id, str(e)
@@ -945,6 +985,9 @@ class FederationEventHandler:
             The event context.
 
         Raises:
+            FederationPullAttemptBackoffError if we are are deliberately not attempting
+                to pull the given event over federation because we've already done so
+                recently and are backing off.
             FederationError if we fail to get the state from the remote server after any
                 missing `prev_event`s.
         """
@@ -955,6 +998,18 @@ class FederationEventHandler:
         seen = await self._store.have_events_in_timeline(prevs)
         missing_prevs = prevs - seen
 
+        # If we've already recently attempted to pull this missing event, don't
+        # try it again so soon. Since we have to fetch all of the prev_events, we can
+        # bail early here if we find any to ignore.
+        prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff(
+            room_id, missing_prevs
+        )
+        if len(prevs_to_ignore) > 0:
+            raise FederationPullAttemptBackoffError(
+                event_ids=prevs_to_ignore,
+                message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).",
+            )
+
         if not missing_prevs:
             return await self._state_handler.compute_event_context(event)
 
@@ -1011,10 +1066,9 @@ class FederationEventHandler:
                 state_res_store=StateResolutionStore(self._store),
             )
 
-        except Exception:
+        except Exception as e:
             logger.warning(
-                "Error attempting to resolve state at missing prev_events",
-                exc_info=True,
+                "Error attempting to resolve state at missing prev_events: %s", e
             )
             raise FederationError(
                 "ERROR",
@@ -1463,8 +1517,8 @@ class FederationEventHandler:
         )
 
     async def backfill_event_id(
-        self, destination: str, room_id: str, event_id: str
-    ) -> EventBase:
+        self, destinations: List[str], room_id: str, event_id: str
+    ) -> PulledPduInfo:
         """Backfill a single event and persist it as a non-outlier which means
         we also pull in all of the state and auth events necessary for it.
 
@@ -1476,24 +1530,21 @@ class FederationEventHandler:
         Raises:
             FederationError if we are unable to find the event from the destination
         """
-        logger.info(
-            "backfill_event_id: event_id=%s from destination=%s", event_id, destination
-        )
+        logger.info("backfill_event_id: event_id=%s", event_id)
 
         room_version = await self._store.get_room_version(room_id)
 
-        event_from_response = await self._federation_client.get_pdu(
-            [destination],
+        pulled_pdu_info = await self._federation_client.get_pdu(
+            destinations,
             event_id,
             room_version,
         )
 
-        if not event_from_response:
+        if not pulled_pdu_info:
             raise FederationError(
                 "ERROR",
                 404,
-                "Unable to find event_id=%s from destination=%s to backfill."
-                % (event_id, destination),
+                f"Unable to find event_id={event_id} from remote servers to backfill.",
                 affected=event_id,
             )
 
@@ -1501,13 +1552,13 @@ class FederationEventHandler:
         # and auth events to de-outlier it. This also sets up the necessary
         # `state_groups` for the event.
         await self._process_pulled_events(
-            destination,
-            [event_from_response],
+            pulled_pdu_info.pull_origin,
+            [pulled_pdu_info.pdu],
             # Prevent notifications going to clients
             backfilled=True,
         )
 
-        return event_from_response
+        return pulled_pdu_info
 
     @trace
     @tag_args
@@ -1530,19 +1581,19 @@ class FederationEventHandler:
         async def get_event(event_id: str) -> None:
             with nested_logging_context(event_id):
                 try:
-                    event = await self._federation_client.get_pdu(
+                    pulled_pdu_info = await self._federation_client.get_pdu(
                         [destination],
                         event_id,
                         room_version,
                     )
-                    if event is None:
+                    if pulled_pdu_info is None:
                         logger.warning(
                             "Server %s didn't return event %s",
                             destination,
                             event_id,
                         )
                         return
-                    events.append(event)
+                    events.append(pulled_pdu_info.pdu)
 
                 except Exception as e:
                     logger.warning(
@@ -1686,6 +1737,15 @@ class FederationEventHandler:
                 except AuthError as e:
                     logger.warning("Rejecting %r because %s", event, e)
                     context.rejected = RejectedReason.AUTH_ERROR
+                except EventSizeError as e:
+                    if e.unpersistable:
+                        # This event is completely unpersistable.
+                        raise e
+                    # Otherwise, we are somewhat lenient and just persist the event
+                    # as rejected, for moderate compatibility with older Synapse
+                    # versions.
+                    logger.warning("While validating received event %r: %s", event, e)
+                    context.rejected = RejectedReason.OVERSIZED_EVENT
 
             events_and_contexts_to_persist.append((event, context))
 
@@ -1731,6 +1791,16 @@ class FederationEventHandler:
             # TODO: use a different rejected reason here?
             context.rejected = RejectedReason.AUTH_ERROR
             return
+        except EventSizeError as e:
+            if e.unpersistable:
+                # This event is completely unpersistable.
+                raise e
+            # Otherwise, we are somewhat lenient and just persist the event
+            # as rejected, for moderate compatibility with older Synapse
+            # versions.
+            logger.warning("While validating received event %r: %s", event, e)
+            context.rejected = RejectedReason.OVERSIZED_EVENT
+            return
 
         # next, check that we have all of the event's auth events.
         #
@@ -2117,8 +2187,8 @@ class FederationEventHandler:
                     min_depth,
                 )
             else:
-                await self._bulk_push_rule_evaluator.action_for_event_by_user(
-                    event, context
+                await self._bulk_push_rule_evaluator.action_for_events_by_user(
+                    [(event, context)]
                 )
 
         try:
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 93d09e9939..848e46eb9b 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -711,7 +711,7 @@ class IdentityHandler:
             inviter_display_name: The current display name of the
                 inviter.
             inviter_avatar_url: The URL of the inviter's avatar.
-            id_access_token (str): The access token to authenticate to the identity
+            id_access_token: The access token to authenticate to the identity
                 server with
 
         Returns:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 860c82c110..9c335e6863 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -57,13 +57,7 @@ class InitialSyncHandler:
         self.validator = EventValidator()
         self.snapshot_cache: ResponseCache[
             Tuple[
-                str,
-                Optional[StreamToken],
-                Optional[StreamToken],
-                str,
-                Optional[int],
-                bool,
-                bool,
+                str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
             ]
         ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
@@ -154,11 +148,6 @@ class InitialSyncHandler:
 
         public_room_ids = await self.store.get_public_room_ids()
 
-        if pagin_config.limit is not None:
-            limit = pagin_config.limit
-        else:
-            limit = 10
-
         serializer_options = SerializeEventConfig(as_client_event=as_client_event)
 
         async def handle_room(event: RoomsForUser) -> None:
@@ -210,7 +199,7 @@ class InitialSyncHandler:
                             run_in_background(
                                 self.store.get_recent_events_for_room,
                                 event.room_id,
-                                limit=limit,
+                                limit=pagin_config.limit,
                                 end_token=room_end_token,
                             ),
                             deferred_room_state,
@@ -360,15 +349,11 @@ class InitialSyncHandler:
             member_event_id
         )
 
-        limit = pagin_config.limit if pagin_config else None
-        if limit is None:
-            limit = 10
-
         leave_position = await self.store.get_position_for_event(member_event_id)
         stream_token = leave_position.to_room_stream_token()
 
         messages, token = await self.store.get_recent_events_for_room(
-            room_id, limit=limit, end_token=stream_token
+            room_id, limit=pagin_config.limit, end_token=stream_token
         )
 
         messages = await filter_events_for_client(
@@ -420,10 +405,6 @@ class InitialSyncHandler:
 
         now_token = self.hs.get_event_sources().get_current_token()
 
-        limit = pagin_config.limit if pagin_config else None
-        if limit is None:
-            limit = 10
-
         room_members = [
             m
             for m in current_state.values()
@@ -467,7 +448,7 @@ class InitialSyncHandler:
                     run_in_background(
                         self.store.get_recent_events_for_room,
                         room_id,
-                        limit=limit,
+                        limit=pagin_config.limit,
                         end_token=now_token.room_key,
                     ),
                 ),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da1acea275..845f683358 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -50,6 +50,7 @@ from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase, relation_from_event
 from synapse.events.builder import EventBuilder
 from synapse.events.snapshot import EventContext
+from synapse.events.utils import maybe_upsert_event_field
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
 from synapse.logging import opentracing
@@ -59,7 +60,6 @@ from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
 from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.state import StateFilter
 from synapse.types import (
     MutableStateMap,
     PersistedEventPosition,
@@ -70,6 +70,7 @@ from synapse.types import (
     UserID,
     create_requester,
 )
+from synapse.types.state import StateFilter
 from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, gather_results
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -877,6 +878,36 @@ class EventCreationHandler:
                 return prev_event
         return None
 
+    async def get_event_from_transaction(
+        self,
+        requester: Requester,
+        txn_id: str,
+        room_id: str,
+    ) -> Optional[EventBase]:
+        """For the given transaction ID and room ID, check if there is a matching event.
+        If so, fetch it and return it.
+
+        Args:
+            requester: The requester making the request in the context of which we want
+                to fetch the event.
+            txn_id: The transaction ID.
+            room_id: The room ID.
+
+        Returns:
+            An event if one could be found, None otherwise.
+        """
+        if requester.access_token_id:
+            existing_event_id = await self.store.get_event_id_from_transaction_id(
+                room_id,
+                requester.user.to_string(),
+                requester.access_token_id,
+                txn_id,
+            )
+            if existing_event_id:
+                return await self.store.get_event(existing_event_id)
+
+        return None
+
     async def create_and_send_nonmember_event(
         self,
         requester: Requester,
@@ -956,18 +987,17 @@ class EventCreationHandler:
         # extremities to pile up, which in turn leads to state resolution
         # taking longer.
         async with self.limiter.queue(event_dict["room_id"]):
-            if txn_id and requester.access_token_id:
-                existing_event_id = await self.store.get_event_id_from_transaction_id(
-                    event_dict["room_id"],
-                    requester.user.to_string(),
-                    requester.access_token_id,
-                    txn_id,
+            if txn_id:
+                event = await self.get_event_from_transaction(
+                    requester, txn_id, event_dict["room_id"]
                 )
-                if existing_event_id:
-                    event = await self.store.get_event(existing_event_id)
+                if event:
                     # we know it was persisted, so must have a stream ordering
                     assert event.internal_metadata.stream_ordering
-                    return event, event.internal_metadata.stream_ordering
+                    return (
+                        event,
+                        event.internal_metadata.stream_ordering,
+                    )
 
             event, context = await self.create_event(
                 requester,
@@ -1106,11 +1136,13 @@ class EventCreationHandler:
             )
             state_events = await self.store.get_events_as_list(state_event_ids)
             # Create a StateMap[str]
-            state_map = {(e.type, e.state_key): e.event_id for e in state_events}
+            current_state_ids = {
+                (e.type, e.state_key): e.event_id for e in state_events
+            }
             # Actually strip down and only use the necessary auth events
             auth_event_ids = self._event_auth_handler.compute_auth_events(
                 event=temp_event,
-                current_state_ids=state_map,
+                current_state_ids=current_state_ids,
                 for_verification=False,
             )
 
@@ -1360,8 +1392,16 @@ class EventCreationHandler:
             else:
                 try:
                     validate_event_for_room_version(event)
+                    # If we are persisting a batch of events the event(s) needed to auth the
+                    # current event may be part of the batch and will not be in the DB yet
+                    event_id_to_event = {e.event_id: e for e, _ in events_and_context}
+                    batched_auth_events = {}
+                    for event_id in event.auth_event_ids():
+                        auth_event = event_id_to_event.get(event_id)
+                        if auth_event:
+                            batched_auth_events[event_id] = auth_event
                     await self._event_auth_handler.check_auth_rules_from_context(
-                        event, context
+                        event, batched_auth_events
                     )
                 except AuthError as err:
                     logger.warning("Denying new event %r because %s", event, err)
@@ -1390,7 +1430,7 @@ class EventCreationHandler:
                             extra_users=extra_users,
                         ),
                         run_in_background(
-                            self.cache_joined_hosts_for_event, event, context
+                            self.cache_joined_hosts_for_events, events_and_context
                         ).addErrback(
                             log_failure, "cache_joined_hosts_for_event failed"
                         ),
@@ -1425,17 +1465,9 @@ class EventCreationHandler:
             a room that has been un-partial stated.
         """
 
-        for event, context in events_and_context:
-            # Skip push notification actions for historical messages
-            # because we don't want to notify people about old history back in time.
-            # The historical messages also do not have the proper `context.current_state_ids`
-            # and `state_groups` because they have `prev_events` that aren't persisted yet
-            # (historical messages persisted in reverse-chronological order).
-            if not event.internal_metadata.is_historical():
-                with opentracing.start_active_span("calculate_push_actions"):
-                    await self._bulk_push_rule_evaluator.action_for_event_by_user(
-                        event, context
-                    )
+        await self._bulk_push_rule_evaluator.action_for_events_by_user(
+            events_and_context
+        )
 
         try:
             # If we're a worker we need to hit out to the master.
@@ -1491,62 +1523,65 @@ class EventCreationHandler:
                 await self.store.remove_push_actions_from_staging(event.event_id)
             raise
 
-    async def cache_joined_hosts_for_event(
-        self, event: EventBase, context: EventContext
+    async def cache_joined_hosts_for_events(
+        self, events_and_context: List[Tuple[EventBase, EventContext]]
     ) -> None:
-        """Precalculate the joined hosts at the event, when using Redis, so that
+        """Precalculate the joined hosts at each of the given events, when using Redis, so that
         external federation senders don't have to recalculate it themselves.
         """
 
-        if not self._external_cache.is_enabled():
-            return
-
-        # If external cache is enabled we should always have this.
-        assert self._external_cache_joined_hosts_updates is not None
+        for event, _ in events_and_context:
+            if not self._external_cache.is_enabled():
+                return
 
-        # We actually store two mappings, event ID -> prev state group,
-        # state group -> joined hosts, which is much more space efficient
-        # than event ID -> joined hosts.
-        #
-        # Note: We have to cache event ID -> prev state group, as we don't
-        # store that in the DB.
-        #
-        # Note: We set the state group -> joined hosts cache if it hasn't been
-        # set for a while, so that the expiry time is reset.
+            # If external cache is enabled we should always have this.
+            assert self._external_cache_joined_hosts_updates is not None
 
-        state_entry = await self.state.resolve_state_groups_for_events(
-            event.room_id, event_ids=event.prev_event_ids()
-        )
+            # We actually store two mappings, event ID -> prev state group,
+            # state group -> joined hosts, which is much more space efficient
+            # than event ID -> joined hosts.
+            #
+            # Note: We have to cache event ID -> prev state group, as we don't
+            # store that in the DB.
+            #
+            # Note: We set the state group -> joined hosts cache if it hasn't been
+            # set for a while, so that the expiry time is reset.
 
-        if state_entry.state_group:
-            await self._external_cache.set(
-                "event_to_prev_state_group",
-                event.event_id,
-                state_entry.state_group,
-                expiry_ms=60 * 60 * 1000,
+            state_entry = await self.state.resolve_state_groups_for_events(
+                event.room_id, event_ids=event.prev_event_ids()
             )
 
-            if state_entry.state_group in self._external_cache_joined_hosts_updates:
-                return
+            if state_entry.state_group:
+                await self._external_cache.set(
+                    "event_to_prev_state_group",
+                    event.event_id,
+                    state_entry.state_group,
+                    expiry_ms=60 * 60 * 1000,
+                )
 
-            state = await state_entry.get_state(
-                self._storage_controllers.state, StateFilter.all()
-            )
-            with opentracing.start_active_span("get_joined_hosts"):
-                joined_hosts = await self.store.get_joined_hosts(
-                    event.room_id, state, state_entry
+                if state_entry.state_group in self._external_cache_joined_hosts_updates:
+                    return
+
+                state = await state_entry.get_state(
+                    self._storage_controllers.state, StateFilter.all()
                 )
+                with opentracing.start_active_span("get_joined_hosts"):
+                    joined_hosts = await self.store.get_joined_hosts(
+                        event.room_id, state, state_entry
+                    )
 
-            # Note that the expiry times must be larger than the expiry time in
-            # _external_cache_joined_hosts_updates.
-            await self._external_cache.set(
-                "get_joined_hosts",
-                str(state_entry.state_group),
-                list(joined_hosts),
-                expiry_ms=60 * 60 * 1000,
-            )
+                # Note that the expiry times must be larger than the expiry time in
+                # _external_cache_joined_hosts_updates.
+                await self._external_cache.set(
+                    "get_joined_hosts",
+                    str(state_entry.state_group),
+                    list(joined_hosts),
+                    expiry_ms=60 * 60 * 1000,
+                )
 
-            self._external_cache_joined_hosts_updates[state_entry.state_group] = None
+                self._external_cache_joined_hosts_updates[
+                    state_entry.state_group
+                ] = None
 
     async def _validate_canonical_alias(
         self,
@@ -1705,12 +1740,15 @@ class EventCreationHandler:
 
             if event.type == EventTypes.Member:
                 if event.content["membership"] == Membership.INVITE:
-                    event.unsigned[
-                        "invite_room_state"
-                    ] = await self.store.get_stripped_room_state_from_event_context(
-                        context,
-                        self.room_prejoin_state_types,
-                        membership_user_id=event.sender,
+                    maybe_upsert_event_field(
+                        event,
+                        event.unsigned,
+                        "invite_room_state",
+                        await self.store.get_stripped_room_state_from_event_context(
+                            context,
+                            self.room_prejoin_state_types,
+                            membership_user_id=event.sender,
+                        ),
                     )
 
                     invitee = UserID.from_string(event.state_key)
@@ -1728,11 +1766,14 @@ class EventCreationHandler:
                         event.signatures.update(returned_invite.signatures)
 
                 if event.content["membership"] == Membership.KNOCK:
-                    event.unsigned[
-                        "knock_room_state"
-                    ] = await self.store.get_stripped_room_state_from_event_context(
-                        context,
-                        self.room_prejoin_state_types,
+                    maybe_upsert_event_field(
+                        event,
+                        event.unsigned,
+                        "knock_room_state",
+                        await self.store.get_stripped_room_state_from_event_context(
+                            context,
+                            self.room_prejoin_state_types,
+                        ),
                     )
 
             if event.type == EventTypes.Redaction:
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index d7a8226900..03de6a4ba6 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -12,14 +12,28 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import binascii
 import inspect
+import json
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Generic,
+    List,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+)
 from urllib.parse import urlencode, urlparse
 
 import attr
+import unpaddedbase64
 from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken, jwt
+from authlib.jose import JsonWebToken, JWTClaims
+from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
 from authlib.oidc.core import CodeIDToken, UserInfo
@@ -35,9 +49,12 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 from twisted.web.http_headers import Headers
 
+from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
 from synapse.handlers.sso import MappingException, UserAttributes
+from synapse.http.server import finish_request
+from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -88,6 +105,8 @@ class Token(TypedDict):
 #: there is no real point of doing this in our case.
 JWK = Dict[str, str]
 
+C = TypeVar("C")
+
 
 #: A JWK Set, as per RFC7517 sec 5.
 class JWKS(TypedDict):
@@ -247,6 +266,80 @@ class OidcHandler:
 
         await oidc_provider.handle_oidc_callback(request, session_data, code)
 
+    async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
+
+        This extracts the logout_token from the request and tries to figure out
+        which OpenID Provider it is comming from. This works by matching the iss claim
+        with the issuer and the aud claim with the client_id.
+
+        Since at this point we don't know who signed the JWT, we can't just
+        decode it using authlib since it will always verifies the signature. We
+        have to decode it manually without validating the signature. The actual JWT
+        verification is done in the `OidcProvider.handler_backchannel_logout` method,
+        once we figured out which provider sent the request.
+
+        Args:
+            request: the incoming request from the browser.
+        """
+        logout_token = parse_string(request, "logout_token")
+        if logout_token is None:
+            raise SynapseError(400, "Missing logout_token in request")
+
+        # A JWT looks like this:
+        #    header.payload.signature
+        # where all parts are encoded with urlsafe base64.
+        # The aud and iss claims we care about are in the payload part, which
+        # is a JSON object.
+        try:
+            # By destructuring the list after splitting, we ensure that we have
+            # exactly 3 segments
+            _, payload, _ = logout_token.split(".")
+        except ValueError:
+            raise SynapseError(400, "Invalid logout_token in request")
+
+        try:
+            payload_bytes = unpaddedbase64.decode_base64(payload)
+            claims = json_decoder.decode(payload_bytes.decode("utf-8"))
+        except (json.JSONDecodeError, binascii.Error, UnicodeError):
+            raise SynapseError(400, "Invalid logout_token payload in request")
+
+        try:
+            # Let's extract the iss and aud claims
+            iss = claims["iss"]
+            aud = claims["aud"]
+            # The aud claim can be either a string or a list of string. Here we
+            # normalize it as a list of strings.
+            if isinstance(aud, str):
+                aud = [aud]
+
+            # Check that we have the right types for the aud and the iss claims
+            if not isinstance(iss, str) or not isinstance(aud, list):
+                raise TypeError()
+            for a in aud:
+                if not isinstance(a, str):
+                    raise TypeError()
+
+            # At this point we properly checked both claims types
+            issuer: str = iss
+            audience: List[str] = aud
+        except (TypeError, KeyError):
+            raise SynapseError(400, "Invalid issuer/audience in logout_token")
+
+        # Now that we know the audience and the issuer, we can figure out from
+        # what provider it is coming from
+        oidc_provider: Optional[OidcProvider] = None
+        for provider in self._providers.values():
+            if provider.issuer == issuer and provider.client_id in audience:
+                oidc_provider = provider
+                break
+
+        if oidc_provider is None:
+            raise SynapseError(400, "Could not find the OP that issued this event")
+
+        # Ask the provider to handle the logout request.
+        await oidc_provider.handle_backchannel_logout(request, logout_token)
+
 
 class OidcError(Exception):
     """Used to catch errors when calling the token_endpoint"""
@@ -275,6 +368,7 @@ class OidcProvider:
         provider: OidcProviderConfig,
     ):
         self._store = hs.get_datastores().main
+        self._clock = hs.get_clock()
 
         self._macaroon_generaton = macaroon_generator
 
@@ -341,6 +435,7 @@ class OidcProvider:
         self.idp_brand = provider.idp_brand
 
         self._sso_handler = hs.get_sso_handler()
+        self._device_handler = hs.get_device_handler()
 
         self._sso_handler.register_identity_provider(self)
 
@@ -399,6 +494,41 @@ class OidcProvider:
             # If we're not using userinfo, we need a valid jwks to validate the ID token
             m.validate_jwks_uri()
 
+        if self._config.backchannel_logout_enabled:
+            if not m.get("backchannel_logout_supported", False):
+                logger.warning(
+                    "OIDC Back-Channel Logout is enabled for issuer %r"
+                    "but it does not advertise support for it",
+                    self.issuer,
+                )
+
+            elif not m.get("backchannel_logout_session_supported", False):
+                logger.warning(
+                    "OIDC Back-Channel Logout is enabled and supported "
+                    "by issuer %r but it might not send a session ID with "
+                    "logout tokens, which is required for the logouts to work",
+                    self.issuer,
+                )
+
+            if not self._config.backchannel_logout_ignore_sub:
+                # If OIDC backchannel logouts are enabled, the provider mapping provider
+                # should use the `sub` claim. We verify that by mapping a dumb user and
+                # see if we get back the sub claim
+                user = UserInfo({"sub": "thisisasubject"})
+                try:
+                    subject = self._user_mapping_provider.get_remote_user_id(user)
+                    if subject != user["sub"]:
+                        raise ValueError("Unexpected subject")
+                except Exception:
+                    logger.warning(
+                        f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
+                        "but it looks like the configured `user_mapping_provider` "
+                        "does not use the `sub` claim as subject. If it is the case, "
+                        "and you want Synapse to ignore the `sub` claim in OIDC "
+                        "Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
+                        "to `true` in the issuer config."
+                    )
+
     @property
     def _uses_userinfo(self) -> bool:
         """Returns True if the ``userinfo_endpoint`` should be used.
@@ -414,6 +544,16 @@ class OidcProvider:
             or self._user_profile_method == "userinfo_endpoint"
         )
 
+    @property
+    def issuer(self) -> str:
+        """The issuer identifying this provider."""
+        return self._config.issuer
+
+    @property
+    def client_id(self) -> str:
+        """The client_id used when interacting with this provider."""
+        return self._config.client_id
+
     async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
         """Return the provider metadata.
 
@@ -647,7 +787,7 @@ class OidcProvider:
                 Must include an ``access_token`` field.
 
         Returns:
-            UserInfo: an object representing the user.
+            an object representing the user.
         """
         logger.debug("Using the OAuth2 access_token to request userinfo")
         metadata = await self.load_metadata()
@@ -661,61 +801,99 @@ class OidcProvider:
 
         return UserInfo(resp)
 
-    async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
-        """Return an instance of UserInfo from token's ``id_token``.
+    async def _verify_jwt(
+        self,
+        alg_values: List[str],
+        token: str,
+        claims_cls: Type[C],
+        claims_options: Optional[dict] = None,
+        claims_params: Optional[dict] = None,
+    ) -> C:
+        """Decode and validate a JWT, re-fetching the JWKS as needed.
 
         Args:
-            token: the token given by the ``token_endpoint``.
-                Must include an ``id_token`` field.
-            nonce: the nonce value originally sent in the initial authorization
-                request. This value should match the one inside the token.
+            alg_values: list of `alg` values allowed when verifying the JWT.
+            token: the JWT.
+            claims_cls: the JWTClaims class to use to validate the claims.
+            claims_options: dict of options passed to the `claims_cls` constructor.
+            claims_params: dict of params passed to the `claims_cls` constructor.
 
         Returns:
-            The decoded claims in the ID token.
+            The decoded claims in the JWT.
         """
-        metadata = await self.load_metadata()
-        claims_params = {
-            "nonce": nonce,
-            "client_id": self._client_auth.client_id,
-        }
-        if "access_token" in token:
-            # If we got an `access_token`, there should be an `at_hash` claim
-            # in the `id_token` that we can check against.
-            claims_params["access_token"] = token["access_token"]
-
-        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
         jwt = JsonWebToken(alg_values)
 
-        claim_options = {"iss": {"values": [metadata["issuer"]]}}
-
-        id_token = token["id_token"]
-        logger.debug("Attempting to decode JWT id_token %r", id_token)
+        logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
 
         # Try to decode the keys in cache first, then retry by forcing the keys
         # to be reloaded
         jwk_set = await self.load_jwks()
         try:
             claims = jwt.decode(
-                id_token,
+                token,
                 key=jwk_set,
-                claims_cls=CodeIDToken,
-                claims_options=claim_options,
+                claims_cls=claims_cls,
+                claims_options=claims_options,
                 claims_params=claims_params,
             )
         except ValueError:
             logger.info("Reloading JWKS after decode error")
             jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
             claims = jwt.decode(
-                id_token,
+                token,
                 key=jwk_set,
-                claims_cls=CodeIDToken,
-                claims_options=claim_options,
+                claims_cls=claims_cls,
+                claims_options=claims_options,
                 claims_params=claims_params,
             )
 
-        logger.debug("Decoded id_token JWT %r; validating", claims)
+        logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
 
-        claims.validate(leeway=120)  # allows 2 min of clock skew
+        claims.validate(
+            now=self._clock.time(), leeway=120
+        )  # allows 2 min of clock skew
+        return claims
+
+    async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
+        """Return an instance of UserInfo from token's ``id_token``.
+
+        Args:
+            token: the token given by the ``token_endpoint``.
+                Must include an ``id_token`` field.
+            nonce: the nonce value originally sent in the initial authorization
+                request. This value should match the one inside the token.
+
+        Returns:
+            The decoded claims in the ID token.
+        """
+        id_token = token.get("id_token")
+
+        # That has been theoritically been checked by the caller, so even though
+        # assertion are not enabled in production, it is mainly here to appease mypy
+        assert id_token is not None
+
+        metadata = await self.load_metadata()
+
+        claims_params = {
+            "nonce": nonce,
+            "client_id": self._client_auth.client_id,
+        }
+        if "access_token" in token:
+            # If we got an `access_token`, there should be an `at_hash` claim
+            # in the `id_token` that we can check against.
+            claims_params["access_token"] = token["access_token"]
+
+        claims_options = {"iss": {"values": [metadata["issuer"]]}}
+
+        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+        claims = await self._verify_jwt(
+            alg_values=alg_values,
+            token=id_token,
+            claims_cls=CodeIDToken,
+            claims_options=claims_options,
+            claims_params=claims_params,
+        )
 
         return claims
 
@@ -1036,6 +1214,146 @@ class OidcProvider:
         # to be strings.
         return str(remote_user_id)
 
+    async def handle_backchannel_logout(
+        self, request: SynapseRequest, logout_token: str
+    ) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
+
+        The OIDC Provider posts a logout token to this endpoint when a user
+        session ends. That token is a JWT signed with the same keys as
+        ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
+        validate the JWT and figure out what session to end.
+
+        Args:
+            request: The request to respond to
+            logout_token: The logout token (a JWT) extracted from the request body
+        """
+        # Back-Channel Logout can be disabled in the config, hence this check.
+        # This is not that important for now since Synapse is registered
+        # manually to the OP, so not specifying the backchannel-logout URI is
+        # as effective than disabling it here. It might make more sense if we
+        # support dynamic registration in Synapse at some point.
+        if not self._config.backchannel_logout_enabled:
+            logger.warning(
+                f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
+            )
+
+            # TODO: this responds with a 400 status code, which is what the OIDC
+            # Back-Channel Logout spec expects, but spec also suggests answering with
+            # a JSON object, with the `error` and `error_description` fields set, which
+            # we are not doing here.
+            # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
+            raise SynapseError(
+                400, "OpenID Connect Back-Channel Logout is disabled for this provider"
+            )
+
+        metadata = await self.load_metadata()
+
+        # As per OIDC Back-Channel Logout 1.0 sec. 2.4:
+        #   A Logout Token MUST be signed and MAY also be encrypted. The same
+        #   keys are used to sign and encrypt Logout Tokens as are used for ID
+        #   Tokens. If the Logout Token is encrypted, it SHOULD replicate the
+        #   iss (issuer) claim in the JWT Header Parameters, as specified in
+        #   Section 5.3 of [JWT].
+        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+        # As per sec. 2.6:
+        #    3. Validate the iss, aud, and iat Claims in the same way they are
+        #       validated in ID Tokens.
+        # Which means the audience should contain Synapse's client_id and the
+        # issuer should be the IdP issuer
+        claims_options = {
+            "iss": {"values": [metadata["issuer"]]},
+            "aud": {"values": [self.client_id]},
+        }
+
+        try:
+            claims = await self._verify_jwt(
+                alg_values=alg_values,
+                token=logout_token,
+                claims_cls=LogoutToken,
+                claims_options=claims_options,
+            )
+        except JoseError:
+            logger.exception("Invalid logout_token")
+            raise SynapseError(400, "Invalid logout_token")
+
+        # As per sec. 2.6:
+        #    4. Verify that the Logout Token contains a sub Claim, a sid Claim,
+        #       or both.
+        #    5. Verify that the Logout Token contains an events Claim whose
+        #       value is JSON object containing the member name
+        #       http://schemas.openid.net/event/backchannel-logout.
+        #    6. Verify that the Logout Token does not contain a nonce Claim.
+        # This is all verified by the LogoutToken claims class, so at this
+        # point the `sid` claim exists and is a string.
+        sid: str = claims.get("sid")
+
+        # If the `sub` claim was included in the logout token, we check that it matches
+        # that it matches the right user. We can have cases where the `sub` claim is not
+        # the ID saved in database, so we let admins disable this check in config.
+        sub: Optional[str] = claims.get("sub")
+        expected_user_id: Optional[str] = None
+        if sub is not None and not self._config.backchannel_logout_ignore_sub:
+            expected_user_id = await self._store.get_user_by_external_id(
+                self.idp_id, sub
+            )
+
+        # Invalidate any running user-mapping sessions, in-flight login tokens and
+        # active devices
+        await self._sso_handler.revoke_sessions_for_provider_session_id(
+            auth_provider_id=self.idp_id,
+            auth_provider_session_id=sid,
+            expected_user_id=expected_user_id,
+        )
+
+        request.setResponseCode(200)
+        request.setHeader(b"Cache-Control", b"no-cache, no-store")
+        request.setHeader(b"Pragma", b"no-cache")
+        finish_request(request)
+
+
+class LogoutToken(JWTClaims):
+    """
+    Holds and verify claims of a logout token, as per
+    https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
+    """
+
+    REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
+
+    def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
+        """Validate everything in claims payload."""
+        super().validate(now, leeway)
+        self.validate_sid()
+        self.validate_events()
+        self.validate_nonce()
+
+    def validate_sid(self) -> None:
+        """Ensure the sid claim is present"""
+        sid = self.get("sid")
+        if not sid:
+            raise MissingClaimError("sid")
+
+        if not isinstance(sid, str):
+            raise InvalidClaimError("sid")
+
+    def validate_nonce(self) -> None:
+        """Ensure the nonce claim is absent"""
+        if "nonce" in self:
+            raise InvalidClaimError("nonce")
+
+    def validate_events(self) -> None:
+        """Ensure the events claim is present and with the right value"""
+        events = self.get("events")
+        if not events:
+            raise MissingClaimError("events")
+
+        if not isinstance(events, dict):
+            raise InvalidClaimError("events")
+
+        if "http://schemas.openid.net/event/backchannel-logout" not in events:
+            raise InvalidClaimError("events")
+
 
 # number of seconds a newly-generated client secret should be valid for
 CLIENT_SECRET_VALIDITY_SECONDS = 3600
@@ -1105,6 +1423,7 @@ class JwtClientSecret:
         logger.info(
             "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
         )
+        jwt = JsonWebToken(header["alg"])
         self._cached_secret = jwt.encode(header, payload, self._key.key)
         self._cached_secret_replacement_time = (
             expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
@@ -1116,12 +1435,10 @@ class UserAttributeDict(TypedDict):
     localpart: Optional[str]
     confirm_localpart: bool
     display_name: Optional[str]
+    picture: Optional[str]  # may be omitted by older `OidcMappingProviders`
     emails: List[str]
 
 
-C = TypeVar("C")
-
-
 class OidcMappingProvider(Generic[C]):
     """A mapping provider maps a UserInfo object to user attributes.
 
@@ -1204,6 +1521,7 @@ env.filters.update(
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class JinjaOidcMappingConfig:
     subject_claim: str
+    picture_claim: str
     localpart_template: Optional[Template]
     display_name_template: Optional[Template]
     email_template: Optional[Template]
@@ -1223,6 +1541,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     @staticmethod
     def parse_config(config: dict) -> JinjaOidcMappingConfig:
         subject_claim = config.get("subject_claim", "sub")
+        picture_claim = config.get("picture_claim", "picture")
 
         def parse_template_config(option_name: str) -> Optional[Template]:
             if option_name not in config:
@@ -1256,6 +1575,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
 
         return JinjaOidcMappingConfig(
             subject_claim=subject_claim,
+            picture_claim=picture_claim,
             localpart_template=localpart_template,
             display_name_template=display_name_template,
             email_template=email_template,
@@ -1295,10 +1615,13 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
         if email:
             emails.append(email)
 
+        picture = userinfo.get("picture")
+
         return UserAttributeDict(
             localpart=localpart,
             display_name=display_name,
             emails=emails,
+            picture=picture,
             confirm_localpart=self._config.confirm_localpart,
         )
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 1f83bab836..8c8ff18a1a 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -27,9 +27,9 @@ from synapse.handlers.room import ShutdownRoomResponse
 from synapse.logging.opentracing import trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, Requester, StreamKeyType
+from synapse.types.state import StateFilter
 from synapse.util.async_helpers import ReadWriteLock
 from synapse.util.stringutils import random_string
 from synapse.visibility import filter_events_for_client
@@ -448,6 +448,12 @@ class PaginationHandler:
 
         if pagin_config.from_token:
             from_token = pagin_config.from_token
+        elif pagin_config.direction == "f":
+            from_token = (
+                await self.hs.get_event_sources().get_start_token_for_pagination(
+                    room_id
+                )
+            )
         else:
             from_token = (
                 await self.hs.get_event_sources().get_current_token_for_pagination(
@@ -458,11 +464,6 @@ class PaginationHandler:
             # `/messages` should still works with live tokens when manually provided.
             assert from_token.room_key.topological is not None
 
-        if pagin_config.limit is None:
-            # This shouldn't happen as we've set a default limit before this
-            # gets called.
-            raise Exception("limit not set")
-
         room_token = from_token.room_key
 
         async with self.pagination_lock.read(room_id):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 4e575ffbaa..2af90b25a3 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -201,7 +201,7 @@ class BasePresenceHandler(abc.ABC):
         """Get the current presence state for multiple users.
 
         Returns:
-            dict: `user_id` -> `UserPresenceState`
+            A mapping of `user_id` -> `UserPresenceState`
         """
         states = {}
         missing = []
@@ -256,7 +256,7 @@ class BasePresenceHandler(abc.ABC):
         with the app.
         """
 
-    async def update_external_syncs_row(
+    async def update_external_syncs_row(  # noqa: B027 (no-op by design)
         self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int
     ) -> None:
         """Update the syncing users for an external process as a delta.
@@ -272,7 +272,9 @@ class BasePresenceHandler(abc.ABC):
             sync_time_msec: Time in ms when the user was last syncing
         """
 
-    async def update_external_syncs_clear(self, process_id: str) -> None:
+    async def update_external_syncs_clear(  # noqa: B027 (no-op by design)
+        self, process_id: str
+    ) -> None:
         """Marks all users that had been marked as syncing by a given process
         as offline.
 
@@ -476,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
             return _NullContextManager()
 
         prev_state = await self.current_state_for_user(user_id)
-        if prev_state != PresenceState.BUSY:
+        if prev_state.state != PresenceState.BUSY:
             # We set state here but pass ignore_status_msg = True as we don't want to
             # cause the status message to be cleared.
             # Note that this causes last_active_ts to be incremented which is not
@@ -1596,7 +1598,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
         self,
         user: UserID,
         from_key: Optional[int],
-        limit: Optional[int] = None,
+        # Having a default limit doesn't match the EventSource API, but some
+        # callers do not provide it. It is unused in this class.
+        limit: int = 0,
         room_ids: Optional[Collection[str]] = None,
         is_guest: bool = False,
         explicit_room_id: Optional[str] = None,
@@ -1688,10 +1692,12 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
 
             if from_key is not None:
                 # First get all users that have had a presence update
-                updated_users = stream_change_cache.get_all_entities_changed(from_key)
+                result = stream_change_cache.get_all_entities_changed(from_key)
 
                 # Cross-reference users we're interested in with those that have had updates.
-                if updated_users is not None:
+                if result.hit:
+                    updated_users = result.entities
+
                     # If we have the full list of changes for presence we can
                     # simply check which ones share a room with the user.
                     get_updates_counter.labels("stream").inc()
@@ -1760,14 +1766,14 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
         Returns:
             A list of presence states for the given user to receive.
         """
+        updated_users = None
         if from_key:
             # Only return updates since the last sync
-            updated_users = self.store.presence_stream_cache.get_all_entities_changed(
-                from_key
-            )
-            if not updated_users:
-                updated_users = []
+            result = self.store.presence_stream_cache.get_all_entities_changed(from_key)
+            if result.hit:
+                updated_users = result.entities
 
+        if updated_users is not None:
             # Get the actual presence update for each change
             users_to_state = await self.get_presence_handler().current_state_for_users(
                 updated_users
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index d8ff5289b5..4bf9a047a3 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -307,7 +307,11 @@ class ProfileHandler:
         if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
             return True
 
-        server_name, _, media_id = parse_and_validate_mxc_uri(mxc)
+        host, port, media_id = parse_and_validate_mxc_uri(mxc)
+        if port is not None:
+            server_name = host + ":" + str(port)
+        else:
+            server_name = host
 
         if server_name == self.server_name:
             media_info = await self.store.get_local_media(media_id)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 4a7ec9e426..6a4fed1156 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -92,7 +92,6 @@ class ReceiptsHandler:
                         continue
 
                     # Check if these receipts apply to a thread.
-                    thread_id = None
                     data = user_values.get("data", {})
                     thread_id = data.get("thread_id")
                     # If the thread ID is invalid, consider it missing.
@@ -257,7 +256,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
         self,
         user: UserID,
         from_key: int,
-        limit: Optional[int],
+        limit: int,
         room_ids: Iterable[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index ca1c7a1866..c611efb760 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
 )
 from synapse.appservice import ApplicationService
 from synapse.config.server import is_threepid_reserved
+from synapse.handlers.device import DeviceHandler
 from synapse.http.servlet import assert_params_in_dict
 from synapse.replication.http.login import RegisterDeviceReplicationServlet
 from synapse.replication.http.register import (
@@ -45,8 +46,8 @@ from synapse.replication.http.register import (
     ReplicationRegisterServlet,
 )
 from synapse.spam_checker_api import RegistrationBehaviour
-from synapse.storage.state import StateFilter
 from synapse.types import RoomAlias, UserID, create_requester
+from synapse.types.state import StateFilter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -841,6 +842,9 @@ class RegistrationHandler:
         refresh_token = None
         refresh_token_id = None
 
+        # This can only run on the main process.
+        assert isinstance(self.device_handler, DeviceHandler)
+
         registered_device_id = await self.device_handler.check_device_registered(
             user_id,
             device_id,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 63bc6a7aa5..e96f9999a8 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,17 +11,21 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import enum
 import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Collection, Dict, FrozenSet, Iterable, List, Optional
 
 import attr
 
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import trace
-from synapse.storage.databases.main.relations import _RelatedEvent
-from synapse.types import JsonDict, Requester, StreamToken, UserID
+from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
+from synapse.streams.config import PaginationConfig
+from synapse.types import JsonDict, Requester, UserID
+from synapse.util.async_helpers import gather_results
 from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
@@ -31,6 +35,13 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class ThreadsListInclude(str, enum.Enum):
+    """Valid values for the 'include' flag of /threads."""
+
+    all = "all"
+    participated = "participated"
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _ThreadAggregation:
     # The latest event in the thread.
@@ -66,19 +77,17 @@ class RelationsHandler:
         self._clock = hs.get_clock()
         self._event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
+        self._event_creation_handler = hs.get_event_creation_handler()
 
     async def get_relations(
         self,
         requester: Requester,
         event_id: str,
         room_id: str,
+        pagin_config: PaginationConfig,
+        include_original_event: bool,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
-        limit: int = 5,
-        direction: str = "b",
-        from_token: Optional[StreamToken] = None,
-        to_token: Optional[StreamToken] = None,
-        include_original_event: bool = False,
     ) -> JsonDict:
         """Get related events of a event, ordered by topological ordering.
 
@@ -88,14 +97,10 @@ class RelationsHandler:
             requester: The user requesting the relations.
             event_id: Fetch events that relate to this event ID.
             room_id: The room the event belongs to.
+            pagin_config: The pagination config rules to apply, if any.
+            include_original_event: Whether to include the parent event.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
-            limit: Only fetch the most recent `limit` events.
-            direction: Whether to fetch the most recent first (`"b"`) or the
-                oldest first (`"f"`).
-            from_token: Fetch rows from the given token, or from the start if None.
-            to_token: Fetch rows up to the given token, or up to the end if None.
-            include_original_event: Whether to include the parent event.
 
         Returns:
             The pagination chunk.
@@ -123,10 +128,10 @@ class RelationsHandler:
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
-            limit=limit,
-            direction=direction,
-            from_token=from_token,
-            to_token=to_token,
+            limit=pagin_config.limit,
+            direction=pagin_config.direction,
+            from_token=pagin_config.from_token,
+            to_token=pagin_config.to_token,
         )
 
         events = await self._main_store.get_events_as_list(
@@ -162,90 +167,167 @@ class RelationsHandler:
         if next_token:
             return_value["next_batch"] = await next_token.to_string(self._main_store)
 
-        if from_token:
-            return_value["prev_batch"] = await from_token.to_string(self._main_store)
+        if pagin_config.from_token:
+            return_value["prev_batch"] = await pagin_config.from_token.to_string(
+                self._main_store
+            )
 
         return return_value
 
-    async def get_relations_for_event(
+    async def redact_events_related_to(
         self,
+        requester: Requester,
         event_id: str,
-        event: EventBase,
-        room_id: str,
-        relation_type: str,
-        ignored_users: FrozenSet[str] = frozenset(),
-    ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
-        """Get a list of events which relate to an event, ordered by topological ordering.
+        initial_redaction_event: EventBase,
+        relation_types: List[str],
+    ) -> None:
+        """Redacts all events related to the given event ID with one of the given
+        relation types.
 
-        Args:
-            event_id: Fetch events that relate to this event ID.
-            event: The matching EventBase to event_id.
-            room_id: The room the event belongs to.
-            relation_type: The type of relation.
-            ignored_users: The users ignored by the requesting user.
+        This method is expected to be called when redacting the event referred to by
+        the given event ID.
 
-        Returns:
-            List of event IDs that match relations requested. The rows are of
-            the form `{"event_id": "..."}`.
-        """
+        If an event cannot be redacted (e.g. because of insufficient permissions), log
+        the error and try to redact the next one.
 
-        # Call the underlying storage method, which is cached.
-        related_events, next_token = await self._main_store.get_relations_for_event(
-            event_id, event, room_id, relation_type, direction="f"
+        Args:
+            requester: The requester to redact events on behalf of.
+            event_id: The event IDs to look and redact relations of.
+            initial_redaction_event: The redaction for the event referred to by
+                event_id.
+            relation_types: The types of relations to look for.
+
+        Raises:
+            ShadowBanError if the requester is shadow-banned
+        """
+        related_event_ids = (
+            await self._main_store.get_all_relations_for_event_with_types(
+                event_id, relation_types
+            )
         )
 
-        # Filter out ignored users and convert to the expected format.
-        related_events = [
-            event for event in related_events if event.sender not in ignored_users
-        ]
-
-        return related_events, next_token
+        for related_event_id in related_event_ids:
+            try:
+                await self._event_creation_handler.create_and_send_nonmember_event(
+                    requester,
+                    {
+                        "type": EventTypes.Redaction,
+                        "content": initial_redaction_event.content,
+                        "room_id": initial_redaction_event.room_id,
+                        "sender": requester.user.to_string(),
+                        "redacts": related_event_id,
+                    },
+                    ratelimit=False,
+                )
+            except SynapseError as e:
+                logger.warning(
+                    "Failed to redact event %s (related to event %s): %s",
+                    related_event_id,
+                    event_id,
+                    e.msg,
+                )
 
-    async def get_annotations_for_event(
-        self,
-        event_id: str,
-        room_id: str,
-        limit: int = 5,
-        ignored_users: FrozenSet[str] = frozenset(),
-    ) -> List[JsonDict]:
-        """Get a list of annotations on the event, grouped by event type and
+    async def get_annotations_for_events(
+        self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
+    ) -> Dict[str, List[JsonDict]]:
+        """Get a list of annotations to the given events, grouped by event type and
         aggregation key, sorted by count.
 
-        This is used e.g. to get the what and how many reactions have happend
+        This is used e.g. to get the what and how many reactions have happened
         on an event.
 
         Args:
-            event_id: Fetch events that relate to this event ID.
-            room_id: The room the event belongs to.
-            limit: Only fetch the `limit` groups.
+            event_ids: Fetch events that relate to these event IDs.
             ignored_users: The users ignored by the requesting user.
 
         Returns:
-            List of groups of annotations that match. Each row is a dict with
-            `type`, `key` and `count` fields.
+            A map of event IDs to a list of groups of annotations that match.
+            Each entry is a dict with `type`, `key` and `count` fields.
         """
         # Get the base results for all users.
-        full_results = await self._main_store.get_aggregation_groups_for_event(
-            event_id, room_id, limit
+        full_results = await self._main_store.get_aggregation_groups_for_events(
+            event_ids
         )
 
+        # Avoid additional logic if there are no ignored users.
+        if not ignored_users:
+            return {
+                event_id: results
+                for event_id, results in full_results.items()
+                if results
+            }
+
         # Then subtract off the results for any ignored users.
         ignored_results = await self._main_store.get_aggregation_groups_for_users(
-            event_id, room_id, limit, ignored_users
+            [event_id for event_id, results in full_results.items() if results],
+            ignored_users,
         )
 
-        filtered_results = []
-        for result in full_results:
-            key = (result["type"], result["key"])
-            if key in ignored_results:
-                result = result.copy()
-                result["count"] -= ignored_results[key]
-                if result["count"] <= 0:
-                    continue
-            filtered_results.append(result)
+        filtered_results = {}
+        for event_id, results in full_results.items():
+            # If no annotations, skip.
+            if not results:
+                continue
+
+            # If there are not ignored results for this event, copy verbatim.
+            if event_id not in ignored_results:
+                filtered_results[event_id] = results
+                continue
+
+            # Otherwise, subtract out the ignored results.
+            event_ignored_results = ignored_results[event_id]
+            for result in results:
+                key = (result["type"], result["key"])
+                if key in event_ignored_results:
+                    # Ensure to not modify the cache.
+                    result = result.copy()
+                    result["count"] -= event_ignored_results[key]
+                    if result["count"] <= 0:
+                        continue
+                filtered_results.setdefault(event_id, []).append(result)
 
         return filtered_results
 
+    async def get_references_for_events(
+        self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
+    ) -> Dict[str, List[_RelatedEvent]]:
+        """Get a list of references to the given events.
+
+        Args:
+            event_ids: Fetch events that relate to this event ID.
+            ignored_users: The users ignored by the requesting user.
+
+        Returns:
+            A map of event IDs to a list related events.
+        """
+
+        related_events = await self._main_store.get_references_for_events(event_ids)
+
+        # Avoid additional logic if there are no ignored users.
+        if not ignored_users:
+            return {
+                event_id: results
+                for event_id, results in related_events.items()
+                if results
+            }
+
+        # Filter out ignored users.
+        results = {}
+        for event_id, events in related_events.items():
+            # If no references, skip.
+            if not events:
+                continue
+
+            # Filter ignored users out.
+            events = [event for event in events if event.sender not in ignored_users]
+            # If there are no events left, skip this event.
+            if not events:
+                continue
+
+            results[event_id] = events
+
+        return results
+
     async def _get_threads_for_events(
         self,
         events_by_id: Dict[str, EventBase],
@@ -308,59 +390,66 @@ class RelationsHandler:
         results = {}
 
         for event_id, summary in summaries.items():
-            if summary:
-                thread_count, latest_thread_event = summary
-
-                # Subtract off the count of any ignored users.
-                for ignored_user in ignored_users:
-                    thread_count -= ignored_results.get((event_id, ignored_user), 0)
-
-                # This is gnarly, but if the latest event is from an ignored user,
-                # attempt to find one that isn't from an ignored user.
-                if latest_thread_event.sender in ignored_users:
-                    room_id = latest_thread_event.room_id
-
-                    # If the root event is not found, something went wrong, do
-                    # not include a summary of the thread.
-                    event = await self._event_handler.get_event(user, room_id, event_id)
-                    if event is None:
-                        continue
+            # If no thread, skip.
+            if not summary:
+                continue
 
-                    potential_events, _ = await self.get_relations_for_event(
-                        event_id,
-                        event,
-                        room_id,
-                        RelationTypes.THREAD,
-                        ignored_users,
-                    )
+            thread_count, latest_thread_event = summary
 
-                    # If all found events are from ignored users, do not include
-                    # a summary of the thread.
-                    if not potential_events:
-                        continue
+            # Subtract off the count of any ignored users.
+            for ignored_user in ignored_users:
+                thread_count -= ignored_results.get((event_id, ignored_user), 0)
 
-                    # The *last* event returned is the one that is cared about.
-                    event = await self._event_handler.get_event(
-                        user, room_id, potential_events[-1].event_id
-                    )
-                    # It is unexpected that the event will not exist.
-                    if event is None:
-                        logger.warning(
-                            "Unable to fetch latest event in a thread with event ID: %s",
-                            potential_events[-1].event_id,
-                        )
-                        continue
-                    latest_thread_event = event
-
-                results[event_id] = _ThreadAggregation(
-                    latest_event=latest_thread_event,
-                    count=thread_count,
-                    # If there's a thread summary it must also exist in the
-                    # participated dictionary.
-                    current_user_participated=events_by_id[event_id].sender == user_id
-                    or participated[event_id],
+            # This is gnarly, but if the latest event is from an ignored user,
+            # attempt to find one that isn't from an ignored user.
+            if latest_thread_event.sender in ignored_users:
+                room_id = latest_thread_event.room_id
+
+                # If the root event is not found, something went wrong, do
+                # not include a summary of the thread.
+                event = await self._event_handler.get_event(user, room_id, event_id)
+                if event is None:
+                    continue
+
+                # Attempt to find another event to use as the latest event.
+                potential_events, _ = await self._main_store.get_relations_for_event(
+                    event_id, event, room_id, RelationTypes.THREAD, direction="f"
                 )
 
+                # Filter out ignored users.
+                potential_events = [
+                    event
+                    for event in potential_events
+                    if event.sender not in ignored_users
+                ]
+
+                # If all found events are from ignored users, do not include
+                # a summary of the thread.
+                if not potential_events:
+                    continue
+
+                # The *last* event returned is the one that is cared about.
+                event = await self._event_handler.get_event(
+                    user, room_id, potential_events[-1].event_id
+                )
+                # It is unexpected that the event will not exist.
+                if event is None:
+                    logger.warning(
+                        "Unable to fetch latest event in a thread with event ID: %s",
+                        potential_events[-1].event_id,
+                    )
+                    continue
+                latest_thread_event = event
+
+            results[event_id] = _ThreadAggregation(
+                latest_event=latest_thread_event,
+                count=thread_count,
+                # If there's a thread summary it must also exist in the
+                # participated dictionary.
+                current_user_participated=events_by_id[event_id].sender == user_id
+                or participated[event_id],
+            )
+
         return results
 
     @trace
@@ -438,48 +527,131 @@ class RelationsHandler:
                 # (as that is what makes it part of the thread).
                 relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
 
-        # Fetch other relations per event.
-        for event in events_by_id.values():
-            # Fetch any annotations (ie, reactions) to bundle with this event.
-            annotations = await self.get_annotations_for_event(
-                event.event_id, event.room_id, ignored_users=ignored_users
+        async def _fetch_annotations() -> None:
+            """Fetch any annotations (ie, reactions) to bundle with this event."""
+            annotations_by_event_id = await self.get_annotations_for_events(
+                events_by_id.keys(), ignored_users=ignored_users
             )
-            if annotations:
-                results.setdefault(
-                    event.event_id, BundledAggregations()
-                ).annotations = {"chunk": annotations}
-
-            # Fetch any references to bundle with this event.
-            references, next_token = await self.get_relations_for_event(
-                event.event_id,
-                event,
-                event.room_id,
-                RelationTypes.REFERENCE,
-                ignored_users=ignored_users,
+            for event_id, annotations in annotations_by_event_id.items():
+                if annotations:
+                    results.setdefault(event_id, BundledAggregations()).annotations = {
+                        "chunk": annotations
+                    }
+
+        async def _fetch_references() -> None:
+            """Fetch any references to bundle with this event."""
+            references_by_event_id = await self.get_references_for_events(
+                events_by_id.keys(), ignored_users=ignored_users
+            )
+            for event_id, references in references_by_event_id.items():
+                if references:
+                    results.setdefault(event_id, BundledAggregations()).references = {
+                        "chunk": [{"event_id": ev.event_id} for ev in references]
+                    }
+
+        async def _fetch_edits() -> None:
+            """
+            Fetch any edits (but not for redacted events).
+
+            Note that there is no use in limiting edits by ignored users since the
+            parent event should be ignored in the first place if the user is ignored.
+            """
+            edits = await self._main_store.get_applicable_edits(
+                [
+                    event_id
+                    for event_id, event in events_by_id.items()
+                    if not event.internal_metadata.is_redacted()
+                ]
+            )
+            for event_id, edit in edits.items():
+                results.setdefault(event_id, BundledAggregations()).replace = edit
+
+        # Parallelize the calls for annotations, references, and edits since they
+        # are unrelated.
+        await make_deferred_yieldable(
+            gather_results(
+                (
+                    run_in_background(_fetch_annotations),
+                    run_in_background(_fetch_references),
+                    run_in_background(_fetch_edits),
+                )
             )
-            if references:
-                aggregations = results.setdefault(event.event_id, BundledAggregations())
-                aggregations.references = {
-                    "chunk": [{"event_id": ev.event_id} for ev in references]
-                }
-
-                if next_token:
-                    aggregations.references["next_batch"] = await next_token.to_string(
-                        self._main_store
-                    )
-
-        # Fetch any edits (but not for redacted events).
-        #
-        # Note that there is no use in limiting edits by ignored users since the
-        # parent event should be ignored in the first place if the user is ignored.
-        edits = await self._main_store.get_applicable_edits(
-            [
-                event_id
-                for event_id, event in events_by_id.items()
-                if not event.internal_metadata.is_redacted()
-            ]
         )
-        for event_id, edit in edits.items():
-            results.setdefault(event_id, BundledAggregations()).replace = edit
 
         return results
+
+    async def get_threads(
+        self,
+        requester: Requester,
+        room_id: str,
+        include: ThreadsListInclude,
+        limit: int = 5,
+        from_token: Optional[ThreadsNextBatch] = None,
+    ) -> JsonDict:
+        """Get related events of a event, ordered by topological ordering.
+
+        Args:
+            requester: The user requesting the relations.
+            room_id: The room the event belongs to.
+            include: One of "all" or "participated" to indicate which threads should
+                be returned.
+            limit: Only fetch the most recent `limit` events.
+            from_token: Fetch rows from the given token, or from the start if None.
+
+        Returns:
+            The pagination chunk.
+        """
+
+        user_id = requester.user.to_string()
+
+        # TODO Properly handle a user leaving a room.
+        (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+            room_id, requester, allow_departed_users=True
+        )
+
+        # Note that ignored users are not passed into get_threads
+        # below. Ignored users are handled in filter_events_for_client (and by
+        # not passing them in here we should get a better cache hit rate).
+        thread_roots, next_batch = await self._main_store.get_threads(
+            room_id=room_id, limit=limit, from_token=from_token
+        )
+
+        events = await self._main_store.get_events_as_list(thread_roots)
+
+        if include == ThreadsListInclude.participated:
+            # Pre-seed thread participation with whether the requester sent the event.
+            participated = {event.event_id: event.sender == user_id for event in events}
+            # For events the requester did not send, check the database for whether
+            # the requester sent a threaded reply.
+            participated.update(
+                await self._main_store.get_threads_participated(
+                    [eid for eid, p in participated.items() if not p],
+                    user_id,
+                )
+            )
+
+            # Limit the returned threads to those the user has participated in.
+            events = [event for event in events if participated[event.event_id]]
+
+        events = await filter_events_for_client(
+            self._storage_controllers,
+            user_id,
+            events,
+            is_peeking=(member_event_id is None),
+        )
+
+        aggregations = await self.get_bundled_aggregations(
+            events, requester.user.to_string()
+        )
+
+        now = self._clock.time_msec()
+        serialized_events = self._event_serializer.serialize_events(
+            events, now, bundle_aggregations=aggregations
+        )
+
+        return_value: JsonDict = {"chunk": serialized_events}
+
+        if next_batch:
+            return_value["next_batch"] = str(next_batch)
+
+        return return_value
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 57ab05ad25..f81241c2b3 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -49,7 +49,6 @@ from synapse.api.constants import (
 from synapse.api.errors import (
     AuthError,
     Codes,
-    HttpResponseException,
     LimitExceededError,
     NotFoundError,
     StoreError,
@@ -60,11 +59,9 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
 from synapse.events.utils import copy_and_fixup_power_levels_contents
-from synapse.federation.federation_client import InvalidResponseError
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
 from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.state import StateFilter
 from synapse.streams import EventSource
 from synapse.types import (
     JsonDict,
@@ -79,6 +76,7 @@ from synapse.types import (
     UserID,
     create_requester,
 )
+from synapse.types.state import StateFilter
 from synapse.util import stringutils
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.util.stringutils import parse_and_validate_server_name
@@ -229,9 +227,7 @@ class RoomCreationHandler:
             },
         )
         validate_event_for_room_version(tombstone_event)
-        await self._event_auth_handler.check_auth_rules_from_context(
-            tombstone_event, tombstone_context
-        )
+        await self._event_auth_handler.check_auth_rules_from_context(tombstone_event)
 
         # Upgrade the room
         #
@@ -561,7 +557,6 @@ class RoomCreationHandler:
             invite_list=[],
             initial_state=initial_state,
             creation_content=creation_content,
-            ratelimit=False,
         )
 
         # Transfer membership events
@@ -755,6 +750,10 @@ class RoomCreationHandler:
                 )
 
         if ratelimit:
+            # Rate limit once in advance, but don't rate limit the individual
+            # events in the room — room creation isn't atomic and it's very
+            # janky if half the events in the initial state don't make it because
+            # of rate limiting.
             await self.request_ratelimiter.ratelimit(requester)
 
         room_version_id = config.get(
@@ -915,7 +914,6 @@ class RoomCreationHandler:
             room_alias=room_alias,
             power_level_content_override=power_level_content_override,
             creator_join_profile=creator_join_profile,
-            ratelimit=ratelimit,
         )
 
         if "name" in config:
@@ -1039,7 +1037,6 @@ class RoomCreationHandler:
         room_alias: Optional[RoomAlias] = None,
         power_level_content_override: Optional[JsonDict] = None,
         creator_join_profile: Optional[JsonDict] = None,
-        ratelimit: bool = True,
     ) -> Tuple[int, str, int]:
         """Sends the initial events into a new room. Sends the room creation, membership,
         and power level events into the room sequentially, then creates and batches up the
@@ -1048,6 +1045,8 @@ class RoomCreationHandler:
         `power_level_content_override` doesn't apply when initial state has
         power level state event content.
 
+        Rate limiting should already have been applied by this point.
+
         Returns:
             A tuple containing the stream ID, event ID and depth of the last
             event sent to the room.
@@ -1057,9 +1056,6 @@ class RoomCreationHandler:
         event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
         depth = 1
 
-        # the last event sent/persisted to the db
-        last_sent_event_id: Optional[str] = None
-
         # the most recently created event
         prev_event: List[str] = []
         # a map of event types, state keys -> event_ids. We collect these mappings this as events are
@@ -1084,6 +1080,19 @@ class RoomCreationHandler:
             for_batch: bool,
             **kwargs: Any,
         ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
+            """
+            Creates an event and associated event context.
+            Args:
+                etype: the type of event to be created
+                content: content of the event
+                for_batch: whether the event is being created for batch persisting. If
+                bool for_batch is true, this will create an event using the prev_event_ids,
+                and will create an event context for the event using the parameters state_map
+                and current_state_group, thus these parameters must be provided in this
+                case if for_batch is True. The subsequently created event and context
+                are suitable for being batched up and bulk persisted to the database
+                with other similarly created events.
+            """
             nonlocal depth
             nonlocal prev_event
 
@@ -1104,26 +1113,6 @@ class RoomCreationHandler:
 
             return new_event, new_context
 
-        async def send(
-            event: EventBase,
-            context: synapse.events.snapshot.EventContext,
-            creator: Requester,
-        ) -> int:
-            nonlocal last_sent_event_id
-
-            ev = await self.event_creation_handler.handle_new_client_event(
-                requester=creator,
-                events_and_context=[(event, context)],
-                ratelimit=False,
-                ignore_shadow_ban=True,
-            )
-
-            last_sent_event_id = ev.event_id
-
-            # we know it was persisted, so must have a stream ordering
-            assert ev.internal_metadata.stream_ordering
-            return ev.internal_metadata.stream_ordering
-
         try:
             config = self._presets_dict[preset_config]
         except KeyError:
@@ -1137,16 +1126,20 @@ class RoomCreationHandler:
         )
 
         logger.debug("Sending %s in new room", EventTypes.Member)
-        await send(creation_event, creation_context, creator)
+        ev = await self.event_creation_handler.handle_new_client_event(
+            requester=creator,
+            events_and_context=[(creation_event, creation_context)],
+            ratelimit=False,
+            ignore_shadow_ban=True,
+        )
+        last_sent_event_id = ev.event_id
 
-        # Room create event must exist at this point
-        assert last_sent_event_id is not None
         member_event_id, _ = await self.room_member_handler.update_membership(
             creator,
             creator.user,
             room_id,
             "join",
-            ratelimit=ratelimit,
+            ratelimit=False,
             content=creator_join_profile,
             new_room=True,
             prev_event_ids=[last_sent_event_id],
@@ -1159,15 +1152,24 @@ class RoomCreationHandler:
         depth += 1
         state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id
 
+        # we need the state group of the membership event as it is the current state group
+        event_to_state = (
+            await self._storage_controllers.state.get_state_group_for_events(
+                [member_event_id]
+            )
+        )
+        current_state_group = event_to_state[member_event_id]
+
+        events_to_send = []
         # We treat the power levels override specially as this needs to be one
         # of the first events that get sent into a room.
         pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
         if pl_content is not None:
             power_event, power_context = await create_event(
-                EventTypes.PowerLevels, pl_content, False
+                EventTypes.PowerLevels, pl_content, True
             )
             current_state_group = power_context._state_group
-            await send(power_event, power_context, creator)
+            events_to_send.append((power_event, power_context))
         else:
             power_level_content: JsonDict = {
                 "users": {creator_id: 100},
@@ -1213,12 +1215,11 @@ class RoomCreationHandler:
             pl_event, pl_context = await create_event(
                 EventTypes.PowerLevels,
                 power_level_content,
-                False,
+                True,
             )
             current_state_group = pl_context._state_group
-            await send(pl_event, pl_context, creator)
+            events_to_send.append((pl_event, pl_context))
 
-        events_to_send = []
         if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
             room_alias_event, room_alias_context = await create_event(
                 EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
@@ -1271,7 +1272,10 @@ class RoomCreationHandler:
             events_to_send.append((encryption_event, encryption_context))
 
         last_event = await self.event_creation_handler.handle_new_client_event(
-            creator, events_to_send, ignore_shadow_ban=True
+            creator,
+            events_to_send,
+            ignore_shadow_ban=True,
+            ratelimit=False,
         )
         assert last_event.internal_metadata.stream_ordering is not None
         return last_event.internal_metadata.stream_ordering, last_event.event_id, depth
@@ -1447,7 +1451,7 @@ class RoomContextHandler:
             events_before=events_before,
             event=event,
             events_after=events_after,
-            state=await filter_evts(state_events),
+            state=state_events,
             aggregations=aggregations,
             start=await token.copy_and_replace(
                 StreamKeyType.ROOM, results.start
@@ -1493,7 +1497,12 @@ class TimestampLookupHandler:
         Raises:
             SynapseError if unable to find any event locally in the given direction
         """
-
+        logger.debug(
+            "get_event_for_timestamp(room_id=%s, timestamp=%s, direction=%s) Finding closest event...",
+            room_id,
+            timestamp,
+            direction,
+        )
         local_event_id = await self.store.get_event_id_for_timestamp(
             room_id, timestamp, direction
         )
@@ -1545,85 +1554,54 @@ class TimestampLookupHandler:
                 )
             )
 
-            # Loop through each homeserver candidate until we get a succesful response
-            for domain in likely_domains:
-                # We don't want to ask our own server for information we don't have
-                if domain == self.server_name:
-                    continue
+            remote_response = await self.federation_client.timestamp_to_event(
+                destinations=likely_domains,
+                room_id=room_id,
+                timestamp=timestamp,
+                direction=direction,
+            )
+            if remote_response is not None:
+                logger.debug(
+                    "get_event_for_timestamp: remote_response=%s",
+                    remote_response,
+                )
 
-                try:
-                    remote_response = await self.federation_client.timestamp_to_event(
-                        domain, room_id, timestamp, direction
-                    )
-                    logger.debug(
-                        "get_event_for_timestamp: response from domain(%s)=%s",
-                        domain,
-                        remote_response,
-                    )
+                remote_event_id = remote_response.event_id
+                remote_origin_server_ts = remote_response.origin_server_ts
 
-                    remote_event_id = remote_response.event_id
-                    remote_origin_server_ts = remote_response.origin_server_ts
-
-                    # Backfill this event so we can get a pagination token for
-                    # it with `/context` and paginate `/messages` from this
-                    # point.
-                    #
-                    # TODO: The requested timestamp may lie in a part of the
-                    #   event graph that the remote server *also* didn't have,
-                    #   in which case they will have returned another event
-                    #   which may be nowhere near the requested timestamp. In
-                    #   the future, we may need to reconcile that gap and ask
-                    #   other homeservers, and/or extend `/timestamp_to_event`
-                    #   to return events on *both* sides of the timestamp to
-                    #   help reconcile the gap faster.
-                    remote_event = (
-                        await self.federation_event_handler.backfill_event_id(
-                            domain, room_id, remote_event_id
-                        )
-                    )
+                # Backfill this event so we can get a pagination token for
+                # it with `/context` and paginate `/messages` from this
+                # point.
+                pulled_pdu_info = await self.federation_event_handler.backfill_event_id(
+                    likely_domains, room_id, remote_event_id
+                )
+                remote_event = pulled_pdu_info.pdu
 
-                    # XXX: When we see that the remote server is not trustworthy,
-                    # maybe we should not ask them first in the future.
-                    if remote_origin_server_ts != remote_event.origin_server_ts:
-                        logger.info(
-                            "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
-                            domain,
-                            remote_event_id,
-                            remote_origin_server_ts,
-                            remote_event.origin_server_ts,
-                        )
-
-                    # Only return the remote event if it's closer than the local event
-                    if not local_event or (
-                        abs(remote_event.origin_server_ts - timestamp)
-                        < abs(local_event.origin_server_ts - timestamp)
-                    ):
-                        logger.info(
-                            "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
-                            remote_event_id,
-                            remote_event.origin_server_ts,
-                            timestamp,
-                            local_event.event_id if local_event else None,
-                            local_event.origin_server_ts if local_event else None,
-                        )
-                        return remote_event_id, remote_origin_server_ts
-                except (HttpResponseException, InvalidResponseError) as ex:
-                    # Let's not put a high priority on some other homeserver
-                    # failing to respond or giving a random response
-                    logger.debug(
-                        "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
-                        domain,
-                        type(ex).__name__,
-                        ex,
-                        ex.args,
+                # XXX: When we see that the remote server is not trustworthy,
+                # maybe we should not ask them first in the future.
+                if remote_origin_server_ts != remote_event.origin_server_ts:
+                    logger.info(
+                        "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
+                        pulled_pdu_info.pull_origin,
+                        remote_event_id,
+                        remote_origin_server_ts,
+                        remote_event.origin_server_ts,
                     )
-                except Exception:
-                    # But we do want to see some exceptions in our code
-                    logger.warning(
-                        "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception",
-                        domain,
-                        exc_info=True,
+
+                # Only return the remote event if it's closer than the local event
+                if not local_event or (
+                    abs(remote_event.origin_server_ts - timestamp)
+                    < abs(local_event.origin_server_ts - timestamp)
+                ):
+                    logger.info(
+                        "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
+                        remote_event_id,
+                        remote_event.origin_server_ts,
+                        timestamp,
+                        local_event.event_id if local_event else None,
+                        local_event.origin_server_ts if local_event else None,
                     )
+                    return remote_event_id, remote_origin_server_ts
 
         # To appease mypy, we have to add both of these conditions to check for
         # `None`. We only expect `local_event` to be `None` when
@@ -1646,7 +1624,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
         self,
         user: UserID,
         from_key: RoomStreamToken,
-        limit: Optional[int],
+        limit: int,
         room_ids: Collection[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 6ad2b38b8f..0c39e852a1 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -34,7 +34,6 @@ from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
 from synapse.logging import opentracing
 from synapse.module_api import NOT_SPAM
-from synapse.storage.state import StateFilter
 from synapse.types import (
     JsonDict,
     Requester,
@@ -45,6 +44,7 @@ from synapse.types import (
     create_requester,
     get_domain_from_id,
 )
+from synapse.types.state import StateFilter
 from synapse.util.async_helpers import Linearizer
 from synapse.util.distributor import user_left_room
 
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 9602f0d0bb..874860d461 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -441,7 +441,7 @@ class DefaultSamlMappingProvider:
             client_redirect_url: where the client wants to redirect to
 
         Returns:
-            dict: A dict containing new user attributes. Possible keys:
+            A dict containing new user attributes. Possible keys:
                 * mxid_localpart (str): Required. The localpart of the user's mxid
                 * displayname (str): The displayname of the user
                 * emails (list[str]): Any emails for the user
@@ -483,7 +483,7 @@ class DefaultSamlMappingProvider:
         Args:
             config: A dictionary containing configuration options for this provider
         Returns:
-            SamlConfig: A custom config object for this module
+            A custom config object for this module
         """
         # Parse config options and use defaults where necessary
         mxid_source_attribute = config.get("mxid_source_attribute", "uid")
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index bcab98c6d5..33115ce488 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -23,8 +23,8 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
-from synapse.storage.state import StateFilter
 from synapse.types import JsonDict, StreamKeyType, UserID
+from synapse.types.state import StateFilter
 from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
diff --git a/synapse/handlers/set_password.py b/synapse/handlers/set_password.py
index 73861bbd40..bd9d0bb34b 100644
--- a/synapse/handlers/set_password.py
+++ b/synapse/handlers/set_password.py
@@ -15,6 +15,7 @@ import logging
 from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import Codes, StoreError, SynapseError
+from synapse.handlers.device import DeviceHandler
 from synapse.types import Requester
 
 if TYPE_CHECKING:
@@ -29,7 +30,10 @@ class SetPasswordHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastores().main
         self._auth_handler = hs.get_auth_handler()
-        self._device_handler = hs.get_device_handler()
+        # This can only be instantiated on the main process.
+        device_handler = hs.get_device_handler()
+        assert isinstance(device_handler, DeviceHandler)
+        self._device_handler = device_handler
 
     async def set_password(
         self,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index e035677b8a..44e70fc4b8 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -12,6 +12,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import abc
+import hashlib
+import io
 import logging
 from typing import (
     TYPE_CHECKING,
@@ -37,6 +39,7 @@ from twisted.web.server import Request
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
 from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.device import DeviceHandler
 from synapse.handlers.register import init_counters_for_auth_provider
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
@@ -137,6 +140,7 @@ class UserAttributes:
     localpart: Optional[str]
     confirm_localpart: bool = False
     display_name: Optional[str] = None
+    picture: Optional[str] = None
     emails: Collection[str] = attr.Factory(list)
 
 
@@ -191,9 +195,14 @@ class SsoHandler:
         self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
         self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
         self._error_template = hs.config.sso.sso_error_template
         self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
         self._profile_handler = hs.get_profile_handler()
+        self._media_repo = (
+            hs.get_media_repository() if hs.config.media.can_load_media_repo else None
+        )
+        self._http_client = hs.get_proxied_blacklisted_http_client()
 
         # The following template is shown after a successful user interactive
         # authentication session. It tells the user they can close the window.
@@ -493,6 +502,8 @@ class SsoHandler:
                         await self._profile_handler.set_displayname(
                             user_id_obj, requester, attributes.display_name, True
                         )
+                if attributes.picture:
+                    await self.set_avatar(user_id, attributes.picture)
 
         await self._auth_handler.complete_sso_login(
             user_id,
@@ -701,8 +712,110 @@ class SsoHandler:
         await self._store.record_user_external_id(
             auth_provider_id, remote_user_id, registered_user_id
         )
+
+        # Set avatar, if available
+        if attributes.picture:
+            await self.set_avatar(registered_user_id, attributes.picture)
+
         return registered_user_id
 
+    async def set_avatar(self, user_id: str, picture_https_url: str) -> bool:
+        """Set avatar of the user.
+
+        This downloads the image file from the URL provided, stores that in
+        the media repository and then sets the avatar on the user's profile.
+
+        It can detect if the same image is being saved again and bails early by storing
+        the hash of the file in the `upload_name` of the avatar image.
+
+        Currently, it only supports server configurations which run the media repository
+        within the same process.
+
+        It silently fails and logs a warning by raising an exception and catching it
+        internally if:
+         * it is unable to fetch the image itself (non 200 status code) or
+         * the image supplied is bigger than max allowed size or
+         * the image type is not one of the allowed image types.
+
+        Args:
+            user_id: matrix user ID in the form @localpart:domain as a string.
+
+            picture_https_url: HTTPS url for the picture image file.
+
+        Returns: `True` if the user's avatar has been successfully set to the image at
+            `picture_https_url`.
+        """
+        if self._media_repo is None:
+            logger.info(
+                "failed to set user avatar because out-of-process media repositories "
+                "are not supported yet "
+            )
+            return False
+
+        try:
+            uid = UserID.from_string(user_id)
+
+            def is_allowed_mime_type(content_type: str) -> bool:
+                if (
+                    self._profile_handler.allowed_avatar_mimetypes
+                    and content_type
+                    not in self._profile_handler.allowed_avatar_mimetypes
+                ):
+                    return False
+                return True
+
+            # download picture, enforcing size limit & mime type check
+            picture = io.BytesIO()
+
+            content_length, headers, uri, code = await self._http_client.get_file(
+                url=picture_https_url,
+                output_stream=picture,
+                max_size=self._profile_handler.max_avatar_size,
+                is_allowed_content_type=is_allowed_mime_type,
+            )
+
+            if code != 200:
+                raise Exception(
+                    "GET request to download sso avatar image returned {}".format(code)
+                )
+
+            # upload name includes hash of the image file's content so that we can
+            # easily check if it requires an update or not, the next time user logs in
+            upload_name = "sso_avatar_" + hashlib.sha256(picture.read()).hexdigest()
+
+            # bail if user already has the same avatar
+            profile = await self._profile_handler.get_profile(user_id)
+            if profile["avatar_url"] is not None:
+                server_name = profile["avatar_url"].split("/")[-2]
+                media_id = profile["avatar_url"].split("/")[-1]
+                if server_name == self._server_name:
+                    media = await self._media_repo.store.get_local_media(media_id)
+                    if media is not None and upload_name == media["upload_name"]:
+                        logger.info("skipping saving the user avatar")
+                        return True
+
+            # store it in media repository
+            avatar_mxc_url = await self._media_repo.create_content(
+                media_type=headers[b"Content-Type"][0].decode("utf-8"),
+                upload_name=upload_name,
+                content=picture,
+                content_length=content_length,
+                auth_user=uid,
+            )
+
+            # save it as user avatar
+            await self._profile_handler.set_avatar_url(
+                uid,
+                create_requester(uid),
+                str(avatar_mxc_url),
+            )
+
+            logger.info("successfully saved the user avatar")
+            return True
+        except Exception:
+            logger.warning("failed to save the user avatar")
+            return False
+
     async def complete_sso_ui_auth_request(
         self,
         auth_provider_id: str,
@@ -874,7 +987,7 @@ class SsoHandler:
         )
 
     async def handle_terms_accepted(
-        self, request: Request, session_id: str, terms_version: str
+        self, request: SynapseRequest, session_id: str, terms_version: str
     ) -> None:
         """Handle a request to the new-user 'consent' endpoint
 
@@ -1026,6 +1139,84 @@ class SsoHandler:
 
         return True
 
+    async def revoke_sessions_for_provider_session_id(
+        self,
+        auth_provider_id: str,
+        auth_provider_session_id: str,
+        expected_user_id: Optional[str] = None,
+    ) -> None:
+        """Revoke any devices and in-flight logins tied to a provider session.
+
+        Can only be called from the main process.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            auth_provider_session_id: The session ID from the provider to logout
+            expected_user_id: The user we're expecting to logout. If set, it will ignore
+                sessions belonging to other users and log an error.
+        """
+
+        # It is expected that this is the main process.
+        assert isinstance(
+            self._device_handler, DeviceHandler
+        ), "revoking SSO sessions can only be called on the main process"
+
+        # Invalidate any running user-mapping sessions
+        to_delete = []
+        for session_id, session in self._username_mapping_sessions.items():
+            if (
+                session.auth_provider_id == auth_provider_id
+                and session.auth_provider_session_id == auth_provider_session_id
+            ):
+                to_delete.append(session_id)
+
+        for session_id in to_delete:
+            logger.info("Revoking mapping session %s", session_id)
+            del self._username_mapping_sessions[session_id]
+
+        # Invalidate any in-flight login tokens
+        await self._store.invalidate_login_tokens_by_session_id(
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+
+        # Fetch any device(s) in the store associated with the session ID.
+        devices = await self._store.get_devices_by_auth_provider_session_id(
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+
+        # We have no guarantee that all the devices of that session are for the same
+        # `user_id`. Hence, we have to iterate over the list of devices and log them out
+        # one by one.
+        for device in devices:
+            user_id = device["user_id"]
+            device_id = device["device_id"]
+
+            # If the user_id associated with that device/session is not the one we got
+            # out of the `sub` claim, skip that device and show log an error.
+            if expected_user_id is not None and user_id != expected_user_id:
+                logger.error(
+                    "Received a logout notification from SSO provider "
+                    f"{auth_provider_id!r} for the user {expected_user_id!r}, but with "
+                    f"a session ID ({auth_provider_session_id!r}) which belongs to "
+                    f"{user_id!r}. This may happen when the SSO provider user mapper "
+                    "uses something else than the standard attribute as mapping ID. "
+                    "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` "
+                    "in the provider config if that is the case."
+                )
+                continue
+
+            logger.info(
+                "Logging out %r (device %r) via SSO (%r) logout notification (session %r).",
+                user_id,
+                device_id,
+                auth_provider_id,
+                auth_provider_session_id,
+            )
+            await self._device_handler.delete_devices(user_id, [device_id])
+
 
 def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
     """Extract the session ID from the cookie
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 1db5d68021..7d6a653747 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -31,18 +31,24 @@ from typing import (
 import attr
 from prometheus_client import Counter
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.api.filtering import FilterCollection
 from synapse.api.presence import UserPresenceState
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
 from synapse.handlers.relations import BundledAggregations
 from synapse.logging.context import current_context
-from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
+from synapse.logging.opentracing import (
+    SynapseTags,
+    log_kv,
+    set_tag,
+    start_active_span,
+    trace,
+)
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.databases.main.event_push_actions import RoomNotifCounts
+from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
 from synapse.storage.roommember import MemberSummary
-from synapse.storage.state import StateFilter
 from synapse.types import (
     DeviceListUpdates,
     JsonDict,
@@ -54,6 +60,7 @@ from synapse.types import (
     StreamToken,
     UserID,
 )
+from synapse.types.state import StateFilter
 from synapse.util.async_helpers import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.caches.lrucache import LruCache
@@ -805,18 +812,6 @@ class SyncHandler:
             if canonical_alias and canonical_alias.content.get("alias"):
                 return summary
 
-        me = sync_config.user.to_string()
-
-        joined_user_ids = [
-            r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
-        ]
-        invited_user_ids = [
-            r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
-        ]
-        gone_user_ids = [
-            r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
-        ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
-
         # FIXME: only build up a member_ids list for our heroes
         member_ids = {}
         for membership in (
@@ -828,11 +823,8 @@ class SyncHandler:
             for user_id, event_id in details.get(membership, empty_ms).members:
                 member_ids[user_id] = event_id
 
-        # FIXME: order by stream ordering rather than as returned by SQL
-        if joined_user_ids or invited_user_ids:
-            summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5]
-        else:
-            summary["m.heroes"] = sorted(gone_user_ids)[0:5]
+        me = sync_config.user.to_string()
+        summary["m.heroes"] = extract_heroes_from_room_summary(details, me)
 
         if not sync_config.filter_collection.lazy_load_members():
             return summary
@@ -1440,14 +1432,14 @@ class SyncHandler:
 
         logger.debug("Fetching OTK data")
         device_id = sync_config.device_id
-        one_time_key_counts: JsonDict = {}
+        one_time_keys_count: JsonDict = {}
         unused_fallback_key_types: List[str] = []
         if device_id:
             # TODO: We should have a way to let clients differentiate between the states of:
             #   * no change in OTK count since the provided since token
             #   * the server has zero OTKs left for this device
             #  Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298
-            one_time_key_counts = await self.store.count_e2e_one_time_keys(
+            one_time_keys_count = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
             unused_fallback_key_types = (
@@ -1477,7 +1469,7 @@ class SyncHandler:
             archived=sync_result_builder.archived,
             to_device=sync_result_builder.to_device,
             device_lists=device_lists,
-            device_one_time_keys_count=one_time_key_counts,
+            device_one_time_keys_count=one_time_keys_count,
             device_unused_fallback_key_types=unused_fallback_key_types,
             next_batch=sync_result_builder.now_token,
         )
@@ -1542,10 +1534,12 @@ class SyncHandler:
             #
             # If we don't have that info cached then we get all the users that
             # share a room with our user and check if those users have changed.
-            changed_users = self.store.get_cached_device_list_changes(
+            cache_result = self.store.get_cached_device_list_changes(
                 since_token.device_list_key
             )
-            if changed_users is not None:
+            if cache_result.hit:
+                changed_users = cache_result.entities
+
                 result = await self.store.get_rooms_for_users(changed_users)
 
                 for changed_user_id, entries in result.items():
@@ -1598,6 +1592,7 @@ class SyncHandler:
         else:
             return DeviceListUpdates()
 
+    @trace
     async def _generate_sync_entry_for_to_device(
         self, sync_result_builder: "SyncResultBuilder"
     ) -> None:
@@ -1617,11 +1612,16 @@ class SyncHandler:
             )
 
             for message in messages:
-                # We pop here as we shouldn't be sending the message ID down
-                # `/sync`
-                message_id = message.pop("message_id", None)
-                if message_id:
-                    set_tag(SynapseTags.TO_DEVICE_MESSAGE_ID, message_id)
+                log_kv(
+                    {
+                        "event": "to_device_message",
+                        "sender": message["sender"],
+                        "type": message["type"],
+                        EventContentFields.TO_DEVICE_MSGID: message["content"].get(
+                            EventContentFields.TO_DEVICE_MSGID
+                        ),
+                    }
+                )
 
             logger.debug(
                 "Returning %d to-device messages between %d and %d (current token: %d)",
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index f953691669..3f656ea4f5 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -420,11 +420,11 @@ class TypingWriterHandler(FollowerTypingHandler):
         if last_id == current_id:
             return [], current_id, False
 
-        changed_rooms: Optional[
-            Iterable[str]
-        ] = self._typing_stream_change_cache.get_all_entities_changed(last_id)
+        result = self._typing_stream_change_cache.get_all_entities_changed(last_id)
 
-        if changed_rooms is None:
+        if result.hit:
+            changed_rooms: Iterable[str] = result.entities
+        else:
             changed_rooms = self._room_serials
 
         rows = []
@@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
         self,
         user: UserID,
         from_key: int,
-        limit: Optional[int],
+        limit: int,
         room_ids: Iterable[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,