summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/experimental.py7
-rw-r--r--synapse/config/logger.py12
-rw-r--r--synapse/events/utils.py69
-rw-r--r--synapse/federation/federation_base.py10
-rw-r--r--synapse/federation/federation_client.py114
-rw-r--r--synapse/federation/sender/per_destination_queue.py18
-rw-r--r--synapse/federation/transport/client.py188
-rw-r--r--synapse/handlers/auth.py58
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/federation_event.py46
-rw-r--r--synapse/handlers/message.py10
-rw-r--r--synapse/handlers/presence.py26
-rw-r--r--synapse/handlers/register.py6
-rw-r--r--synapse/handlers/room_member.py46
-rw-r--r--synapse/handlers/search.py643
-rw-r--r--synapse/handlers/sync.py68
-rw-r--r--synapse/http/matrixfederationclient.py50
-rw-r--r--synapse/logging/_structured.py163
-rw-r--r--synapse/module_api/__init__.py11
-rw-r--r--synapse/notifier.py4
-rw-r--r--synapse/push/baserules.py10
-rw-r--r--synapse/push/httppusher.py9
-rw-r--r--synapse/python_dependencies.py3
-rw-r--r--synapse/res/providers.json4
-rw-r--r--synapse/rest/client/account.py2
-rw-r--r--synapse/rest/client/auth.py8
-rw-r--r--synapse/rest/client/capabilities.py14
-rw-r--r--synapse/rest/client/register.py7
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/rest/media/v1/preview_url_resource.py10
-rw-r--r--synapse/storage/databases/main/devices.py10
-rw-r--r--synapse/storage/databases/main/events.py27
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/presence.py61
-rw-r--r--synapse/storage/databases/main/purge_events.py13
-rw-r--r--synapse/storage/databases/main/registration.py21
-rw-r--r--synapse/storage/databases/main/relations.py26
-rw-r--r--synapse/storage/databases/main/room.py4
-rw-r--r--synapse/storage/databases/main/roommember.py62
-rw-r--r--synapse/storage/databases/main/search.py17
-rw-r--r--synapse/storage/databases/main/user_directory.py22
-rw-r--r--synapse/storage/databases/state/store.py203
-rw-r--r--synapse/storage/state.py8
-rw-r--r--synapse/streams/events.py6
-rw-r--r--synapse/types.py6
-rw-r--r--synapse/util/caches/__init__.py1
-rw-r--r--synapse/util/caches/expiringcache.py5
-rw-r--r--synapse/util/caches/lrucache.py4
-rw-r--r--synapse/util/daemonize.py8
-rw-r--r--synapse/util/patch_inline_callbacks.py6
50 files changed, 1449 insertions, 685 deletions
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 09d692d9a1..bcdeb9ee23 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -41,9 +41,6 @@ class ExperimentalConfig(Config):
         # MSC3244 (room version capabilities)
         self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
 
-        # MSC3283 (set displayname, avatar_url and change 3pid capabilities)
-        self.msc3283_enabled: bool = experimental.get("msc3283_enabled", False)
-
         # MSC3266 (room summary api)
         self.msc3266_enabled: bool = experimental.get("msc3266_enabled", False)
 
@@ -64,3 +61,7 @@ class ExperimentalConfig(Config):
 
         # MSC3706 (server-side support for partial state in /send_join responses)
         self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
+
+        # experimental support for faster joins over federation (msc2775, msc3706)
+        # requires a target server with msc3706_enabled enabled.
+        self.faster_joins_enabled: bool = experimental.get("faster_joins", False)
diff --git a/synapse/config/logger.py b/synapse/config/logger.py
index b7145a44ae..cbbe221965 100644
--- a/synapse/config/logger.py
+++ b/synapse/config/logger.py
@@ -33,7 +33,6 @@ from twisted.logger import (
     globalLogBeginner,
 )
 
-from synapse.logging._structured import setup_structured_logging
 from synapse.logging.context import LoggingContextFilter
 from synapse.logging.filter import MetadataFilter
 
@@ -138,6 +137,12 @@ Support for the log_file configuration option and --log-file command-line option
 removed in Synapse 1.3.0. You should instead set up a separate log configuration file.
 """
 
+STRUCTURED_ERROR = """\
+Support for the structured configuration option was removed in Synapse 1.54.0.
+You should instead use the standard logging configuration. See
+https://matrix-org.github.io/synapse/v1.54/structured_logging.html
+"""
+
 
 class LoggingConfig(Config):
     section = "logging"
@@ -292,10 +297,9 @@ def _load_logging_config(log_config_path: str) -> None:
     if not log_config:
         logging.warning("Loaded a blank logging config?")
 
-    # If the old structured logging configuration is being used, convert it to
-    # the new style configuration.
+    # If the old structured logging configuration is being used, raise an error.
     if "structured" in log_config and log_config.get("structured"):
-        log_config = setup_structured_logging(log_config)
+        raise ConfigError(STRUCTURED_ERROR)
 
     logging.config.dictConfig(log_config)
 
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 243696b357..9386fa29dd 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -425,6 +425,33 @@ class EventClientSerializer:
 
         return serialized_event
 
+    def _apply_edit(
+        self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase
+    ) -> None:
+        """Replace the content, preserving existing relations of the serialized event.
+
+        Args:
+            orig_event: The original event.
+            serialized_event: The original event, serialized. This is modified.
+            edit: The event which edits the above.
+        """
+
+        # Ensure we take copies of the edit content, otherwise we risk modifying
+        # the original event.
+        edit_content = edit.content.copy()
+
+        # Unfreeze the event content if necessary, so that we may modify it below
+        edit_content = unfreeze(edit_content)
+        serialized_event["content"] = edit_content.get("m.new_content", {})
+
+        # Check for existing relations
+        relates_to = orig_event.content.get("m.relates_to")
+        if relates_to:
+            # Keep the relations, ensuring we use a dict copy of the original
+            serialized_event["content"]["m.relates_to"] = relates_to.copy()
+        else:
+            serialized_event["content"].pop("m.relates_to", None)
+
     def _inject_bundled_aggregations(
         self,
         event: EventBase,
@@ -450,26 +477,11 @@ class EventClientSerializer:
             serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
 
         if aggregations.replace:
-            # If there is an edit replace the content, preserving existing
-            # relations.
+            # If there is an edit, apply it to the event.
             edit = aggregations.replace
+            self._apply_edit(event, serialized_event, edit)
 
-            # Ensure we take copies of the edit content, otherwise we risk modifying
-            # the original event.
-            edit_content = edit.content.copy()
-
-            # Unfreeze the event content if necessary, so that we may modify it below
-            edit_content = unfreeze(edit_content)
-            serialized_event["content"] = edit_content.get("m.new_content", {})
-
-            # Check for existing relations
-            relates_to = event.content.get("m.relates_to")
-            if relates_to:
-                # Keep the relations, ensuring we use a dict copy of the original
-                serialized_event["content"]["m.relates_to"] = relates_to.copy()
-            else:
-                serialized_event["content"].pop("m.relates_to", None)
-
+            # Include information about it in the relations dict.
             serialized_aggregations[RelationTypes.REPLACE] = {
                 "event_id": edit.event_id,
                 "origin_server_ts": edit.origin_server_ts,
@@ -478,13 +490,22 @@ class EventClientSerializer:
 
         # If this event is the start of a thread, include a summary of the replies.
         if aggregations.thread:
+            thread = aggregations.thread
+
+            # Don't bundle aggregations as this could recurse forever.
+            serialized_latest_event = self.serialize_event(
+                thread.latest_event, time_now, bundle_aggregations=None
+            )
+            # Manually apply an edit, if one exists.
+            if thread.latest_edit:
+                self._apply_edit(
+                    thread.latest_event, serialized_latest_event, thread.latest_edit
+                )
+
             serialized_aggregations[RelationTypes.THREAD] = {
-                # Don't bundle aggregations as this could recurse forever.
-                "latest_event": self.serialize_event(
-                    aggregations.thread.latest_event, time_now, bundle_aggregations=None
-                ),
-                "count": aggregations.thread.count,
-                "current_user_participated": aggregations.thread.current_user_participated,
+                "latest_event": serialized_latest_event,
+                "count": thread.count,
+                "current_user_participated": thread.current_user_participated,
             }
 
         # Include the bundled aggregations in the event.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 896168c05c..fab6da3c08 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -47,6 +47,11 @@ class FederationBase:
     ) -> EventBase:
         """Checks that event is correctly signed by the sending server.
 
+        Also checks the content hash, and redacts the event if there is a mismatch.
+
+        Also runs the event through the spam checker; if it fails, redacts the event
+        and flags it as soft-failed.
+
         Args:
             room_version: The room version of the PDU
             pdu: the event to be checked
@@ -55,7 +60,10 @@ class FederationBase:
               * the original event if the checks pass
               * a redacted version of the event (if the signature
                 matched but the hash did not)
-              * throws a SynapseError if the signature check failed."""
+
+        Raises:
+              SynapseError if the signature check failed.
+        """
         try:
             await _check_sigs_on_pdu(self.keyring, room_version, pdu)
         except SynapseError as e:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 74f17aa4da..c2997997da 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -1,4 +1,4 @@
-# Copyright 2015-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2015-2022 The Matrix.org Foundation C.I.C.
 # Copyright 2020 Sorunome
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -89,6 +89,12 @@ class SendJoinResult:
     state: List[EventBase]
     auth_chain: List[EventBase]
 
+    # True if 'state' elides non-critical membership events
+    partial_state: bool
+
+    # if 'partial_state' is set, a list of the servers in the room (otherwise empty)
+    servers_in_room: List[str]
+
 
 class FederationClient(FederationBase):
     def __init__(self, hs: "HomeServer"):
@@ -413,26 +419,90 @@ class FederationClient(FederationBase):
 
         return state_event_ids, auth_event_ids
 
+    async def get_room_state(
+        self,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        room_version: RoomVersion,
+    ) -> Tuple[List[EventBase], List[EventBase]]:
+        """Calls the /state endpoint to fetch the state at a particular point
+        in the room.
+
+        Any invalid events (those with incorrect or unverifiable signatures or hashes)
+        are filtered out from the response, and any duplicate events are removed.
+
+        (Size limits and other event-format checks are *not* performed.)
+
+        Note that the result is not ordered, so callers must be careful to process
+        the events in an order that handles dependencies.
+
+        Returns:
+            a tuple of (state events, auth events)
+        """
+        result = await self.transport_layer.get_room_state(
+            room_version,
+            destination,
+            room_id,
+            event_id,
+        )
+        state_events = result.state
+        auth_events = result.auth_events
+
+        # we may as well filter out any duplicates from the response, to save
+        # processing them multiple times. (In particular, events may be present in
+        # `auth_events` as well as `state`, which is redundant).
+        #
+        # We don't rely on the sort order of the events, so we can just stick them
+        # in a dict.
+        state_event_map = {event.event_id: event for event in state_events}
+        auth_event_map = {
+            event.event_id: event
+            for event in auth_events
+            if event.event_id not in state_event_map
+        }
+
+        logger.info(
+            "Processing from /state: %d state events, %d auth events",
+            len(state_event_map),
+            len(auth_event_map),
+        )
+
+        valid_auth_events = await self._check_sigs_and_hash_and_fetch(
+            destination, auth_event_map.values(), room_version
+        )
+
+        valid_state_events = await self._check_sigs_and_hash_and_fetch(
+            destination, state_event_map.values(), room_version
+        )
+
+        return valid_state_events, valid_auth_events
+
     async def _check_sigs_and_hash_and_fetch(
         self,
         origin: str,
         pdus: Collection[EventBase],
         room_version: RoomVersion,
     ) -> List[EventBase]:
-        """Takes a list of PDUs and checks the signatures and hashes of each
-        one. If a PDU fails its signature check then we check if we have it in
-        the database and if not then request if from the originating server of
-        that PDU.
+        """Checks the signatures and hashes of a list of events.
+
+        If a PDU fails its signature check then we check if we have it in
+        the database, and if not then request it from the sender's server (if that
+        is different from `origin`). If that still fails, the event is omitted from
+        the returned list.
 
         If a PDU fails its content hash check then it is redacted.
 
-        The given list of PDUs are not modified, instead the function returns
+        Also runs each event through the spam checker; if it fails, redacts the event
+        and flags it as soft-failed.
+
+        The given list of PDUs are not modified; instead the function returns
         a new list.
 
         Args:
-            origin
-            pdu
-            room_version
+            origin: The server that sent us these events
+            pdus: The events to be checked
+            room_version: the version of the room these events are in
 
         Returns:
             A list of PDUs that have valid signatures and hashes.
@@ -463,11 +533,16 @@ class FederationClient(FederationBase):
         origin: str,
         room_version: RoomVersion,
     ) -> Optional[EventBase]:
-        """Takes a PDU and checks its signatures and hashes. If the PDU fails
-        its signature check then we check if we have it in the database and if
-        not then request if from the originating server of that PDU.
+        """Takes a PDU and checks its signatures and hashes.
+
+        If the PDU fails its signature check then we check if we have it in the
+        database; if not, we then request it from sender's server (if that is not the
+        same as `origin`). If that still fails, we return None.
 
-        If then PDU fails its content hash check then it is redacted.
+        If the PDU fails its content hash check, it is redacted.
+
+        Also runs the event through the spam checker; if it fails, redacts the event
+        and flags it as soft-failed.
 
         Args:
             origin
@@ -864,23 +939,32 @@ class FederationClient(FederationBase):
             for s in signed_state:
                 s.internal_metadata = copy.deepcopy(s.internal_metadata)
 
-            # double-check that the same create event has ended up in the auth chain
+            # double-check that the auth chain doesn't include a different create event
             auth_chain_create_events = [
                 e.event_id
                 for e in signed_auth
                 if (e.type, e.state_key) == (EventTypes.Create, "")
             ]
-            if auth_chain_create_events != [create_event.event_id]:
+            if auth_chain_create_events and auth_chain_create_events != [
+                create_event.event_id
+            ]:
                 raise InvalidResponseError(
                     "Unexpected create event(s) in auth chain: %s"
                     % (auth_chain_create_events,)
                 )
 
+            if response.partial_state and not response.servers_in_room:
+                raise InvalidResponseError(
+                    "partial_state was set, but no servers were listed in the room"
+                )
+
             return SendJoinResult(
                 event=event,
                 state=signed_state,
                 auth_chain=signed_auth,
                 origin=destination,
+                partial_state=response.partial_state,
+                servers_in_room=response.servers_in_room or [],
             )
 
         # MSC3083 defines additional error codes for room joins.
diff --git a/synapse/federation/sender/per_destination_queue.py b/synapse/federation/sender/per_destination_queue.py
index 8152e80b88..c3132f7319 100644
--- a/synapse/federation/sender/per_destination_queue.py
+++ b/synapse/federation/sender/per_destination_queue.py
@@ -381,7 +381,9 @@ class PerDestinationQueue:
                 )
             )
 
-        if self._last_successful_stream_ordering is None:
+        last_successful_stream_ordering = self._last_successful_stream_ordering
+
+        if last_successful_stream_ordering is None:
             # if it's still None, then this means we don't have the information
             # in our database ­ we haven't successfully sent a PDU to this server
             # (at least since the introduction of the feature tracking
@@ -394,8 +396,7 @@ class PerDestinationQueue:
         # get at most 50 catchup room/PDUs
         while True:
             event_ids = await self._store.get_catch_up_room_event_ids(
-                self._destination,
-                self._last_successful_stream_ordering,
+                self._destination, last_successful_stream_ordering
             )
 
             if not event_ids:
@@ -403,7 +404,7 @@ class PerDestinationQueue:
                 # of a race condition, so we check that no new events have been
                 # skipped due to us being in catch-up mode
 
-                if self._catchup_last_skipped > self._last_successful_stream_ordering:
+                if self._catchup_last_skipped > last_successful_stream_ordering:
                     # another event has been skipped because we were in catch-up mode
                     continue
 
@@ -470,7 +471,7 @@ class PerDestinationQueue:
                         # offline
                         if (
                             p.internal_metadata.stream_ordering
-                            < self._last_successful_stream_ordering
+                            < last_successful_stream_ordering
                         ):
                             continue
 
@@ -513,12 +514,11 @@ class PerDestinationQueue:
                 # from the *original* PDU, rather than the PDU(s) we actually
                 # send. This is because we use it to mark our position in the
                 # queue of missed PDUs to process.
-                self._last_successful_stream_ordering = (
-                    pdu.internal_metadata.stream_ordering
-                )
+                last_successful_stream_ordering = pdu.internal_metadata.stream_ordering
 
+                self._last_successful_stream_ordering = last_successful_stream_ordering
                 await self._store.set_destination_last_successful_stream_ordering(
-                    self._destination, self._last_successful_stream_ordering
+                    self._destination, last_successful_stream_ordering
                 )
 
     def _get_rr_edus(self, force_flush: bool) -> Iterable[Edu]:
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 8782586cd6..7e510e224a 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
 # Copyright 2020 Sorunome
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -60,17 +60,17 @@ class TransportLayerClient:
     def __init__(self, hs):
         self.server_name = hs.hostname
         self.client = hs.get_federation_http_client()
+        self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
 
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
     ) -> JsonDict:
-        """Requests all state for a given room from the given server at the
-        given event. Returns the state's event_id's
+        """Requests the IDs of all state for a given room at the given event.
 
         Args:
             destination: The host name of the remote homeserver we want
                 to get the state from.
-            context: The name of the context we want the state of
+            room_id: the room we want the state of
             event_id: The event we want the context at.
 
         Returns:
@@ -86,6 +86,29 @@ class TransportLayerClient:
             try_trailing_slash_on_400=True,
         )
 
+    async def get_room_state(
+        self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
+    ) -> "StateRequestResponse":
+        """Requests the full state for a given room at the given event.
+
+        Args:
+            room_version: the version of the room (required to build the event objects)
+            destination: The host name of the remote homeserver we want
+                to get the state from.
+            room_id: the room we want the state of
+            event_id: The event we want the context at.
+
+        Returns:
+            Results in a dict received from the remote homeserver.
+        """
+        path = _create_v1_path("/state/%s", room_id)
+        return await self.client.get_json(
+            destination,
+            path=path,
+            args={"event_id": event_id},
+            parser=_StateParser(room_version),
+        )
+
     async def get_event(
         self, destination: str, event_id: str, timeout: Optional[int] = None
     ) -> JsonDict:
@@ -336,10 +359,15 @@ class TransportLayerClient:
         content: JsonDict,
     ) -> "SendJoinResponse":
         path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+        query_params: Dict[str, str] = {}
+        if self._faster_joins_enabled:
+            # lazy-load state on join
+            query_params["org.matrix.msc3706.partial_state"] = "true"
 
         return await self.client.put_json(
             destination=destination,
             path=path,
+            args=query_params,
             data=content,
             parser=SendJoinParser(room_version, v1_api=False),
             max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
@@ -1271,6 +1299,20 @@ class SendJoinResponse:
     # "event" is not included in the response.
     event: Optional[EventBase] = None
 
+    # The room state is incomplete
+    partial_state: bool = False
+
+    # List of servers in the room
+    servers_in_room: Optional[List[str]] = None
+
+
+@attr.s(slots=True, auto_attribs=True)
+class StateRequestResponse:
+    """The parsed response of a `/state` request."""
+
+    auth_events: List[EventBase]
+    state: List[EventBase]
+
 
 @ijson.coroutine
 def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
@@ -1297,6 +1339,32 @@ def _event_list_parser(
         events.append(event)
 
 
+@ijson.coroutine
+def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
+    """Helper function for use with `ijson.items_coro`
+
+    Parses the partial_state field in send_join responses
+    """
+    while True:
+        val = yield
+        if not isinstance(val, bool):
+            raise TypeError("partial_state must be a boolean")
+        response.partial_state = val
+
+
+@ijson.coroutine
+def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
+    """Helper function for use with `ijson.items_coro`
+
+    Parses the servers_in_room field in send_join responses
+    """
+    while True:
+        val = yield
+        if not isinstance(val, list) or any(not isinstance(x, str) for x in val):
+            raise TypeError("servers_in_room must be a list of strings")
+        response.servers_in_room = val
+
+
 class SendJoinParser(ByteParser[SendJoinResponse]):
     """A parser for the response to `/send_join` requests.
 
@@ -1308,44 +1376,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
     CONTENT_TYPE = "application/json"
 
     def __init__(self, room_version: RoomVersion, v1_api: bool):
-        self._response = SendJoinResponse([], [], {})
+        self._response = SendJoinResponse([], [], event_dict={})
         self._room_version = room_version
+        self._coros = []
 
         # The V1 API has the shape of `[200, {...}]`, which we handle by
         # prefixing with `item.*`.
         prefix = "item." if v1_api else ""
 
-        self._coro_state = ijson.items_coro(
-            _event_list_parser(room_version, self._response.state),
-            prefix + "state.item",
-            use_float=True,
-        )
-        self._coro_auth = ijson.items_coro(
-            _event_list_parser(room_version, self._response.auth_events),
-            prefix + "auth_chain.item",
-            use_float=True,
-        )
-        # TODO Remove the unstable prefix when servers have updated.
-        #
-        # By re-using the same event dictionary this will cause the parsing of
-        # org.matrix.msc3083.v2.event and event to stomp over each other.
-        # Generally this should be fine.
-        self._coro_unstable_event = ijson.kvitems_coro(
-            _event_parser(self._response.event_dict),
-            prefix + "org.matrix.msc3083.v2.event",
-            use_float=True,
-        )
-        self._coro_event = ijson.kvitems_coro(
-            _event_parser(self._response.event_dict),
-            prefix + "event",
-            use_float=True,
-        )
+        self._coros = [
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.state),
+                prefix + "state.item",
+                use_float=True,
+            ),
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.auth_events),
+                prefix + "auth_chain.item",
+                use_float=True,
+            ),
+            # TODO Remove the unstable prefix when servers have updated.
+            #
+            # By re-using the same event dictionary this will cause the parsing of
+            # org.matrix.msc3083.v2.event and event to stomp over each other.
+            # Generally this should be fine.
+            ijson.kvitems_coro(
+                _event_parser(self._response.event_dict),
+                prefix + "org.matrix.msc3083.v2.event",
+                use_float=True,
+            ),
+            ijson.kvitems_coro(
+                _event_parser(self._response.event_dict),
+                prefix + "event",
+                use_float=True,
+            ),
+        ]
+
+        if not v1_api:
+            self._coros.append(
+                ijson.items_coro(
+                    _partial_state_parser(self._response),
+                    "org.matrix.msc3706.partial_state",
+                    use_float="True",
+                )
+            )
+
+            self._coros.append(
+                ijson.items_coro(
+                    _servers_in_room_parser(self._response),
+                    "org.matrix.msc3706.servers_in_room",
+                    use_float="True",
+                )
+            )
 
     def write(self, data: bytes) -> int:
-        self._coro_state.send(data)
-        self._coro_auth.send(data)
-        self._coro_unstable_event.send(data)
-        self._coro_event.send(data)
+        for c in self._coros:
+            c.send(data)
 
         return len(data)
 
@@ -1355,3 +1441,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
                 self._response.event_dict, self._room_version
             )
         return self._response
+
+
+class _StateParser(ByteParser[StateRequestResponse]):
+    """A parser for the response to `/state` requests.
+
+    Args:
+        room_version: The version of the room.
+    """
+
+    CONTENT_TYPE = "application/json"
+
+    def __init__(self, room_version: RoomVersion):
+        self._response = StateRequestResponse([], [])
+        self._room_version = room_version
+        self._coros = [
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.state),
+                "pdus.item",
+                use_float=True,
+            ),
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.auth_events),
+                "auth_chain.item",
+                use_float=True,
+            ),
+        ]
+
+    def write(self, data: bytes) -> int:
+        for c in self._coros:
+            c.send(data)
+        return len(data)
+
+    def finish(self) -> StateRequestResponse:
+        return self._response
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 6959d1aa7e..572f54b1e3 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -2064,6 +2064,10 @@ GET_USERNAME_FOR_REGISTRATION_CALLBACK = Callable[
     [JsonDict, JsonDict],
     Awaitable[Optional[str]],
 ]
+GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK = Callable[
+    [JsonDict, JsonDict],
+    Awaitable[Optional[str]],
+]
 IS_3PID_ALLOWED_CALLBACK = Callable[[str, str, bool], Awaitable[bool]]
 
 
@@ -2080,6 +2084,9 @@ class PasswordAuthProvider:
         self.get_username_for_registration_callbacks: List[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = []
+        self.get_displayname_for_registration_callbacks: List[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = []
         self.is_3pid_allowed_callbacks: List[IS_3PID_ALLOWED_CALLBACK] = []
 
         # Mapping from login type to login parameters
@@ -2099,6 +2106,9 @@ class PasswordAuthProvider:
         get_username_for_registration: Optional[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = None,
+        get_displayname_for_registration: Optional[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         # Register check_3pid_auth callback
         if check_3pid_auth is not None:
@@ -2148,6 +2158,11 @@ class PasswordAuthProvider:
                 get_username_for_registration,
             )
 
+        if get_displayname_for_registration is not None:
+            self.get_displayname_for_registration_callbacks.append(
+                get_displayname_for_registration,
+            )
+
         if is_3pid_allowed is not None:
             self.is_3pid_allowed_callbacks.append(is_3pid_allowed)
 
@@ -2350,6 +2365,49 @@ class PasswordAuthProvider:
 
         return None
 
+    async def get_displayname_for_registration(
+        self,
+        uia_results: JsonDict,
+        params: JsonDict,
+    ) -> Optional[str]:
+        """Defines the display name to use when registering the user, using the
+        credentials and parameters provided during the UIA flow.
+
+        Stops at the first callback that returns a tuple containing at least one string.
+
+        Args:
+            uia_results: The credentials provided during the UIA flow.
+            params: The parameters provided by the registration request.
+
+        Returns:
+            A tuple which first element is the display name, and the second is an MXC URL
+            to the user's avatar.
+        """
+        for callback in self.get_displayname_for_registration_callbacks:
+            try:
+                res = await callback(uia_results, params)
+
+                if isinstance(res, str):
+                    return res
+                elif res is not None:
+                    # mypy complains that this line is unreachable because it assumes the
+                    # data returned by the module fits the expected type. We just want
+                    # to make sure this is the case.
+                    logger.warning(  # type: ignore[unreachable]
+                        "Ignoring non-string value returned by"
+                        " get_displayname_for_registration callback %s: %s",
+                        callback,
+                        res,
+                    )
+            except Exception as e:
+                logger.error(
+                    "Module raised an exception in get_displayname_for_registration: %s",
+                    e,
+                )
+                raise SynapseError(code=500, msg="Internal Server Error")
+
+        return None
+
     async def is_3pid_allowed(
         self,
         medium: str,
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c0f642005f..c8356f233d 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -516,7 +516,7 @@ class FederationHandler:
             await self.store.upsert_room_on_join(
                 room_id=room_id,
                 room_version=room_version_obj,
-                auth_events=auth_chain,
+                state_events=state,
             )
 
             max_stream_id = await self._federation_event_handler.process_remote_join(
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9edc7369d6..7683246bef 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -419,10 +419,8 @@ class FederationEventHandler:
         Raises:
             SynapseError if the response is in some way invalid.
         """
-        event_map = {e.event_id: e for e in itertools.chain(auth_events, state)}
-
         create_event = None
-        for e in auth_events:
+        for e in state:
             if (e.type, e.state_key) == (EventTypes.Create, ""):
                 create_event = e
                 break
@@ -439,11 +437,6 @@ class FederationEventHandler:
         if room_version.identifier != room_version_id:
             raise SynapseError(400, "Room version mismatch")
 
-        # filter out any events we have already seen
-        seen_remotes = await self._store.have_seen_events(room_id, event_map.keys())
-        for s in seen_remotes:
-            event_map.pop(s, None)
-
         # persist the auth chain and state events.
         #
         # any invalid events here will be marked as rejected, and we'll carry on.
@@ -455,7 +448,9 @@ class FederationEventHandler:
         # signatures right now doesn't mean that we will *never* be able to, so it
         # is premature to reject them.
         #
-        await self._auth_and_persist_outliers(room_id, event_map.values())
+        await self._auth_and_persist_outliers(
+            room_id, itertools.chain(auth_events, state)
+        )
 
         # and now persist the join event itself.
         logger.info("Peristing join-via-remote %s", event)
@@ -1245,6 +1240,16 @@ class FederationEventHandler:
         """
         event_map = {event.event_id: event for event in events}
 
+        # filter out any events we have already seen. This might happen because
+        # the events were eagerly pushed to us (eg, during a room join), or because
+        # another thread has raced against us since we decided to request the event.
+        #
+        # This is just an optimisation, so it doesn't need to be watertight - the event
+        # persister does another round of deduplication.
+        seen_remotes = await self._store.have_seen_events(room_id, event_map.keys())
+        for s in seen_remotes:
+            event_map.pop(s, None)
+
         # XXX: it might be possible to kick this process off in parallel with fetching
         # the events.
         while event_map:
@@ -1717,31 +1722,22 @@ class FederationEventHandler:
             event_id: the event for which we are lacking auth events
         """
         try:
-            remote_event_map = {
-                e.event_id: e
-                for e in await self._federation_client.get_event_auth(
-                    destination, room_id, event_id
-                )
-            }
+            remote_events = await self._federation_client.get_event_auth(
+                destination, room_id, event_id
+            )
+
         except RequestSendFailed as e1:
             # The other side isn't around or doesn't implement the
             # endpoint, so lets just bail out.
             logger.info("Failed to get event auth from remote: %s", e1)
             return
 
-        logger.info("/event_auth returned %i events", len(remote_event_map))
+        logger.info("/event_auth returned %i events", len(remote_events))
 
         # `event` may be returned, but we should not yet process it.
-        remote_event_map.pop(event_id, None)
-
-        # nor should we reprocess any events we have already seen.
-        seen_remotes = await self._store.have_seen_events(
-            room_id, remote_event_map.keys()
-        )
-        for s in seen_remotes:
-            remote_event_map.pop(s, None)
+        remote_auth_events = (e for e in remote_events if e.event_id != event_id)
 
-        await self._auth_and_persist_outliers(room_id, remote_event_map.values())
+        await self._auth_and_persist_outliers(room_id, remote_auth_events)
 
     async def _update_context_for_auth_events(
         self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9267e586a8..4d0da84287 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -550,10 +550,11 @@ class EventCreationHandler:
 
         if event_dict["type"] == EventTypes.Create and event_dict["state_key"] == "":
             room_version_id = event_dict["content"]["room_version"]
-            room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
-            if not room_version_obj:
+            maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+            if not maybe_room_version_obj:
                 # this can happen if support is withdrawn for a room version
                 raise UnsupportedRoomVersionError(room_version_id)
+            room_version_obj = maybe_room_version_obj
         else:
             try:
                 room_version_obj = await self.store.get_room_version(
@@ -1145,12 +1146,13 @@ class EventCreationHandler:
             room_version_id = event.content.get(
                 "room_version", RoomVersions.V1.identifier
             )
-            room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
-            if not room_version_obj:
+            maybe_room_version_obj = KNOWN_ROOM_VERSIONS.get(room_version_id)
+            if not maybe_room_version_obj:
                 raise UnsupportedRoomVersionError(
                     "Attempt to create a room with unsupported room version %s"
                     % (room_version_id,)
                 )
+            room_version_obj = maybe_room_version_obj
         else:
             room_version_obj = await self.store.get_room_version(event.room_id)
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 067c43ae47..b223b72623 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -204,25 +204,27 @@ class BasePresenceHandler(abc.ABC):
         Returns:
             dict: `user_id` -> `UserPresenceState`
         """
-        states = {
-            user_id: self.user_to_current_state.get(user_id, None)
-            for user_id in user_ids
-        }
+        states = {}
+        missing = []
+        for user_id in user_ids:
+            state = self.user_to_current_state.get(user_id, None)
+            if state:
+                states[user_id] = state
+            else:
+                missing.append(user_id)
 
-        missing = [user_id for user_id, state in states.items() if not state]
         if missing:
             # There are things not in our in memory cache. Lets pull them out of
             # the database.
             res = await self.store.get_presence_for_users(missing)
             states.update(res)
 
-            missing = [user_id for user_id, state in states.items() if not state]
-            if missing:
-                new = {
-                    user_id: UserPresenceState.default(user_id) for user_id in missing
-                }
-                states.update(new)
-                self.user_to_current_state.update(new)
+            for user_id in missing:
+                # if user has no state in database, create the state
+                if not res.get(user_id, None):
+                    new_state = UserPresenceState.default(user_id)
+                    states[user_id] = new_state
+                    self.user_to_current_state[user_id] = new_state
 
         return states
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index a719d5eef3..80320d2c07 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -320,12 +320,12 @@ class RegistrationHandler:
                 if fail_count > 10:
                     raise SynapseError(500, "Unable to find a suitable guest user ID")
 
-                localpart = await self.store.generate_user_id()
-                user = UserID(localpart, self.hs.hostname)
+                generated_localpart = await self.store.generate_user_id()
+                user = UserID(generated_localpart, self.hs.hostname)
                 user_id = user.to_string()
                 self.check_user_id_not_appservice_exclusive(user_id)
                 if generate_display_name:
-                    default_display_name = localpart
+                    default_display_name = generated_localpart
                 try:
                     await self.register_with_store(
                         user_id=user_id,
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index bf1a47efb0..b2adc0f48b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -82,6 +82,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         self.event_auth_handler = hs.get_event_auth_handler()
 
         self.member_linearizer: Linearizer = Linearizer(name="member")
+        self.member_as_limiter = Linearizer(max_count=10, name="member_as_limiter")
 
         self.clock = hs.get_clock()
         self.spam_checker = hs.get_spam_checker()
@@ -500,25 +501,32 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         key = (room_id,)
 
-        with (await self.member_linearizer.queue(key)):
-            result = await self.update_membership_locked(
-                requester,
-                target,
-                room_id,
-                action,
-                txn_id=txn_id,
-                remote_room_hosts=remote_room_hosts,
-                third_party_signed=third_party_signed,
-                ratelimit=ratelimit,
-                content=content,
-                new_room=new_room,
-                require_consent=require_consent,
-                outlier=outlier,
-                historical=historical,
-                allow_no_prev_events=allow_no_prev_events,
-                prev_event_ids=prev_event_ids,
-                auth_event_ids=auth_event_ids,
-            )
+        as_id = object()
+        if requester.app_service:
+            as_id = requester.app_service.id
+
+        # We first linearise by the application service (to try to limit concurrent joins
+        # by application services), and then by room ID.
+        with (await self.member_as_limiter.queue(as_id)):
+            with (await self.member_linearizer.queue(key)):
+                result = await self.update_membership_locked(
+                    requester,
+                    target,
+                    room_id,
+                    action,
+                    txn_id=txn_id,
+                    remote_room_hosts=remote_room_hosts,
+                    third_party_signed=third_party_signed,
+                    ratelimit=ratelimit,
+                    content=content,
+                    new_room=new_room,
+                    require_consent=require_consent,
+                    outlier=outlier,
+                    historical=historical,
+                    allow_no_prev_events=allow_no_prev_events,
+                    prev_event_ids=prev_event_ids,
+                    auth_event_ids=auth_event_ids,
+                )
 
         return result
 
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 41cb809078..0e0e58de02 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -14,8 +14,9 @@
 
 import itertools
 import logging
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional
+from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
 
+import attr
 from unpaddedbase64 import decode_base64, encode_base64
 
 from synapse.api.constants import EventTypes, Membership
@@ -32,6 +33,20 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _SearchResult:
+    # The count of results.
+    count: int
+    # A mapping of event ID to the rank of that event.
+    rank_map: Dict[str, int]
+    # A list of the resulting events.
+    allowed_events: List[EventBase]
+    # A map of room ID to results.
+    room_groups: Dict[str, JsonDict]
+    # A set of event IDs to highlight.
+    highlights: Set[str]
+
+
 class SearchHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -100,7 +115,7 @@ class SearchHandler:
         """Performs a full text search for a user.
 
         Args:
-            user
+            user: The user performing the search.
             content: Search parameters
             batch: The next_batch parameter. Used for pagination.
 
@@ -156,6 +171,8 @@ class SearchHandler:
 
             # Include context around each event?
             event_context = room_cat.get("event_context", None)
+            before_limit = after_limit = None
+            include_profile = False
 
             # Group results together? May allow clients to paginate within a
             # group
@@ -182,6 +199,73 @@ class SearchHandler:
                 % (set(group_keys) - {"room_id", "sender"},),
             )
 
+        return await self._search(
+            user,
+            batch_group,
+            batch_group_key,
+            batch_token,
+            search_term,
+            keys,
+            filter_dict,
+            order_by,
+            include_state,
+            group_keys,
+            event_context,
+            before_limit,
+            after_limit,
+            include_profile,
+        )
+
+    async def _search(
+        self,
+        user: UserID,
+        batch_group: Optional[str],
+        batch_group_key: Optional[str],
+        batch_token: Optional[str],
+        search_term: str,
+        keys: List[str],
+        filter_dict: JsonDict,
+        order_by: str,
+        include_state: bool,
+        group_keys: List[str],
+        event_context: Optional[bool],
+        before_limit: Optional[int],
+        after_limit: Optional[int],
+        include_profile: bool,
+    ) -> JsonDict:
+        """Performs a full text search for a user.
+
+        Args:
+            user: The user performing the search.
+            batch_group: Pagination information.
+            batch_group_key: Pagination information.
+            batch_token: Pagination information.
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            filter_dict: The JSON to build a filter out of.
+            order_by: How to order the results. Valid values ore "rank" and "recent".
+            include_state: True if the state of the room at each result should
+                be included.
+            group_keys: A list of ways to group the results. Valid values are
+                "room_id" and "sender".
+            event_context: True to include contextual events around results.
+            before_limit:
+                The number of events before a result to include as context.
+
+                Only used if event_context is True.
+            after_limit:
+                The number of events after a result to include as context.
+
+                Only used if event_context is True.
+            include_profile: True if historical profile information should be
+                included in the event context.
+
+                Only used if event_context is True.
+
+        Returns:
+            dict to be returned to the client with results of search
+        """
         search_filter = Filter(self.hs, filter_dict)
 
         # TODO: Search through left rooms too
@@ -216,278 +300,399 @@ class SearchHandler:
                 }
             }
 
-        rank_map = {}  # event_id -> rank of event
-        allowed_events = []
-        # Holds result of grouping by room, if applicable
-        room_groups: Dict[str, JsonDict] = {}
-        # Holds result of grouping by sender, if applicable
-        sender_group: Dict[str, JsonDict] = {}
+        sender_group: Optional[Dict[str, JsonDict]]
 
-        # Holds the next_batch for the entire result set if one of those exists
-        global_next_batch = None
-
-        highlights = set()
+        if order_by == "rank":
+            search_result, sender_group = await self._search_by_rank(
+                user, room_ids, search_term, keys, search_filter
+            )
+            # Unused return values for rank search.
+            global_next_batch = None
+        elif order_by == "recent":
+            search_result, global_next_batch = await self._search_by_recent(
+                user,
+                room_ids,
+                search_term,
+                keys,
+                search_filter,
+                batch_group,
+                batch_group_key,
+                batch_token,
+            )
+            # Unused return values for recent search.
+            sender_group = None
+        else:
+            # We should never get here due to the guard earlier.
+            raise NotImplementedError()
 
-        count = None
+        logger.info("Found %d events to return", len(search_result.allowed_events))
 
-        if order_by == "rank":
-            search_result = await self.store.search_msgs(room_ids, search_term, keys)
+        # If client has asked for "context" for each event (i.e. some surrounding
+        # events and state), fetch that
+        if event_context is not None:
+            # Note that before and after limit must be set in this case.
+            assert before_limit is not None
+            assert after_limit is not None
+
+            contexts = await self._calculate_event_contexts(
+                user,
+                search_result.allowed_events,
+                before_limit,
+                after_limit,
+                include_profile,
+            )
+        else:
+            contexts = {}
 
-            count = search_result["count"]
+        # TODO: Add a limit
 
-            if search_result["highlights"]:
-                highlights.update(search_result["highlights"])
+        state_results = {}
+        if include_state:
+            for room_id in {e.room_id for e in search_result.allowed_events}:
+                state = await self.state_handler.get_current_state(room_id)
+                state_results[room_id] = list(state.values())
 
-            results = search_result["results"]
+        aggregations = None
+        if self._msc3666_enabled:
+            aggregations = await self.store.get_bundled_aggregations(
+                # Generate an iterable of EventBase for all the events that will be
+                # returned, including contextual events.
+                itertools.chain(
+                    # The events_before and events_after for each context.
+                    itertools.chain.from_iterable(
+                        itertools.chain(context["events_before"], context["events_after"])  # type: ignore[arg-type]
+                        for context in contexts.values()
+                    ),
+                    # The returned events.
+                    search_result.allowed_events,
+                ),
+                user.to_string(),
+            )
 
-            rank_map.update({r["event"].event_id: r["rank"] for r in results})
+        # We're now about to serialize the events. We should not make any
+        # blocking calls after this. Otherwise, the 'age' will be wrong.
 
-            filtered_events = await search_filter.filter([r["event"] for r in results])
+        time_now = self.clock.time_msec()
 
-            events = await filter_events_for_client(
-                self.storage, user.to_string(), filtered_events
+        for context in contexts.values():
+            context["events_before"] = self._event_serializer.serialize_events(
+                context["events_before"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+            )
+            context["events_after"] = self._event_serializer.serialize_events(
+                context["events_after"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
             )
 
-            events.sort(key=lambda e: -rank_map[e.event_id])
-            allowed_events = events[: search_filter.limit]
+        results = [
+            {
+                "rank": search_result.rank_map[e.event_id],
+                "result": self._event_serializer.serialize_event(
+                    e, time_now, bundle_aggregations=aggregations
+                ),
+                "context": contexts.get(e.event_id, {}),
+            }
+            for e in search_result.allowed_events
+        ]
 
-            for e in allowed_events:
-                rm = room_groups.setdefault(
-                    e.room_id, {"results": [], "order": rank_map[e.event_id]}
-                )
-                rm["results"].append(e.event_id)
+        rooms_cat_res: JsonDict = {
+            "results": results,
+            "count": search_result.count,
+            "highlights": list(search_result.highlights),
+        }
 
-                s = sender_group.setdefault(
-                    e.sender, {"results": [], "order": rank_map[e.event_id]}
-                )
-                s["results"].append(e.event_id)
+        if state_results:
+            rooms_cat_res["state"] = {
+                room_id: self._event_serializer.serialize_events(state_events, time_now)
+                for room_id, state_events in state_results.items()
+            }
 
-        elif order_by == "recent":
-            room_events: List[EventBase] = []
-            i = 0
-
-            pagination_token = batch_token
-
-            # We keep looping and we keep filtering until we reach the limit
-            # or we run out of things.
-            # But only go around 5 times since otherwise synapse will be sad.
-            while len(room_events) < search_filter.limit and i < 5:
-                i += 1
-                search_result = await self.store.search_rooms(
-                    room_ids,
-                    search_term,
-                    keys,
-                    search_filter.limit * 2,
-                    pagination_token=pagination_token,
-                )
+        if search_result.room_groups and "room_id" in group_keys:
+            rooms_cat_res.setdefault("groups", {})[
+                "room_id"
+            ] = search_result.room_groups
 
-                if search_result["highlights"]:
-                    highlights.update(search_result["highlights"])
+        if sender_group and "sender" in group_keys:
+            rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
 
-                count = search_result["count"]
+        if global_next_batch:
+            rooms_cat_res["next_batch"] = global_next_batch
 
-                results = search_result["results"]
+        return {"search_categories": {"room_events": rooms_cat_res}}
 
-                results_map = {r["event"].event_id: r for r in results}
+    async def _search_by_rank(
+        self,
+        user: UserID,
+        room_ids: Collection[str],
+        search_term: str,
+        keys: Iterable[str],
+        search_filter: Filter,
+    ) -> Tuple[_SearchResult, Dict[str, JsonDict]]:
+        """
+        Performs a full text search for a user ordering by rank.
 
-                rank_map.update({r["event"].event_id: r["rank"] for r in results})
+        Args:
+            user: The user performing the search.
+            room_ids: List of room ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            search_filter: The event filter to use.
 
-                filtered_events = await search_filter.filter(
-                    [r["event"] for r in results]
-                )
+        Returns:
+            A tuple of:
+                The search results.
+                A map of sender ID to results.
+        """
+        rank_map = {}  # event_id -> rank of event
+        # Holds result of grouping by room, if applicable
+        room_groups: Dict[str, JsonDict] = {}
+        # Holds result of grouping by sender, if applicable
+        sender_group: Dict[str, JsonDict] = {}
 
-                events = await filter_events_for_client(
-                    self.storage, user.to_string(), filtered_events
-                )
+        search_result = await self.store.search_msgs(room_ids, search_term, keys)
 
-                room_events.extend(events)
-                room_events = room_events[: search_filter.limit]
+        if search_result["highlights"]:
+            highlights = search_result["highlights"]
+        else:
+            highlights = set()
 
-                if len(results) < search_filter.limit * 2:
-                    pagination_token = None
-                    break
-                else:
-                    pagination_token = results[-1]["pagination_token"]
-
-            for event in room_events:
-                group = room_groups.setdefault(event.room_id, {"results": []})
-                group["results"].append(event.event_id)
-
-            if room_events and len(room_events) >= search_filter.limit:
-                last_event_id = room_events[-1].event_id
-                pagination_token = results_map[last_event_id]["pagination_token"]
-
-                # We want to respect the given batch group and group keys so
-                # that if people blindly use the top level `next_batch` token
-                # it returns more from the same group (if applicable) rather
-                # than reverting to searching all results again.
-                if batch_group and batch_group_key:
-                    global_next_batch = encode_base64(
-                        (
-                            "%s\n%s\n%s"
-                            % (batch_group, batch_group_key, pagination_token)
-                        ).encode("ascii")
-                    )
-                else:
-                    global_next_batch = encode_base64(
-                        ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
-                    )
+        results = search_result["results"]
 
-                for room_id, group in room_groups.items():
-                    group["next_batch"] = encode_base64(
-                        ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
-                            "ascii"
-                        )
-                    )
+        # event_id -> rank of event
+        rank_map = {r["event"].event_id: r["rank"] for r in results}
 
-            allowed_events.extend(room_events)
+        filtered_events = await search_filter.filter([r["event"] for r in results])
 
-        else:
-            # We should never get here due to the guard earlier.
-            raise NotImplementedError()
+        events = await filter_events_for_client(
+            self.storage, user.to_string(), filtered_events
+        )
 
-        logger.info("Found %d events to return", len(allowed_events))
+        events.sort(key=lambda e: -rank_map[e.event_id])
+        allowed_events = events[: search_filter.limit]
 
-        # If client has asked for "context" for each event (i.e. some surrounding
-        # events and state), fetch that
-        if event_context is not None:
-            now_token = self.hs.get_event_sources().get_current_token()
+        for e in allowed_events:
+            rm = room_groups.setdefault(
+                e.room_id, {"results": [], "order": rank_map[e.event_id]}
+            )
+            rm["results"].append(e.event_id)
 
-            contexts = {}
-            for event in allowed_events:
-                res = await self.store.get_events_around(
-                    event.room_id, event.event_id, before_limit, after_limit
-                )
+            s = sender_group.setdefault(
+                e.sender, {"results": [], "order": rank_map[e.event_id]}
+            )
+            s["results"].append(e.event_id)
+
+        return (
+            _SearchResult(
+                search_result["count"],
+                rank_map,
+                allowed_events,
+                room_groups,
+                highlights,
+            ),
+            sender_group,
+        )
 
-                logger.info(
-                    "Context for search returned %d and %d events",
-                    len(res.events_before),
-                    len(res.events_after),
-                )
+    async def _search_by_recent(
+        self,
+        user: UserID,
+        room_ids: Collection[str],
+        search_term: str,
+        keys: Iterable[str],
+        search_filter: Filter,
+        batch_group: Optional[str],
+        batch_group_key: Optional[str],
+        batch_token: Optional[str],
+    ) -> Tuple[_SearchResult, Optional[str]]:
+        """
+        Performs a full text search for a user ordering by recent.
 
-                events_before = await filter_events_for_client(
-                    self.storage, user.to_string(), res.events_before
-                )
+        Args:
+            user: The user performing the search.
+            room_ids: List of room ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
+                "content.body", "content.name", "content.topic"
+            search_filter: The event filter to use.
+            batch_group: Pagination information.
+            batch_group_key: Pagination information.
+            batch_token: Pagination information.
 
-                events_after = await filter_events_for_client(
-                    self.storage, user.to_string(), res.events_after
-                )
+        Returns:
+            A tuple of:
+                The search results.
+                Optionally, a pagination token.
+        """
+        rank_map = {}  # event_id -> rank of event
+        # Holds result of grouping by room, if applicable
+        room_groups: Dict[str, JsonDict] = {}
 
-                context = {
-                    "events_before": events_before,
-                    "events_after": events_after,
-                    "start": await now_token.copy_and_replace(
-                        "room_key", res.start
-                    ).to_string(self.store),
-                    "end": await now_token.copy_and_replace(
-                        "room_key", res.end
-                    ).to_string(self.store),
-                }
+        # Holds the next_batch for the entire result set if one of those exists
+        global_next_batch = None
 
-                if include_profile:
-                    senders = {
-                        ev.sender
-                        for ev in itertools.chain(events_before, [event], events_after)
-                    }
+        highlights = set()
 
-                    if events_after:
-                        last_event_id = events_after[-1].event_id
-                    else:
-                        last_event_id = event.event_id
+        room_events: List[EventBase] = []
+        i = 0
+
+        pagination_token = batch_token
+
+        # We keep looping and we keep filtering until we reach the limit
+        # or we run out of things.
+        # But only go around 5 times since otherwise synapse will be sad.
+        while len(room_events) < search_filter.limit and i < 5:
+            i += 1
+            search_result = await self.store.search_rooms(
+                room_ids,
+                search_term,
+                keys,
+                search_filter.limit * 2,
+                pagination_token=pagination_token,
+            )
 
-                    state_filter = StateFilter.from_types(
-                        [(EventTypes.Member, sender) for sender in senders]
-                    )
+            if search_result["highlights"]:
+                highlights.update(search_result["highlights"])
+
+            count = search_result["count"]
+
+            results = search_result["results"]
+
+            results_map = {r["event"].event_id: r for r in results}
+
+            rank_map.update({r["event"].event_id: r["rank"] for r in results})
+
+            filtered_events = await search_filter.filter([r["event"] for r in results])
 
-                    state = await self.state_store.get_state_for_event(
-                        last_event_id, state_filter
+            events = await filter_events_for_client(
+                self.storage, user.to_string(), filtered_events
+            )
+
+            room_events.extend(events)
+            room_events = room_events[: search_filter.limit]
+
+            if len(results) < search_filter.limit * 2:
+                break
+            else:
+                pagination_token = results[-1]["pagination_token"]
+
+        for event in room_events:
+            group = room_groups.setdefault(event.room_id, {"results": []})
+            group["results"].append(event.event_id)
+
+        if room_events and len(room_events) >= search_filter.limit:
+            last_event_id = room_events[-1].event_id
+            pagination_token = results_map[last_event_id]["pagination_token"]
+
+            # We want to respect the given batch group and group keys so
+            # that if people blindly use the top level `next_batch` token
+            # it returns more from the same group (if applicable) rather
+            # than reverting to searching all results again.
+            if batch_group and batch_group_key:
+                global_next_batch = encode_base64(
+                    (
+                        "%s\n%s\n%s" % (batch_group, batch_group_key, pagination_token)
+                    ).encode("ascii")
+                )
+            else:
+                global_next_batch = encode_base64(
+                    ("%s\n%s\n%s" % ("all", "", pagination_token)).encode("ascii")
+                )
+
+            for room_id, group in room_groups.items():
+                group["next_batch"] = encode_base64(
+                    ("%s\n%s\n%s" % ("room_id", room_id, pagination_token)).encode(
+                        "ascii"
                     )
+                )
 
-                    context["profile_info"] = {
-                        s.state_key: {
-                            "displayname": s.content.get("displayname", None),
-                            "avatar_url": s.content.get("avatar_url", None),
-                        }
-                        for s in state.values()
-                        if s.type == EventTypes.Member and s.state_key in senders
-                    }
+        return (
+            _SearchResult(count, rank_map, room_events, room_groups, highlights),
+            global_next_batch,
+        )
 
-                contexts[event.event_id] = context
-        else:
-            contexts = {}
+    async def _calculate_event_contexts(
+        self,
+        user: UserID,
+        allowed_events: List[EventBase],
+        before_limit: int,
+        after_limit: int,
+        include_profile: bool,
+    ) -> Dict[str, JsonDict]:
+        """
+        Calculates the contextual events for any search results.
 
-        # TODO: Add a limit
+        Args:
+            user: The user performing the search.
+            allowed_events: The search results.
+            before_limit:
+                The number of events before a result to include as context.
+            after_limit:
+                The number of events after a result to include as context.
+            include_profile: True if historical profile information should be
+                included in the event context.
 
-        time_now = self.clock.time_msec()
+        Returns:
+            A map of event ID to contextual information.
+        """
+        now_token = self.hs.get_event_sources().get_current_token()
 
-        aggregations = None
-        if self._msc3666_enabled:
-            aggregations = await self.store.get_bundled_aggregations(
-                # Generate an iterable of EventBase for all the events that will be
-                # returned, including contextual events.
-                itertools.chain(
-                    # The events_before and events_after for each context.
-                    itertools.chain.from_iterable(
-                        itertools.chain(context["events_before"], context["events_after"])  # type: ignore[arg-type]
-                        for context in contexts.values()
-                    ),
-                    # The returned events.
-                    allowed_events,
-                ),
-                user.to_string(),
+        contexts = {}
+        for event in allowed_events:
+            res = await self.store.get_events_around(
+                event.room_id, event.event_id, before_limit, after_limit
             )
 
-        for context in contexts.values():
-            context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+            logger.info(
+                "Context for search returned %d and %d events",
+                len(res.events_before),
+                len(res.events_after),
             )
-            context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
+
+            events_before = await filter_events_for_client(
+                self.storage, user.to_string(), res.events_before
             )
 
-        state_results = {}
-        if include_state:
-            for room_id in {e.room_id for e in allowed_events}:
-                state = await self.state_handler.get_current_state(room_id)
-                state_results[room_id] = list(state.values())
+            events_after = await filter_events_for_client(
+                self.storage, user.to_string(), res.events_after
+            )
 
-        # We're now about to serialize the events. We should not make any
-        # blocking calls after this. Otherwise the 'age' will be wrong
+            context: JsonDict = {
+                "events_before": events_before,
+                "events_after": events_after,
+                "start": await now_token.copy_and_replace(
+                    "room_key", res.start
+                ).to_string(self.store),
+                "end": await now_token.copy_and_replace("room_key", res.end).to_string(
+                    self.store
+                ),
+            }
 
-        results = []
-        for e in allowed_events:
-            results.append(
-                {
-                    "rank": rank_map[e.event_id],
-                    "result": self._event_serializer.serialize_event(
-                        e, time_now, bundle_aggregations=aggregations
-                    ),
-                    "context": contexts.get(e.event_id, {}),
+            if include_profile:
+                senders = {
+                    ev.sender
+                    for ev in itertools.chain(events_before, [event], events_after)
                 }
-            )
 
-        rooms_cat_res = {
-            "results": results,
-            "count": count,
-            "highlights": list(highlights),
-        }
+                if events_after:
+                    last_event_id = events_after[-1].event_id
+                else:
+                    last_event_id = event.event_id
 
-        if state_results:
-            s = {}
-            for room_id, state_events in state_results.items():
-                s[room_id] = self._event_serializer.serialize_events(
-                    state_events, time_now
+                state_filter = StateFilter.from_types(
+                    [(EventTypes.Member, sender) for sender in senders]
                 )
 
-            rooms_cat_res["state"] = s
-
-        if room_groups and "room_id" in group_keys:
-            rooms_cat_res.setdefault("groups", {})["room_id"] = room_groups
+                state = await self.state_store.get_state_for_event(
+                    last_event_id, state_filter
+                )
 
-        if sender_group and "sender" in group_keys:
-            rooms_cat_res.setdefault("groups", {})["sender"] = sender_group
+                context["profile_info"] = {
+                    s.state_key: {
+                        "displayname": s.content.get("displayname", None),
+                        "avatar_url": s.content.get("avatar_url", None),
+                    }
+                    for s in state.values()
+                    if s.type == EventTypes.Member and s.state_key in senders
+                }
 
-        if global_next_batch:
-            rooms_cat_res["next_batch"] = global_next_batch
+            contexts[event.event_id] = context
 
-        return {"search_categories": {"room_events": rooms_cat_res}}
+        return contexts
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index aa9a76f8a9..e6050cbce6 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1289,23 +1289,54 @@ class SyncHandler:
             # room with by looking at all users that have left a room plus users
             # that were in a room we've left.
 
-            users_who_share_room = await self.store.get_users_who_share_room_with_user(
-                user_id
-            )
-
-            # Always tell the user about their own devices. We check as the user
-            # ID is almost certainly already included (unless they're not in any
-            # rooms) and taking a copy of the set is relatively expensive.
-            if user_id not in users_who_share_room:
-                users_who_share_room = set(users_who_share_room)
-                users_who_share_room.add(user_id)
+            users_that_have_changed = set()
 
-            tracked_users = users_who_share_room
+            joined_rooms = sync_result_builder.joined_room_ids
 
-            # Step 1a, check for changes in devices of users we share a room with
-            users_that_have_changed = await self.store.get_users_whose_devices_changed(
-                since_token.device_list_key, tracked_users
+            # Step 1a, check for changes in devices of users we share a room
+            # with
+            #
+            # We do this in two different ways depending on what we have cached.
+            # If we already have a list of all the user that have changed since
+            # the last sync then it's likely more efficient to compare the rooms
+            # they're in with the rooms the syncing user is in.
+            #
+            # If we don't have that info cached then we get all the users that
+            # share a room with our user and check if those users have changed.
+            changed_users = self.store.get_cached_device_list_changes(
+                since_token.device_list_key
             )
+            if changed_users is not None:
+                result = await self.store.get_rooms_for_users_with_stream_ordering(
+                    changed_users
+                )
+
+                for changed_user_id, entries in result.items():
+                    # Check if the changed user shares any rooms with the user,
+                    # or if the changed user is the syncing user (as we always
+                    # want to include device list updates of their own devices).
+                    if user_id == changed_user_id or any(
+                        e.room_id in joined_rooms for e in entries
+                    ):
+                        users_that_have_changed.add(changed_user_id)
+            else:
+                users_who_share_room = (
+                    await self.store.get_users_who_share_room_with_user(user_id)
+                )
+
+                # Always tell the user about their own devices. We check as the user
+                # ID is almost certainly already included (unless they're not in any
+                # rooms) and taking a copy of the set is relatively expensive.
+                if user_id not in users_who_share_room:
+                    users_who_share_room = set(users_who_share_room)
+                    users_who_share_room.add(user_id)
+
+                tracked_users = users_who_share_room
+                users_that_have_changed = (
+                    await self.store.get_users_whose_devices_changed(
+                        since_token.device_list_key, tracked_users
+                    )
+                )
 
             # Step 1b, check for newly joined rooms
             for room_id in newly_joined_rooms:
@@ -1329,7 +1360,14 @@ class SyncHandler:
                 newly_left_users.update(left_users)
 
             # Remove any users that we still share a room with.
-            newly_left_users -= users_who_share_room
+            left_users_rooms = (
+                await self.store.get_rooms_for_users_with_stream_ordering(
+                    newly_left_users
+                )
+            )
+            for user_id, entries in left_users_rooms.items():
+                if any(e.room_id in joined_rooms for e in entries):
+                    newly_left_users.discard(user_id)
 
             return DeviceLists(changed=users_that_have_changed, left=newly_left_users)
         else:
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c5f8fcbb2a..e7656fbb9f 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -958,6 +958,7 @@ class MatrixFederationHttpClient:
         )
         return body
 
+    @overload
     async def get_json(
         self,
         destination: str,
@@ -967,7 +968,38 @@ class MatrixFederationHttpClient:
         timeout: Optional[int] = None,
         ignore_backoff: bool = False,
         try_trailing_slash_on_400: bool = False,
+        parser: Literal[None] = None,
+        max_response_size: Optional[int] = None,
     ) -> Union[JsonDict, list]:
+        ...
+
+    @overload
+    async def get_json(
+        self,
+        destination: str,
+        path: str,
+        args: Optional[QueryArgs] = ...,
+        retry_on_dns_fail: bool = ...,
+        timeout: Optional[int] = ...,
+        ignore_backoff: bool = ...,
+        try_trailing_slash_on_400: bool = ...,
+        parser: ByteParser[T] = ...,
+        max_response_size: Optional[int] = ...,
+    ) -> T:
+        ...
+
+    async def get_json(
+        self,
+        destination: str,
+        path: str,
+        args: Optional[QueryArgs] = None,
+        retry_on_dns_fail: bool = True,
+        timeout: Optional[int] = None,
+        ignore_backoff: bool = False,
+        try_trailing_slash_on_400: bool = False,
+        parser: Optional[ByteParser] = None,
+        max_response_size: Optional[int] = None,
+    ):
         """GETs some json from the given host homeserver and path
 
         Args:
@@ -992,6 +1024,13 @@ class MatrixFederationHttpClient:
             try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
                 response we should try appending a trailing slash to the end of
                 the request. Workaround for #3622 in Synapse <= v0.99.3.
+
+            parser: The parser to use to decode the response. Defaults to
+                parsing as JSON.
+
+            max_response_size: The maximum size to read from the response. If None,
+                uses the default.
+
         Returns:
             Succeeds when we get a 2xx HTTP response. The
             result will be the decoded JSON body.
@@ -1026,8 +1065,17 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
+        if parser is None:
+            parser = JsonParser()
+
         body = await _handle_response(
-            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
+            self.reactor,
+            _sec_timeout,
+            request,
+            response,
+            start_ms,
+            parser=parser,
+            max_response_size=max_response_size,
         )
 
         return body
diff --git a/synapse/logging/_structured.py b/synapse/logging/_structured.py
deleted file mode 100644
index b9933a1528..0000000000
--- a/synapse/logging/_structured.py
+++ /dev/null
@@ -1,163 +0,0 @@
-# Copyright 2019 The Matrix.org Foundation C.I.C.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-#     http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import os.path
-from typing import Any, Dict, Generator, Optional, Tuple
-
-from constantly import NamedConstant, Names
-
-from synapse.config._base import ConfigError
-
-
-class DrainType(Names):
-    CONSOLE = NamedConstant()
-    CONSOLE_JSON = NamedConstant()
-    CONSOLE_JSON_TERSE = NamedConstant()
-    FILE = NamedConstant()
-    FILE_JSON = NamedConstant()
-    NETWORK_JSON_TERSE = NamedConstant()
-
-
-DEFAULT_LOGGERS = {"synapse": {"level": "info"}}
-
-
-def parse_drain_configs(
-    drains: dict,
-) -> Generator[Tuple[str, Dict[str, Any]], None, None]:
-    """
-    Parse the drain configurations.
-
-    Args:
-        drains (dict): A list of drain configurations.
-
-    Yields:
-        dict instances representing a logging handler.
-
-    Raises:
-        ConfigError: If any of the drain configuration items are invalid.
-    """
-
-    for name, config in drains.items():
-        if "type" not in config:
-            raise ConfigError("Logging drains require a 'type' key.")
-
-        try:
-            logging_type = DrainType.lookupByName(config["type"].upper())
-        except ValueError:
-            raise ConfigError(
-                "%s is not a known logging drain type." % (config["type"],)
-            )
-
-        # Either use the default formatter or the tersejson one.
-        if logging_type in (
-            DrainType.CONSOLE_JSON,
-            DrainType.FILE_JSON,
-        ):
-            formatter: Optional[str] = "json"
-        elif logging_type in (
-            DrainType.CONSOLE_JSON_TERSE,
-            DrainType.NETWORK_JSON_TERSE,
-        ):
-            formatter = "tersejson"
-        else:
-            # A formatter of None implies using the default formatter.
-            formatter = None
-
-        if logging_type in [
-            DrainType.CONSOLE,
-            DrainType.CONSOLE_JSON,
-            DrainType.CONSOLE_JSON_TERSE,
-        ]:
-            location = config.get("location")
-            if location is None or location not in ["stdout", "stderr"]:
-                raise ConfigError(
-                    (
-                        "The %s drain needs the 'location' key set to "
-                        "either 'stdout' or 'stderr'."
-                    )
-                    % (logging_type,)
-                )
-
-            yield name, {
-                "class": "logging.StreamHandler",
-                "formatter": formatter,
-                "stream": "ext://sys." + location,
-            }
-
-        elif logging_type in [DrainType.FILE, DrainType.FILE_JSON]:
-            if "location" not in config:
-                raise ConfigError(
-                    "The %s drain needs the 'location' key set." % (logging_type,)
-                )
-
-            location = config.get("location")
-            if os.path.abspath(location) != location:
-                raise ConfigError(
-                    "File paths need to be absolute, '%s' is a relative path"
-                    % (location,)
-                )
-
-            yield name, {
-                "class": "logging.FileHandler",
-                "formatter": formatter,
-                "filename": location,
-            }
-
-        elif logging_type in [DrainType.NETWORK_JSON_TERSE]:
-            host = config.get("host")
-            port = config.get("port")
-            maximum_buffer = config.get("maximum_buffer", 1000)
-
-            yield name, {
-                "class": "synapse.logging.RemoteHandler",
-                "formatter": formatter,
-                "host": host,
-                "port": port,
-                "maximum_buffer": maximum_buffer,
-            }
-
-        else:
-            raise ConfigError(
-                "The %s drain type is currently not implemented."
-                % (config["type"].upper(),)
-            )
-
-
-def setup_structured_logging(
-    log_config: dict,
-) -> dict:
-    """
-    Convert a legacy structured logging configuration (from Synapse < v1.23.0)
-    to one compatible with the new standard library handlers.
-    """
-    if "drains" not in log_config:
-        raise ConfigError("The logging configuration requires a list of drains.")
-
-    new_config = {
-        "version": 1,
-        "formatters": {
-            "json": {"class": "synapse.logging.JsonFormatter"},
-            "tersejson": {"class": "synapse.logging.TerseJsonFormatter"},
-        },
-        "handlers": {},
-        "loggers": log_config.get("loggers", DEFAULT_LOGGERS),
-        "root": {"handlers": []},
-    }
-
-    for handler_name, handler in parse_drain_configs(log_config["drains"]):
-        new_config["handlers"][handler_name] = handler
-
-        # Add each handler to the root logger.
-        new_config["root"]["handlers"].append(handler_name)
-
-    return new_config
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index d4fca36923..07020bfb8d 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -70,6 +70,7 @@ from synapse.handlers.account_validity import (
 from synapse.handlers.auth import (
     CHECK_3PID_AUTH_CALLBACK,
     CHECK_AUTH_CALLBACK,
+    GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK,
     GET_USERNAME_FOR_REGISTRATION_CALLBACK,
     IS_3PID_ALLOWED_CALLBACK,
     ON_LOGGED_OUT_CALLBACK,
@@ -317,6 +318,9 @@ class ModuleApi:
         get_username_for_registration: Optional[
             GET_USERNAME_FOR_REGISTRATION_CALLBACK
         ] = None,
+        get_displayname_for_registration: Optional[
+            GET_DISPLAYNAME_FOR_REGISTRATION_CALLBACK
+        ] = None,
     ) -> None:
         """Registers callbacks for password auth provider capabilities.
 
@@ -328,6 +332,7 @@ class ModuleApi:
             is_3pid_allowed=is_3pid_allowed,
             auth_checkers=auth_checkers,
             get_username_for_registration=get_username_for_registration,
+            get_displayname_for_registration=get_displayname_for_registration,
         )
 
     def register_background_update_controller_callbacks(
@@ -648,7 +653,11 @@ class ModuleApi:
         Added in Synapse v1.9.0.
 
         Args:
-            auth_provider: identifier for the remote auth provider
+            auth_provider: identifier for the remote auth provider, see `sso` and
+                `oidc_providers` in the homeserver configuration.
+
+                Note that no error is raised if the provided value is not in the
+                homeserver configuration.
             external_id: id on that system
             user_id: complete mxid that it is mapped to
         """
diff --git a/synapse/notifier.py b/synapse/notifier.py
index e0fad2da66..753dd6b6a5 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -138,7 +138,7 @@ class _NotifierUserStream:
         self.current_token = self.current_token.copy_and_advance(stream_key, stream_id)
         self.last_notified_token = self.current_token
         self.last_notified_ms = time_now_ms
-        noify_deferred = self.notify_deferred
+        notify_deferred = self.notify_deferred
 
         log_kv(
             {
@@ -153,7 +153,7 @@ class _NotifierUserStream:
 
         with PreserveLoggingContext():
             self.notify_deferred = ObservableDeferred(defer.Deferred())
-            noify_deferred.callback(self.current_token)
+            notify_deferred.callback(self.current_token)
 
     def remove(self, notifier: "Notifier") -> None:
         """Remove this listener from all the indexes in the Notifier
diff --git a/synapse/push/baserules.py b/synapse/push/baserules.py
index 910b05c0da..832eaa34e9 100644
--- a/synapse/push/baserules.py
+++ b/synapse/push/baserules.py
@@ -130,7 +130,9 @@ def make_base_prepend_rules(
     return rules
 
 
-BASE_APPEND_CONTENT_RULES = [
+# We have to annotate these types, otherwise mypy infers them as
+# `List[Dict[str, Sequence[Collection[str]]]]`.
+BASE_APPEND_CONTENT_RULES: List[Dict[str, Any]] = [
     {
         "rule_id": "global/content/.m.rule.contains_user_name",
         "conditions": [
@@ -149,7 +151,7 @@ BASE_APPEND_CONTENT_RULES = [
 ]
 
 
-BASE_PREPEND_OVERRIDE_RULES = [
+BASE_PREPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
     {
         "rule_id": "global/override/.m.rule.master",
         "enabled": False,
@@ -159,7 +161,7 @@ BASE_PREPEND_OVERRIDE_RULES = [
 ]
 
 
-BASE_APPEND_OVERRIDE_RULES = [
+BASE_APPEND_OVERRIDE_RULES: List[Dict[str, Any]] = [
     {
         "rule_id": "global/override/.m.rule.suppress_notices",
         "conditions": [
@@ -278,7 +280,7 @@ BASE_APPEND_OVERRIDE_RULES = [
 ]
 
 
-BASE_APPEND_UNDERRIDE_RULES = [
+BASE_APPEND_UNDERRIDE_RULES: List[Dict[str, Any]] = [
     {
         "rule_id": "global/underride/.m.rule.call",
         "conditions": [
diff --git a/synapse/push/httppusher.py b/synapse/push/httppusher.py
index 96559081d0..52c7ff3572 100644
--- a/synapse/push/httppusher.py
+++ b/synapse/push/httppusher.py
@@ -109,6 +109,7 @@ class HttpPusher(Pusher):
         self.data_minus_url = {}
         self.data_minus_url.update(self.data)
         del self.data_minus_url["url"]
+        self.badge_count_last_call: Optional[int] = None
 
     def on_started(self, should_check_for_notifs: bool) -> None:
         """Called when this pusher has been started.
@@ -136,7 +137,9 @@ class HttpPusher(Pusher):
             self.user_id,
             group_by_room=self._group_unread_count_by_room,
         )
-        await self._send_badge(badge)
+        if self.badge_count_last_call is None or self.badge_count_last_call != badge:
+            self.badge_count_last_call = badge
+            await self._send_badge(badge)
 
     def on_timer(self) -> None:
         self._start_processing()
@@ -322,7 +325,7 @@ class HttpPusher(Pusher):
         # This was checked in the __init__, but mypy doesn't seem to know that.
         assert self.data is not None
         if self.data.get("format") == "event_id_only":
-            d = {
+            d: Dict[str, Any] = {
                 "notification": {
                     "event_id": event.event_id,
                     "room_id": event.room_id,
@@ -402,6 +405,8 @@ class HttpPusher(Pusher):
         rejected = []
         if "rejected" in resp:
             rejected = resp["rejected"]
+        else:
+            self.badge_count_last_call = badge
         return rejected
 
     async def _send_badge(self, badge: int) -> None:
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 86162e0f2c..f43fbb5842 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -87,7 +87,8 @@ REQUIREMENTS = [
     # We enforce that we have a `cryptography` version that bundles an `openssl`
     # with the latest security patches.
     "cryptography>=3.4.7",
-    "ijson>=3.1",
+    # ijson 3.1.4 fixes a bug with "." in property names
+    "ijson>=3.1.4",
     "matrix-common~=1.1.0",
 ]
 
diff --git a/synapse/res/providers.json b/synapse/res/providers.json
index f1838f9559..7b9958e454 100644
--- a/synapse/res/providers.json
+++ b/synapse/res/providers.json
@@ -5,8 +5,6 @@
         "endpoints": [
             {
                 "schemes": [
-                    "https://twitter.com/*/status/*",
-                    "https://*.twitter.com/*/status/*",
                     "https://twitter.com/*/moments/*",
                     "https://*.twitter.com/*/moments/*"
                 ],
@@ -14,4 +12,4 @@
             }
         ]
     }
-]
\ No newline at end of file
+]
diff --git a/synapse/rest/client/account.py b/synapse/rest/client/account.py
index cfa2aee76d..efe299e698 100644
--- a/synapse/rest/client/account.py
+++ b/synapse/rest/client/account.py
@@ -883,7 +883,9 @@ class WhoamiRestServlet(RestServlet):
         response = {
             "user_id": requester.user.to_string(),
             # MSC: https://github.com/matrix-org/matrix-doc/pull/3069
+            # Entered spec in Matrix 1.2
             "org.matrix.msc3069.is_guest": bool(requester.is_guest),
+            "is_guest": bool(requester.is_guest),
         }
 
         # Appservices and similar accounts do not have device IDs
diff --git a/synapse/rest/client/auth.py b/synapse/rest/client/auth.py
index 9c15a04338..e0b2b80e5b 100644
--- a/synapse/rest/client/auth.py
+++ b/synapse/rest/client/auth.py
@@ -62,7 +62,7 @@ class AuthRestServlet(RestServlet):
         if stagetype == LoginType.RECAPTCHA:
             html = self.recaptcha_template.render(
                 session=session,
-                myurl="%s/r0/auth/%s/fallback/web"
+                myurl="%s/v3/auth/%s/fallback/web"
                 % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
                 sitekey=self.hs.config.captcha.recaptcha_public_key,
             )
@@ -74,7 +74,7 @@ class AuthRestServlet(RestServlet):
                     self.hs.config.server.public_baseurl,
                     self.hs.config.consent.user_consent_version,
                 ),
-                myurl="%s/r0/auth/%s/fallback/web"
+                myurl="%s/v3/auth/%s/fallback/web"
                 % (CLIENT_API_PREFIX, LoginType.TERMS),
             )
 
@@ -118,7 +118,7 @@ class AuthRestServlet(RestServlet):
                 # Authentication failed, let user try again
                 html = self.recaptcha_template.render(
                     session=session,
-                    myurl="%s/r0/auth/%s/fallback/web"
+                    myurl="%s/v3/auth/%s/fallback/web"
                     % (CLIENT_API_PREFIX, LoginType.RECAPTCHA),
                     sitekey=self.hs.config.captcha.recaptcha_public_key,
                     error=e.msg,
@@ -143,7 +143,7 @@ class AuthRestServlet(RestServlet):
                         self.hs.config.server.public_baseurl,
                         self.hs.config.consent.user_consent_version,
                     ),
-                    myurl="%s/r0/auth/%s/fallback/web"
+                    myurl="%s/v3/auth/%s/fallback/web"
                     % (CLIENT_API_PREFIX, LoginType.TERMS),
                     error=e.msg,
                 )
diff --git a/synapse/rest/client/capabilities.py b/synapse/rest/client/capabilities.py
index 6682da077a..e05c926b6f 100644
--- a/synapse/rest/client/capabilities.py
+++ b/synapse/rest/client/capabilities.py
@@ -72,20 +72,6 @@ class CapabilitiesRestServlet(RestServlet):
                 "org.matrix.msc3244.room_capabilities"
             ] = MSC3244_CAPABILITIES
 
-        # Must be removed in later versions.
-        # Is only included for migration.
-        # Also the parts in `synapse/config/experimental.py`.
-        if self.config.experimental.msc3283_enabled:
-            response["capabilities"]["org.matrix.msc3283.set_displayname"] = {
-                "enabled": self.config.registration.enable_set_displayname
-            }
-            response["capabilities"]["org.matrix.msc3283.set_avatar_url"] = {
-                "enabled": self.config.registration.enable_set_avatar_url
-            }
-            response["capabilities"]["org.matrix.msc3283.3pid_changes"] = {
-                "enabled": self.config.registration.enable_3pid_changes
-            }
-
         if self.config.experimental.msc3440_enabled:
             response["capabilities"]["io.element.thread"] = {"enabled": True}
 
diff --git a/synapse/rest/client/register.py b/synapse/rest/client/register.py
index c965e2bda2..b8a5135e02 100644
--- a/synapse/rest/client/register.py
+++ b/synapse/rest/client/register.py
@@ -694,11 +694,18 @@ class RegisterRestServlet(RestServlet):
                 session_id
             )
 
+            display_name = await (
+                self.password_auth_provider.get_displayname_for_registration(
+                    auth_result, params
+                )
+            )
+
             registered_user_id = await self.registration_handler.register_user(
                 localpart=desired_username,
                 password_hash=password_hash,
                 guest_access_token=guest_access_token,
                 threepid=threepid,
+                default_display_name=display_name,
                 address=client_addr,
                 user_agent_ips=entries,
             )
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index 2290c57c12..00f29344a8 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -73,6 +73,8 @@ class VersionsRestServlet(RestServlet):
                     "r0.5.0",
                     "r0.6.0",
                     "r0.6.1",
+                    "v1.1",
+                    "v1.2",
                 ],
                 # as per MSC1497:
                 "unstable_features": {
diff --git a/synapse/rest/media/v1/preview_url_resource.py b/synapse/rest/media/v1/preview_url_resource.py
index 8d3d1e54dc..c08b60d10a 100644
--- a/synapse/rest/media/v1/preview_url_resource.py
+++ b/synapse/rest/media/v1/preview_url_resource.py
@@ -402,7 +402,15 @@ class PreviewUrlResource(DirectServeJsonResource):
                 url,
                 output_stream=output_stream,
                 max_size=self.max_spider_size,
-                headers={"Accept-Language": self.url_preview_accept_language},
+                headers={
+                    b"Accept-Language": self.url_preview_accept_language,
+                    # Use a custom user agent for the preview because some sites will only return
+                    # Open Graph metadata to crawler user agents. Omit the Synapse version
+                    # string to avoid leaking information.
+                    b"User-Agent": [
+                        "Synapse (bot; +https://github.com/matrix-org/synapse)"
+                    ],
+                },
                 is_allowed_content_type=_is_previewable,
             )
         except SynapseError:
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 8d845fe951..3b3a089b76 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -670,6 +670,16 @@ class DeviceWorkerStore(SQLBaseStore):
             device["device_id"]: db_to_json(device["content"]) for device in devices
         }
 
+    def get_cached_device_list_changes(
+        self,
+        from_key: int,
+    ) -> Optional[Set[str]]:
+        """Get set of users whose devices have changed since `from_key`, or None
+        if that information is not in our cache.
+        """
+
+        return self._device_list_stream_cache.get_all_entities_changed(from_key)
+
     async def get_users_whose_devices_changed(
         self, from_key: int, user_ids: Iterable[str]
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 5246fccad5..a1d7a9b413 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -975,6 +975,17 @@ class PersistEventsStore:
             to_delete = delta_state.to_delete
             to_insert = delta_state.to_insert
 
+            # Figure out the changes of membership to invalidate the
+            # `get_rooms_for_user` cache.
+            # We find out which membership events we may have deleted
+            # and which we have added, then we invalidate the caches for all
+            # those users.
+            members_changed = {
+                state_key
+                for ev_type, state_key in itertools.chain(to_delete, to_insert)
+                if ev_type == EventTypes.Member
+            }
+
             if delta_state.no_longer_in_room:
                 # Server is no longer in the room so we delete the room from
                 # current_state_events, being careful we've already updated the
@@ -993,6 +1004,11 @@ class PersistEventsStore:
                 """
                 txn.execute(sql, (stream_id, self._instance_name, room_id))
 
+                # We also want to invalidate the membership caches for users
+                # that were in the room.
+                users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+                members_changed.update(users_in_room)
+
                 self.db_pool.simple_delete_txn(
                     txn,
                     table="current_state_events",
@@ -1102,17 +1118,6 @@ class PersistEventsStore:
 
             # Invalidate the various caches
 
-            # Figure out the changes of membership to invalidate the
-            # `get_rooms_for_user` cache.
-            # We find out which membership events we may have deleted
-            # and which we have added, then we invalidate the caches for all
-            # those users.
-            members_changed = {
-                state_key
-                for ev_type, state_key in itertools.chain(to_delete, to_insert)
-                if ev_type == EventTypes.Member
-            }
-
             for member in members_changed:
                 txn.call_after(
                     self.store.get_rooms_for_user_with_stream_ordering.invalidate,
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8d4287045a..2a255d1031 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore):
                 include the previous states content in the unsigned field.
 
             allow_rejected: If True, return rejected events. Otherwise,
-                omits rejeted events from the response.
+                omits rejected events from the response.
 
         Returns:
             A mapping from event_id to event.
@@ -1854,7 +1854,7 @@ class EventsWorkerStore(SQLBaseStore):
             forward_edge_query = """
                 SELECT 1 FROM event_edges
                 /* Check to make sure the event referencing our event in question is not rejected */
-                LEFT JOIN rejections ON event_edges.event_id == rejections.event_id
+                LEFT JOIN rejections ON event_edges.event_id = rejections.event_id
                 WHERE
                     event_edges.room_id = ?
                     AND event_edges.prev_event_id = ?
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 4f05811a77..d3c4611686 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,15 +12,23 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
 
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
 from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.engines import PostgresEngine
 from synapse.storage.types import Connection
-from synapse.storage.util.id_generators import MultiWriterIdGenerator, StreamIdGenerator
+from synapse.storage.util.id_generators import (
+    AbstractStreamIdGenerator,
+    MultiWriterIdGenerator,
+    StreamIdGenerator,
+)
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.iterutils import batch_iter
@@ -35,7 +43,7 @@ class PresenceBackgroundUpdateStore(SQLBaseStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
         # Used by `PresenceStore._get_active_presence()`
@@ -54,11 +62,14 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
+        self._instance_name = hs.get_instance_name()
+        self._presence_id_gen: AbstractStreamIdGenerator
+
         self._can_persist_presence = (
-            hs.get_instance_name() in hs.config.worker.writers.presence
+            self._instance_name in hs.config.worker.writers.presence
         )
 
         if isinstance(database.engine, PostgresEngine):
@@ -109,7 +120,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return stream_orderings[-1], self._presence_id_gen.get_current_token()
 
-    def _update_presence_txn(self, txn, stream_orderings, presence_states):
+    def _update_presence_txn(
+        self, txn: LoggingTransaction, stream_orderings, presence_states
+    ) -> None:
         for stream_id, state in zip(stream_orderings, presence_states):
             txn.call_after(
                 self.presence_stream_cache.entity_has_changed, state.user_id, stream_id
@@ -183,19 +196,23 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_presence_updates_txn(txn):
+        def get_all_presence_updates_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, list]], int, bool]:
             sql = """
                 SELECT stream_id, user_id, state, last_active_ts,
                     last_federation_update_ts, last_user_sync_ts,
-                    status_msg,
-                currently_active
+                    status_msg, currently_active
                 FROM presence_stream
                 WHERE ? < stream_id AND stream_id <= ?
                 ORDER BY stream_id ASC
                 LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [(row[0], row[1:]) for row in txn]
+            updates = cast(
+                List[Tuple[int, list]],
+                [(row[0], row[1:]) for row in txn],
+            )
 
             upper_bound = current_id
             limited = False
@@ -210,7 +227,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         )
 
     @cached()
-    def _get_presence_for_user(self, user_id):
+    def _get_presence_for_user(self, user_id: str) -> None:
         raise NotImplementedError()
 
     @cachedList(
@@ -218,7 +235,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         list_name="user_ids",
         num_args=1,
     )
-    async def get_presence_for_users(self, user_ids):
+    async def get_presence_for_users(
+        self, user_ids: Iterable[str]
+    ) -> Dict[str, UserPresenceState]:
         rows = await self.db_pool.simple_select_many_batch(
             table="presence_stream",
             column="user_id",
@@ -257,7 +276,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             True if the user should have full presence sent to them, False otherwise.
         """
 
-        def _should_user_receive_full_presence_with_token_txn(txn):
+        def _should_user_receive_full_presence_with_token_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
             sql = """
                 SELECT 1 FROM users_to_send_full_presence_to
                 WHERE user_id = ?
@@ -271,7 +292,7 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             _should_user_receive_full_presence_with_token_txn,
         )
 
-    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]):
+    async def add_users_to_send_full_presence_to(self, user_ids: Iterable[str]) -> None:
         """Adds to the list of users who should receive a full snapshot of presence
         upon their next sync.
 
@@ -353,10 +374,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return users_to_state
 
-    def get_current_presence_token(self):
+    def get_current_presence_token(self) -> int:
         return self._presence_id_gen.get_current_token()
 
-    def _get_active_presence(self, db_conn: Connection):
+    def _get_active_presence(self, db_conn: Connection) -> List[UserPresenceState]:
         """Fetch non-offline presence from the database so that we can register
         the appropriate time outs.
         """
@@ -379,12 +400,12 @@ class PresenceStore(PresenceBackgroundUpdateStore):
 
         return [UserPresenceState(**row) for row in rows]
 
-    def take_presence_startup_info(self):
+    def take_presence_startup_info(self) -> List[UserPresenceState]:
         active_on_startup = self._presence_on_startup
-        self._presence_on_startup = None
+        self._presence_on_startup = []
         return active_on_startup
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
         if stream_name == PresenceStream.NAME:
             self._presence_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/storage/databases/main/purge_events.py b/synapse/storage/databases/main/purge_events.py
index e87a8fb85d..2e3818e432 100644
--- a/synapse/storage/databases/main/purge_events.py
+++ b/synapse/storage/databases/main/purge_events.py
@@ -13,9 +13,10 @@
 # limitations under the License.
 
 import logging
-from typing import Any, List, Set, Tuple
+from typing import Any, List, Set, Tuple, cast
 
 from synapse.api.errors import SynapseError
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main import CacheInvalidationWorkerStore
 from synapse.storage.databases.main.state import StateGroupWorkerStore
 from synapse.types import RoomStreamToken
@@ -55,7 +56,11 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         )
 
     def _purge_history_txn(
-        self, txn, room_id: str, token: RoomStreamToken, delete_local_events: bool
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        token: RoomStreamToken,
+        delete_local_events: bool,
     ) -> Set[int]:
         # Tables that should be pruned:
         #     event_auth
@@ -273,7 +278,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
         """,
             (room_id,),
         )
-        (min_depth,) = txn.fetchone()
+        (min_depth,) = cast(Tuple[int], txn.fetchone())
 
         logger.info("[purge] updating room_depth to %d", min_depth)
 
@@ -318,7 +323,7 @@ class PurgeEventsStore(StateGroupWorkerStore, CacheInvalidationWorkerStore):
             "purge_room", self._purge_room_txn, room_id
         )
 
-    def _purge_room_txn(self, txn, room_id: str) -> List[int]:
+    def _purge_room_txn(self, txn: LoggingTransaction, room_id: str) -> List[int]:
         # First we fetch all the state groups that should be deleted, before
         # we delete that information.
         txn.execute(
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index aac94fa464..17110bb033 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -622,10 +622,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     ) -> None:
         """Record a mapping from an external user id to a mxid
 
+        See notes in _record_user_external_id_txn about what constitutes valid data.
+
         Args:
             auth_provider: identifier for the remote auth provider
             external_id: id on that system
             user_id: complete mxid that it is mapped to
+
         Raises:
             ExternalIDReuseException if the new external_id could not be mapped.
         """
@@ -648,6 +651,21 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         external_id: str,
         user_id: str,
     ) -> None:
+        """
+        Record a mapping from an external user id to a mxid.
+
+        Note that the auth provider IDs (and the external IDs) are not validated
+        against configured IdPs as Synapse does not know its relationship to
+        external systems. For example, it might be useful to pre-configure users
+        before enabling a new IdP or an IdP might be temporarily offline, but
+        still valid.
+
+        Args:
+            txn: The database transaction.
+            auth_provider: identifier for the remote auth provider
+            external_id: id on that system
+            user_id: complete mxid that it is mapped to
+        """
 
         self.db_pool.simple_insert_txn(
             txn,
@@ -687,10 +705,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         """Replace mappings from external user ids to a mxid in a single transaction.
         All mappings are deleted and the new ones are created.
 
+        See notes in _record_user_external_id_txn about what constitutes valid data.
+
         Args:
             record_external_ids:
                 List with tuple of auth_provider and external_id to record
             user_id: complete mxid that it is mapped to
+
         Raises:
             ExternalIDReuseException if the new external_id could not be mapped.
         """
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index e2c27e594b..36aa1092f6 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -53,8 +53,13 @@ logger = logging.getLogger(__name__)
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _ThreadAggregation:
+    # The latest event in the thread.
     latest_event: EventBase
+    # The latest edit to the latest event in the thread.
+    latest_edit: Optional[EventBase]
+    # The total number of events in the thread.
     count: int
+    # True if the current user has sent an event to the thread.
     current_user_participated: bool
 
 
@@ -461,8 +466,8 @@ class RelationsWorkerStore(SQLBaseStore):
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
     async def _get_thread_summaries(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
-        """Get the number of threaded replies and the latest reply (if any) for the given event.
+    ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
+        """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
 
         Args:
             event_ids: Summarize the thread related to this event ID.
@@ -471,8 +476,10 @@ class RelationsWorkerStore(SQLBaseStore):
             A map of the thread summary each event. A missing event implies there
             are no threaded replies.
 
-            Each summary includes the number of items in the thread and the most
-            recent response.
+            Each summary is a tuple of:
+                The number of events in the thread.
+                The most recent event in the thread.
+                The most recent edit to the most recent event in the thread, if applicable.
         """
 
         def _get_thread_summaries_txn(
@@ -482,7 +489,7 @@ class RelationsWorkerStore(SQLBaseStore):
             # TODO Should this only allow m.room.message events.
             if isinstance(self.database_engine, PostgresEngine):
                 # The `DISTINCT ON` clause will pick the *first* row it encounters,
-                # so ordering by topologica ordering + stream ordering desc will
+                # so ordering by topological ordering + stream ordering desc will
                 # ensure we get the latest event in the thread.
                 sql = """
                     SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
@@ -558,6 +565,9 @@ class RelationsWorkerStore(SQLBaseStore):
 
         latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
+        # Check to see if any of those events are edited.
+        latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+
         # Map to the event IDs to the thread summary.
         #
         # There might not be a summary due to there not being a thread or
@@ -568,7 +578,8 @@ class RelationsWorkerStore(SQLBaseStore):
 
             summary = None
             if latest_event:
-                summary = (counts[parent_event_id], latest_event)
+                latest_edit = latest_edits.get(latest_event_id)
+                summary = (counts[parent_event_id], latest_event, latest_edit)
             summaries[parent_event_id] = summary
 
         return summaries
@@ -828,11 +839,12 @@ class RelationsWorkerStore(SQLBaseStore):
             )
             for event_id, summary in summaries.items():
                 if summary:
-                    thread_count, latest_thread_event = summary
+                    thread_count, latest_thread_event, edit = summary
                     results.setdefault(
                         event_id, BundledAggregations()
                     ).thread = _ThreadAggregation(
                         latest_event=latest_thread_event,
+                        latest_edit=edit,
                         count=thread_count,
                         # If there's a thread summary it must also exist in the
                         # participated dictionary.
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 95167116c9..0416df64ce 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -1498,7 +1498,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         self._event_reports_id_gen = IdGenerator(db_conn, "event_reports", "id")
 
     async def upsert_room_on_join(
-        self, room_id: str, room_version: RoomVersion, auth_events: List[EventBase]
+        self, room_id: str, room_version: RoomVersion, state_events: List[EventBase]
     ) -> None:
         """Ensure that the room is stored in the table
 
@@ -1511,7 +1511,7 @@ class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
         has_auth_chain_index = await self.has_auth_chain_index(room_id)
 
         create_event = None
-        for e in auth_events:
+        for e in state_events:
             if (e.type, e.state_key) == (EventTypes.Create, ""):
                 create_event = e
                 break
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4489732fda..e48ec5f495 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -504,6 +504,68 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             for room_id, instance, stream_id in txn
         )
 
+    @cachedList(
+        cached_method_name="get_rooms_for_user_with_stream_ordering",
+        list_name="user_ids",
+    )
+    async def get_rooms_for_users_with_stream_ordering(
+        self, user_ids: Collection[str]
+    ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+        """A batched version of `get_rooms_for_user_with_stream_ordering`.
+
+        Returns:
+            Map from user_id to set of rooms that is currently in.
+        """
+        return await self.db_pool.runInteraction(
+            "get_rooms_for_users_with_stream_ordering",
+            self._get_rooms_for_users_with_stream_ordering_txn,
+            user_ids,
+        )
+
+    def _get_rooms_for_users_with_stream_ordering_txn(
+        self, txn, user_ids: Collection[str]
+    ) -> Dict[str, FrozenSet[GetRoomsForUserWithStreamOrdering]]:
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine,
+            "c.state_key",
+            user_ids,
+        )
+
+        if self._current_state_events_membership_up_to_date:
+            sql = f"""
+                SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND c.membership = ?
+                    AND {clause}
+            """
+        else:
+            sql = f"""
+                SELECT c.state_key, room_id, e.instance_name, e.stream_ordering
+                FROM current_state_events AS c
+                INNER JOIN room_memberships AS m USING (room_id, event_id)
+                INNER JOIN events AS e USING (room_id, event_id)
+                WHERE
+                    c.type = 'm.room.member'
+                    AND m.membership = ?
+                    AND {clause}
+            """
+
+        txn.execute(sql, [Membership.JOIN] + args)
+
+        result = {user_id: set() for user_id in user_ids}
+        for user_id, room_id, instance, stream_id in txn:
+            result[user_id].add(
+                GetRoomsForUserWithStreamOrdering(
+                    room_id, PersistedEventPosition(instance, stream_id)
+                )
+            )
+
+        return {user_id: frozenset(v) for user_id, v in result.items()}
+
     async def get_users_server_still_shares_room_with(
         self, user_ids: Collection[str]
     ) -> Set[str]:
diff --git a/synapse/storage/databases/main/search.py b/synapse/storage/databases/main/search.py
index 2d085a5764..acea300ed3 100644
--- a/synapse/storage/databases/main/search.py
+++ b/synapse/storage/databases/main/search.py
@@ -28,6 +28,7 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.types import JsonDict
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -381,17 +382,19 @@ class SearchStore(SearchBackgroundUpdateStore):
     ):
         super().__init__(database, db_conn, hs)
 
-    async def search_msgs(self, room_ids, search_term, keys):
+    async def search_msgs(
+        self, room_ids: Collection[str], search_term: str, keys: Iterable[str]
+    ) -> JsonDict:
         """Performs a full text search over events with given keys.
 
         Args:
-            room_ids (list): List of room ids to search in
-            search_term (str): Search term to search for
-            keys (list): List of keys to search in, currently supports
+            room_ids: List of room ids to search in
+            search_term: Search term to search for
+            keys: List of keys to search in, currently supports
                 "content.body", "content.name", "content.topic"
 
         Returns:
-            list of dicts
+            Dictionary of results
         """
         clauses = []
 
@@ -499,10 +502,10 @@ class SearchStore(SearchBackgroundUpdateStore):
         self,
         room_ids: Collection[str],
         search_term: str,
-        keys: List[str],
+        keys: Iterable[str],
         limit,
         pagination_token: Optional[str] = None,
-    ) -> List[dict]:
+    ) -> JsonDict:
         """Performs a full text search over events with given keys.
 
         Args:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index f7c778bdf2..e7fddd2426 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -58,7 +58,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         database: DatabasePool,
         db_conn: LoggingDatabaseConnection,
         hs: "HomeServer",
-    ):
+    ) -> None:
         super().__init__(database, db_conn, hs)
 
         self.server_name = hs.hostname
@@ -234,10 +234,10 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         processed_event_count = 0
 
         for room_id, event_count in rooms_to_work_on:
-            is_in_room = await self.is_host_joined(room_id, self.server_name)
+            is_in_room = await self.is_host_joined(room_id, self.server_name)  # type: ignore[attr-defined]
 
             if is_in_room:
-                users_with_profile = await self.get_users_in_room_with_profiles(room_id)
+                users_with_profile = await self.get_users_in_room_with_profiles(room_id)  # type: ignore[attr-defined]
                 # Throw away users excluded from the directory.
                 users_with_profile = {
                     user_id: profile
@@ -368,7 +368,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
 
         for user_id in users_to_work_on:
             if await self.should_include_local_user_in_dir(user_id):
-                profile = await self.get_profileinfo(get_localpart_from_id(user_id))
+                profile = await self.get_profileinfo(get_localpart_from_id(user_id))  # type: ignore[attr-defined]
                 await self.update_profile_in_user_dir(
                     user_id, profile.display_name, profile.avatar_url
                 )
@@ -397,7 +397,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         # technically it could be DM-able. In the future, this could potentially
         # be configurable per-appservice whether the appservice sender can be
         # contacted.
-        if self.get_app_service_by_user_id(user) is not None:
+        if self.get_app_service_by_user_id(user) is not None:  # type: ignore[attr-defined]
             return False
 
         # We're opting to exclude appservice users (anyone matching the user
@@ -405,17 +405,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         # they could be DM-able. In the future, this could potentially
         # be configurable per-appservice whether the appservice users can be
         # contacted.
-        if self.get_if_app_services_interested_in_user(user):
+        if self.get_if_app_services_interested_in_user(user):  # type: ignore[attr-defined]
             # TODO we might want to make this configurable for each app service
             return False
 
         # Support users are for diagnostics and should not appear in the user directory.
-        if await self.is_support_user(user):
+        if await self.is_support_user(user):  # type: ignore[attr-defined]
             return False
 
         # Deactivated users aren't contactable, so should not appear in the user directory.
         try:
-            if await self.get_user_deactivated_status(user):
+            if await self.get_user_deactivated_status(user):  # type: ignore[attr-defined]
                 return False
         except StoreError:
             # No such user in the users table. No need to do this when calling
@@ -433,20 +433,20 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             (EventTypes.RoomHistoryVisibility, ""),
         )
 
-        current_state_ids = await self.get_filtered_current_state_ids(
+        current_state_ids = await self.get_filtered_current_state_ids(  # type: ignore[attr-defined]
             room_id, StateFilter.from_types(types_to_filter)
         )
 
         join_rules_id = current_state_ids.get((EventTypes.JoinRules, ""))
         if join_rules_id:
-            join_rule_ev = await self.get_event(join_rules_id, allow_none=True)
+            join_rule_ev = await self.get_event(join_rules_id, allow_none=True)  # type: ignore[attr-defined]
             if join_rule_ev:
                 if join_rule_ev.content.get("join_rule") == JoinRules.PUBLIC:
                     return True
 
         hist_vis_id = current_state_ids.get((EventTypes.RoomHistoryVisibility, ""))
         if hist_vis_id:
-            hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)
+            hist_vis_ev = await self.get_event(hist_vis_id, allow_none=True)  # type: ignore[attr-defined]
             if hist_vis_ev:
                 if (
                     hist_vis_ev.content.get("history_visibility")
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index 7614d76ac6..3af69a2076 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -13,11 +13,23 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 
+from twisted.internet import defer
+
 from synapse.api.constants import EventTypes
+from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.storage._base import SQLBaseStore
 from synapse.storage.database import (
     DatabasePool,
@@ -29,6 +41,12 @@ from synapse.storage.state import StateFilter
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import MutableStateMap, StateKey, StateMap
+from synapse.util import unwrapFirstError
+from synapse.util.async_helpers import (
+    AbstractObservableDeferred,
+    ObservableDeferred,
+    yieldable_gather_results,
+)
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
 
@@ -37,7 +55,6 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-
 MAX_STATE_DELTA_HOPS = 100
 
 
@@ -106,6 +123,12 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             500000,
         )
 
+        # Current ongoing get_state_for_groups in-flight requests
+        # {group ID -> {StateFilter -> ObservableDeferred}}
+        self._state_group_inflight_requests: Dict[
+            int, Dict[StateFilter, AbstractObservableDeferred[StateMap[str]]]
+        ] = {}
+
         def get_max_state_group_txn(txn: Cursor) -> int:
             txn.execute("SELECT COALESCE(max(id), 0) FROM state_groups")
             return txn.fetchone()[0]  # type: ignore
@@ -157,7 +180,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         )
 
     async def _get_state_groups_from_groups(
-        self, groups: List[int], state_filter: StateFilter
+        self, groups: Sequence[int], state_filter: StateFilter
     ) -> Dict[int, StateMap[str]]:
         """Returns the state groups for a given set of groups from the
         database, filtering on types of state events.
@@ -228,6 +251,150 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return state_filter.filter_state(state_dict_ids), not missing_types
 
+    def _get_state_for_group_gather_inflight_requests(
+        self, group: int, state_filter_left_over: StateFilter
+    ) -> Tuple[Sequence[AbstractObservableDeferred[StateMap[str]]], StateFilter]:
+        """
+        Attempts to gather in-flight requests and re-use them to retrieve state
+        for the given state group, filtered with the given state filter.
+
+        Used as part of _get_state_for_group_using_inflight_cache.
+
+        Returns:
+            Tuple of two values:
+                A sequence of ObservableDeferreds to observe
+                A StateFilter representing what else needs to be requested to fulfill the request
+        """
+
+        inflight_requests = self._state_group_inflight_requests.get(group)
+        if inflight_requests is None:
+            # no requests for this group, need to retrieve it all ourselves
+            return (), state_filter_left_over
+
+        # The list of ongoing requests which will help narrow the current request.
+        reusable_requests = []
+        for (request_state_filter, request_deferred) in inflight_requests.items():
+            new_state_filter_left_over = state_filter_left_over.approx_difference(
+                request_state_filter
+            )
+            if new_state_filter_left_over == state_filter_left_over:
+                # Reusing this request would not gain us anything, so don't bother.
+                continue
+
+            reusable_requests.append(request_deferred)
+            state_filter_left_over = new_state_filter_left_over
+            if state_filter_left_over == StateFilter.none():
+                # we have managed to collect enough of the in-flight requests
+                # to cover our StateFilter and give us the state we need.
+                break
+
+        return reusable_requests, state_filter_left_over
+
+    async def _get_state_for_group_fire_request(
+        self, group: int, state_filter: StateFilter
+    ) -> StateMap[str]:
+        """
+        Fires off a request to get the state at a state group,
+        potentially filtering by type and/or state key.
+
+        This request will be tracked in the in-flight request cache and automatically
+        removed when it is finished.
+
+        Used as part of _get_state_for_group_using_inflight_cache.
+
+        Args:
+            group: ID of the state group for which we want to get state
+            state_filter: the state filter used to fetch state from the database
+        """
+        cache_sequence_nm = self._state_group_cache.sequence
+        cache_sequence_m = self._state_group_members_cache.sequence
+
+        # Help the cache hit ratio by expanding the filter a bit
+        db_state_filter = state_filter.return_expanded()
+
+        async def _the_request() -> StateMap[str]:
+            group_to_state_dict = await self._get_state_groups_from_groups(
+                (group,), state_filter=db_state_filter
+            )
+
+            # Now let's update the caches
+            self._insert_into_cache(
+                group_to_state_dict,
+                db_state_filter,
+                cache_seq_num_members=cache_sequence_m,
+                cache_seq_num_non_members=cache_sequence_nm,
+            )
+
+            # Remove ourselves from the in-flight cache
+            group_request_dict = self._state_group_inflight_requests[group]
+            del group_request_dict[db_state_filter]
+            if not group_request_dict:
+                # If there are no more requests in-flight for this group,
+                # clean up the cache by removing the empty dictionary
+                del self._state_group_inflight_requests[group]
+
+            return group_to_state_dict[group]
+
+        # We don't immediately await the result, so must use run_in_background
+        # But we DO await the result before the current log context (request)
+        # finishes, so don't need to run it as a background process.
+        request_deferred = run_in_background(_the_request)
+        observable_deferred = ObservableDeferred(request_deferred, consumeErrors=True)
+
+        # Insert the ObservableDeferred into the cache
+        group_request_dict = self._state_group_inflight_requests.setdefault(group, {})
+        group_request_dict[db_state_filter] = observable_deferred
+
+        return await make_deferred_yieldable(observable_deferred.observe())
+
+    async def _get_state_for_group_using_inflight_cache(
+        self, group: int, state_filter: StateFilter
+    ) -> MutableStateMap[str]:
+        """
+        Gets the state at a state group, potentially filtering by type and/or
+        state key.
+
+        1. Calls _get_state_for_group_gather_inflight_requests to gather any
+           ongoing requests which might overlap with the current request.
+        2. Fires a new request, using _get_state_for_group_fire_request,
+           for any state which cannot be gathered from ongoing requests.
+
+        Args:
+            group: ID of the state group for which we want to get state
+            state_filter: the state filter used to fetch state from the database
+        Returns:
+            state map
+        """
+
+        # first, figure out whether we can re-use any in-flight requests
+        # (and if so, what would be left over)
+        (
+            reusable_requests,
+            state_filter_left_over,
+        ) = self._get_state_for_group_gather_inflight_requests(group, state_filter)
+
+        if state_filter_left_over != StateFilter.none():
+            # Fetch remaining state
+            remaining = await self._get_state_for_group_fire_request(
+                group, state_filter_left_over
+            )
+            assembled_state: MutableStateMap[str] = dict(remaining)
+        else:
+            assembled_state = {}
+
+        gathered = await make_deferred_yieldable(
+            defer.gatherResults(
+                (r.observe() for r in reusable_requests), consumeErrors=True
+            )
+        ).addErrback(unwrapFirstError)
+
+        # assemble our result.
+        for result_piece in gathered:
+            assembled_state.update(result_piece)
+
+        # Filter out any state that may be more than what we asked for.
+        return state_filter.filter_state(assembled_state)
+
     async def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Dict[int, MutableStateMap[str]]:
@@ -269,31 +436,17 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
         if not incomplete_groups:
             return state
 
-        cache_sequence_nm = self._state_group_cache.sequence
-        cache_sequence_m = self._state_group_members_cache.sequence
-
-        # Help the cache hit ratio by expanding the filter a bit
-        db_state_filter = state_filter.return_expanded()
-
-        group_to_state_dict = await self._get_state_groups_from_groups(
-            list(incomplete_groups), state_filter=db_state_filter
-        )
+        async def get_from_cache(group: int, state_filter: StateFilter) -> None:
+            state[group] = await self._get_state_for_group_using_inflight_cache(
+                group, state_filter
+            )
 
-        # Now lets update the caches
-        self._insert_into_cache(
-            group_to_state_dict,
-            db_state_filter,
-            cache_seq_num_members=cache_sequence_m,
-            cache_seq_num_non_members=cache_sequence_nm,
+        await yieldable_gather_results(
+            get_from_cache,
+            incomplete_groups,
+            state_filter,
         )
 
-        # And finally update the result dict, by filtering out any extra
-        # stuff we pulled out of the database.
-        for group, group_state_dict in group_to_state_dict.items():
-            # We just replace any existing entries, as we will have loaded
-            # everything we need from the database anyway.
-            state[group] = state_filter.filter_state(group_state_dict)
-
         return state
 
     def _get_state_for_groups_using_cache(
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 913448f0f9..e79ecf64a0 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -204,13 +204,16 @@ class StateFilter:
         if get_all_members:
             # We want to return everything.
             return StateFilter.all()
-        else:
+        elif EventTypes.Member in self.types:
             # We want to return all non-members, but only particular
             # memberships
             return StateFilter(
                 types=frozendict({EventTypes.Member: self.types[EventTypes.Member]}),
                 include_others=True,
             )
+        else:
+            # We want to return all non-members
+            return _ALL_NON_MEMBER_STATE_FILTER
 
     def make_sql_filter_clause(self) -> Tuple[str, List[str]]:
         """Converts the filter to an SQL clause.
@@ -528,6 +531,9 @@ class StateFilter:
 
 
 _ALL_STATE_FILTER = StateFilter(types=frozendict(), include_others=True)
+_ALL_NON_MEMBER_STATE_FILTER = StateFilter(
+    types=frozendict({EventTypes.Member: frozenset()}), include_others=True
+)
 _NONE_STATE_FILTER = StateFilter(types=frozendict(), include_others=False)
 
 
diff --git a/synapse/streams/events.py b/synapse/streams/events.py
index 21591d0bfd..4ec2a713cf 100644
--- a/synapse/streams/events.py
+++ b/synapse/streams/events.py
@@ -37,14 +37,16 @@ class _EventSourcesInner:
     account_data: AccountDataEventSource
 
     def get_sources(self) -> Iterator[Tuple[str, EventSource]]:
-        for attribute in _EventSourcesInner.__attrs_attrs__:  # type: ignore[attr-defined]
+        for attribute in attr.fields(_EventSourcesInner):
             yield attribute.name, getattr(self, attribute.name)
 
 
 class EventSources:
     def __init__(self, hs: "HomeServer"):
         self.sources = _EventSourcesInner(
-            *(attribute.type(hs) for attribute in _EventSourcesInner.__attrs_attrs__)  # type: ignore[attr-defined]
+            # mypy thinks attribute.type is `Optional`, but we know it's never `None` here since
+            # all the attributes of `_EventSourcesInner` are annotated.
+            *(attribute.type(hs) for attribute in attr.fields(_EventSourcesInner))  # type: ignore[misc]
         )
         self.store = hs.get_datastore()
 
diff --git a/synapse/types.py b/synapse/types.py
index f89fb216a6..53be3583a0 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -51,7 +51,7 @@ from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
     from synapse.appservice.api import ApplicationService
-    from synapse.storage.databases.main import DataStore
+    from synapse.storage.databases.main import DataStore, PurgeEventsStore
 
 # Define a state map type from type/state_key to T (usually an event ID or
 # event)
@@ -485,7 +485,7 @@ class RoomStreamToken:
             )
 
     @classmethod
-    async def parse(cls, store: "DataStore", string: str) -> "RoomStreamToken":
+    async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken":
         try:
             if string[0] == "s":
                 return cls(topological=None, stream=int(string[1:]))
@@ -502,7 +502,7 @@ class RoomStreamToken:
                     instance_id = int(key)
                     pos = int(value)
 
-                    instance_name = await store.get_name_from_instance_id(instance_id)
+                    instance_name = await store.get_name_from_instance_id(instance_id)  # type: ignore[attr-defined]
                     instance_map[instance_name] = pos
 
                 return cls(
diff --git a/synapse/util/caches/__init__.py b/synapse/util/caches/__init__.py
index 15debd6c46..1cbc180eda 100644
--- a/synapse/util/caches/__init__.py
+++ b/synapse/util/caches/__init__.py
@@ -56,6 +56,7 @@ response_cache_total = Gauge("synapse_util_caches_response_cache:total", "", ["n
 class EvictionReason(Enum):
     size = auto()
     time = auto()
+    invalidation = auto()
 
 
 @attr.s(slots=True, auto_attribs=True)
diff --git a/synapse/util/caches/expiringcache.py b/synapse/util/caches/expiringcache.py
index 67ee4c693b..c6a5d0dfc0 100644
--- a/synapse/util/caches/expiringcache.py
+++ b/synapse/util/caches/expiringcache.py
@@ -133,6 +133,11 @@ class ExpiringCache(Generic[KT, VT]):
                 raise KeyError(key)
             return default
 
+        if self.iterable:
+            self.metrics.inc_evictions(EvictionReason.invalidation, len(value.value))
+        else:
+            self.metrics.inc_evictions(EvictionReason.invalidation)
+
         return value.value
 
     def __contains__(self, key: KT) -> bool:
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 7548b38548..45ff0de638 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -560,8 +560,10 @@ class LruCache(Generic[KT, VT]):
         def cache_pop(key: KT, default: Optional[T] = None) -> Union[None, T, VT]:
             node = cache.get(key, None)
             if node:
-                delete_node(node)
+                evicted_len = delete_node(node)
                 cache.pop(node.key, None)
+                if metrics:
+                    metrics.inc_evictions(EvictionReason.invalidation, evicted_len)
                 return node.value
             else:
                 return default
diff --git a/synapse/util/daemonize.py b/synapse/util/daemonize.py
index de04f34e4e..031880ec39 100644
--- a/synapse/util/daemonize.py
+++ b/synapse/util/daemonize.py
@@ -20,7 +20,7 @@ import os
 import signal
 import sys
 from types import FrameType, TracebackType
-from typing import NoReturn, Type
+from typing import NoReturn, Optional, Type
 
 
 def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -> None:
@@ -100,7 +100,9 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
     # also catch any other uncaught exceptions before we get that far.)
 
     def excepthook(
-        type_: Type[BaseException], value: BaseException, traceback: TracebackType
+        type_: Type[BaseException],
+        value: BaseException,
+        traceback: Optional[TracebackType],
     ) -> None:
         logger.critical("Unhanded exception", exc_info=(type_, value, traceback))
 
@@ -123,7 +125,7 @@ def daemonize_process(pid_file: str, logger: logging.Logger, chdir: str = "/") -
         sys.exit(1)
 
     # write a log line on SIGTERM.
-    def sigterm(signum: signal.Signals, frame: FrameType) -> NoReturn:
+    def sigterm(signum: int, frame: Optional[FrameType]) -> NoReturn:
         logger.warning("Caught signal %s. Stopping daemon." % signum)
         sys.exit(0)
 
diff --git a/synapse/util/patch_inline_callbacks.py b/synapse/util/patch_inline_callbacks.py
index 1f18654d47..6d4b0b7c5a 100644
--- a/synapse/util/patch_inline_callbacks.py
+++ b/synapse/util/patch_inline_callbacks.py
@@ -14,7 +14,7 @@
 
 import functools
 import sys
-from typing import Any, Callable, Generator, List, TypeVar
+from typing import Any, Callable, Generator, List, TypeVar, cast
 
 from twisted.internet import defer
 from twisted.internet.defer import Deferred
@@ -174,7 +174,9 @@ def _check_yield_points(
                         )
                     )
                     changes.append(err)
-                return getattr(e, "value", None)
+                # The `StopIteration` or `_DefGen_Return` contains the return value from the
+                # generator.
+                return cast(T, e.value)
 
             frame = gen.gi_frame