summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/admin.py1
-rw-r--r--synapse/handlers/appservice.py11
-rw-r--r--synapse/handlers/auth.py19
-rw-r--r--synapse/handlers/device.py39
-rw-r--r--synapse/handlers/directory.py34
-rw-r--r--synapse/handlers/e2e_keys.py98
-rw-r--r--synapse/handlers/e2e_room_keys.py4
-rw-r--r--synapse/handlers/event_auth.py9
-rw-r--r--synapse/handlers/events.py9
-rw-r--r--synapse/handlers/federation.py178
-rw-r--r--synapse/handlers/federation_event.py719
-rw-r--r--synapse/handlers/identity.py232
-rw-r--r--synapse/handlers/initial_sync.py17
-rw-r--r--synapse/handlers/message.py95
-rw-r--r--synapse/handlers/pagination.py45
-rw-r--r--synapse/handlers/presence.py115
-rw-r--r--synapse/handlers/receipts.py7
-rw-r--r--synapse/handlers/register.py15
-rw-r--r--synapse/handlers/relations.py7
-rw-r--r--synapse/handlers/room.py149
-rw-r--r--synapse/handlers/room_list.py2
-rw-r--r--synapse/handlers/room_member.py209
-rw-r--r--synapse/handlers/room_summary.py6
-rw-r--r--synapse/handlers/send_email.py36
-rw-r--r--synapse/handlers/sync.py410
-rw-r--r--synapse/handlers/typing.py30
-rw-r--r--synapse/handlers/ui_auth/checkers.py21
27 files changed, 1499 insertions, 1018 deletions
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index d4fe7df533..cf9f19608a 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -70,6 +70,7 @@ class AdminHandler:
             "appservice_id",
             "consent_server_notice_sent",
             "consent_version",
+            "consent_ts",
             "user_type",
             "is_guest",
         }
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 814553e098..203b62e015 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -104,14 +104,15 @@ class ApplicationServicesHandler:
         with Measure(self.clock, "notify_interested_services"):
             self.is_processing = True
             try:
-                limit = 100
                 upper_bound = -1
                 while upper_bound < self.current_max:
+                    last_token = await self.store.get_appservice_last_pos()
                     (
                         upper_bound,
                         events,
-                    ) = await self.store.get_new_events_for_appservice(
-                        self.current_max, limit
+                        event_to_received_ts,
+                    ) = await self.store.get_all_new_events_stream(
+                        last_token, self.current_max, limit=100, get_prev_content=True
                     )
 
                     events_by_room: Dict[str, List[EventBase]] = {}
@@ -150,7 +151,7 @@ class ApplicationServicesHandler:
                             )
 
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(event.event_id)
+                        ts = event_to_received_ts[event.event_id]
                         assert ts is not None
 
                         synapse.metrics.event_processing_lag_by_event.labels(
@@ -187,7 +188,7 @@ class ApplicationServicesHandler:
 
                     if events:
                         now = self.clock.time_msec()
-                        ts = await self.store.get_received_ts(events[-1].event_id)
+                        ts = event_to_received_ts[events[-1].event_id]
                         assert ts is not None
 
                         synapse.metrics.event_processing_lag.labels(
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 3d83236b0c..0327fc57a4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -280,7 +280,7 @@ class AuthHandler:
         that it isn't stolen by re-authenticating them.
 
         Args:
-            requester: The user, as given by the access token
+            requester: The user making the request, according to the access token.
 
             request: The request sent by the client.
 
@@ -565,7 +565,7 @@ class AuthHandler:
             except LoginError as e:
                 # this step failed. Merge the error dict into the response
                 # so that the client can have another go.
-                errordict = e.error_dict()
+                errordict = e.error_dict(self.hs.config)
 
         creds = await self.store.get_completed_ui_auth_stages(session.session_id)
         for f in flows:
@@ -1435,20 +1435,25 @@ class AuthHandler:
             access_token: access token to be deleted
 
         """
-        user_info = await self.auth.get_user_by_access_token(access_token)
+        token = await self.store.get_user_by_access_token(access_token)
+        if not token:
+            # At this point, the token should already have been fetched once by
+            # the caller, so this should not happen, unless of a race condition
+            # between two delete requests
+            raise SynapseError(HTTPStatus.UNAUTHORIZED, "Unrecognised access token")
         await self.store.delete_access_token(access_token)
 
         # see if any modules want to know about this
         await self.password_auth_provider.on_logged_out(
-            user_id=user_info.user_id,
-            device_id=user_info.device_id,
+            user_id=token.user_id,
+            device_id=token.device_id,
             access_token=access_token,
         )
 
         # delete pushers associated with this access token
-        if user_info.token_id is not None:
+        if token.token_id is not None:
             await self.hs.get_pusherpool().remove_pushers_by_access_token(
-                user_info.user_id, (user_info.token_id,)
+                token.user_id, (token.token_id,)
             )
 
     async def delete_access_tokens_for_user(
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index c05a170c55..901e2310b7 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -45,13 +45,13 @@ from synapse.types import (
     JsonDict,
     StreamKeyType,
     StreamToken,
-    UserID,
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.cancellation import cancellable
 from synapse.util.metrics import measure_func
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -74,6 +74,7 @@ class DeviceWorkerHandler:
         self._state_storage = hs.get_storage_controllers().state
         self._auth_handler = hs.get_auth_handler()
         self.server_name = hs.hostname
+        self._msc3852_enabled = hs.config.experimental.msc3852_enabled
 
     @trace
     async def get_devices_by_user(self, user_id: str) -> List[JsonDict]:
@@ -118,11 +119,12 @@ class DeviceWorkerHandler:
         ips = await self.store.get_last_client_ip_by_device(user_id, device_id)
         _update_device_from_client_ips(device, ips)
 
-        set_tag("device", device)
-        set_tag("ips", ips)
+        set_tag("device", str(device))
+        set_tag("ips", str(ips))
 
         return device
 
+    @cancellable
     async def get_device_changes_in_shared_rooms(
         self, user_id: str, room_ids: Collection[str], from_token: StreamToken
     ) -> Collection[str]:
@@ -162,6 +164,7 @@ class DeviceWorkerHandler:
 
     @trace
     @measure_func("device.get_user_ids_changed")
+    @cancellable
     async def get_user_ids_changed(
         self, user_id: str, from_token: StreamToken
     ) -> JsonDict:
@@ -170,7 +173,7 @@ class DeviceWorkerHandler:
         """
 
         set_tag("user_id", user_id)
-        set_tag("from_token", from_token)
+        set_tag("from_token", str(from_token))
         now_room_key = self.store.get_room_max_token()
 
         room_ids = await self.store.get_rooms_for_user(user_id)
@@ -309,6 +312,7 @@ class DeviceHandler(DeviceWorkerHandler):
         super().__init__(hs)
 
         self.federation_sender = hs.get_federation_sender()
+        self._storage_controllers = hs.get_storage_controllers()
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -319,8 +323,6 @@ class DeviceHandler(DeviceWorkerHandler):
             self.device_list_updater.incoming_device_list_update,
         )
 
-        hs.get_distributor().observe("user_left_room", self.user_left_room)
-
         # Whether `_handle_new_device_update_async` is currently processing.
         self._handle_new_device_update_is_processing = False
 
@@ -564,14 +566,6 @@ class DeviceHandler(DeviceWorkerHandler):
             StreamKeyType.DEVICE_LIST, position, users=[from_user_id]
         )
 
-    async def user_left_room(self, user: UserID, room_id: str) -> None:
-        user_id = user.to_string()
-        room_ids = await self.store.get_rooms_for_user(user_id)
-        if not room_ids:
-            # We no longer share rooms with this user, so we'll no longer
-            # receive device updates. Mark this in DB.
-            await self.store.mark_remote_user_device_list_as_unsubscribed(user_id)
-
     async def store_dehydrated_device(
         self,
         user_id: str,
@@ -693,8 +687,11 @@ class DeviceHandler(DeviceWorkerHandler):
 
                     # Ignore any users that aren't ours
                     if self.hs.is_mine_id(user_id):
-                        joined_user_ids = await self.store.get_users_in_room(room_id)
-                        hosts = {get_domain_from_id(u) for u in joined_user_ids}
+                        hosts = set(
+                            await self._storage_controllers.state.get_current_hosts_in_room(
+                                room_id
+                            )
+                        )
                         hosts.discard(self.server_name)
 
                     # Check if we've already sent this update to some hosts
@@ -747,7 +744,13 @@ def _update_device_from_client_ips(
     device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
 ) -> None:
     ip = client_ips.get((device["user_id"], device["device_id"]), {})
-    device.update({"last_seen_ts": ip.get("last_seen"), "last_seen_ip": ip.get("ip")})
+    device.update(
+        {
+            "last_seen_user_agent": ip.get("user_agent"),
+            "last_seen_ts": ip.get("last_seen"),
+            "last_seen_ip": ip.get("ip"),
+        }
+    )
 
 
 class DeviceListUpdater:
@@ -795,7 +798,7 @@ class DeviceListUpdater:
         """
 
         set_tag("origin", origin)
-        set_tag("edu_content", edu_content)
+        set_tag("edu_content", str(edu_content))
         user_id = edu_content.pop("user_id")
         device_id = edu_content.pop("device_id")
         stream_id = str(edu_content.pop("stream_id"))  # They may come as ints
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 09a7a4b238..7127d5aefc 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -30,7 +30,7 @@ from synapse.api.errors import (
 from synapse.appservice import ApplicationService
 from synapse.module_api import NOT_SPAM
 from synapse.storage.databases.main.directory import RoomAliasMapping
-from synapse.types import JsonDict, Requester, RoomAlias, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, RoomAlias
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -83,8 +83,9 @@ class DirectoryHandler:
         # TODO(erikj): Add transactions.
         # TODO(erikj): Check if there is a current association.
         if not servers:
-            users = await self.store.get_users_in_room(room_id)
-            servers = {get_domain_from_id(u) for u in users}
+            servers = await self._storage_controllers.state.get_current_hosts_in_room(
+                room_id
+            )
 
         if not servers:
             raise SynapseError(400, "Failed to get server list")
@@ -133,7 +134,7 @@ class DirectoryHandler:
         else:
             # Server admins are not subject to the same constraints as normal
             # users when creating an alias (e.g. being in the room).
-            is_admin = await self.auth.is_server_admin(requester.user)
+            is_admin = await self.auth.is_server_admin(requester)
 
             if (self.require_membership and check_membership) and not is_admin:
                 rooms_for_user = await self.store.get_rooms_for_user(user_id)
@@ -197,7 +198,7 @@ class DirectoryHandler:
         user_id = requester.user.to_string()
 
         try:
-            can_delete = await self._user_can_delete_alias(room_alias, user_id)
+            can_delete = await self._user_can_delete_alias(room_alias, requester)
         except StoreError as e:
             if e.code == 404:
                 raise NotFoundError("Unknown room alias")
@@ -287,8 +288,9 @@ class DirectoryHandler:
                 Codes.NOT_FOUND,
             )
 
-        users = await self.store.get_users_in_room(room_id)
-        extra_servers = {get_domain_from_id(u) for u in users}
+        extra_servers = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
         servers_set = set(extra_servers) | set(servers)
 
         # If this server is in the list of servers, return it first.
@@ -400,7 +402,9 @@ class DirectoryHandler:
         # either no interested services, or no service with an exclusive lock
         return True
 
-    async def _user_can_delete_alias(self, alias: RoomAlias, user_id: str) -> bool:
+    async def _user_can_delete_alias(
+        self, alias: RoomAlias, requester: Requester
+    ) -> bool:
         """Determine whether a user can delete an alias.
 
         One of the following must be true:
@@ -413,7 +417,7 @@ class DirectoryHandler:
         """
         creator = await self.store.get_room_alias_creator(alias.to_string())
 
-        if creator == user_id:
+        if creator == requester.user.to_string():
             return True
 
         # Resolve the alias to the corresponding room.
@@ -422,9 +426,7 @@ class DirectoryHandler:
         if not room_id:
             return False
 
-        return await self.auth.check_can_change_room_list(
-            room_id, UserID.from_string(user_id)
-        )
+        return await self.auth.check_can_change_room_list(room_id, requester)
 
     async def edit_published_room_list(
         self, requester: Requester, room_id: str, visibility: str
@@ -463,7 +465,7 @@ class DirectoryHandler:
             raise SynapseError(400, "Unknown room")
 
         can_change_room_list = await self.auth.check_can_change_room_list(
-            room_id, requester.user
+            room_id, requester
         )
         if not can_change_room_list:
             raise AuthError(
@@ -528,10 +530,8 @@ class DirectoryHandler:
         Get a list of the aliases that currently point to this room on this server
         """
         # allow access to server admins and current members of the room
-        is_admin = await self.auth.is_server_admin(requester.user)
+        is_admin = await self.auth.is_server_admin(requester)
         if not is_admin:
-            await self.auth.check_user_in_room_or_world_readable(
-                room_id, requester.user.to_string()
-            )
+            await self.auth.check_user_in_room_or_world_readable(room_id, requester)
 
         return await self.store.get_aliases_for_room(room_id)
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 52bb5c9c55..8eed63ccf3 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
 
 import attr
 from canonicaljson import encode_canonical_json
@@ -37,7 +37,8 @@ from synapse.types import (
     get_verify_key_from_cross_signing_key,
 )
 from synapse.util import json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, delay_cancellation
+from synapse.util.cancellation import cancellable
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -91,8 +92,13 @@ class E2eKeysHandler:
         )
 
     @trace
+    @cancellable
     async def query_devices(
-        self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
+        self,
+        query_body: JsonDict,
+        timeout: int,
+        from_user_id: str,
+        from_device_id: Optional[str],
     ) -> JsonDict:
         """Handle a device key query from a client
 
@@ -120,9 +126,7 @@ class E2eKeysHandler:
                 the number of in-flight queries at a time.
         """
         async with self._query_devices_linearizer.queue((from_user_id, from_device_id)):
-            device_keys_query: Dict[str, Iterable[str]] = query_body.get(
-                "device_keys", {}
-            )
+            device_keys_query: Dict[str, List[str]] = query_body.get("device_keys", {})
 
             # separate users by domain.
             # make a map from domain to user_id to device_ids
@@ -136,8 +140,8 @@ class E2eKeysHandler:
                 else:
                     remote_queries[user_id] = device_ids
 
-            set_tag("local_key_query", local_query)
-            set_tag("remote_key_query", remote_queries)
+            set_tag("local_key_query", str(local_query))
+            set_tag("remote_key_query", str(remote_queries))
 
             # First get local devices.
             # A map of destination -> failure response.
@@ -171,6 +175,32 @@ class E2eKeysHandler:
                     user_ids_not_in_cache,
                     remote_results,
                 ) = await self.store.get_user_devices_from_cache(query_list)
+
+                # Check that the homeserver still shares a room with all cached users.
+                # Note that this check may be slightly racy when a remote user leaves a
+                # room after we have fetched their cached device list. In the worst case
+                # we will do extra federation queries for devices that we had cached.
+                cached_users = set(remote_results.keys())
+                valid_cached_users = (
+                    await self.store.get_users_server_still_shares_room_with(
+                        remote_results.keys()
+                    )
+                )
+                invalid_cached_users = cached_users - valid_cached_users
+                if invalid_cached_users:
+                    # Fix up results. If we get here, there is either a bug in device
+                    # list tracking, or we hit the race mentioned above.
+                    user_ids_not_in_cache.update(invalid_cached_users)
+                    for invalid_user_id in invalid_cached_users:
+                        remote_results.pop(invalid_user_id)
+                    # This log message may be removed if it turns out it's almost
+                    # entirely triggered by races.
+                    logger.error(
+                        "Devices for %s were cached, but the server no longer shares "
+                        "any rooms with them. The cached device lists are stale.",
+                        invalid_cached_users,
+                    )
+
                 for user_id, devices in remote_results.items():
                     user_devices = results.setdefault(user_id, {})
                     for device_id, device in devices.items():
@@ -206,22 +236,26 @@ class E2eKeysHandler:
                     r[user_id] = remote_queries[user_id]
 
             # Now fetch any devices that we don't have in our cache
+            # TODO It might make sense to propagate cancellations into the
+            #      deferreds which are querying remote homeservers.
             await make_deferred_yieldable(
-                defer.gatherResults(
-                    [
-                        run_in_background(
-                            self._query_devices_for_destination,
-                            results,
-                            cross_signing_keys,
-                            failures,
-                            destination,
-                            queries,
-                            timeout,
-                        )
-                        for destination, queries in remote_queries_not_in_cache.items()
-                    ],
-                    consumeErrors=True,
-                ).addErrback(unwrapFirstError)
+                delay_cancellation(
+                    defer.gatherResults(
+                        [
+                            run_in_background(
+                                self._query_devices_for_destination,
+                                results,
+                                cross_signing_keys,
+                                failures,
+                                destination,
+                                queries,
+                                timeout,
+                            )
+                            for destination, queries in remote_queries_not_in_cache.items()
+                        ],
+                        consumeErrors=True,
+                    ).addErrback(unwrapFirstError)
+                )
             )
 
             ret = {"device_keys": results, "failures": failures}
@@ -341,10 +375,11 @@ class E2eKeysHandler:
             failure = _exception_to_failure(e)
             failures[destination] = failure
             set_tag("error", True)
-            set_tag("reason", failure)
+            set_tag("reason", str(failure))
 
         return
 
+    @cancellable
     async def get_cross_signing_keys_from_cache(
         self, query: Iterable[str], from_user_id: Optional[str]
     ) -> Dict[str, Dict[str, dict]]:
@@ -391,8 +426,9 @@ class E2eKeysHandler:
         }
 
     @trace
+    @cancellable
     async def query_local_devices(
-        self, query: Dict[str, Optional[List[str]]]
+        self, query: Mapping[str, Optional[List[str]]]
     ) -> Dict[str, Dict[str, dict]]:
         """Get E2E device keys for local users
 
@@ -403,7 +439,7 @@ class E2eKeysHandler:
         Returns:
             A map from user_id -> device_id -> device details
         """
-        set_tag("local_query", query)
+        set_tag("local_query", str(query))
         local_query: List[Tuple[str, Optional[str]]] = []
 
         result_dict: Dict[str, Dict[str, dict]] = {}
@@ -461,7 +497,7 @@ class E2eKeysHandler:
 
     @trace
     async def claim_one_time_keys(
-        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: int
+        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
     ) -> JsonDict:
         local_query: List[Tuple[str, str, str]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@@ -475,8 +511,8 @@ class E2eKeysHandler:
                 domain = get_domain_from_id(user_id)
                 remote_queries.setdefault(domain, {})[user_id] = one_time_keys
 
-        set_tag("local_key_query", local_query)
-        set_tag("remote_key_query", remote_queries)
+        set_tag("local_key_query", str(local_query))
+        set_tag("remote_key_query", str(remote_queries))
 
         results = await self.store.claim_e2e_one_time_keys(local_query)
 
@@ -506,7 +542,7 @@ class E2eKeysHandler:
                 failure = _exception_to_failure(e)
                 failures[destination] = failure
                 set_tag("error", True)
-                set_tag("reason", failure)
+                set_tag("reason", str(failure))
 
         await make_deferred_yieldable(
             defer.gatherResults(
@@ -609,7 +645,7 @@ class E2eKeysHandler:
 
         result = await self.store.count_e2e_one_time_keys(user_id, device_id)
 
-        set_tag("one_time_key_counts", result)
+        set_tag("one_time_key_counts", str(result))
         return {"one_time_key_counts": result}
 
     async def _upload_one_time_keys_for_user(
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 446f509bdc..28dc08c22a 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Dict, Optional
+from typing import TYPE_CHECKING, Dict, Optional, cast
 
 from typing_extensions import Literal
 
@@ -97,7 +97,7 @@ class E2eRoomKeysHandler:
                 user_id, version, room_id, session_id
             )
 
-            log_kv(results)
+            log_kv(cast(JsonDict, results))
             return results
 
     @trace
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a2dd9c7efa..c3ddc5d182 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -129,12 +129,9 @@ class EventAuthHandler:
         else:
             users = {}
 
-        # Find the user with the highest power level.
-        users_in_room = await self._store.get_users_in_room(room_id)
-        # Only interested in local users.
-        local_users_in_room = [
-            u for u in users_in_room if get_domain_from_id(u) == self._server_name
-        ]
+        # Find the user with the highest power level (only interested in local
+        # users).
+        local_users_in_room = await self._store.get_local_users_in_room(room_id)
         chosen_user = max(
             local_users_in_room,
             key=lambda user: users.get(user, users_default_level),
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index ac13340d3a..949b69cb41 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -151,7 +151,7 @@ class EventHandler:
         """Retrieve a single specified event.
 
         Args:
-            user: The user requesting the event
+            user: The local user requesting the event
             room_id: The expected room id. We'll return None if the
                 event's room does not match.
             event_id: The event ID to obtain.
@@ -173,8 +173,11 @@ class EventHandler:
         if not event:
             return None
 
-        users = await self.store.get_users_in_room(event.room_id)
-        is_peeking = user.to_string() not in users
+        is_user_in_room = await self.store.check_local_user_in_room(
+            user_id=user.to_string(), room_id=event.room_id
+        )
+        # The user is peeking if they aren't in the room already
+        is_peeking = not is_user_in_room
 
         filtered = await filter_events_for_client(
             self._storage_controllers, user.to_string(), [event], is_peeking=is_peeking
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3b5eaf5156..dd4b9f66d1 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -32,6 +32,7 @@ from typing import (
 )
 
 import attr
+from prometheus_client import Histogram
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -59,6 +60,7 @@ from synapse.events.validator import EventValidator
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.http.servlet import assert_params_in_dict
 from synapse.logging.context import nested_logging_context
+from synapse.logging.opentracing import SynapseTags, set_tag, tag_args, trace
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.module_api import NOT_SPAM
 from synapse.replication.http.federation import (
@@ -68,7 +70,7 @@ from synapse.replication.http.federation import (
 from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import JsonDict, StateMap, get_domain_from_id
+from synapse.types import JsonDict, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.visibility import filter_events_for_server
@@ -78,36 +80,28 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
-
-def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
-    """Get joined domains from state
-
-    Args:
-        state: State map from type/state key to event.
-
-    Returns:
-        Returns a list of servers with the lowest depth of their joins.
-            Sorted by lowest depth first.
-    """
-    joined_users = [
-        (state_key, int(event.depth))
-        for (e_type, state_key), event in state.items()
-        if e_type == EventTypes.Member and event.membership == Membership.JOIN
-    ]
-
-    joined_domains: Dict[str, int] = {}
-    for u, d in joined_users:
-        try:
-            dom = get_domain_from_id(u)
-            old_d = joined_domains.get(dom)
-            if old_d:
-                joined_domains[dom] = min(d, old_d)
-            else:
-                joined_domains[dom] = d
-        except Exception:
-            pass
-
-    return sorted(joined_domains.items(), key=lambda d: d[1])
+# Added to debug performance and track progress on optimizations
+backfill_processing_before_timer = Histogram(
+    "synapse_federation_backfill_processing_before_time_seconds",
+    "sec",
+    [],
+    buckets=(
+        0.1,
+        0.5,
+        1.0,
+        2.5,
+        5.0,
+        7.5,
+        10.0,
+        15.0,
+        20.0,
+        30.0,
+        40.0,
+        60.0,
+        80.0,
+        "+Inf",
+    ),
+)
 
 
 class _BackfillPointType(Enum):
@@ -137,6 +131,7 @@ class FederationHandler:
     def __init__(self, hs: "HomeServer"):
         self.hs = hs
 
+        self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
@@ -180,6 +175,7 @@ class FederationHandler:
                 "resume_sync_partial_state_room", self._resume_sync_partial_state_room
             )
 
+    @trace
     async def maybe_backfill(
         self, room_id: str, current_depth: int, limit: int
     ) -> bool:
@@ -195,12 +191,39 @@ class FederationHandler:
                 return. This is used as part of the heuristic to decide if we
                 should back paginate.
         """
+        # Starting the processing time here so we can include the room backfill
+        # linearizer lock queue in the timing
+        processing_start_time = self.clock.time_msec()
+
         async with self._room_backfill.queue(room_id):
-            return await self._maybe_backfill_inner(room_id, current_depth, limit)
+            return await self._maybe_backfill_inner(
+                room_id,
+                current_depth,
+                limit,
+                processing_start_time=processing_start_time,
+            )
 
     async def _maybe_backfill_inner(
-        self, room_id: str, current_depth: int, limit: int
+        self,
+        room_id: str,
+        current_depth: int,
+        limit: int,
+        *,
+        processing_start_time: int,
     ) -> bool:
+        """
+        Checks whether the `current_depth` is at or approaching any backfill
+        points in the room and if so, will backfill. We only care about
+        checking backfill points that happened before the `current_depth`
+        (meaning less than or equal to the `current_depth`).
+
+        Args:
+            room_id: The room to backfill in.
+            current_depth: The depth to check at for any upcoming backfill points.
+            limit: The max number of events to request from the remote federated server.
+            processing_start_time: The time when `maybe_backfill` started
+                processing. Only used for timing.
+        """
         backwards_extremities = [
             _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
             for event_id, depth in await self.store.get_oldest_event_ids_with_depth_in_room(
@@ -368,23 +391,29 @@ class FederationHandler:
         logger.debug(
             "_maybe_backfill_inner: extremities_to_request %s", extremities_to_request
         )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "extremities_to_request",
+            str(extremities_to_request),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "extremities_to_request.length",
+            str(len(extremities_to_request)),
+        )
 
         # Now we need to decide which hosts to hit first.
-
-        # First we try hosts that are already in the room
+        # First we try hosts that are already in the room.
         # TODO: HEURISTIC ALERT.
+        likely_domains = (
+            await self._storage_controllers.state.get_current_hosts_in_room(room_id)
+        )
 
-        curr_state = await self._storage_controllers.state.get_current_state(room_id)
-
-        curr_domains = get_domains_from_state(curr_state)
-
-        likely_domains = [
-            domain for domain, depth in curr_domains if domain != self.server_name
-        ]
-
-        async def try_backfill(domains: List[str]) -> bool:
+        async def try_backfill(domains: Collection[str]) -> bool:
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
+                # We don't want to ask our own server for information we don't have
+                if dom == self.server_name:
+                    continue
+
                 try:
                     await self._federation_event_handler.backfill(
                         dom, room_id, limit=100, extremities=extremities_to_request
@@ -423,6 +452,11 @@ class FederationHandler:
 
             return False
 
+        processing_end_time = self.clock.time_msec()
+        backfill_processing_before_timer.observe(
+            (processing_end_time - processing_start_time) / 1000
+        )
+
         success = await try_backfill(likely_domains)
         if success:
             return True
@@ -546,9 +580,9 @@ class FederationHandler:
             )
 
             if ret.partial_state:
-                # TODO(faster_joins): roll this back if we don't manage to start the
-                #   background resync (eg process_remote_join fails)
-                #   https://github.com/matrix-org/synapse/issues/12998
+                # Mark the room as having partial state.
+                # The background process is responsible for unmarking this flag,
+                # even if the join fails.
                 await self.store.store_partial_state_room(room_id, ret.servers_in_room)
 
             try:
@@ -574,17 +608,21 @@ class FederationHandler:
                     room_id,
                 )
                 raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0)
-
-            if ret.partial_state:
-                # Kick off the process of asynchronously fetching the state for this
-                # room.
-                run_as_background_process(
-                    desc="sync_partial_state_room",
-                    func=self._sync_partial_state_room,
-                    initial_destination=origin,
-                    other_destinations=ret.servers_in_room,
-                    room_id=room_id,
-                )
+            finally:
+                # Always kick off the background process that asynchronously fetches
+                # state for the room.
+                # If the join failed, the background process is responsible for
+                # cleaning up — including unmarking the room as a partial state room.
+                if ret.partial_state:
+                    # Kick off the process of asynchronously fetching the state for this
+                    # room.
+                    run_as_background_process(
+                        desc="sync_partial_state_room",
+                        func=self._sync_partial_state_room,
+                        initial_destination=origin,
+                        other_destinations=ret.servers_in_room,
+                        room_id=room_id,
+                    )
 
             # We wait here until this instance has seen the events come down
             # replication (if we're using replication) as the below uses caches.
@@ -748,6 +786,23 @@ class FederationHandler:
         # (and return a 404 otherwise)
         room_version = await self.store.get_room_version(room_id)
 
+        if await self.store.is_partial_state_room(room_id):
+            # If our server is still only partially joined, we can't give a complete
+            # response to /make_join, so return a 404 as we would if we weren't in the
+            # room at all.
+            # The main reason we can't respond properly is that we need to know about
+            # the auth events for the join event that we would return.
+            # We also should not bother entertaining the /make_join since we cannot
+            # handle the /send_join.
+            logger.info(
+                "Rejecting /make_join to %s because it's a partial state room", room_id
+            )
+            raise SynapseError(
+                404,
+                "Unable to handle /make_join right now; this server is not fully joined.",
+                errcode=Codes.NOT_FOUND,
+            )
+
         # now check that we are *still* in the room
         is_in_room = await self._event_auth_handler.check_host_in_room(
             room_id, self.server_name
@@ -1058,6 +1113,8 @@ class FederationHandler:
 
         return event
 
+    @trace
+    @tag_args
     async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event."""
         event = await self.store.get_event(event_id, check_room_id=room_id)
@@ -1539,15 +1596,16 @@ class FederationHandler:
 
         # Make an infinite iterator of destinations to try. Once we find a working
         # destination, we'll stick with it until it flakes.
+        destinations: Collection[str]
         if initial_destination is not None:
             # Move `initial_destination` to the front of the list.
             destinations = list(other_destinations)
             if initial_destination in destinations:
                 destinations.remove(initial_destination)
             destinations = [initial_destination] + destinations
-            destination_iter = itertools.cycle(destinations)
         else:
-            destination_iter = itertools.cycle(other_destinations)
+            destinations = other_destinations
+        destination_iter = itertools.cycle(destinations)
 
         # `destination` is the current remote homeserver we're pulling from.
         destination = next(destination_iter)
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 9ce5bea0ed..100785ebba 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import collections
 import itertools
 import logging
 from http import HTTPStatus
@@ -28,7 +29,7 @@ from typing import (
     Tuple,
 )
 
-from prometheus_client import Counter
+from prometheus_client import Counter, Histogram
 
 from synapse import event_auth
 from synapse.api.constants import (
@@ -58,6 +59,13 @@ from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.logging.context import nested_logging_context
+from synapse.logging.opentracing import (
+    SynapseTags,
+    set_tag,
+    start_active_span,
+    tag_args,
+    trace,
+)
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.devices import ReplicationUserDevicesResyncRestServlet
 from synapse.replication.http.federation import (
@@ -90,6 +98,36 @@ soft_failed_event_counter = Counter(
     "Events received over federation that we marked as soft_failed",
 )
 
+# Added to debug performance and track progress on optimizations
+backfill_processing_after_timer = Histogram(
+    "synapse_federation_backfill_processing_after_time_seconds",
+    "sec",
+    [],
+    buckets=(
+        0.1,
+        0.25,
+        0.5,
+        1.0,
+        2.5,
+        5.0,
+        7.5,
+        10.0,
+        15.0,
+        20.0,
+        25.0,
+        30.0,
+        40.0,
+        50.0,
+        60.0,
+        80.0,
+        100.0,
+        120.0,
+        150.0,
+        180.0,
+        "+Inf",
+    ),
+)
+
 
 class FederationEventHandler:
     """Handles events that originated from federation.
@@ -277,7 +315,8 @@ class FederationEventHandler:
                 )
 
         try:
-            await self._process_received_pdu(origin, pdu, state_ids=None)
+            context = await self._state_handler.compute_event_context(pdu)
+            await self._process_received_pdu(origin, pdu, context)
         except PartialStateConflictError:
             # The room was un-partial stated while we were processing the PDU.
             # Try once more, with full state this time.
@@ -285,7 +324,8 @@ class FederationEventHandler:
                 "Room %s was un-partial stated while processing the PDU, trying again.",
                 room_id,
             )
-            await self._process_received_pdu(origin, pdu, state_ids=None)
+            context = await self._state_handler.compute_event_context(pdu)
+            await self._process_received_pdu(origin, pdu, context)
 
     async def on_send_membership_event(
         self, origin: str, event: EventBase
@@ -315,6 +355,7 @@ class FederationEventHandler:
             The event and context of the event after inserting it into the room graph.
 
         Raises:
+            RuntimeError if any prev_events are missing
             SynapseError if the event is not accepted into the room
             PartialStateConflictError if the room was un-partial stated in between
                 computing the state at the event and persisting it. The caller should
@@ -347,7 +388,7 @@ class FederationEventHandler:
         event.internal_metadata.send_on_behalf_of = origin
 
         context = await self._state_handler.compute_event_context(event)
-        context = await self._check_event_auth(origin, event, context)
+        await self._check_event_auth(origin, event, context)
         if context.rejected:
             raise SynapseError(
                 403, f"{event.membership} event was rejected", Codes.FORBIDDEN
@@ -375,7 +416,7 @@ class FederationEventHandler:
         # need to.
         await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
 
-        await self._check_for_soft_fail(event, None, origin=origin)
+        await self._check_for_soft_fail(event, context=context, origin=origin)
         await self._run_push_actions_and_persist_event(event, context)
         return event, context
 
@@ -405,6 +446,7 @@ class FederationEventHandler:
             prev_member_event,
         )
 
+    @trace
     async def process_remote_join(
         self,
         origin: str,
@@ -485,7 +527,7 @@ class FederationEventHandler:
                 partial_state=partial_state,
             )
 
-            context = await self._check_event_auth(origin, event, context)
+            await self._check_event_auth(origin, event, context)
             if context.rejected:
                 raise SynapseError(400, "Join event was rejected")
 
@@ -533,32 +575,36 @@ class FederationEventHandler:
             #
             # This is the same operation as we do when we receive a regular event
             # over federation.
-            state_ids = await self._resolve_state_at_missing_prevs(destination, event)
-
-            # build a new state group for it if need be
-            context = await self._state_handler.compute_event_context(
-                event,
-                state_ids_before_event=state_ids,
+            context = await self._compute_event_context_with_maybe_missing_prevs(
+                destination, event
             )
             if context.partial_state:
                 # this can happen if some or all of the event's prev_events still have
-                # partial state - ie, an event has an earlier stream_ordering than one
-                # or more of its prev_events, so we de-partial-state it before its
-                # prev_events.
+                # partial state. We were careful to only pick events from the db without
+                # partial-state prev events, so that implies that a prev event has
+                # been persisted (with partial state) since we did the query.
                 #
-                # TODO(faster_joins): we probably need to be more intelligent, and
-                #    exclude partial-state prev_events from consideration
-                #    https://github.com/matrix-org/synapse/issues/13001
+                # So, let's just ignore `event` for now; when we re-run the db query
+                # we should instead get its partial-state prev event, which we will
+                # de-partial-state, and then come back to event.
                 logger.warning(
-                    "%s still has partial state: can't de-partial-state it yet",
+                    "%s still has prev_events with partial state: can't de-partial-state it yet",
                     event.event_id,
                 )
                 return
+
+            # since the state at this event has changed, we should now re-evaluate
+            # whether it should have been rejected. We must already have all of the
+            # auth events (from last time we went round this path), so there is no
+            # need to pass the origin.
+            await self._check_event_auth(None, event, context)
+
             await self._store.update_state_for_partial_state_event(event, context)
             self._state_storage_controller.notify_event_un_partial_stated(
                 event.event_id
             )
 
+    @trace
     async def backfill(
         self, dest: str, room_id: str, limit: int, extremities: Collection[str]
     ) -> None:
@@ -588,21 +634,23 @@ class FederationEventHandler:
         if not events:
             return
 
-        # if there are any events in the wrong room, the remote server is buggy and
-        # should not be trusted.
-        for ev in events:
-            if ev.room_id != room_id:
-                raise InvalidResponseError(
-                    f"Remote server {dest} returned event {ev.event_id} which is in "
-                    f"room {ev.room_id}, when we were backfilling in {room_id}"
-                )
+        with backfill_processing_after_timer.time():
+            # if there are any events in the wrong room, the remote server is buggy and
+            # should not be trusted.
+            for ev in events:
+                if ev.room_id != room_id:
+                    raise InvalidResponseError(
+                        f"Remote server {dest} returned event {ev.event_id} which is in "
+                        f"room {ev.room_id}, when we were backfilling in {room_id}"
+                    )
 
-        await self._process_pulled_events(
-            dest,
-            events,
-            backfilled=True,
-        )
+            await self._process_pulled_events(
+                dest,
+                events,
+                backfilled=True,
+            )
 
+    @trace
     async def _get_missing_events_for_pdu(
         self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
     ) -> None:
@@ -703,8 +751,9 @@ class FederationEventHandler:
         logger.info("Got %d prev_events", len(missing_events))
         await self._process_pulled_events(origin, missing_events, backfilled=False)
 
+    @trace
     async def _process_pulled_events(
-        self, origin: str, events: Iterable[EventBase], backfilled: bool
+        self, origin: str, events: Collection[EventBase], backfilled: bool
     ) -> None:
         """Process a batch of events we have pulled from a remote server
 
@@ -719,6 +768,15 @@ class FederationEventHandler:
             backfilled: True if this is part of a historical batch of events (inhibits
                 notification to clients, and validation of device keys.)
         """
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str([event.event_id for event in events]),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(events)),
+        )
+        set_tag(SynapseTags.FUNC_ARG_PREFIX + "backfilled", str(backfilled))
         logger.debug(
             "processing pulled backfilled=%s events=%s",
             backfilled,
@@ -741,6 +799,8 @@ class FederationEventHandler:
             with nested_logging_context(ev.event_id):
                 await self._process_pulled_event(origin, ev, backfilled=backfilled)
 
+    @trace
+    @tag_args
     async def _process_pulled_event(
         self, origin: str, event: EventBase, backfilled: bool
     ) -> None:
@@ -765,10 +825,24 @@ class FederationEventHandler:
         """
         logger.info("Processing pulled event %s", event)
 
-        # these should not be outliers.
-        assert (
-            not event.internal_metadata.is_outlier()
-        ), "pulled event unexpectedly flagged as outlier"
+        # This function should not be used to persist outliers (use something
+        # else) because this does a bunch of operations that aren't necessary
+        # (extra work; in particular, it makes sure we have all the prev_events
+        # and resolves the state across those prev events). If you happen to run
+        # into a situation where the event you're trying to process/backfill is
+        # marked as an `outlier`, then you should update that spot to return an
+        # `EventBase` copy that doesn't have `outlier` flag set.
+        #
+        # `EventBase` is used to represent both an event we have not yet
+        # persisted, and one that we have persisted and now keep in the cache.
+        # In an ideal world this method would only be called with the first type
+        # of event, but it turns out that's not actually the case and for
+        # example, you could get an event from cache that is marked as an
+        # `outlier` (fix up that spot though).
+        assert not event.internal_metadata.is_outlier(), (
+            "Outlier event passed to _process_pulled_event. "
+            "To persist an event as a non-outlier, make sure to pass in a copy without `event.internal_metadata.outlier = true`."
+        )
 
         event_id = event.event_id
 
@@ -778,7 +852,7 @@ class FederationEventHandler:
         if existing:
             if not existing.internal_metadata.is_outlier():
                 logger.info(
-                    "Ignoring received event %s which we have already seen",
+                    "_process_pulled_event: Ignoring received event %s which we have already seen",
                     event_id,
                 )
                 return
@@ -788,32 +862,66 @@ class FederationEventHandler:
             self._sanity_check_event(event)
         except SynapseError as err:
             logger.warning("Event %s failed sanity check: %s", event_id, err)
+            await self._store.record_event_failed_pull_attempt(
+                event.room_id, event_id, str(err)
+            )
             return
 
         try:
-            state_ids = await self._resolve_state_at_missing_prevs(origin, event)
-            # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
-            #   not return partial state
-            #   https://github.com/matrix-org/synapse/issues/13002
+            try:
+                context = await self._compute_event_context_with_maybe_missing_prevs(
+                    origin, event
+                )
+                await self._process_received_pdu(
+                    origin,
+                    event,
+                    context,
+                    backfilled=backfilled,
+                )
+            except PartialStateConflictError:
+                # The room was un-partial stated while we were processing the event.
+                # Try once more, with full state this time.
+                context = await self._compute_event_context_with_maybe_missing_prevs(
+                    origin, event
+                )
 
-            await self._process_received_pdu(
-                origin, event, state_ids=state_ids, backfilled=backfilled
-            )
+                # We ought to have full state now, barring some unlikely race where we left and
+                # rejoned the room in the background.
+                if context.partial_state:
+                    raise AssertionError(
+                        f"Event {event.event_id} still has a partial resolved state "
+                        f"after room {event.room_id} was un-partial stated"
+                    )
+
+                await self._process_received_pdu(
+                    origin,
+                    event,
+                    context,
+                    backfilled=backfilled,
+                )
         except FederationError as e:
+            await self._store.record_event_failed_pull_attempt(
+                event.room_id, event_id, str(e)
+            )
+
             if e.code == 403:
                 logger.warning("Pulled event %s failed history check.", event_id)
             else:
                 raise
 
-    async def _resolve_state_at_missing_prevs(
+    @trace
+    async def _compute_event_context_with_maybe_missing_prevs(
         self, dest: str, event: EventBase
-    ) -> Optional[StateMap[str]]:
-        """Calculate the state at an event with missing prev_events.
+    ) -> EventContext:
+        """Build an EventContext structure for a non-outlier event whose prev_events may
+        be missing.
 
-        This is used when we have pulled a batch of events from a remote server, and
-        still don't have all the prev_events.
+        This is used when we have pulled a batch of events from a remote server, and may
+        not have all the prev_events.
 
-        If we already have all the prev_events for `event`, this method does nothing.
+        To build an EventContext, we need to calculate the state before the event. If we
+        already have all the prev_events for `event`, we can simply use the state after
+        the prev_events to calculate the state before `event`.
 
         Otherwise, the missing prevs become new backwards extremities, and we fall back
         to asking the remote server for the state after each missing `prev_event`,
@@ -834,8 +942,7 @@ class FederationEventHandler:
             event: an event to check for missing prevs.
 
         Returns:
-            if we already had all the prev events, `None`. Otherwise, returns
-            the event ids of the state at `event`.
+            The event context.
 
         Raises:
             FederationError if we fail to get the state from the remote server after any
@@ -849,7 +956,7 @@ class FederationEventHandler:
         missing_prevs = prevs - seen
 
         if not missing_prevs:
-            return None
+            return await self._state_handler.compute_event_context(event)
 
         logger.info(
             "Event %s is missing prev_events %s: calculating state for a "
@@ -861,9 +968,15 @@ class FederationEventHandler:
         # resolve them to find the correct state at the current event.
 
         try:
+            # Determine whether we may be about to retrieve partial state
+            # Events may be un-partial stated right after we compute the partial state
+            # flag, but that's okay, as long as the flag errs on the conservative side.
+            partial_state_flags = await self._store.get_partial_state_events(seen)
+            partial_state = any(partial_state_flags.values())
+
             # Get the state of the events we know about
             ours = await self._state_storage_controller.get_state_groups_ids(
-                room_id, seen
+                room_id, seen, await_full_state=False
             )
 
             # state_maps is a list of mappings from (type, state_key) to event_id
@@ -909,8 +1022,12 @@ class FederationEventHandler:
                 "We can't get valid state history.",
                 affected=event_id,
             )
-        return state_map
+        return await self._state_handler.compute_event_context(
+            event, state_ids_before_event=state_map, partial_state=partial_state
+        )
 
+    @trace
+    @tag_args
     async def _get_state_ids_after_missing_prev_event(
         self,
         destination: str,
@@ -931,6 +1048,14 @@ class FederationEventHandler:
             InvalidResponseError: if the remote homeserver's response contains fields
                 of the wrong type.
         """
+
+        # It would be better if we could query the difference from our known
+        # state to the given `event_id` so the sending server doesn't have to
+        # send as much and we don't have to process as many events. For example
+        # in a room like #matrix:matrix.org, we get 200k events (77k state_events, 122k
+        # auth_events) from this call.
+        #
+        # Tracked by https://github.com/matrix-org/synapse/issues/13618
         (
             state_event_ids,
             auth_event_ids,
@@ -955,11 +1080,11 @@ class FederationEventHandler:
         )
         have_events = await self._store.have_seen_events(room_id, desired_events)
 
-        missing_desired_events = desired_events - have_events
+        missing_desired_event_ids = desired_events - have_events
         logger.debug(
             "_get_state_ids_after_missing_prev_event(event_id=%s): We are missing %i events (got %i)",
             event_id,
-            len(missing_desired_events),
+            len(missing_desired_event_ids),
             len(have_events),
         )
 
@@ -971,17 +1096,34 @@ class FederationEventHandler:
         #   already have a bunch of the state events. It would be nice if the
         #   federation api gave us a way of finding out which we actually need.
 
-        missing_auth_events = set(auth_event_ids) - have_events
-        missing_auth_events.difference_update(
-            await self._store.have_seen_events(room_id, missing_auth_events)
+        missing_auth_event_ids = set(auth_event_ids) - have_events
+        missing_auth_event_ids.difference_update(
+            await self._store.have_seen_events(room_id, missing_auth_event_ids)
         )
         logger.debug(
             "_get_state_ids_after_missing_prev_event(event_id=%s): We are also missing %i auth events",
             event_id,
-            len(missing_auth_events),
+            len(missing_auth_event_ids),
         )
 
-        missing_events = missing_desired_events | missing_auth_events
+        missing_event_ids = missing_desired_event_ids | missing_auth_event_ids
+
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_auth_event_ids",
+            str(missing_auth_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_auth_event_ids.length",
+            str(len(missing_auth_event_ids)),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_desired_event_ids",
+            str(missing_desired_event_ids),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "missing_desired_event_ids.length",
+            str(len(missing_desired_event_ids)),
+        )
 
         # Making an individual request for each of 1000s of events has a lot of
         # overhead. On the other hand, we don't really want to fetch all of the events
@@ -992,7 +1134,7 @@ class FederationEventHandler:
         #
         # TODO: might it be better to have an API which lets us do an aggregate event
         #   request
-        if (len(missing_events) * 10) >= len(auth_event_ids) + len(state_event_ids):
+        if (len(missing_event_ids) * 10) >= len(auth_event_ids) + len(state_event_ids):
             logger.debug(
                 "_get_state_ids_after_missing_prev_event(event_id=%s): Requesting complete state from remote",
                 event_id,
@@ -1002,10 +1144,10 @@ class FederationEventHandler:
             logger.debug(
                 "_get_state_ids_after_missing_prev_event(event_id=%s): Fetching %i events from remote",
                 event_id,
-                len(missing_events),
+                len(missing_event_ids),
             )
             await self._get_events_and_persist(
-                destination=destination, room_id=room_id, event_ids=missing_events
+                destination=destination, room_id=room_id, event_ids=missing_event_ids
             )
 
         # We now need to fill out the state map, which involves fetching the
@@ -1063,12 +1205,24 @@ class FederationEventHandler:
         # missing state at that event is a warning, not a blocker
         # XXX: this doesn't sound right? it means that we'll end up with incomplete
         #   state.
+        failed_to_fetch = desired_events - event_metadata.keys()
+        # `event_id` could be missing from `event_metadata` because it's not necessarily
+        # a state event. We've already checked that we've fetched it above.
+        failed_to_fetch.discard(event_id)
         if failed_to_fetch:
             logger.warning(
                 "_get_state_ids_after_missing_prev_event(event_id=%s): Failed to fetch missing state events %s",
                 event_id,
                 failed_to_fetch,
             )
+            set_tag(
+                SynapseTags.RESULT_PREFIX + "failed_to_fetch",
+                str(failed_to_fetch),
+            )
+            set_tag(
+                SynapseTags.RESULT_PREFIX + "failed_to_fetch.length",
+                str(len(failed_to_fetch)),
+            )
 
         if remote_event.is_state() and remote_event.rejected_reason is None:
             state_map[
@@ -1077,6 +1231,8 @@ class FederationEventHandler:
 
         return state_map
 
+    @trace
+    @tag_args
     async def _get_state_and_persist(
         self, destination: str, room_id: str, event_id: str
     ) -> None:
@@ -1098,11 +1254,12 @@ class FederationEventHandler:
                 destination=destination, room_id=room_id, event_ids=(event_id,)
             )
 
+    @trace
     async def _process_received_pdu(
         self,
         origin: str,
         event: EventBase,
-        state_ids: Optional[StateMap[str]],
+        context: EventContext,
         backfilled: bool = False,
     ) -> None:
         """Called when we have a new non-outlier event.
@@ -1124,30 +1281,20 @@ class FederationEventHandler:
 
             event: event to be persisted
 
-            state_ids: Normally None, but if we are handling a gap in the graph
-                (ie, we are missing one or more prev_events), the resolved state at the
-                event. Must not be partial state.
+            context: The `EventContext` to persist the event with.
 
             backfilled: True if this is part of a historical batch of events (inhibits
                 notification to clients, and validation of device keys.)
 
         PartialStateConflictError: if the room was un-partial stated in between
-            computing the state at the event and persisting it. The caller should retry
-            exactly once in this case. Will never be raised if `state_ids` is provided.
+            computing the state at the event and persisting it. The caller should
+            recompute `context` and retry exactly once when this happens.
         """
         logger.debug("Processing event: %s", event)
         assert not event.internal_metadata.outlier
 
-        context = await self._state_handler.compute_event_context(
-            event,
-            state_ids_before_event=state_ids,
-        )
         try:
-            context = await self._check_event_auth(
-                origin,
-                event,
-                context,
-            )
+            await self._check_event_auth(origin, event, context)
         except AuthError as e:
             # This happens only if we couldn't find the auth events. We'll already have
             # logged a warning, so now we just convert to a FederationError.
@@ -1157,7 +1304,7 @@ class FederationEventHandler:
             # For new (non-backfilled and non-outlier) events we check if the event
             # passes auth based on the current state. If it doesn't then we
             # "soft-fail" the event.
-            await self._check_for_soft_fail(event, state_ids, origin=origin)
+            await self._check_for_soft_fail(event, context=context, origin=origin)
 
         await self._run_push_actions_and_persist_event(event, context, backfilled)
 
@@ -1258,6 +1405,7 @@ class FederationEventHandler:
         except Exception:
             logger.exception("Failed to resync device for %s", sender)
 
+    @trace
     async def _handle_marker_event(self, origin: str, marker_event: EventBase) -> None:
         """Handles backfilling the insertion event when we receive a marker
         event that points to one.
@@ -1289,7 +1437,7 @@ class FederationEventHandler:
         logger.debug("_handle_marker_event: received %s", marker_event)
 
         insertion_event_id = marker_event.content.get(
-            EventContentFields.MSC2716_MARKER_INSERTION
+            EventContentFields.MSC2716_INSERTION_EVENT_REFERENCE
         )
 
         if insertion_event_id is None:
@@ -1342,6 +1490,55 @@ class FederationEventHandler:
             marker_event,
         )
 
+    async def backfill_event_id(
+        self, destination: str, room_id: str, event_id: str
+    ) -> EventBase:
+        """Backfill a single event and persist it as a non-outlier which means
+        we also pull in all of the state and auth events necessary for it.
+
+        Args:
+            destination: The homeserver to pull the given event_id from.
+            room_id: The room where the event is from.
+            event_id: The event ID to backfill.
+
+        Raises:
+            FederationError if we are unable to find the event from the destination
+        """
+        logger.info(
+            "backfill_event_id: event_id=%s from destination=%s", event_id, destination
+        )
+
+        room_version = await self._store.get_room_version(room_id)
+
+        event_from_response = await self._federation_client.get_pdu(
+            [destination],
+            event_id,
+            room_version,
+        )
+
+        if not event_from_response:
+            raise FederationError(
+                "ERROR",
+                404,
+                "Unable to find event_id=%s from destination=%s to backfill."
+                % (event_id, destination),
+                affected=event_id,
+            )
+
+        # Persist the event we just fetched, including pulling all of the state
+        # and auth events to de-outlier it. This also sets up the necessary
+        # `state_groups` for the event.
+        await self._process_pulled_events(
+            destination,
+            [event_from_response],
+            # Prevent notifications going to clients
+            backfilled=True,
+        )
+
+        return event_from_response
+
+    @trace
+    @tag_args
     async def _get_events_and_persist(
         self, destination: str, room_id: str, event_ids: Collection[str]
     ) -> None:
@@ -1387,6 +1584,7 @@ class FederationEventHandler:
         logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
         await self._auth_and_persist_outliers(room_id, events)
 
+    @trace
     async def _auth_and_persist_outliers(
         self, room_id: str, events: Iterable[EventBase]
     ) -> None:
@@ -1405,6 +1603,16 @@ class FederationEventHandler:
         """
         event_map = {event.event_id: event for event in events}
 
+        event_ids = event_map.keys()
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids",
+            str(event_ids),
+        )
+        set_tag(
+            SynapseTags.FUNC_ARG_PREFIX + "event_ids.length",
+            str(len(event_ids)),
+        )
+
         # filter out any events we have already seen. This might happen because
         # the events were eagerly pushed to us (eg, during a room join), or because
         # another thread has raced against us since we decided to request the event.
@@ -1521,24 +1729,21 @@ class FederationEventHandler:
             backfilled=True,
         )
 
+    @trace
     async def _check_event_auth(
-        self,
-        origin: str,
-        event: EventBase,
-        context: EventContext,
-    ) -> EventContext:
+        self, origin: Optional[str], event: EventBase, context: EventContext
+    ) -> None:
         """
         Checks whether an event should be rejected (for failing auth checks).
 
         Args:
-            origin: The host the event originates from.
+            origin: The host the event originates from. This is used to fetch
+               any missing auth events. It can be set to None, but only if we are
+               sure that we already have all the auth events.
             event: The event itself.
             context:
                 The event context.
 
-        Returns:
-            The updated context object.
-
         Raises:
             AuthError if we were unable to find copies of the event's auth events.
                (Most other failures just cause us to set `context.rejected`.)
@@ -1553,7 +1758,7 @@ class FederationEventHandler:
             logger.warning("While validating received event %r: %s", event, e)
             # TODO: use a different rejected reason here?
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
 
         # next, check that we have all of the event's auth events.
         #
@@ -1563,8 +1768,19 @@ class FederationEventHandler:
         claimed_auth_events = await self._load_or_fetch_auth_events_for_event(
             origin, event
         )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "claimed_auth_events",
+            str([ev.event_id for ev in claimed_auth_events]),
+        )
+        set_tag(
+            SynapseTags.RESULT_PREFIX + "claimed_auth_events.length",
+            str(len(claimed_auth_events)),
+        )
 
         # ... and check that the event passes auth at those auth events.
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   4. Passes authorization rules based on the event’s auth events,
+        #      otherwise it is rejected.
         try:
             await check_state_independent_auth_rules(self._store, event)
             check_state_dependent_auth_rules(event, claimed_auth_events)
@@ -1573,55 +1789,91 @@ class FederationEventHandler:
                 "While checking auth of %r against auth_events: %s", event, e
             )
             context.rejected = RejectedReason.AUTH_ERROR
-            return context
+            return
+
+        # now check the auth rules pass against the room state before the event
+        # https://spec.matrix.org/v1.3/server-server-api/#checks-performed-on-receipt-of-a-pdu:
+        #   5. Passes authorization rules based on the state before the event,
+        #      otherwise it is rejected.
+        #
+        # ... however, if we only have partial state for the room, then there is a good
+        # chance that we'll be missing some of the state needed to auth the new event.
+        # So, we state-resolve the auth events that we are given against the state that
+        # we know about, which ensures things like bans are applied. (Note that we'll
+        # already have checked we have all the auth events, in
+        # _load_or_fetch_auth_events_for_event above)
+        if context.partial_state:
+            room_version = await self._store.get_room_version_id(event.room_id)
+
+            local_state_id_map = await context.get_prev_state_ids()
+            claimed_auth_events_id_map = {
+                (ev.type, ev.state_key): ev.event_id for ev in claimed_auth_events
+            }
+
+            state_for_auth_id_map = (
+                await self._state_resolution_handler.resolve_events_with_store(
+                    event.room_id,
+                    room_version,
+                    [local_state_id_map, claimed_auth_events_id_map],
+                    event_map=None,
+                    state_res_store=StateResolutionStore(self._store),
+                )
+            )
+        else:
+            event_types = event_auth.auth_types_for_event(event.room_version, event)
+            state_for_auth_id_map = await context.get_prev_state_ids(
+                StateFilter.from_types(event_types)
+            )
 
-        # now check auth against what we think the auth events *should* be.
-        event_types = event_auth.auth_types_for_event(event.room_version, event)
-        prev_state_ids = await context.get_prev_state_ids(
-            StateFilter.from_types(event_types)
+        calculated_auth_event_ids = self._event_auth_handler.compute_auth_events(
+            event, state_for_auth_id_map, for_verification=True
         )
 
-        auth_events_ids = self._event_auth_handler.compute_auth_events(
-            event, prev_state_ids, for_verification=True
+        # if those are the same, we're done here.
+        if collections.Counter(event.auth_event_ids()) == collections.Counter(
+            calculated_auth_event_ids
+        ):
+            return
+
+        # otherwise, re-run the auth checks based on what we calculated.
+        calculated_auth_events = await self._store.get_events_as_list(
+            calculated_auth_event_ids
         )
-        auth_events_x = await self._store.get_events(auth_events_ids)
+
+        # log the differences
+
+        claimed_auth_event_map = {(e.type, e.state_key): e for e in claimed_auth_events}
         calculated_auth_event_map = {
-            (e.type, e.state_key): e for e in auth_events_x.values()
+            (e.type, e.state_key): e for e in calculated_auth_events
         }
+        logger.info(
+            "event's auth_events are different to our calculated auth_events. "
+            "Claimed but not calculated: %s. Calculated but not claimed: %s",
+            [
+                ev
+                for k, ev in claimed_auth_event_map.items()
+                if k not in calculated_auth_event_map
+                or calculated_auth_event_map[k].event_id != ev.event_id
+            ],
+            [
+                ev
+                for k, ev in calculated_auth_event_map.items()
+                if k not in claimed_auth_event_map
+                or claimed_auth_event_map[k].event_id != ev.event_id
+            ],
+        )
 
         try:
-            updated_auth_events = await self._update_auth_events_for_auth(
+            check_state_dependent_auth_rules(event, calculated_auth_events)
+        except AuthError as e:
+            logger.warning(
+                "While checking auth of %r against room state before the event: %s",
                 event,
-                calculated_auth_event_map=calculated_auth_event_map,
-            )
-        except Exception:
-            # We don't really mind if the above fails, so lets not fail
-            # processing if it does. However, it really shouldn't fail so
-            # let's still log as an exception since we'll still want to fix
-            # any bugs.
-            logger.exception(
-                "Failed to double check auth events for %s with remote. "
-                "Ignoring failure and continuing processing of event.",
-                event.event_id,
-            )
-            updated_auth_events = None
-
-        if updated_auth_events:
-            context = await self._update_context_for_auth_events(
-                event, context, updated_auth_events
+                e,
             )
-            auth_events_for_auth = updated_auth_events
-        else:
-            auth_events_for_auth = calculated_auth_event_map
-
-        try:
-            check_state_dependent_auth_rules(event, auth_events_for_auth.values())
-        except AuthError as e:
-            logger.warning("Failed auth resolution for %r because %s", event, e)
             context.rejected = RejectedReason.AUTH_ERROR
 
-        return context
-
+    @trace
     async def _maybe_kick_guest_users(self, event: EventBase) -> None:
         if event.type != EventTypes.GuestAccess:
             return
@@ -1639,17 +1891,27 @@ class FederationEventHandler:
     async def _check_for_soft_fail(
         self,
         event: EventBase,
-        state_ids: Optional[StateMap[str]],
+        context: EventContext,
         origin: str,
     ) -> None:
         """Checks if we should soft fail the event; if so, marks the event as
         such.
 
+        Does nothing for events in rooms with partial state, since we may not have an
+        accurate membership event for the sender in the current state.
+
         Args:
             event
-            state_ids: The state at the event if we don't have all the event's prev events
+            context: The `EventContext` which we are about to persist the event with.
             origin: The host the event originates from.
         """
+        if await self._store.is_partial_state_room(event.room_id):
+            # We might not know the sender's membership in the current state, so don't
+            # soft fail anything. Even if we do have a membership for the sender in the
+            # current state, it may have been derived from state resolution between
+            # partial and full state and may not be accurate.
+            return
+
         extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
         extrem_ids = set(extrem_ids_list)
         prev_event_ids = set(event.prev_event_ids())
@@ -1666,11 +1928,15 @@ class FederationEventHandler:
         auth_types = auth_types_for_event(room_version_obj, event)
 
         # Calculate the "current state".
-        if state_ids is not None:
-            # If we're explicitly given the state then we won't have all the
-            # prev events, and so we have a gap in the graph. In this case
-            # we want to be a little careful as we might have been down for
-            # a while and have an incorrect view of the current state,
+        seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
+        has_missing_prevs = bool(prev_event_ids - seen_event_ids)
+        if has_missing_prevs:
+            # We don't have all the prev_events of this event, which means we have a
+            # gap in the graph, and the new event is going to become a new backwards
+            # extremity.
+            #
+            # In this case we want to be a little careful as we might have been
+            # down for a while and have an incorrect view of the current state,
             # however we still want to do checks as gaps are easy to
             # maliciously manufacture.
             #
@@ -1683,6 +1949,7 @@ class FederationEventHandler:
                 event.room_id, extrem_ids
             )
             state_sets: List[StateMap[str]] = list(state_sets_d.values())
+            state_ids = await context.get_prev_state_ids()
             state_sets.append(state_ids)
             current_state_ids = (
                 await self._state_resolution_handler.resolve_events_with_store(
@@ -1731,95 +1998,8 @@ class FederationEventHandler:
             soft_failed_event_counter.inc()
             event.internal_metadata.soft_failed = True
 
-    async def _update_auth_events_for_auth(
-        self,
-        event: EventBase,
-        calculated_auth_event_map: StateMap[EventBase],
-    ) -> Optional[StateMap[EventBase]]:
-        """Helper for _check_event_auth. See there for docs.
-
-        Checks whether a given event has the expected auth events. If it
-        doesn't then we talk to the remote server to compare state to see if
-        we can come to a consensus (e.g. if one server missed some valid
-        state).
-
-        This attempts to resolve any potential divergence of state between
-        servers, but is not essential and so failures should not block further
-        processing of the event.
-
-        Args:
-            event:
-
-            calculated_auth_event_map:
-                Our calculated auth_events based on the state of the room
-                at the event's position in the DAG.
-
-        Returns:
-            updated auth event map, or None if no changes are needed.
-
-        """
-        assert not event.internal_metadata.outlier
-
-        # check for events which are in the event's claimed auth_events, but not
-        # in our calculated event map.
-        event_auth_events = set(event.auth_event_ids())
-        different_auth = event_auth_events.difference(
-            e.event_id for e in calculated_auth_event_map.values()
-        )
-
-        if not different_auth:
-            return None
-
-        logger.info(
-            "auth_events refers to events which are not in our calculated auth "
-            "chain: %s",
-            different_auth,
-        )
-
-        # XXX: currently this checks for redactions but I'm not convinced that is
-        # necessary?
-        different_events = await self._store.get_events_as_list(different_auth)
-
-        # double-check they're all in the same room - we should already have checked
-        # this but it doesn't hurt to check again.
-        for d in different_events:
-            assert (
-                d.room_id == event.room_id
-            ), f"Event {event.event_id} refers to auth_event {d.event_id} which is in a different room"
-
-        # now we state-resolve between our own idea of the auth events, and the remote's
-        # idea of them.
-
-        local_state = calculated_auth_event_map.values()
-        remote_auth_events = dict(calculated_auth_event_map)
-        remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
-        remote_state = remote_auth_events.values()
-
-        room_version = await self._store.get_room_version_id(event.room_id)
-        new_state = await self._state_handler.resolve_events(
-            room_version, (local_state, remote_state), event
-        )
-        different_state = {
-            (d.type, d.state_key): d
-            for d in new_state.values()
-            if calculated_auth_event_map.get((d.type, d.state_key)) != d
-        }
-        if not different_state:
-            logger.info("State res returned no new state")
-            return None
-
-        logger.info(
-            "After state res: updating auth_events with new state %s",
-            different_state.values(),
-        )
-
-        # take a copy of calculated_auth_event_map before we modify it.
-        auth_events = dict(calculated_auth_event_map)
-        auth_events.update(different_state)
-        return auth_events
-
     async def _load_or_fetch_auth_events_for_event(
-        self, destination: str, event: EventBase
+        self, destination: Optional[str], event: EventBase
     ) -> Collection[EventBase]:
         """Fetch this event's auth_events, from database or remote
 
@@ -1835,12 +2015,19 @@ class FederationEventHandler:
         Args:
             destination: where to send the /event_auth request. Typically the server
                that sent us `event` in the first place.
+
+               If this is None, no attempt is made to load any missing auth events:
+               rather, an AssertionError is raised if there are any missing events.
+
             event: the event whose auth_events we want
 
         Returns:
             all of the events listed in `event.auth_events_ids`, after deduplication
 
         Raises:
+            AssertionError if some auth events were missing and no `destination` was
+            supplied.
+
             AuthError if we were unable to fetch the auth_events for any reason.
         """
         event_auth_event_ids = set(event.auth_event_ids())
@@ -1852,6 +2039,13 @@ class FederationEventHandler:
         )
         if not missing_auth_event_ids:
             return event_auth_events.values()
+        if destination is None:
+            # this shouldn't happen: destination must be set unless we know we have already
+            # persisted the auth events.
+            raise AssertionError(
+                "_load_or_fetch_auth_events_for_event() called with no destination for "
+                "an event with missing auth_events"
+            )
 
         logger.info(
             "Event %s refers to unknown auth events %s: fetching auth chain",
@@ -1887,6 +2081,8 @@ class FederationEventHandler:
         # instead we raise an AuthError, which will make the caller ignore it.
         raise AuthError(code=HTTPStatus.FORBIDDEN, msg="Auth events could not be found")
 
+    @trace
+    @tag_args
     async def _get_remote_auth_chain_for_event(
         self, destination: str, room_id: str, event_id: str
     ) -> None:
@@ -1915,61 +2111,7 @@ class FederationEventHandler:
 
         await self._auth_and_persist_outliers(room_id, remote_auth_events)
 
-    async def _update_context_for_auth_events(
-        self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
-    ) -> EventContext:
-        """Update the state_ids in an event context after auth event resolution,
-        storing the changes as a new state group.
-
-        Args:
-            event: The event we're handling the context for
-
-            context: initial event context
-
-            auth_events: Events to update in the event context.
-
-        Returns:
-            new event context
-        """
-        # exclude the state key of the new event from the current_state in the context.
-        if event.is_state():
-            event_key: Optional[Tuple[str, str]] = (event.type, event.state_key)
-        else:
-            event_key = None
-        state_updates = {
-            k: a.event_id for k, a in auth_events.items() if k != event_key
-        }
-
-        current_state_ids = await context.get_current_state_ids()
-        current_state_ids = dict(current_state_ids)  # type: ignore
-
-        current_state_ids.update(state_updates)
-
-        prev_state_ids = await context.get_prev_state_ids()
-        prev_state_ids = dict(prev_state_ids)
-
-        prev_state_ids.update({k: a.event_id for k, a in auth_events.items()})
-
-        # create a new state group as a delta from the existing one.
-        prev_group = context.state_group
-        state_group = await self._state_storage_controller.store_state_group(
-            event.event_id,
-            event.room_id,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            current_state_ids=current_state_ids,
-        )
-
-        return EventContext.with_state(
-            storage=self._storage_controllers,
-            state_group=state_group,
-            state_group_before_event=context.state_group_before_event,
-            state_delta_due_to_event=state_updates,
-            prev_group=prev_group,
-            delta_ids=state_updates,
-            partial_state=context.partial_state,
-        )
-
+    @trace
     async def _run_push_actions_and_persist_event(
         self, event: EventBase, context: EventContext, backfilled: bool = False
     ) -> None:
@@ -2078,8 +2220,17 @@ class FederationEventHandler:
                     self._message_handler.maybe_schedule_expiry(event)
 
             if not backfilled:  # Never notify for backfilled events
-                for event in events:
-                    await self._notify_persisted_event(event, max_stream_token)
+                with start_active_span("notify_persisted_events"):
+                    set_tag(
+                        SynapseTags.RESULT_PREFIX + "event_ids",
+                        str([ev.event_id for ev in events]),
+                    )
+                    set_tag(
+                        SynapseTags.RESULT_PREFIX + "event_ids.length",
+                        str(len(events)),
+                    )
+                    for event in events:
+                        await self._notify_persisted_event(event, max_stream_token)
 
             return max_stream_token.stream
 
@@ -2120,6 +2271,10 @@ class FederationEventHandler:
             event, event_pos, max_stream_token, extra_users=extra_users
         )
 
+        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+            # TODO retrieve the previous state, and exclude join -> join transitions
+            self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
     def _sanity_check_event(self, ev: EventBase) -> None:
         """
         Do some early sanity checks of a received event
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9bca2bc4b2..93d09e9939 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -26,7 +26,6 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.api.ratelimiting import Ratelimiter
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.http import RequestTimedOutError
 from synapse.http.client import SimpleHttpClient
 from synapse.http.site import SynapseRequest
@@ -163,8 +162,7 @@ class IdentityHandler:
         sid: str,
         mxid: str,
         id_server: str,
-        id_access_token: Optional[str] = None,
-        use_v2: bool = True,
+        id_access_token: str,
     ) -> JsonDict:
         """Bind a 3PID to an identity server
 
@@ -174,8 +172,7 @@ class IdentityHandler:
             mxid: The MXID to bind the 3PID to
             id_server: The domain of the identity server to query
             id_access_token: The access token to authenticate to the identity
-                server with, if necessary. Required if use_v2 is true
-            use_v2: Whether to use v2 Identity Service API endpoints. Defaults to True
+                server with
 
         Raises:
             SynapseError: On any of the following conditions
@@ -187,24 +184,15 @@ class IdentityHandler:
         """
         logger.debug("Proxying threepid bind request for %s to %s", mxid, id_server)
 
-        # If an id_access_token is not supplied, force usage of v1
-        if id_access_token is None:
-            use_v2 = False
-
         if not valid_id_server_location(id_server):
             raise SynapseError(
                 400,
                 "id_server must be a valid hostname with optional port and path components",
             )
 
-        # Decide which API endpoint URLs to use
-        headers = {}
         bind_data = {"sid": sid, "client_secret": client_secret, "mxid": mxid}
-        if use_v2:
-            bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
-            headers["Authorization"] = create_id_access_token_header(id_access_token)  # type: ignore
-        else:
-            bind_url = "https://%s/_matrix/identity/api/v1/3pid/bind" % (id_server,)
+        bind_url = "https://%s/_matrix/identity/v2/3pid/bind" % (id_server,)
+        headers = {"Authorization": create_id_access_token_header(id_access_token)}
 
         try:
             # Use the blacklisting http client as this call is only to identity servers
@@ -223,21 +211,14 @@ class IdentityHandler:
 
             return data
         except HttpResponseException as e:
-            if e.code != 404 or not use_v2:
-                logger.error("3PID bind failed with Matrix error: %r", e)
-                raise e.to_synapse_error()
+            logger.error("3PID bind failed with Matrix error: %r", e)
+            raise e.to_synapse_error()
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
         except CodeMessageException as e:
             data = json_decoder.decode(e.msg)  # XXX WAT?
             return data
 
-        logger.info("Got 404 when POSTing JSON %s, falling back to v1 URL", bind_url)
-        res = await self.bind_threepid(
-            client_secret, sid, mxid, id_server, id_access_token, use_v2=False
-        )
-        return res
-
     async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
         """Attempt to remove a 3PID from an identity server, or if one is not provided, all
         identity servers we're aware the binding is present on
@@ -300,8 +281,8 @@ class IdentityHandler:
                 "id_server must be a valid hostname with optional port and path components",
             )
 
-        url = "https://%s/_matrix/identity/api/v1/3pid/unbind" % (id_server,)
-        url_bytes = b"/_matrix/identity/api/v1/3pid/unbind"
+        url = "https://%s/_matrix/identity/v2/3pid/unbind" % (id_server,)
+        url_bytes = b"/_matrix/identity/v2/3pid/unbind"
 
         content = {
             "mxid": mxid,
@@ -434,48 +415,6 @@ class IdentityHandler:
 
         return session_id
 
-    async def requestEmailToken(
-        self,
-        id_server: str,
-        email: str,
-        client_secret: str,
-        send_attempt: int,
-        next_link: Optional[str] = None,
-    ) -> JsonDict:
-        """
-        Request an external server send an email on our behalf for the purposes of threepid
-        validation.
-
-        Args:
-            id_server: The identity server to proxy to
-            email: The email to send the message to
-            client_secret: The unique client_secret sends by the user
-            send_attempt: Which attempt this is
-            next_link: A link to redirect the user to once they submit the token
-
-        Returns:
-            The json response body from the server
-        """
-        params = {
-            "email": email,
-            "client_secret": client_secret,
-            "send_attempt": send_attempt,
-        }
-        if next_link:
-            params["next_link"] = next_link
-
-        try:
-            data = await self.http_client.post_json_get_json(
-                id_server + "/_matrix/identity/api/v1/validate/email/requestToken",
-                params,
-            )
-            return data
-        except HttpResponseException as e:
-            logger.info("Proxied requestToken failed: %r", e)
-            raise e.to_synapse_error()
-        except RequestTimedOutError:
-            raise SynapseError(500, "Timed out contacting identity server")
-
     async def requestMsisdnToken(
         self,
         id_server: str,
@@ -549,18 +488,7 @@ class IdentityHandler:
         validation_session = None
 
         # Try to validate as email
-        if self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.REMOTE:
-            # Remote emails will only be used if a valid identity server is provided.
-            assert (
-                self.hs.config.registration.account_threepid_delegate_email is not None
-            )
-
-            # Ask our delegated email identity server
-            validation_session = await self.threepid_from_creds(
-                self.hs.config.registration.account_threepid_delegate_email,
-                threepid_creds,
-            )
-        elif self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL:
+        if self.hs.config.email.can_verify_email:
             # Get a validated session matching these details
             validation_session = await self.store.get_threepid_validation_session(
                 "email", client_secret, sid=sid, validated=True
@@ -610,11 +538,7 @@ class IdentityHandler:
             raise SynapseError(400, "Error contacting the identity server")
 
     async def lookup_3pid(
-        self,
-        id_server: str,
-        medium: str,
-        address: str,
-        id_access_token: Optional[str] = None,
+        self, id_server: str, medium: str, address: str, id_access_token: str
     ) -> Optional[str]:
         """Looks up a 3pid in the passed identity server.
 
@@ -629,60 +553,15 @@ class IdentityHandler:
         Returns:
             the matrix ID of the 3pid, or None if it is not recognized.
         """
-        if id_access_token is not None:
-            try:
-                results = await self._lookup_3pid_v2(
-                    id_server, id_access_token, medium, address
-                )
-                return results
-
-            except Exception as e:
-                # Catch HttpResponseExcept for a non-200 response code
-                # Check if this identity server does not know about v2 lookups
-                if isinstance(e, HttpResponseException) and e.code == 404:
-                    # This is an old identity server that does not yet support v2 lookups
-                    logger.warning(
-                        "Attempted v2 lookup on v1 identity server %s. Falling "
-                        "back to v1",
-                        id_server,
-                    )
-                else:
-                    logger.warning("Error when looking up hashing details: %s", e)
-                    return None
-
-        return await self._lookup_3pid_v1(id_server, medium, address)
-
-    async def _lookup_3pid_v1(
-        self, id_server: str, medium: str, address: str
-    ) -> Optional[str]:
-        """Looks up a 3pid in the passed identity server using v1 lookup.
-
-        Args:
-            id_server: The server name (including port, if required)
-                of the identity server to use.
-            medium: The type of the third party identifier (e.g. "email").
-            address: The third party identifier (e.g. "foo@example.com").
 
-        Returns:
-            the matrix ID of the 3pid, or None if it is not recognized.
-        """
         try:
-            data = await self.blacklisting_http_client.get_json(
-                "%s%s/_matrix/identity/api/v1/lookup" % (id_server_scheme, id_server),
-                {"medium": medium, "address": address},
+            results = await self._lookup_3pid_v2(
+                id_server, id_access_token, medium, address
             )
-
-            if "mxid" in data:
-                # note: we used to verify the identity server's signature here, but no longer
-                # require or validate it. See the following for context:
-                # https://github.com/matrix-org/synapse/issues/5253#issuecomment-666246950
-                return data["mxid"]
-        except RequestTimedOutError:
-            raise SynapseError(500, "Timed out contacting identity server")
-        except OSError as e:
-            logger.warning("Error from v1 identity server lookup: %s" % (e,))
-
-        return None
+            return results
+        except Exception as e:
+            logger.warning("Error when looking up hashing details: %s", e)
+            return None
 
     async def _lookup_3pid_v2(
         self, id_server: str, id_access_token: str, medium: str, address: str
@@ -811,7 +690,7 @@ class IdentityHandler:
         room_type: Optional[str],
         inviter_display_name: str,
         inviter_avatar_url: str,
-        id_access_token: Optional[str] = None,
+        id_access_token: str,
     ) -> Tuple[str, List[Dict[str, str]], Dict[str, str], str]:
         """
         Asks an identity server for a third party invite.
@@ -832,7 +711,7 @@ class IdentityHandler:
             inviter_display_name: The current display name of the
                 inviter.
             inviter_avatar_url: The URL of the inviter's avatar.
-            id_access_token (str|None): The access token to authenticate to the identity
+            id_access_token (str): The access token to authenticate to the identity
                 server with
 
         Returns:
@@ -864,71 +743,24 @@ class IdentityHandler:
             invite_config["org.matrix.web_client_location"] = self._web_client_location
 
         # Add the identity service access token to the JSON body and use the v2
-        # Identity Service endpoints if id_access_token is present
+        # Identity Service endpoints
         data = None
-        base_url = "%s%s/_matrix/identity" % (id_server_scheme, id_server)
 
-        if id_access_token:
-            key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
-                id_server_scheme,
-                id_server,
-            )
+        key_validity_url = "%s%s/_matrix/identity/v2/pubkey/isvalid" % (
+            id_server_scheme,
+            id_server,
+        )
 
-            # Attempt a v2 lookup
-            url = base_url + "/v2/store-invite"
-            try:
-                data = await self.blacklisting_http_client.post_json_get_json(
-                    url,
-                    invite_config,
-                    {"Authorization": create_id_access_token_header(id_access_token)},
-                )
-            except RequestTimedOutError:
-                raise SynapseError(500, "Timed out contacting identity server")
-            except HttpResponseException as e:
-                if e.code != 404:
-                    logger.info("Failed to POST %s with JSON: %s", url, e)
-                    raise e
-
-        if data is None:
-            key_validity_url = "%s%s/_matrix/identity/api/v1/pubkey/isvalid" % (
-                id_server_scheme,
-                id_server,
+        url = "%s%s/_matrix/identity/v2/store-invite" % (id_server_scheme, id_server)
+        try:
+            data = await self.blacklisting_http_client.post_json_get_json(
+                url,
+                invite_config,
+                {"Authorization": create_id_access_token_header(id_access_token)},
             )
-            url = base_url + "/api/v1/store-invite"
-
-            try:
-                data = await self.blacklisting_http_client.post_json_get_json(
-                    url, invite_config
-                )
-            except RequestTimedOutError:
-                raise SynapseError(500, "Timed out contacting identity server")
-            except HttpResponseException as e:
-                logger.warning(
-                    "Error trying to call /store-invite on %s%s: %s",
-                    id_server_scheme,
-                    id_server,
-                    e,
-                )
-
-            if data is None:
-                # Some identity servers may only support application/x-www-form-urlencoded
-                # types. This is especially true with old instances of Sydent, see
-                # https://github.com/matrix-org/sydent/pull/170
-                try:
-                    data = await self.blacklisting_http_client.post_urlencoded_get_json(
-                        url, invite_config
-                    )
-                except HttpResponseException as e:
-                    logger.warning(
-                        "Error calling /store-invite on %s%s with fallback "
-                        "encoding: %s",
-                        id_server_scheme,
-                        id_server,
-                        e,
-                    )
-                    raise e
-
-        # TODO: Check for success
+        except RequestTimedOutError:
+            raise SynapseError(500, "Timed out contacting identity server")
+
         token = data["token"]
         public_keys = data.get("public_keys", [])
         if "public_key" in data:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 85b472f250..860c82c110 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -143,8 +143,8 @@ class InitialSyncHandler:
             joined_rooms,
             to_key=int(now_token.receipt_key),
         )
-        if self.hs.config.experimental.msc2285_enabled:
-            receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
+
+        receipt = ReceiptEventSource.filter_out_private_receipts(receipt, user_id)
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
 
@@ -309,18 +309,18 @@ class InitialSyncHandler:
         if blocked:
             raise SynapseError(403, "This room has been blocked on this server")
 
-        user_id = requester.user.to_string()
-
         (
             membership,
             member_event_id,
         ) = await self.auth.check_user_in_room_or_world_readable(
             room_id,
-            user_id,
+            requester,
             allow_departed_users=True,
         )
         is_peeking = member_event_id is None
 
+        user_id = requester.user.to_string()
+
         if membership == Membership.JOIN:
             result = await self._room_initial_sync_joined(
                 user_id, room_id, pagin_config, membership, is_peeking
@@ -456,11 +456,8 @@ class InitialSyncHandler:
             )
             if not receipts:
                 return []
-            if self.hs.config.experimental.msc2285_enabled:
-                receipts = ReceiptEventSource.filter_out_private_receipts(
-                    receipts, user_id
-                )
-            return receipts
+
+            return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
             gather_results(
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1980e37dae..e07cda133a 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -41,6 +41,7 @@ from synapse.api.errors import (
     NotFoundError,
     ShadowBanError,
     SynapseError,
+    UnstableSpecAuthError,
     UnsupportedRoomVersionError,
 )
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
@@ -51,6 +52,7 @@ from synapse.events.builder import EventBuilder
 from synapse.events.snapshot import EventContext
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
+from synapse.logging import opentracing
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
@@ -102,7 +104,7 @@ class MessageHandler:
 
     async def get_room_data(
         self,
-        user_id: str,
+        requester: Requester,
         room_id: str,
         event_type: str,
         state_key: str,
@@ -110,7 +112,7 @@ class MessageHandler:
         """Get data from a room.
 
         Args:
-            user_id
+            requester: The user who did the request.
             room_id
             event_type
             state_key
@@ -123,7 +125,7 @@ class MessageHandler:
             membership,
             membership_event_id,
         ) = await self.auth.check_user_in_room_or_world_readable(
-            room_id, user_id, allow_departed_users=True
+            room_id, requester, allow_departed_users=True
         )
 
         if membership == Membership.JOIN:
@@ -149,17 +151,20 @@ class MessageHandler:
                 "Attempted to retrieve data from a room for a user that has never been in it. "
                 "This should not have happened."
             )
-            raise SynapseError(403, "User not in room", errcode=Codes.FORBIDDEN)
+            raise UnstableSpecAuthError(
+                403,
+                "User not in room",
+                errcode=Codes.NOT_JOINED,
+            )
 
         return data
 
     async def get_state_events(
         self,
-        user_id: str,
+        requester: Requester,
         room_id: str,
         state_filter: Optional[StateFilter] = None,
         at_token: Optional[StreamToken] = None,
-        is_guest: bool = False,
     ) -> List[dict]:
         """Retrieve all state events for a given room. If the user is
         joined to the room then return the current state. If the user has
@@ -168,14 +173,13 @@ class MessageHandler:
         visible.
 
         Args:
-            user_id: The user requesting state events.
+            requester: The user requesting state events.
             room_id: The room ID to get all state events from.
             state_filter: The state filter used to fetch state from the database.
             at_token: the stream token of the at which we are requesting
                 the stats. If the user is not allowed to view the state as of that
                 stream token, we raise a 403 SynapseError. If None, returns the current
                 state based on the current_state_events table.
-            is_guest: whether this user is a guest
         Returns:
             A list of dicts representing state events. [{}, {}, {}]
         Raises:
@@ -185,6 +189,7 @@ class MessageHandler:
             members of this room.
         """
         state_filter = state_filter or StateFilter.all()
+        user_id = requester.user.to_string()
 
         if at_token:
             last_event_id = (
@@ -217,7 +222,7 @@ class MessageHandler:
                 membership,
                 membership_event_id,
             ) = await self.auth.check_user_in_room_or_world_readable(
-                room_id, user_id, allow_departed_users=True
+                room_id, requester, allow_departed_users=True
             )
 
             if membership == Membership.JOIN:
@@ -311,30 +316,42 @@ class MessageHandler:
         Returns:
             A dict of user_id to profile info
         """
-        user_id = requester.user.to_string()
         if not requester.app_service:
             # We check AS auth after fetching the room membership, as it
             # requires us to pull out all joined members anyway.
             membership, _ = await self.auth.check_user_in_room_or_world_readable(
-                room_id, user_id, allow_departed_users=True
+                room_id, requester, allow_departed_users=True
             )
             if membership != Membership.JOIN:
-                raise NotImplementedError(
-                    "Getting joined members after leaving is not implemented"
+                raise SynapseError(
+                    code=403,
+                    errcode=Codes.FORBIDDEN,
+                    msg="Getting joined members while not being a current member of the room is forbidden.",
                 )
 
-        users_with_profile = await self.store.get_users_in_room_with_profiles(room_id)
+        users_with_profile = (
+            await self._state_storage_controller.get_users_in_room_with_profiles(
+                room_id
+            )
+        )
 
         # If this is an AS, double check that they are allowed to see the members.
         # This can either be because the AS user is in the room or because there
         # is a user in the room that the AS is "interested in"
-        if requester.app_service and user_id not in users_with_profile:
+        if (
+            requester.app_service
+            and requester.user.to_string() not in users_with_profile
+        ):
             for uid in users_with_profile:
                 if requester.app_service.is_interested_in_user(uid):
                     break
             else:
                 # Loop fell through, AS has no interested users in room
-                raise AuthError(403, "Appservice not in room")
+                raise UnstableSpecAuthError(
+                    403,
+                    "Appservice not in room",
+                    errcode=Codes.NOT_JOINED,
+                )
 
         return {
             user_id: {
@@ -463,6 +480,7 @@ class EventCreationHandler:
         )
         self._events_shard_config = self.config.worker.events_shard_config
         self._instance_name = hs.get_instance_name()
+        self._notifier = hs.get_notifier()
 
         self.room_prejoin_state_types = self.hs.config.api.room_prejoin_state
 
@@ -734,18 +752,12 @@ class EventCreationHandler:
         if builder.type == EventTypes.Member:
             membership = builder.content.get("membership", None)
             if membership == Membership.JOIN:
-                return await self._is_server_notices_room(builder.room_id)
+                return await self.store.is_server_notice_room(builder.room_id)
             elif membership == Membership.LEAVE:
                 # the user is always allowed to leave (but not kick people)
                 return builder.state_key == requester.user.to_string()
         return False
 
-    async def _is_server_notices_room(self, room_id: str) -> bool:
-        if self.config.servernotices.server_notices_mxid is None:
-            return False
-        user_ids = await self.store.get_users_in_room(room_id)
-        return self.config.servernotices.server_notices_mxid in user_ids
-
     async def assert_accepted_privacy_policy(self, requester: Requester) -> None:
         """Check if a user has accepted the privacy policy
 
@@ -1134,6 +1146,10 @@ class EventCreationHandler:
             context = await self.state.compute_event_context(
                 event,
                 state_ids_before_event=state_map_for_event,
+                # TODO(faster_joins): check how MSC2716 works and whether we can have
+                #   partial state here
+                #   https://github.com/matrix-org/synapse/issues/13003
+                partial_state=False,
             )
         else:
             context = await self.state.compute_event_context(event)
@@ -1358,9 +1374,10 @@ class EventCreationHandler:
         # and `state_groups` because they have `prev_events` that aren't persisted yet
         # (historical messages persisted in reverse-chronological order).
         if not event.internal_metadata.is_historical():
-            await self._bulk_push_rule_evaluator.action_for_event_by_user(
-                event, context
-            )
+            with opentracing.start_active_span("calculate_push_actions"):
+                await self._bulk_push_rule_evaluator.action_for_event_by_user(
+                    event, context
+                )
 
         try:
             # If we're a worker we need to hit out to the master.
@@ -1444,7 +1461,13 @@ class EventCreationHandler:
             if state_entry.state_group in self._external_cache_joined_hosts_updates:
                 return
 
-            joined_hosts = await self.store.get_joined_hosts(event.room_id, state_entry)
+            state = await state_entry.get_state(
+                self._storage_controllers.state, StateFilter.all()
+            )
+            with opentracing.start_active_span("get_joined_hosts"):
+                joined_hosts = await self.store.get_joined_hosts(
+                    event.room_id, state, state_entry
+                )
 
             # Note that the expiry times must be larger than the expiry time in
             # _external_cache_joined_hosts_updates.
@@ -1545,6 +1568,16 @@ class EventCreationHandler:
                 requester, is_admin_redaction=is_admin_redaction
             )
 
+        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+            (
+                current_membership,
+                _,
+            ) = await self.store.get_local_current_membership_for_user_in_room(
+                event.state_key, event.room_id
+            )
+            if current_membership != Membership.JOIN:
+                self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+
         await self._maybe_kick_guest_users(event, context)
 
         if event.type == EventTypes.CanonicalAlias:
@@ -1844,13 +1877,8 @@ class EventCreationHandler:
 
         # For each room we need to find a joined member we can use to send
         # the dummy event with.
-        latest_event_ids = await self.store.get_prev_events_for_room(room_id)
-        members = await self.state.get_current_users_in_room(
-            room_id, latest_event_ids=latest_event_ids
-        )
+        members = await self.store.get_local_users_in_room(room_id)
         for user_id in members:
-            if not self.hs.is_mine_id(user_id):
-                continue
             requester = create_requester(user_id, authenticated_entity=self.server_name)
             try:
                 event, context = await self.create_event(
@@ -1861,7 +1889,6 @@ class EventCreationHandler:
                         "room_id": room_id,
                         "sender": user_id,
                     },
-                    prev_event_ids=latest_event_ids,
                 )
 
                 event.internal_metadata.proactively_send = False
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 6262a35822..1f83bab836 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -24,7 +24,9 @@ from synapse.api.errors import SynapseError
 from synapse.api.filtering import Filter
 from synapse.events.utils import SerializeEventConfig
 from synapse.handlers.room import ShutdownRoomResponse
+from synapse.logging.opentracing import trace
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.rest.admin._base import assert_user_is_admin
 from synapse.storage.state import StateFilter
 from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, Requester, StreamKeyType
@@ -158,11 +160,9 @@ class PaginationHandler:
         self._retention_allowed_lifetime_max = (
             hs.config.retention.retention_allowed_lifetime_max
         )
+        self._is_master = hs.config.worker.worker_app is None
 
-        if (
-            hs.config.worker.run_background_tasks
-            and hs.config.retention.retention_enabled
-        ):
+        if hs.config.retention.retention_enabled and self._is_master:
             # Run the purge jobs described in the configuration file.
             for job in hs.config.retention.retention_purge_jobs:
                 logger.info("Setting up purge job with config: %s", job)
@@ -416,6 +416,7 @@ class PaginationHandler:
 
             await self._storage_controllers.purge_events.purge_room(room_id)
 
+    @trace
     async def get_messages(
         self,
         requester: Requester,
@@ -423,6 +424,7 @@ class PaginationHandler:
         pagin_config: PaginationConfig,
         as_client_event: bool = True,
         event_filter: Optional[Filter] = None,
+        use_admin_priviledge: bool = False,
     ) -> JsonDict:
         """Get messages in a room.
 
@@ -432,10 +434,16 @@ class PaginationHandler:
             pagin_config: The pagination config rules to apply, if any.
             as_client_event: True to get events in client-server format.
             event_filter: Filter to apply to results or None
+            use_admin_priviledge: if `True`, return all events, regardless
+                of whether `user` has access to them. To be used **ONLY**
+                from the admin API.
 
         Returns:
             Pagination API results
         """
+        if use_admin_priviledge:
+            await assert_user_is_admin(self.auth, requester)
+
         user_id = requester.user.to_string()
 
         if pagin_config.from_token:
@@ -458,12 +466,14 @@ class PaginationHandler:
         room_token = from_token.room_key
 
         async with self.pagination_lock.read(room_id):
-            (
-                membership,
-                member_event_id,
-            ) = await self.auth.check_user_in_room_or_world_readable(
-                room_id, user_id, allow_departed_users=True
-            )
+            (membership, member_event_id) = (None, None)
+            if not use_admin_priviledge:
+                (
+                    membership,
+                    member_event_id,
+                ) = await self.auth.check_user_in_room_or_world_readable(
+                    room_id, requester, allow_departed_users=True
+                )
 
             if pagin_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
@@ -475,7 +485,7 @@ class PaginationHandler:
                         room_id, room_token.stream
                     )
 
-                if membership == Membership.LEAVE:
+                if not use_admin_priviledge and membership == Membership.LEAVE:
                     # If they have left the room then clamp the token to be before
                     # they left the room, to save the effort of loading from the
                     # database.
@@ -528,12 +538,13 @@ class PaginationHandler:
         if event_filter:
             events = await event_filter.filter(events)
 
-        events = await filter_events_for_client(
-            self._storage_controllers,
-            user_id,
-            events,
-            is_peeking=(member_event_id is None),
-        )
+        if not use_admin_priviledge:
+            events = await filter_events_for_client(
+                self._storage_controllers,
+                user_id,
+                events,
+                is_peeking=(member_event_id is None),
+            )
 
         # if after the filter applied there are no more events
         # return immediately - but there might be more in next_token batch
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 895ea63ed3..4e575ffbaa 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -34,7 +34,6 @@ from typing import (
     Callable,
     Collection,
     Dict,
-    FrozenSet,
     Generator,
     Iterable,
     List,
@@ -42,7 +41,6 @@ from typing import (
     Set,
     Tuple,
     Type,
-    Union,
 )
 
 from prometheus_client import Counter
@@ -68,7 +66,6 @@ from synapse.storage.databases.main import DataStore
 from synapse.streams import EventSource
 from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import _CacheContext, cached
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 
@@ -1656,15 +1653,18 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
                 # doesn't return. C.f. #5503.
                 return [], max_token
 
-            # Figure out which other users this user should receive updates for
-            users_interested_in = await self._get_interested_in(user, explicit_room_id)
+            # Figure out which other users this user should explicitly receive
+            # updates for
+            additional_users_interested_in = (
+                await self.get_presence_router().get_interested_users(user.to_string())
+            )
 
             # We have a set of users that we're interested in the presence of. We want to
             # cross-reference that with the users that have actually changed their presence.
 
             # Check whether this user should see all user updates
 
-            if users_interested_in == PresenceRouter.ALL_USERS:
+            if additional_users_interested_in == PresenceRouter.ALL_USERS:
                 # Provide presence state for all users
                 presence_updates = await self._filter_all_presence_updates_for_user(
                     user_id, include_offline, from_key
@@ -1673,34 +1673,47 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
                 return presence_updates, max_token
 
             # Make mypy happy. users_interested_in should now be a set
-            assert not isinstance(users_interested_in, str)
+            assert not isinstance(additional_users_interested_in, str)
+
+            # We always care about our own presence.
+            additional_users_interested_in.add(user_id)
+
+            if explicit_room_id:
+                user_ids = await self.store.get_users_in_room(explicit_room_id)
+                additional_users_interested_in.update(user_ids)
 
             # The set of users that we're interested in and that have had a presence update.
             # We'll actually pull the presence updates for these users at the end.
-            interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
+            interested_and_updated_users: Collection[str]
 
             if from_key is not None:
                 # First get all users that have had a presence update
                 updated_users = stream_change_cache.get_all_entities_changed(from_key)
 
                 # Cross-reference users we're interested in with those that have had updates.
-                # Use a slightly-optimised method for processing smaller sets of updates.
-                if updated_users is not None and len(updated_users) < 500:
-                    # For small deltas, it's quicker to get all changes and then
-                    # cross-reference with the users we're interested in
+                if updated_users is not None:
+                    # If we have the full list of changes for presence we can
+                    # simply check which ones share a room with the user.
                     get_updates_counter.labels("stream").inc()
-                    for other_user_id in updated_users:
-                        if other_user_id in users_interested_in:
-                            # mypy thinks this variable could be a FrozenSet as it's possibly set
-                            # to one in the `get_entities_changed` call below, and `add()` is not
-                            # method on a FrozenSet. That doesn't affect us here though, as
-                            # `interested_and_updated_users` is clearly a set() above.
-                            interested_and_updated_users.add(other_user_id)  # type: ignore
+
+                    sharing_users = await self.store.do_users_share_a_room(
+                        user_id, updated_users
+                    )
+
+                    interested_and_updated_users = (
+                        sharing_users.union(additional_users_interested_in)
+                    ).intersection(updated_users)
+
                 else:
                     # Too many possible updates. Find all users we can see and check
                     # if any of them have changed.
                     get_updates_counter.labels("full").inc()
 
+                    users_interested_in = (
+                        await self.store.get_users_who_share_room_with_user(user_id)
+                    )
+                    users_interested_in.update(additional_users_interested_in)
+
                     interested_and_updated_users = (
                         stream_change_cache.get_entities_changed(
                             users_interested_in, from_key
@@ -1709,7 +1722,10 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
             else:
                 # No from_key has been specified. Return the presence for all users
                 # this user is interested in
-                interested_and_updated_users = users_interested_in
+                interested_and_updated_users = (
+                    await self.store.get_users_who_share_room_with_user(user_id)
+                )
+                interested_and_updated_users.update(additional_users_interested_in)
 
             # Retrieve the current presence state for each user
             users_to_state = await self.get_presence_handler().current_state_for_users(
@@ -1804,62 +1820,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
     def get_current_key(self) -> int:
         return self.store.get_current_presence_token()
 
-    @cached(num_args=2, cache_context=True)
-    async def _get_interested_in(
-        self,
-        user: UserID,
-        explicit_room_id: Optional[str] = None,
-        cache_context: Optional[_CacheContext] = None,
-    ) -> Union[Set[str], str]:
-        """Returns the set of users that the given user should see presence
-        updates for.
-
-        Args:
-            user: The user to retrieve presence updates for.
-            explicit_room_id: The users that are in the room will be returned.
-
-        Returns:
-            A set of user IDs to return presence updates for, or "ALL" to return all
-            known updates.
-        """
-        user_id = user.to_string()
-        users_interested_in = set()
-        users_interested_in.add(user_id)  # So that we receive our own presence
-
-        # cache_context isn't likely to ever be None due to the @cached decorator,
-        # but we can't have a non-optional argument after the optional argument
-        # explicit_room_id either. Assert cache_context is not None so we can use it
-        # without mypy complaining.
-        assert cache_context
-
-        # Check with the presence router whether we should poll additional users for
-        # their presence information
-        additional_users = await self.get_presence_router().get_interested_users(
-            user.to_string()
-        )
-        if additional_users == PresenceRouter.ALL_USERS:
-            # If the module requested that this user see the presence updates of *all*
-            # users, then simply return that instead of calculating what rooms this
-            # user shares
-            return PresenceRouter.ALL_USERS
-
-        # Add the additional users from the router
-        users_interested_in.update(additional_users)
-
-        # Find the users who share a room with this user
-        users_who_share_room = await self.store.get_users_who_share_room_with_user(
-            user_id, on_invalidate=cache_context.invalidate
-        )
-        users_interested_in.update(users_who_share_room)
-
-        if explicit_room_id:
-            user_ids = await self.store.get_users_in_room(
-                explicit_room_id, on_invalidate=cache_context.invalidate
-            )
-            users_interested_in.update(user_ids)
-
-        return users_interested_in
-
 
 def handle_timeouts(
     user_states: List[UserPresenceState],
@@ -2091,8 +2051,7 @@ async def get_interested_remotes(
     )
 
     for room_id, states in room_ids_to_states.items():
-        user_ids = await store.get_users_in_room(room_id)
-        hosts = {get_domain_from_id(user_id) for user_id in user_ids}
+        hosts = await store.get_current_hosts_in_room(room_id)
         for host in hosts:
             hosts_and_states.setdefault(host, set()).update(states)
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 43d2882b0a..d2bdb9c8be 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -256,10 +256,9 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
             room_ids, from_key=from_key, to_key=to_key
         )
 
-        if self.config.experimental.msc2285_enabled:
-            events = ReceiptEventSource.filter_out_private_receipts(
-                events, user.to_string()
-            )
+        events = ReceiptEventSource.filter_out_private_receipts(
+            events, user.to_string()
+        )
 
         return events, to_key
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c77d181722..20ec22105a 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -29,7 +29,13 @@ from synapse.api.constants import (
     JoinRules,
     LoginType,
 )
-from synapse.api.errors import AuthError, Codes, ConsentNotGivenError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    ConsentNotGivenError,
+    InvalidClientTokenError,
+    SynapseError,
+)
 from synapse.appservice import ApplicationService
 from synapse.config.server import is_threepid_reserved
 from synapse.http.servlet import assert_params_in_dict
@@ -180,10 +186,7 @@ class RegistrationHandler:
                 )
             if guest_access_token:
                 user_data = await self.auth.get_user_by_access_token(guest_access_token)
-                if (
-                    not user_data.is_guest
-                    or UserID.from_string(user_data.user_id).localpart != localpart
-                ):
+                if not user_data.is_guest or user_data.user.localpart != localpart:
                     raise AuthError(
                         403,
                         "Cannot register taken user ID without valid guest "
@@ -618,7 +621,7 @@ class RegistrationHandler:
         user_id = user.to_string()
         service = self.store.get_app_service_by_token(as_token)
         if not service:
-            raise AuthError(403, "Invalid application service token.")
+            raise InvalidClientTokenError()
         if not service.is_interested_in_user(user_id):
             raise SynapseError(
                 400,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 0b63cd2186..28d7093f08 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -19,6 +19,7 @@ import attr
 from synapse.api.constants import RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
+from synapse.logging.opentracing import trace
 from synapse.storage.databases.main.relations import _RelatedEvent
 from synapse.types import JsonDict, Requester, StreamToken, UserID
 from synapse.visibility import filter_events_for_client
@@ -73,7 +74,6 @@ class RelationsHandler:
         room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
-        aggregation_key: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
         from_token: Optional[StreamToken] = None,
@@ -89,7 +89,6 @@ class RelationsHandler:
             room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
-            aggregation_key: Only fetch events with this aggregation key, if given.
             limit: Only fetch the most recent `limit` events.
             direction: Whether to fetch the most recent first (`"b"`) or the
                 oldest first (`"f"`).
@@ -104,7 +103,7 @@ class RelationsHandler:
 
         # TODO Properly handle a user leaving a room.
         (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
-            room_id, user_id, allow_departed_users=True
+            room_id, requester, allow_departed_users=True
         )
 
         # This gets the original event and checks that a) the event exists and
@@ -122,7 +121,6 @@ class RelationsHandler:
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
-            aggregation_key=aggregation_key,
             limit=limit,
             direction=direction,
             from_token=from_token,
@@ -364,6 +362,7 @@ class RelationsHandler:
 
         return results
 
+    @trace
     async def get_bundled_aggregations(
         self, events: Iterable[EventBase], user_id: str
     ) -> Dict[str, BundledAggregations]:
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index a54f163c0a..33e9a87002 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -19,6 +19,7 @@ import math
 import random
 import string
 from collections import OrderedDict
+from http import HTTPStatus
 from typing import (
     TYPE_CHECKING,
     Any,
@@ -60,7 +61,6 @@ from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
 from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
-from synapse.handlers.federation import get_domains_from_state
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
 from synapse.rest.admin._base import assert_user_is_admin
@@ -705,8 +705,8 @@ class RoomCreationHandler:
                 was, requested, `room_alias`. Secondly, the stream_id of the
                 last persisted event.
         Raises:
-            SynapseError if the room ID couldn't be stored, or something went
-            horribly wrong.
+            SynapseError if the room ID couldn't be stored, 3pid invitation config
+            validation failed, or something went horribly wrong.
             ResourceLimitError if server is blocked to some resource being
             exceeded
         """
@@ -721,7 +721,7 @@ class RoomCreationHandler:
             # allow the server notices mxid to create rooms
             is_requester_admin = True
         else:
-            is_requester_admin = await self.auth.is_server_admin(requester.user)
+            is_requester_admin = await self.auth.is_server_admin(requester)
 
         # Let the third party rules modify the room creation config if needed, or abort
         # the room creation entirely with an exception.
@@ -732,6 +732,19 @@ class RoomCreationHandler:
         invite_3pid_list = config.get("invite_3pid", [])
         invite_list = config.get("invite", [])
 
+        # validate each entry for correctness
+        for invite_3pid in invite_3pid_list:
+            if not all(
+                key in invite_3pid
+                for key in ("medium", "address", "id_server", "id_access_token")
+            ):
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "all of `medium`, `address`, `id_server` and `id_access_token` "
+                    "are required when making a 3pid invite",
+                    Codes.MISSING_PARAM,
+                )
+
         if not is_requester_admin:
             spam_check = await self.spam_checker.user_may_create_room(user_id)
             if spam_check != NOT_SPAM:
@@ -889,7 +902,11 @@ class RoomCreationHandler:
         # override any attempt to set room versions via the creation_content
         creation_content["room_version"] = room_version.identifier
 
-        last_stream_id = await self._send_events_for_new_room(
+        (
+            last_stream_id,
+            last_sent_event_id,
+            depth,
+        ) = await self._send_events_for_new_room(
             requester,
             room_id,
             preset_config=preset_config,
@@ -905,7 +922,7 @@ class RoomCreationHandler:
         if "name" in config:
             name = config["name"]
             (
-                _,
+                name_event,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
@@ -917,12 +934,16 @@ class RoomCreationHandler:
                     "content": {"name": name},
                 },
                 ratelimit=False,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = name_event.event_id
+            depth += 1
 
         if "topic" in config:
             topic = config["topic"]
             (
-                _,
+                topic_event,
                 last_stream_id,
             ) = await self.event_creation_handler.create_and_send_nonmember_event(
                 requester,
@@ -934,7 +955,11 @@ class RoomCreationHandler:
                     "content": {"topic": topic},
                 },
                 ratelimit=False,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = topic_event.event_id
+            depth += 1
 
         # we avoid dropping the lock between invites, as otherwise joins can
         # start coming in and making the createRoom slow.
@@ -949,7 +974,7 @@ class RoomCreationHandler:
 
             for invitee in invite_list:
                 (
-                    _,
+                    member_event_id,
                     last_stream_id,
                 ) = await self.room_member_handler.update_membership_locked(
                     requester,
@@ -959,16 +984,23 @@ class RoomCreationHandler:
                     ratelimit=False,
                     content=content,
                     new_room=True,
+                    prev_event_ids=[last_sent_event_id],
+                    depth=depth,
                 )
+                last_sent_event_id = member_event_id
+                depth += 1
 
         for invite_3pid in invite_3pid_list:
             id_server = invite_3pid["id_server"]
-            id_access_token = invite_3pid.get("id_access_token")  # optional
+            id_access_token = invite_3pid["id_access_token"]
             address = invite_3pid["address"]
             medium = invite_3pid["medium"]
             # Note that do_3pid_invite can raise a  ShadowBanError, but this was
             # handled above by emptying invite_3pid_list.
-            last_stream_id = await self.hs.get_room_member_handler().do_3pid_invite(
+            (
+                member_event_id,
+                last_stream_id,
+            ) = await self.hs.get_room_member_handler().do_3pid_invite(
                 room_id,
                 requester.user,
                 medium,
@@ -977,7 +1009,11 @@ class RoomCreationHandler:
                 requester,
                 txn_id=None,
                 id_access_token=id_access_token,
+                prev_event_ids=[last_sent_event_id],
+                depth=depth,
             )
+            last_sent_event_id = member_event_id
+            depth += 1
 
         result = {"room_id": room_id}
 
@@ -1005,20 +1041,22 @@ class RoomCreationHandler:
         power_level_content_override: Optional[JsonDict] = None,
         creator_join_profile: Optional[JsonDict] = None,
         ratelimit: bool = True,
-    ) -> int:
+    ) -> Tuple[int, str, int]:
         """Sends the initial events into a new room.
 
         `power_level_content_override` doesn't apply when initial state has
         power level state event content.
 
         Returns:
-            The stream_id of the last event persisted.
+            A tuple containing the stream ID, event ID and depth of the last
+            event sent to the room.
         """
 
         creator_id = creator.user.to_string()
 
         event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
 
+        depth = 1
         last_sent_event_id: Optional[str] = None
 
         def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
@@ -1031,6 +1069,7 @@ class RoomCreationHandler:
 
         async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
             nonlocal last_sent_event_id
+            nonlocal depth
 
             event = create(etype, content, **kwargs)
             logger.debug("Sending %s in new room", etype)
@@ -1047,9 +1086,11 @@ class RoomCreationHandler:
                 # Note: we don't pass state_event_ids here because this triggers
                 # an additional query per event to look them up from the events table.
                 prev_event_ids=[last_sent_event_id] if last_sent_event_id else [],
+                depth=depth,
             )
 
             last_sent_event_id = sent_event.event_id
+            depth += 1
 
             return last_stream_id
 
@@ -1075,6 +1116,7 @@ class RoomCreationHandler:
             content=creator_join_profile,
             new_room=True,
             prev_event_ids=[last_sent_event_id],
+            depth=depth,
         )
         last_sent_event_id = member_event_id
 
@@ -1168,7 +1210,7 @@ class RoomCreationHandler:
                 content={"algorithm": RoomEncryptionAlgorithms.DEFAULT},
             )
 
-        return last_sent_stream_id
+        return last_sent_stream_id, last_sent_event_id, depth
 
     def _generate_room_id(self) -> str:
         """Generates a random room ID.
@@ -1250,13 +1292,16 @@ class RoomContextHandler:
         """
         user = requester.user
         if use_admin_priviledge:
-            await assert_user_is_admin(self.auth, requester.user)
+            await assert_user_is_admin(self.auth, requester)
 
         before_limit = math.floor(limit / 2.0)
         after_limit = limit - before_limit
 
-        users = await self.store.get_users_in_room(room_id)
-        is_peeking = user.to_string() not in users
+        is_user_in_room = await self.store.check_local_user_in_room(
+            user_id=user.to_string(), room_id=room_id
+        )
+        # The user is peeking if they aren't in the room already
+        is_peeking = not is_user_in_room
 
         async def filter_evts(events: List[EventBase]) -> List[EventBase]:
             if use_admin_priviledge:
@@ -1355,6 +1400,7 @@ class TimestampLookupHandler:
         self.store = hs.get_datastores().main
         self.state_handler = hs.get_state_handler()
         self.federation_client = hs.get_federation_client()
+        self.federation_event_handler = hs.get_federation_event_handler()
         self._storage_controllers = hs.get_storage_controllers()
 
     async def get_event_for_timestamp(
@@ -1429,17 +1475,16 @@ class TimestampLookupHandler:
                 timestamp,
             )
 
-            # Find other homeservers from the given state in the room
-            curr_state = await self._storage_controllers.state.get_current_state(
-                room_id
+            likely_domains = (
+                await self._storage_controllers.state.get_current_hosts_in_room(room_id)
             )
-            curr_domains = get_domains_from_state(curr_state)
-            likely_domains = [
-                domain for domain, depth in curr_domains if domain != self.server_name
-            ]
 
             # Loop through each homeserver candidate until we get a succesful response
             for domain in likely_domains:
+                # We don't want to ask our own server for information we don't have
+                if domain == self.server_name:
+                    continue
+
                 try:
                     remote_response = await self.federation_client.timestamp_to_event(
                         domain, room_id, timestamp, direction
@@ -1450,38 +1495,68 @@ class TimestampLookupHandler:
                         remote_response,
                     )
 
-                    # TODO: Do we want to persist this as an extremity?
-                    # TODO: I think ideally, we would try to backfill from
-                    # this event and run this whole
-                    # `get_event_for_timestamp` function again to make sure
-                    # they didn't give us an event from their gappy history.
                     remote_event_id = remote_response.event_id
-                    origin_server_ts = remote_response.origin_server_ts
+                    remote_origin_server_ts = remote_response.origin_server_ts
+
+                    # Backfill this event so we can get a pagination token for
+                    # it with `/context` and paginate `/messages` from this
+                    # point.
+                    #
+                    # TODO: The requested timestamp may lie in a part of the
+                    #   event graph that the remote server *also* didn't have,
+                    #   in which case they will have returned another event
+                    #   which may be nowhere near the requested timestamp. In
+                    #   the future, we may need to reconcile that gap and ask
+                    #   other homeservers, and/or extend `/timestamp_to_event`
+                    #   to return events on *both* sides of the timestamp to
+                    #   help reconcile the gap faster.
+                    remote_event = (
+                        await self.federation_event_handler.backfill_event_id(
+                            domain, room_id, remote_event_id
+                        )
+                    )
+
+                    # XXX: When we see that the remote server is not trustworthy,
+                    # maybe we should not ask them first in the future.
+                    if remote_origin_server_ts != remote_event.origin_server_ts:
+                        logger.info(
+                            "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
+                            domain,
+                            remote_event_id,
+                            remote_origin_server_ts,
+                            remote_event.origin_server_ts,
+                        )
 
                     # Only return the remote event if it's closer than the local event
                     if not local_event or (
-                        abs(origin_server_ts - timestamp)
+                        abs(remote_event.origin_server_ts - timestamp)
                         < abs(local_event.origin_server_ts - timestamp)
                     ):
-                        return remote_event_id, origin_server_ts
+                        logger.info(
+                            "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
+                            remote_event_id,
+                            remote_event.origin_server_ts,
+                            timestamp,
+                            local_event.event_id if local_event else None,
+                            local_event.origin_server_ts if local_event else None,
+                        )
+                        return remote_event_id, remote_origin_server_ts
                 except (HttpResponseException, InvalidResponseError) as ex:
                     # Let's not put a high priority on some other homeserver
                     # failing to respond or giving a random response
                     logger.debug(
-                        "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+                        "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
                         domain,
                         type(ex).__name__,
                         ex,
                         ex.args,
                     )
-                except Exception as ex:
+                except Exception:
                     # But we do want to see some exceptions in our code
                     logger.warning(
-                        "Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
+                        "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception",
                         domain,
-                        type(ex).__name__,
-                        ex,
-                        ex.args,
+                        exc_info=True,
                     )
 
         # To appease mypy, we have to add both of these conditions to check for
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 29868eb743..bb0bdb8e6f 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -182,7 +182,7 @@ class RoomListHandler:
                 == HistoryVisibility.WORLD_READABLE,
                 "guest_can_join": room["guest_access"] == "can_join",
                 "join_rule": room["join_rules"],
-                "org.matrix.msc3827.room_type": room["room_type"],
+                "room_type": room["room_type"],
             }
 
             # Filter out Nones – rather omit the field altogether
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 04c44b2ccb..8d01f4bf2b 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -32,6 +32,7 @@ from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
+from synapse.logging import opentracing
 from synapse.module_api import NOT_SPAM
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -94,12 +95,29 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             rate_hz=hs.config.ratelimiting.rc_joins_local.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_local.burst_count,
         )
+        # Tracks joins from local users to rooms this server isn't a member of.
+        # I.e. joins this server makes by requesting /make_join /send_join from
+        # another server.
         self._join_rate_limiter_remote = Ratelimiter(
             store=self.store,
             clock=self.clock,
             rate_hz=hs.config.ratelimiting.rc_joins_remote.per_second,
             burst_count=hs.config.ratelimiting.rc_joins_remote.burst_count,
         )
+        # TODO: find a better place to keep this Ratelimiter.
+        #   It needs to be
+        #    - written to by event persistence code
+        #    - written to by something which can snoop on replication streams
+        #    - read by the RoomMemberHandler to rate limit joins from local users
+        #    - read by the FederationServer to rate limit make_joins and send_joins from
+        #      other homeservers
+        #   I wonder if a homeserver-wide collection of rate limiters might be cleaner?
+        self._join_rate_per_room_limiter = Ratelimiter(
+            store=self.store,
+            clock=self.clock,
+            rate_hz=hs.config.ratelimiting.rc_joins_per_room.per_second,
+            burst_count=hs.config.ratelimiting.rc_joins_per_room.burst_count,
+        )
 
         # Ratelimiter for invites, keyed by room (across all issuers, all
         # recipients).
@@ -136,6 +154,18 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         )
 
         self.request_ratelimiter = hs.get_request_ratelimiter()
+        hs.get_notifier().add_new_join_in_room_callback(self._on_user_joined_room)
+
+    def _on_user_joined_room(self, event_id: str, room_id: str) -> None:
+        """Notify the rate limiter that a room join has occurred.
+
+        Use this to inform the RoomMemberHandler about joins that have either
+        - taken place on another homeserver, or
+        - on another worker in this homeserver.
+        Joins actioned by this worker should use the usual `ratelimit` method, which
+        checks the limit and increments the counter in one go.
+        """
+        self._join_rate_per_room_limiter.record_action(requester=None, key=room_id)
 
     @abc.abstractmethod
     async def _remote_join(
@@ -149,7 +179,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         """Try and join a room that this server is not in
 
         Args:
-            requester
+            requester: The user making the request, according to the access token.
             remote_room_hosts: List of servers that can be used to join via.
             room_id: Room that we are trying to join
             user: User who is trying to join
@@ -285,6 +315,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
         txn_id: Optional[str] = None,
         ratelimit: bool = True,
         content: Optional[dict] = None,
@@ -315,6 +346,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
             txn_id:
             ratelimit:
@@ -370,6 +404,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             allow_no_prev_events=allow_no_prev_events,
             prev_event_ids=prev_event_ids,
             state_event_ids=state_event_ids,
+            depth=depth,
             require_consent=require_consent,
             outlier=outlier,
             historical=historical,
@@ -391,14 +426,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # up blocking profile updates.
             if newly_joined and ratelimit:
                 await self._join_rate_limiter_local.ratelimit(requester)
-
-        result_event = await self.event_creation_handler.handle_new_client_event(
-            requester,
-            event,
-            context,
-            extra_users=[target],
-            ratelimit=ratelimit,
-        )
+                await self._join_rate_per_room_limiter.ratelimit(
+                    requester, key=room_id, update=False
+                )
+        with opentracing.start_active_span("handle_new_client_event"):
+            result_event = await self.event_creation_handler.handle_new_client_event(
+                requester,
+                event,
+                context,
+                extra_users=[target],
+                ratelimit=ratelimit,
+            )
 
         if event.membership == Membership.LEAVE:
             if prev_member_event_id:
@@ -466,6 +504,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Update a user's membership in a room.
 
@@ -501,6 +540,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
             A tuple of the new event ID and stream ID.
@@ -523,24 +565,26 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # by application services), and then by room ID.
         async with self.member_as_limiter.queue(as_id):
             async with self.member_linearizer.queue(key):
-                result = await self.update_membership_locked(
-                    requester,
-                    target,
-                    room_id,
-                    action,
-                    txn_id=txn_id,
-                    remote_room_hosts=remote_room_hosts,
-                    third_party_signed=third_party_signed,
-                    ratelimit=ratelimit,
-                    content=content,
-                    new_room=new_room,
-                    require_consent=require_consent,
-                    outlier=outlier,
-                    historical=historical,
-                    allow_no_prev_events=allow_no_prev_events,
-                    prev_event_ids=prev_event_ids,
-                    state_event_ids=state_event_ids,
-                )
+                with opentracing.start_active_span("update_membership_locked"):
+                    result = await self.update_membership_locked(
+                        requester,
+                        target,
+                        room_id,
+                        action,
+                        txn_id=txn_id,
+                        remote_room_hosts=remote_room_hosts,
+                        third_party_signed=third_party_signed,
+                        ratelimit=ratelimit,
+                        content=content,
+                        new_room=new_room,
+                        require_consent=require_consent,
+                        outlier=outlier,
+                        historical=historical,
+                        allow_no_prev_events=allow_no_prev_events,
+                        prev_event_ids=prev_event_ids,
+                        state_event_ids=state_event_ids,
+                        depth=depth,
+                    )
 
         return result
 
@@ -562,6 +606,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         allow_no_prev_events: bool = False,
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Helper for update_membership.
 
@@ -599,10 +644,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 prev_events are set so we need to set them ourself via this argument.
                 This should normally be left as None, which will cause the auth_event_ids
                 to be calculated based on the room state at the prev_events.
+            depth: Override the depth used to order the event in the DAG.
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
             A tuple of the new event ID and stream ID.
         """
+
         content_specified = bool(content)
         if content is None:
             content = {}
@@ -640,7 +689,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 errcode=Codes.BAD_JSON,
             )
 
-        if "avatar_url" in content:
+        if "avatar_url" in content and content.get("avatar_url") is not None:
             if not await self.profile_handler.check_avatar_size_and_mime_type(
                 content["avatar_url"],
             ):
@@ -695,7 +744,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 is_requester_admin = True
 
             else:
-                is_requester_admin = await self.auth.is_server_admin(requester.user)
+                is_requester_admin = await self.auth.is_server_admin(requester)
 
             if not is_requester_admin:
                 if self.config.server.block_non_admin_invites:
@@ -732,6 +781,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 allow_no_prev_events=allow_no_prev_events,
                 prev_event_ids=prev_event_ids,
                 state_event_ids=state_event_ids,
+                depth=depth,
                 content=content,
                 require_consent=require_consent,
                 outlier=outlier,
@@ -740,14 +790,14 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
 
-        current_state_ids = await self.state_handler.get_current_state_ids(
-            room_id, latest_event_ids=latest_event_ids
+        state_before_join = await self.state_handler.compute_state_after_events(
+            room_id, latest_event_ids
         )
 
         # TODO: Refactor into dictionary of explicitly allowed transitions
         # between old and new state, with specific error messages for some
         # transitions and generic otherwise
-        old_state_id = current_state_ids.get((EventTypes.Member, target.to_string()))
+        old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
         if old_state_id:
             old_state = await self.store.get_event(old_state_id, allow_none=True)
             old_membership = old_state.content.get("membership") if old_state else None
@@ -787,7 +837,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 old_membership == Membership.INVITE
                 and effective_membership_state == Membership.LEAVE
             ):
-                is_blocked = await self._is_server_notice_room(room_id)
+                is_blocked = await self.store.is_server_notice_room(room_id)
                 if is_blocked:
                     raise SynapseError(
                         HTTPStatus.FORBIDDEN,
@@ -798,11 +848,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if action == "kick":
                 raise AuthError(403, "The target user is not in the room")
 
-        is_host_in_room = await self._is_host_in_room(current_state_ids)
+        is_host_in_room = await self._is_host_in_room(state_before_join)
 
         if effective_membership_state == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = await self._can_guest_join(current_state_ids)
+                guest_can_join = await self._can_guest_join(state_before_join)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -818,7 +868,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 bypass_spam_checker = True
 
             else:
-                bypass_spam_checker = await self.auth.is_server_admin(requester.user)
+                bypass_spam_checker = await self.auth.is_server_admin(requester)
 
             inviter = await self._get_inviter(target.to_string(), room_id)
             if (
@@ -840,13 +890,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
             # Check if a remote join should be performed.
             remote_join, remote_room_hosts = await self._should_perform_remote_join(
-                target.to_string(), room_id, remote_room_hosts, content, is_host_in_room
+                target.to_string(),
+                room_id,
+                remote_room_hosts,
+                content,
+                is_host_in_room,
+                state_before_join,
             )
             if remote_join:
                 if ratelimit:
                     await self._join_rate_limiter_remote.ratelimit(
                         requester,
                     )
+                    await self._join_rate_per_room_limiter.ratelimit(
+                        requester,
+                        key=room_id,
+                        update=False,
+                    )
 
                 inviter = await self._get_inviter(target.to_string(), room_id)
                 if inviter and not self.hs.is_mine(inviter):
@@ -967,6 +1027,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             ratelimit=ratelimit,
             prev_event_ids=latest_event_ids,
             state_event_ids=state_event_ids,
+            depth=depth,
             content=content,
             require_consent=require_consent,
             outlier=outlier,
@@ -979,6 +1040,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         remote_room_hosts: List[str],
         content: JsonDict,
         is_host_in_room: bool,
+        state_before_join: StateMap[str],
     ) -> Tuple[bool, List[str]]:
         """
         Check whether the server should do a remote join (as opposed to a local
@@ -998,6 +1060,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             content: The content to use as the event body of the join. This may
                 be modified.
             is_host_in_room: True if the host is in the room.
+            state_before_join: The state before the join event (i.e. the resolution of
+                the states after its parent events).
 
         Returns:
             A tuple of:
@@ -1014,20 +1078,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # If the host is in the room, but not one of the authorised hosts
         # for restricted join rules, a remote join must be used.
         room_version = await self.store.get_room_version(room_id)
-        current_state_ids = await self._storage_controllers.state.get_current_state_ids(
-            room_id
-        )
 
         # If restricted join rules are not being used, a local join can always
         # be used.
         if not await self.event_auth_handler.has_restricted_join_rules(
-            current_state_ids, room_version
+            state_before_join, room_version
         ):
             return False, []
 
         # If the user is invited to the room or already joined, the join
         # event can always be issued locally.
-        prev_member_event_id = current_state_ids.get((EventTypes.Member, user_id), None)
+        prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
         prev_member_event = None
         if prev_member_event_id:
             prev_member_event = await self.store.get_event(prev_member_event_id)
@@ -1042,10 +1103,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         #
         # If not, generate a new list of remote hosts based on which
         # can issue invites.
-        event_map = await self.store.get_events(current_state_ids.values())
+        event_map = await self.store.get_events(state_before_join.values())
         current_state = {
             state_key: event_map[event_id]
-            for state_key, event_id in current_state_ids.items()
+            for state_key, event_id in state_before_join.items()
         }
         allowed_servers = get_servers_from_users(
             get_users_which_can_issue_invite(current_state)
@@ -1059,7 +1120,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Ensure the member should be allowed access via membership in a room.
         await self.event_auth_handler.check_restricted_join_rules(
-            current_state_ids, room_version, user_id, prev_member_event
+            state_before_join, room_version, user_id, prev_member_event
         )
 
         # If this is going to be a local join, additional information must
@@ -1069,7 +1130,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             EventContentFields.AUTHORISING_USER
         ] = await self.event_auth_handler.get_user_which_could_invite(
             room_id,
-            current_state_ids,
+            state_before_join,
         )
 
         return False, []
@@ -1321,8 +1382,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         id_server: str,
         requester: Requester,
         txn_id: Optional[str],
-        id_access_token: Optional[str] = None,
-    ) -> int:
+        id_access_token: str,
+        prev_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
+    ) -> Tuple[str, int]:
         """Invite a 3PID to a room.
 
         Args:
@@ -1334,16 +1397,20 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             requester: The user making the request.
             txn_id: The transaction ID this is part of, or None if this is not
                 part of a transaction.
-            id_access_token: The optional identity server access token.
+            id_access_token: Identity server access token.
+            depth: Override the depth used to order the event in the DAG.
+            prev_event_ids: The event IDs to use as the prev events
+                Should normally be set to None, which will cause the depth to be calculated
+                based on the prev_events.
 
         Returns:
-             The new stream ID.
+            Tuple of event ID and stream ordering position
 
         Raises:
             ShadowBanError if the requester has been shadow-banned.
         """
         if self.config.server.block_non_admin_invites:
-            is_requester_admin = await self.auth.is_server_admin(requester.user)
+            is_requester_admin = await self.auth.is_server_admin(requester)
             if not is_requester_admin:
                 raise SynapseError(
                     403, "Invites have been disabled on this server", Codes.FORBIDDEN
@@ -1383,7 +1450,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             # We don't check the invite against the spamchecker(s) here (through
             # user_may_invite) because we'll do it further down the line anyway (in
             # update_membership_locked).
-            _, stream_id = await self.update_membership(
+            event_id, stream_id = await self.update_membership(
                 requester, UserID.from_string(invitee), room_id, "invite", txn_id=txn_id
             )
         else:
@@ -1402,7 +1469,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     additional_fields=spam_check[1],
                 )
 
-            stream_id = await self._make_and_store_3pid_invite(
+            event, stream_id = await self._make_and_store_3pid_invite(
                 requester,
                 id_server,
                 medium,
@@ -1411,9 +1478,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 inviter,
                 txn_id=txn_id,
                 id_access_token=id_access_token,
+                prev_event_ids=prev_event_ids,
+                depth=depth,
             )
+            event_id = event.event_id
 
-        return stream_id
+        return event_id, stream_id
 
     async def _make_and_store_3pid_invite(
         self,
@@ -1424,8 +1494,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         room_id: str,
         user: UserID,
         txn_id: Optional[str],
-        id_access_token: Optional[str] = None,
-    ) -> int:
+        id_access_token: str,
+        prev_event_ids: Optional[List[str]] = None,
+        depth: Optional[int] = None,
+    ) -> Tuple[EventBase, int]:
         room_state = await self._storage_controllers.state.get_current_state(
             room_id,
             StateFilter.from_types(
@@ -1518,8 +1590,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             },
             ratelimit=False,
             txn_id=txn_id,
+            prev_event_ids=prev_event_ids,
+            depth=depth,
         )
-        return stream_id
+        return event, stream_id
 
     async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
         # Have we just created the room, and is this about to be the very
@@ -1543,12 +1617,6 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         return False
 
-    async def _is_server_notice_room(self, room_id: str) -> bool:
-        if self._server_notices_mxid is None:
-            return False
-        user_ids = await self.store.get_users_in_room(room_id)
-        return self._server_notices_mxid in user_ids
-
 
 class RoomMemberMasterHandler(RoomMemberHandler):
     def __init__(self, hs: "HomeServer"):
@@ -1608,14 +1676,18 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]
 
         if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
+            raise SynapseError(
+                404,
+                "Can't join remote room because no servers "
+                "that are in the room have been provided.",
+            )
 
         check_complexity = self.hs.config.server.limit_remote_rooms.enabled
         if (
             check_complexity
             and self.hs.config.server.limit_remote_rooms.admins_can_join
         ):
-            check_complexity = not await self.auth.is_server_admin(user)
+            check_complexity = not await self.store.is_server_admin(user)
 
         if check_complexity:
             # Fetch the room complexity
@@ -1845,8 +1917,11 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]:
             raise SynapseError(400, "User %s in room %s" % (user_id, room_id))
 
-        if membership:
-            await self.store.forget(user_id, room_id)
+        # In normal case this call is only required if `membership` is not `None`.
+        # But: After the last member had left the room, the background update
+        # `_background_remove_left_rooms` is deleting rows related to this room from
+        # the table `current_state_events` and `get_current_state_events` is `None`.
+        await self.store.forget(user_id, room_id)
 
 
 def get_users_which_can_issue_invite(auth_events: StateMap[EventBase]) -> List[str]:
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 13098f56ed..ebd445adca 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -28,11 +28,11 @@ from synapse.api.constants import (
     RoomTypes,
 )
 from synapse.api.errors import (
-    AuthError,
     Codes,
     NotFoundError,
     StoreError,
     SynapseError,
+    UnstableSpecAuthError,
     UnsupportedRoomVersionError,
 )
 from synapse.api.ratelimiting import Ratelimiter
@@ -175,10 +175,11 @@ class RoomSummaryHandler:
 
         # First of all, check that the room is accessible.
         if not await self._is_local_room_accessible(requested_room_id, requester):
-            raise AuthError(
+            raise UnstableSpecAuthError(
                 403,
                 "User %s not in room %s, and room previews are disabled"
                 % (requester, requested_room_id),
+                errcode=Codes.NOT_JOINED,
             )
 
         # If this is continuing a previous session, pull the persisted data.
@@ -452,7 +453,6 @@ class RoomSummaryHandler:
                 "type": e.type,
                 "state_key": e.state_key,
                 "content": e.content,
-                "room_id": e.room_id,
                 "sender": e.sender,
                 "origin_server_ts": e.origin_server_ts,
             }
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index a305a66860..e2844799e8 100644
--- a/synapse/handlers/send_email.py
+++ b/synapse/handlers/send_email.py
@@ -23,10 +23,12 @@ from pkg_resources import parse_version
 
 import twisted
 from twisted.internet.defer import Deferred
-from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP
+from twisted.internet.interfaces import IOpenSSLContextFactory
+from twisted.internet.ssl import optionsForClientTLS
 from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
 
 from synapse.logging.context import make_deferred_yieldable
+from synapse.types import ISynapseReactor
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -48,7 +50,7 @@ class _NoTLSESMTPSender(ESMTPSender):
 
 
 async def _sendmail(
-    reactor: IReactorTCP,
+    reactor: ISynapseReactor,
     smtphost: str,
     smtpport: int,
     from_addr: str,
@@ -59,6 +61,7 @@ async def _sendmail(
     require_auth: bool = False,
     require_tls: bool = False,
     enable_tls: bool = True,
+    force_tls: bool = False,
 ) -> None:
     """A simple wrapper around ESMTPSenderFactory, to allow substitution in tests
 
@@ -73,8 +76,9 @@ async def _sendmail(
         password: password to give when authenticating
         require_auth: if auth is not offered, fail the request
         require_tls: if TLS is not offered, fail the reqest
-        enable_tls: True to enable TLS. If this is False and require_tls is True,
+        enable_tls: True to enable STARTTLS. If this is False and require_tls is True,
            the request will fail.
+        force_tls: True to enable Implicit TLS.
     """
     msg = BytesIO(msg_bytes)
     d: "Deferred[object]" = Deferred()
@@ -105,13 +109,23 @@ async def _sendmail(
         # set to enable TLS.
         factory = build_sender_factory(hostname=smtphost if enable_tls else None)
 
-    reactor.connectTCP(
-        smtphost,
-        smtpport,
-        factory,
-        timeout=30,
-        bindAddress=None,
-    )
+    if force_tls:
+        reactor.connectSSL(
+            smtphost,
+            smtpport,
+            factory,
+            optionsForClientTLS(smtphost),
+            timeout=30,
+            bindAddress=None,
+        )
+    else:
+        reactor.connectTCP(
+            smtphost,
+            smtpport,
+            factory,
+            timeout=30,
+            bindAddress=None,
+        )
 
     await make_deferred_yieldable(d)
 
@@ -132,6 +146,7 @@ class SendEmailHandler:
         self._smtp_pass = passwd.encode("utf-8") if passwd is not None else None
         self._require_transport_security = hs.config.email.require_transport_security
         self._enable_tls = hs.config.email.enable_smtp_tls
+        self._force_tls = hs.config.email.force_tls
 
         self._sendmail = _sendmail
 
@@ -189,4 +204,5 @@ class SendEmailHandler:
             require_auth=self._smtp_user is not None,
             require_tls=self._require_transport_security,
             enable_tls=self._enable_tls,
+            force_tls=self._force_tls,
         )
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index d42a414c90..5293fa4d0e 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -13,7 +13,20 @@
 # limitations under the License.
 import itertools
 import logging
-from typing import TYPE_CHECKING, Any, Dict, FrozenSet, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    AbstractSet,
+    Any,
+    Collection,
+    Dict,
+    FrozenSet,
+    List,
+    Mapping,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 import attr
 from prometheus_client import Counter
@@ -89,7 +102,7 @@ class SyncConfig:
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class TimelineBatch:
     prev_batch: StreamToken
-    events: List[EventBase]
+    events: Sequence[EventBase]
     limited: bool
     # A mapping of event ID to the bundled aggregations for the above events.
     # This is only calculated if limited is true.
@@ -507,10 +520,17 @@ class SyncHandler:
                 # ensure that we always include current state in the timeline
                 current_state_ids: FrozenSet[str] = frozenset()
                 if any(e.is_state() for e in recents):
+                    # FIXME(faster_joins): We use the partial state here as
+                    # we don't want to block `/sync` on finishing a lazy join.
+                    # Which should be fine once
+                    # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+                    # since we shouldn't reach here anymore?
+                    # Note that we use the current state as a whitelist for filtering
+                    # `recents`, so partial state is only a problem when a membership
+                    # event turns up in `recents` but has not made it into the current
+                    # state.
                     current_state_ids_map = (
-                        await self._state_storage_controller.get_current_state_ids(
-                            room_id
-                        )
+                        await self.store.get_partial_current_state_ids(room_id)
                     )
                     current_state_ids = frozenset(current_state_ids_map.values())
 
@@ -579,7 +599,13 @@ class SyncHandler:
                 if any(e.is_state() for e in loaded_recents):
                     # FIXME(faster_joins): We use the partial state here as
                     # we don't want to block `/sync` on finishing a lazy join.
-                    # Is this the correct way of doing it?
+                    # Which should be fine once
+                    # https://github.com/matrix-org/synapse/issues/12989 is resolved,
+                    # since we shouldn't reach here anymore?
+                    # Note that we use the current state as a whitelist for filtering
+                    # `loaded_recents`, so partial state is only a problem when a
+                    # membership event turns up in `loaded_recents` but has not made it
+                    # into the current state.
                     current_state_ids_map = (
                         await self.store.get_partial_current_state_ids(room_id)
                     )
@@ -627,7 +653,10 @@ class SyncHandler:
         )
 
     async def get_state_after_event(
-        self, event_id: str, state_filter: Optional[StateFilter] = None
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """
         Get the room state after the given event
@@ -635,9 +664,14 @@ class SyncHandler:
         Args:
             event_id: event of interest
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the event and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
         """
         state_ids = await self._state_storage_controller.get_state_ids_for_event(
-            event_id, state_filter=state_filter or StateFilter.all()
+            event_id,
+            state_filter=state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
         )
 
         # using get_metadata_for_events here (instead of get_event) sidesteps an issue
@@ -660,6 +694,7 @@ class SyncHandler:
         room_id: str,
         stream_position: StreamToken,
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """Get the room state at a particular stream position
 
@@ -667,6 +702,9 @@ class SyncHandler:
             room_id: room for which to get state
             stream_position: point at which to get state
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the last event in the room before `stream_position` and
+                `state_filter` is not satisfied by partial state. Defaults to `True`.
         """
         # FIXME: This gets the state at the latest event before the stream ordering,
         # which might not be the same as the "current state" of the room at the time
@@ -678,7 +716,9 @@ class SyncHandler:
 
         if last_event_id:
             state = await self.get_state_after_event(
-                last_event_id, state_filter=state_filter or StateFilter.all()
+                last_event_id,
+                state_filter=state_filter or StateFilter.all(),
+                await_full_state=await_full_state,
             )
 
         else:
@@ -852,16 +892,26 @@ class SyncHandler:
         now_token: StreamToken,
         full_state: bool,
     ) -> MutableStateMap[EventBase]:
-        """Works out the difference in state between the start of the timeline
-        and the previous sync.
+        """Works out the difference in state between the end of the previous sync and
+        the start of the timeline.
 
         Args:
             room_id:
             batch: The timeline batch for the room that will be sent to the user.
             sync_config:
-            since_token: Token of the end of the previous batch. May be None.
+            since_token: Token of the end of the previous batch. May be `None`.
             now_token: Token of the end of the current batch.
             full_state: Whether to force returning the full state.
+                `lazy_load_members` still applies when `full_state` is `True`.
+
+        Returns:
+            The state to return in the sync response for the room.
+
+            Clients will overlay this onto the state at the end of the previous sync to
+            arrive at the state at the start of the timeline.
+
+            Clients will then overlay state events in the timeline to arrive at the
+            state at the end of the timeline, in preparation for the next sync.
         """
         # TODO(mjark) Check if the state events were received by the server
         # after the previous sync, since we need to include those state
@@ -869,8 +919,17 @@ class SyncHandler:
         # TODO(mjark) Check for new redactions in the state events.
 
         with Measure(self.clock, "compute_state_delta"):
+            # The memberships needed for events in the timeline.
+            # Only calculated when `lazy_load_members` is on.
+            members_to_fetch: Optional[Set[str]] = None
+
+            # A dictionary mapping user IDs to the first event in the timeline sent by
+            # them. Only calculated when `lazy_load_members` is on.
+            first_event_by_sender_map: Optional[Dict[str, EventBase]] = None
 
-            members_to_fetch = None
+            # The contribution to the room state from state events in the timeline.
+            # Only contains the last event for any given state key.
+            timeline_state: StateMap[str]
 
             lazy_load_members = sync_config.filter_collection.lazy_load_members()
             include_redundant_members = (
@@ -881,10 +940,23 @@ class SyncHandler:
                 # We only request state for the members needed to display the
                 # timeline:
 
-                members_to_fetch = {
-                    event.sender  # FIXME: we also care about invite targets etc.
-                    for event in batch.events
-                }
+                timeline_state = {}
+
+                members_to_fetch = set()
+                first_event_by_sender_map = {}
+                for event in batch.events:
+                    # Build the map from user IDs to the first timeline event they sent.
+                    if event.sender not in first_event_by_sender_map:
+                        first_event_by_sender_map[event.sender] = event
+
+                    # We need the event's sender, unless their membership was in a
+                    # previous timeline event.
+                    if (EventTypes.Member, event.sender) not in timeline_state:
+                        members_to_fetch.add(event.sender)
+                    # FIXME: we also care about invite targets etc.
+
+                    if event.is_state():
+                        timeline_state[(event.type, event.state_key)] = event.event_id
 
                 if full_state:
                     # always make sure we LL ourselves so we know we're in the room
@@ -894,55 +966,80 @@ class SyncHandler:
                     members_to_fetch.add(sync_config.user.to_string())
 
                 state_filter = StateFilter.from_lazy_load_member_list(members_to_fetch)
+
+                # We are happy to use partial state to compute the `/sync` response.
+                # Since partial state may not include the lazy-loaded memberships we
+                # require, we fix up the state response afterwards with memberships from
+                # auth events.
+                await_full_state = False
             else:
+                timeline_state = {
+                    (event.type, event.state_key): event.event_id
+                    for event in batch.events
+                    if event.is_state()
+                }
+
                 state_filter = StateFilter.all()
+                await_full_state = True
 
-            timeline_state = {
-                (event.type, event.state_key): event.event_id
-                for event in batch.events
-                if event.is_state()
-            }
+            # Now calculate the state to return in the sync response for the room.
+            # This is more or less the change in state between the end of the previous
+            # sync's timeline and the start of the current sync's timeline.
+            # See the docstring above for details.
+            state_ids: StateMap[str]
 
             if full_state:
                 if batch:
-                    current_state_ids = (
+                    state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[-1].event_id, state_filter=state_filter
+                            batch.events[-1].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
 
-                    state_ids = (
+                    state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[0].event_id, state_filter=state_filter
+                            batch.events[0].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
 
                 else:
-                    current_state_ids = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                    state_at_timeline_end = await self.get_state_at(
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
-                    state_ids = current_state_ids
+                    state_at_timeline_start = state_at_timeline_end
 
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
-                    timeline_start=state_ids,
-                    previous={},
-                    current=current_state_ids,
+                    timeline_start=state_at_timeline_start,
+                    timeline_end=state_at_timeline_end,
+                    previous_timeline_end={},
                     lazy_load_members=lazy_load_members,
                 )
             elif batch.limited:
                 if batch:
                     state_at_timeline_start = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[0].event_id, state_filter=state_filter
+                            batch.events[0].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
                 else:
                     # We can get here if the user has ignored the senders of all
                     # the recent events.
                     state_at_timeline_start = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                 # for now, we disable LL for gappy syncs - see
@@ -964,28 +1061,35 @@ class SyncHandler:
                 # is indeed the case.
                 assert since_token is not None
                 state_at_previous_sync = await self.get_state_at(
-                    room_id, stream_position=since_token, state_filter=state_filter
+                    room_id,
+                    stream_position=since_token,
+                    state_filter=state_filter,
+                    await_full_state=await_full_state,
                 )
 
                 if batch:
-                    current_state_ids = (
+                    state_at_timeline_end = (
                         await self._state_storage_controller.get_state_ids_for_event(
-                            batch.events[-1].event_id, state_filter=state_filter
+                            batch.events[-1].event_id,
+                            state_filter=state_filter,
+                            await_full_state=await_full_state,
                         )
                     )
                 else:
-                    # Its not clear how we get here, but empirically we do
-                    # (#5407). Logging has been added elsewhere to try and
-                    # figure out where this state comes from.
-                    current_state_ids = await self.get_state_at(
-                        room_id, stream_position=now_token, state_filter=state_filter
+                    # We can get here if the user has ignored the senders of all
+                    # the recent events.
+                    state_at_timeline_end = await self.get_state_at(
+                        room_id,
+                        stream_position=now_token,
+                        state_filter=state_filter,
+                        await_full_state=await_full_state,
                     )
 
                 state_ids = _calculate_state(
                     timeline_contains=timeline_state,
                     timeline_start=state_at_timeline_start,
-                    previous=state_at_previous_sync,
-                    current=current_state_ids,
+                    timeline_end=state_at_timeline_end,
+                    previous_timeline_end=state_at_previous_sync,
                     # we have to include LL members in case LL initial sync missed them
                     lazy_load_members=lazy_load_members,
                 )
@@ -1008,8 +1112,30 @@ class SyncHandler:
                                 (EventTypes.Member, member)
                                 for member in members_to_fetch
                             ),
+                            await_full_state=False,
                         )
 
+            # If we only have partial state for the room, `state_ids` may be missing the
+            # memberships we wanted. We attempt to find some by digging through the auth
+            # events of timeline events.
+            if lazy_load_members and await self.store.is_partial_state_room(room_id):
+                assert members_to_fetch is not None
+                assert first_event_by_sender_map is not None
+
+                additional_state_ids = (
+                    await self._find_missing_partial_state_memberships(
+                        room_id, members_to_fetch, first_event_by_sender_map, state_ids
+                    )
+                )
+                state_ids = {**state_ids, **additional_state_ids}
+
+            # At this point, if `lazy_load_members` is enabled, `state_ids` includes
+            # the memberships of all event senders in the timeline. This is because we
+            # may not have sent the memberships in a previous sync.
+
+            # When `include_redundant_members` is on, we send all the lazy-loaded
+            # memberships of event senders. Otherwise we make an effort to limit the set
+            # of memberships we send to those that we have not already sent to this client.
             if lazy_load_members and not include_redundant_members:
                 cache_key = (sync_config.user.to_string(), sync_config.device_id)
                 cache = self.get_lazy_loaded_members_cache(cache_key)
@@ -1051,6 +1177,99 @@ class SyncHandler:
             if e.type != EventTypes.Aliases  # until MSC2261 or alternative solution
         }
 
+    async def _find_missing_partial_state_memberships(
+        self,
+        room_id: str,
+        members_to_fetch: Collection[str],
+        events_with_membership_auth: Mapping[str, EventBase],
+        found_state_ids: StateMap[str],
+    ) -> StateMap[str]:
+        """Finds missing memberships from a set of auth events and returns them as a
+        state map.
+
+        Args:
+            room_id: The partial state room to find the remaining memberships for.
+            members_to_fetch: The memberships to find.
+            events_with_membership_auth: A mapping from user IDs to events whose auth
+                events are known to contain their membership.
+            found_state_ids: A dict from (type, state_key) -> state_event_id, containing
+                memberships that have been previously found. Entries in
+                `members_to_fetch` that have a membership in `found_state_ids` are
+                ignored.
+
+        Returns:
+            A dict from ("m.room.member", state_key) -> state_event_id, containing the
+            memberships missing from `found_state_ids`.
+
+        Raises:
+            KeyError: if `events_with_membership_auth` does not have an entry for a
+                missing membership. Memberships in `found_state_ids` do not need an
+                entry in `events_with_membership_auth`.
+        """
+        additional_state_ids: MutableStateMap[str] = {}
+
+        # Tracks the missing members for logging purposes.
+        missing_members = set()
+
+        # Identify memberships missing from `found_state_ids` and pick out the auth
+        # events in which to look for them.
+        auth_event_ids: Set[str] = set()
+        for member in members_to_fetch:
+            if (EventTypes.Member, member) in found_state_ids:
+                continue
+
+            missing_members.add(member)
+            event_with_membership_auth = events_with_membership_auth[member]
+            auth_event_ids.update(event_with_membership_auth.auth_event_ids())
+
+        auth_events = await self.store.get_events(auth_event_ids)
+
+        # Run through the missing memberships once more, picking out the memberships
+        # from the pile of auth events we have just fetched.
+        for member in members_to_fetch:
+            if (EventTypes.Member, member) in found_state_ids:
+                continue
+
+            event_with_membership_auth = events_with_membership_auth[member]
+
+            # Dig through the auth events to find the desired membership.
+            for auth_event_id in event_with_membership_auth.auth_event_ids():
+                # We only store events once we have all their auth events,
+                # so the auth event must be in the pile we have just
+                # fetched.
+                auth_event = auth_events[auth_event_id]
+
+                if (
+                    auth_event.type == EventTypes.Member
+                    and auth_event.state_key == member
+                ):
+                    missing_members.remove(member)
+                    additional_state_ids[
+                        (EventTypes.Member, member)
+                    ] = auth_event.event_id
+                    break
+
+        if missing_members:
+            # There really shouldn't be any missing memberships now. Either:
+            #  * we couldn't find an auth event, which shouldn't happen because we do
+            #    not persist events with persisting their auth events first, or
+            #  * the set of auth events did not contain a membership we wanted, which
+            #    means our caller didn't compute the events in `members_to_fetch`
+            #    correctly, or we somehow accepted an event whose auth events were
+            #    dodgy.
+            logger.error(
+                "Failed to find memberships for %s in partial state room "
+                "%s in the auth events of %s.",
+                missing_members,
+                room_id,
+                [
+                    events_with_membership_auth[member].event_id
+                    for member in missing_members
+                ],
+            )
+
+        return additional_state_ids
+
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
     ) -> NotifCounts:
@@ -1195,10 +1414,10 @@ class SyncHandler:
     async def _generate_sync_entry_for_device_list(
         self,
         sync_result_builder: "SyncResultBuilder",
-        newly_joined_rooms: Set[str],
-        newly_joined_or_invited_or_knocked_users: Set[str],
-        newly_left_rooms: Set[str],
-        newly_left_users: Set[str],
+        newly_joined_rooms: AbstractSet[str],
+        newly_joined_or_invited_or_knocked_users: AbstractSet[str],
+        newly_left_rooms: AbstractSet[str],
+        newly_left_users: AbstractSet[str],
     ) -> DeviceListUpdates:
         """Generate the DeviceListUpdates section of sync
 
@@ -1216,8 +1435,7 @@ class SyncHandler:
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
 
-        # We're going to mutate these fields, so lets copy them rather than
-        # assume they won't get used later.
+        # Take a copy since these fields will be mutated later.
         newly_joined_or_invited_or_knocked_users = set(
             newly_joined_or_invited_or_knocked_users
         )
@@ -1417,8 +1635,8 @@ class SyncHandler:
     async def _generate_sync_entry_for_presence(
         self,
         sync_result_builder: "SyncResultBuilder",
-        newly_joined_rooms: Set[str],
-        newly_joined_or_invited_users: Set[str],
+        newly_joined_rooms: AbstractSet[str],
+        newly_joined_or_invited_users: AbstractSet[str],
     ) -> None:
         """Generates the presence portion of the sync response. Populates the
         `sync_result_builder` with the result.
@@ -1476,7 +1694,7 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         account_data_by_room: Dict[str, Dict[str, JsonDict]],
-    ) -> Tuple[Set[str], Set[str], Set[str], Set[str]]:
+    ) -> Tuple[AbstractSet[str], AbstractSet[str], AbstractSet[str], AbstractSet[str]]:
         """Generates the rooms portion of the sync response. Populates the
         `sync_result_builder` with the result.
 
@@ -1536,15 +1754,13 @@ class SyncHandler:
         ignored_users = await self.store.ignored_users(user_id)
         if since_token:
             room_changes = await self._get_rooms_changed(
-                sync_result_builder, ignored_users, self.rooms_to_exclude
+                sync_result_builder, ignored_users
             )
             tags_by_room = await self.store.get_updated_tags(
                 user_id, since_token.account_data_key
             )
         else:
-            room_changes = await self._get_all_rooms(
-                sync_result_builder, ignored_users, self.rooms_to_exclude
-            )
+            room_changes = await self._get_all_rooms(sync_result_builder, ignored_users)
             tags_by_room = await self.store.get_tags_for_user(user_id)
 
         log_kv({"rooms_changed": len(room_changes.room_entries)})
@@ -1623,13 +1839,14 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         ignored_users: FrozenSet[str],
-        excluded_rooms: List[str],
     ) -> _RoomChanges:
         """Determine the changes in rooms to report to the user.
 
         This function is a first pass at generating the rooms part of the sync response.
         It determines which rooms have changed during the sync period, and categorises
-        them into four buckets: "knock", "invite", "join" and "leave".
+        them into four buckets: "knock", "invite", "join" and "leave". It also excludes
+        from that list any room that appears in the list of rooms to exclude from sync
+        results in the server configuration.
 
         1. Finds all membership changes for the user in the sync period (from
            `since_token` up to `now_token`).
@@ -1655,7 +1872,7 @@ class SyncHandler:
         #       _have_rooms_changed. We could keep the results in memory to avoid a
         #       second query, at the cost of more complicated source code.
         membership_change_events = await self.store.get_membership_changes_for_user(
-            user_id, since_token.room_key, now_token.room_key, excluded_rooms
+            user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
         )
 
         mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
@@ -1696,7 +1913,11 @@ class SyncHandler:
                 continue
 
             if room_id in sync_result_builder.joined_room_ids or has_join:
-                old_state_ids = await self.get_state_at(room_id, since_token)
+                old_state_ids = await self.get_state_at(
+                    room_id,
+                    since_token,
+                    state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
+                )
                 old_mem_ev_id = old_state_ids.get((EventTypes.Member, user_id), None)
                 old_mem_ev = None
                 if old_mem_ev_id:
@@ -1722,7 +1943,13 @@ class SyncHandler:
                     newly_left_rooms.append(room_id)
                 else:
                     if not old_state_ids:
-                        old_state_ids = await self.get_state_at(room_id, since_token)
+                        old_state_ids = await self.get_state_at(
+                            room_id,
+                            since_token,
+                            state_filter=StateFilter.from_types(
+                                [(EventTypes.Member, user_id)]
+                            ),
+                        )
                         old_mem_ev_id = old_state_ids.get(
                             (EventTypes.Member, user_id), None
                         )
@@ -1862,7 +2089,6 @@ class SyncHandler:
         self,
         sync_result_builder: "SyncResultBuilder",
         ignored_users: FrozenSet[str],
-        ignored_rooms: List[str],
     ) -> _RoomChanges:
         """Returns entries for all rooms for the user.
 
@@ -1884,7 +2110,7 @@ class SyncHandler:
         room_list = await self.store.get_rooms_for_local_user_where_membership_is(
             user_id=user_id,
             membership_list=Membership.LIST,
-            excluded_rooms=ignored_rooms,
+            excluded_rooms=self.rooms_to_exclude,
         )
 
         room_entries = []
@@ -2150,7 +2376,9 @@ class SyncHandler:
                 raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
     async def get_rooms_for_user_at(
-        self, user_id: str, room_key: RoomStreamToken
+        self,
+        user_id: str,
+        room_key: RoomStreamToken,
     ) -> FrozenSet[str]:
         """Get set of joined rooms for a user at the given stream ordering.
 
@@ -2176,7 +2404,12 @@ class SyncHandler:
         # If the membership's stream ordering is after the given stream
         # ordering, we need to go and work out if the user was in the room
         # before.
+        # We also need to check whether the room should be excluded from sync
+        # responses as per the homeserver config.
         for joined_room in joined_rooms:
+            if joined_room.room_id in self.rooms_to_exclude:
+                continue
+
             if not joined_room.event_pos.persisted_after(room_key):
                 joined_room_ids.add(joined_room.room_id)
                 continue
@@ -2188,10 +2421,10 @@ class SyncHandler:
                     joined_room.room_id, joined_room.event_pos.stream
                 )
             )
-            users_in_room = await self.state.get_current_users_in_room(
+            user_ids_in_room = await self.state.get_current_user_ids_in_room(
                 joined_room.room_id, extrems
             )
-            if user_id in users_in_room:
+            if user_id in user_ids_in_room:
                 joined_room_ids.add(joined_room.room_id)
 
         return frozenset(joined_room_ids)
@@ -2211,8 +2444,8 @@ def _action_has_highlight(actions: List[JsonDict]) -> bool:
 def _calculate_state(
     timeline_contains: StateMap[str],
     timeline_start: StateMap[str],
-    previous: StateMap[str],
-    current: StateMap[str],
+    timeline_end: StateMap[str],
+    previous_timeline_end: StateMap[str],
     lazy_load_members: bool,
 ) -> StateMap[str]:
     """Works out what state to include in a sync response.
@@ -2220,45 +2453,50 @@ def _calculate_state(
     Args:
         timeline_contains: state in the timeline
         timeline_start: state at the start of the timeline
-        previous: state at the end of the previous sync (or empty dict
+        timeline_end: state at the end of the timeline
+        previous_timeline_end: state at the end of the previous sync (or empty dict
             if this is an initial sync)
-        current: state at the end of the timeline
         lazy_load_members: whether to return members from timeline_start
             or not.  assumes that timeline_start has already been filtered to
             include only the members the client needs to know about.
     """
-    event_id_to_key = {
-        e: key
-        for key, e in itertools.chain(
+    event_id_to_state_key = {
+        event_id: state_key
+        for state_key, event_id in itertools.chain(
             timeline_contains.items(),
-            previous.items(),
             timeline_start.items(),
-            current.items(),
+            timeline_end.items(),
+            previous_timeline_end.items(),
         )
     }
 
-    c_ids = set(current.values())
-    ts_ids = set(timeline_start.values())
-    p_ids = set(previous.values())
-    tc_ids = set(timeline_contains.values())
+    timeline_end_ids = set(timeline_end.values())
+    timeline_start_ids = set(timeline_start.values())
+    previous_timeline_end_ids = set(previous_timeline_end.values())
+    timeline_contains_ids = set(timeline_contains.values())
 
     # If we are lazyloading room members, we explicitly add the membership events
     # for the senders in the timeline into the state block returned by /sync,
     # as we may not have sent them to the client before.  We find these membership
     # events by filtering them out of timeline_start, which has already been filtered
     # to only include membership events for the senders in the timeline.
-    # In practice, we can do this by removing them from the p_ids list,
-    # which is the list of relevant state we know we have already sent to the client.
+    # In practice, we can do this by removing them from the previous_timeline_end_ids
+    # list, which is the list of relevant state we know we have already sent to the
+    # client.
     # see https://github.com/matrix-org/synapse/pull/2970/files/efcdacad7d1b7f52f879179701c7e0d9b763511f#r204732809
 
     if lazy_load_members:
-        p_ids.difference_update(
+        previous_timeline_end_ids.difference_update(
             e for t, e in timeline_start.items() if t[0] == EventTypes.Member
         )
 
-    state_ids = ((c_ids | ts_ids) - p_ids) - tc_ids
+    state_ids = (
+        (timeline_end_ids | timeline_start_ids)
+        - previous_timeline_end_ids
+        - timeline_contains_ids
+    )
 
-    return {event_id_to_key[e]: e for e in state_ids}
+    return {event_id_to_state_key[e]: e for e in state_ids}
 
 
 @attr.s(slots=True, auto_attribs=True)
@@ -2296,7 +2534,7 @@ class SyncResultBuilder:
     archived: List[ArchivedSyncResult] = attr.Factory(list)
     to_device: List[JsonDict] = attr.Factory(list)
 
-    def calculate_user_changes(self) -> Tuple[Set[str], Set[str]]:
+    def calculate_user_changes(self) -> Tuple[AbstractSet[str], AbstractSet[str]]:
         """Work out which other users have joined or left rooms we are joined to.
 
         This data only is only useful for an incremental sync.
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index d104ea07fe..a4cd8b8f0c 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -26,7 +26,7 @@ from synapse.metrics.background_process_metrics import (
 )
 from synapse.replication.tcp.streams import TypingStream
 from synapse.streams import EventSource
-from synapse.types import JsonDict, Requester, StreamKeyType, UserID, get_domain_from_id
+from synapse.types import JsonDict, Requester, StreamKeyType, UserID
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
@@ -253,12 +253,11 @@ class TypingWriterHandler(FollowerTypingHandler):
         self, target_user: UserID, requester: Requester, room_id: str, timeout: int
     ) -> None:
         target_user_id = target_user.to_string()
-        auth_user_id = requester.user.to_string()
 
         if not self.is_mine_id(target_user_id):
             raise SynapseError(400, "User is not hosted on this homeserver")
 
-        if target_user_id != auth_user_id:
+        if target_user != requester.user:
             raise AuthError(400, "Cannot set another user's typing state")
 
         if requester.shadow_banned:
@@ -266,7 +265,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             await self.clock.sleep(random.randint(1, 10))
             raise ShadowBanError()
 
-        await self.auth.check_user_in_room(room_id, target_user_id)
+        await self.auth.check_user_in_room(room_id, requester)
 
         logger.debug("%s has started typing in %s", target_user_id, room_id)
 
@@ -289,12 +288,11 @@ class TypingWriterHandler(FollowerTypingHandler):
         self, target_user: UserID, requester: Requester, room_id: str
     ) -> None:
         target_user_id = target_user.to_string()
-        auth_user_id = requester.user.to_string()
 
         if not self.is_mine_id(target_user_id):
             raise SynapseError(400, "User is not hosted on this homeserver")
 
-        if target_user_id != auth_user_id:
+        if target_user != requester.user:
             raise AuthError(400, "Cannot set another user's typing state")
 
         if requester.shadow_banned:
@@ -302,7 +300,7 @@ class TypingWriterHandler(FollowerTypingHandler):
             await self.clock.sleep(random.randint(1, 10))
             raise ShadowBanError()
 
-        await self.auth.check_user_in_room(room_id, target_user_id)
+        await self.auth.check_user_in_room(room_id, requester)
 
         logger.debug("%s has stopped typing in %s", target_user_id, room_id)
 
@@ -364,8 +362,9 @@ class TypingWriterHandler(FollowerTypingHandler):
             )
             return
 
-        users = await self.store.get_users_in_room(room_id)
-        domains = {get_domain_from_id(u) for u in users}
+        domains = await self._storage_controllers.state.get_current_hosts_in_room(
+            room_id
+        )
 
         if self.server_name in domains:
             logger.info("Got typing update from %s: %r", user_id, content)
@@ -489,8 +488,15 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
             handler = self.get_typing_handler()
 
             events = []
-            for room_id in handler._room_serials.keys():
-                if handler._room_serials[room_id] <= from_key:
+
+            # Work on a copy of things here as these may change in the handler while
+            # waiting for the AS `is_interested_in_room` call to complete.
+            # Shallow copy is safe as no nested data is present.
+            latest_room_serial = handler._latest_room_serial
+            room_serials = handler._room_serials.copy()
+
+            for room_id, serial in room_serials.items():
+                if serial <= from_key:
                     continue
 
                 if not await service.is_interested_in_room(room_id, self._main_store):
@@ -498,7 +504,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
 
                 events.append(self._make_event_for(room_id))
 
-            return events, handler._latest_room_serial
+            return events, latest_room_serial
 
     async def get_new_events(
         self,
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 05cebb5d4d..a744d68c64 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -19,7 +19,6 @@ from twisted.web.client import PartialDownloadError
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, LoginError, SynapseError
-from synapse.config.emailconfig import ThreepidBehaviour
 from synapse.util import json_decoder
 
 if TYPE_CHECKING:
@@ -153,7 +152,7 @@ class _BaseThreepidAuthChecker:
 
         logger.info("Getting validated threepid. threepidcreds: %r", (threepid_creds,))
 
-        # msisdns are currently always ThreepidBehaviour.REMOTE
+        # msisdns are currently always verified via the IS
         if medium == "msisdn":
             if not self.hs.config.registration.account_threepid_delegate_msisdn:
                 raise SynapseError(
@@ -164,18 +163,7 @@ class _BaseThreepidAuthChecker:
                 threepid_creds,
             )
         elif medium == "email":
-            if (
-                self.hs.config.email.threepid_behaviour_email
-                == ThreepidBehaviour.REMOTE
-            ):
-                assert self.hs.config.registration.account_threepid_delegate_email
-                threepid = await identity_handler.threepid_from_creds(
-                    self.hs.config.registration.account_threepid_delegate_email,
-                    threepid_creds,
-                )
-            elif (
-                self.hs.config.email.threepid_behaviour_email == ThreepidBehaviour.LOCAL
-            ):
+            if self.hs.config.email.can_verify_email:
                 threepid = None
                 row = await self.store.get_threepid_validation_session(
                     medium,
@@ -227,10 +215,7 @@ class EmailIdentityAuthChecker(UserInteractiveAuthChecker, _BaseThreepidAuthChec
         _BaseThreepidAuthChecker.__init__(self, hs)
 
     def is_enabled(self) -> bool:
-        return self.hs.config.email.threepid_behaviour_email in (
-            ThreepidBehaviour.REMOTE,
-            ThreepidBehaviour.LOCAL,
-        )
+        return self.hs.config.email.can_verify_email
 
     async def check_auth(self, authdict: dict, clientip: str) -> Any:
         return await self._check_threepid("email", authdict)