summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/account_data.py2
-rw-r--r--synapse/handlers/admin.py6
-rw-r--r--synapse/handlers/appservice.py9
-rw-r--r--synapse/handlers/auth.py75
-rw-r--r--synapse/handlers/cas.py3
-rw-r--r--synapse/handlers/device.py151
-rw-r--r--synapse/handlers/directory.py23
-rw-r--r--synapse/handlers/e2e_keys.py39
-rw-r--r--synapse/handlers/e2e_room_keys.py5
-rw-r--r--synapse/handlers/event_auth.py18
-rw-r--r--synapse/handlers/federation.py301
-rw-r--r--synapse/handlers/federation_event.py137
-rw-r--r--synapse/handlers/identity.py2
-rw-r--r--synapse/handlers/initial_sync.py27
-rw-r--r--synapse/handlers/message.py968
-rw-r--r--synapse/handlers/oidc.py388
-rw-r--r--synapse/handlers/pagination.py5
-rw-r--r--synapse/handlers/presence.py14
-rw-r--r--synapse/handlers/profile.py6
-rw-r--r--synapse/handlers/receipts.py13
-rw-r--r--synapse/handlers/register.py8
-rw-r--r--synapse/handlers/relations.py193
-rw-r--r--synapse/handlers/room.py324
-rw-r--r--synapse/handlers/room_batch.py3
-rw-r--r--synapse/handlers/room_member.py28
-rw-r--r--synapse/handlers/saml.py4
-rw-r--r--synapse/handlers/send_email.py13
-rw-r--r--synapse/handlers/sso.py82
-rw-r--r--synapse/handlers/sync.py242
-rw-r--r--synapse/handlers/typing.py2
-rw-r--r--synapse/handlers/ui_auth/checkers.py3
-rw-r--r--synapse/handlers/user_directory.py36
32 files changed, 2134 insertions, 996 deletions
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 0478448b47..fc21d58001 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -225,7 +225,7 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
         self,
         user: UserID,
         from_key: int,
-        limit: Optional[int],
+        limit: int,
         room_ids: Collection[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index cf9f19608a..5bf8e86387 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -32,6 +32,7 @@ class AdminHandler:
         self.store = hs.get_datastores().main
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
+        self._msc3866_enabled = hs.config.experimental.msc3866.enabled
 
     async def get_whois(self, user: UserID) -> JsonDict:
         connections = []
@@ -75,6 +76,10 @@ class AdminHandler:
             "is_guest",
         }
 
+        if self._msc3866_enabled:
+            # Only include the approved flag if support for MSC3866 is enabled.
+            user_info_to_return.add("approved")
+
         # Restrict returned keys to a known set.
         user_info_dict = {
             key: value
@@ -95,6 +100,7 @@ class AdminHandler:
         user_info_dict["avatar_url"] = profile.avatar_url
         user_info_dict["threepids"] = threepids
         user_info_dict["external_ids"] = external_ids
+        user_info_dict["erased"] = await self.store.is_user_erased(user.to_string())
 
         return user_info_dict
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 203b62e015..66f5b8d108 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -109,10 +109,13 @@ class ApplicationServicesHandler:
                     last_token = await self.store.get_appservice_last_pos()
                     (
                         upper_bound,
-                        events,
                         event_to_received_ts,
-                    ) = await self.store.get_all_new_events_stream(
-                        last_token, self.current_max, limit=100, get_prev_content=True
+                    ) = await self.store.get_all_new_event_ids_stream(
+                        last_token, self.current_max, limit=100
+                    )
+
+                    events = await self.store.get_events_as_list(
+                        event_to_received_ts.keys(), get_prev_content=True
                     )
 
                     events_by_room: Dict[str, List[EventBase]] = {}
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index eacd631ee0..8b9ef25d29 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -38,6 +38,7 @@ from typing import (
 import attr
 import bcrypt
 import unpaddedbase64
+from prometheus_client import Counter
 
 from twisted.internet.defer import CancelledError
 from twisted.web.server import Request
@@ -48,6 +49,7 @@ from synapse.api.errors import (
     Codes,
     InteractiveAuthIncompleteError,
     LoginError,
+    NotFoundError,
     StoreError,
     SynapseError,
     UserDeactivatedError,
@@ -63,10 +65,14 @@ from synapse.http.server import finish_request, respond_with_html
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.metrics.background_process_metrics import run_as_background_process
+from synapse.storage.databases.main.registration import (
+    LoginTokenExpired,
+    LoginTokenLookupResult,
+    LoginTokenReused,
+)
 from synapse.types import JsonDict, Requester, UserID
 from synapse.util import stringutils as stringutils
 from synapse.util.async_helpers import delay_cancellation, maybe_awaitable
-from synapse.util.macaroons import LoginTokenAttributes
 from synapse.util.msisdn import phone_number_to_msisdn
 from synapse.util.stringutils import base62_encode
 from synapse.util.threepids import canonicalise_email
@@ -80,6 +86,12 @@ logger = logging.getLogger(__name__)
 
 INVALID_USERNAME_OR_PASSWORD = "Invalid username or password"
 
+invalid_login_token_counter = Counter(
+    "synapse_user_login_invalid_login_tokens",
+    "Counts the number of rejected m.login.token on /login",
+    ["reason"],
+)
+
 
 def convert_client_dict_legacy_fields_to_identifier(
     submission: JsonDict,
@@ -883,6 +895,25 @@ class AuthHandler:
 
         return True
 
+    async def create_login_token_for_user_id(
+        self,
+        user_id: str,
+        duration_ms: int = (2 * 60 * 1000),
+        auth_provider_id: Optional[str] = None,
+        auth_provider_session_id: Optional[str] = None,
+    ) -> str:
+        login_token = self.generate_login_token()
+        now = self._clock.time_msec()
+        expiry_ts = now + duration_ms
+        await self.store.add_login_token_to_user(
+            user_id=user_id,
+            token=login_token,
+            expiry_ts=expiry_ts,
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+        return login_token
+
     async def create_refresh_token_for_user_id(
         self,
         user_id: str,
@@ -1009,6 +1040,17 @@ class AuthHandler:
             return res[0]
         return None
 
+    async def is_user_approved(self, user_id: str) -> bool:
+        """Checks if a user is approved and therefore can be allowed to log in.
+
+        Args:
+            user_id: the user to check the approval status of.
+
+        Returns:
+            A boolean that is True if the user is approved, False otherwise.
+        """
+        return await self.store.is_user_approved(user_id)
+
     async def _find_user_id_and_pwd_hash(
         self, user_id: str
     ) -> Optional[Tuple[str, str]]:
@@ -1390,6 +1432,18 @@ class AuthHandler:
             return None
         return user_id
 
+    def generate_login_token(self) -> str:
+        """Generates an opaque string, for use as an short-term login token"""
+
+        # we use the following format for access tokens:
+        #    syl_<random string>_<base62 crc check>
+
+        random_string = stringutils.random_string(20)
+        base = f"syl_{random_string}"
+
+        crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
+        return f"{base}_{crc}"
+
     def generate_access_token(self, for_user: UserID) -> str:
         """Generates an opaque string, for use as an access token"""
 
@@ -1416,16 +1470,17 @@ class AuthHandler:
         crc = base62_encode(crc32(base.encode("ascii")), minwidth=6)
         return f"{base}_{crc}"
 
-    async def validate_short_term_login_token(
-        self, login_token: str
-    ) -> LoginTokenAttributes:
+    async def consume_login_token(self, login_token: str) -> LoginTokenLookupResult:
         try:
-            res = self.macaroon_gen.verify_short_term_login_token(login_token)
-        except Exception:
-            raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
+            return await self.store.consume_login_token(login_token)
+        except LoginTokenExpired:
+            invalid_login_token_counter.labels("expired").inc()
+        except LoginTokenReused:
+            invalid_login_token_counter.labels("reused").inc()
+        except NotFoundError:
+            invalid_login_token_counter.labels("not found").inc()
 
-        await self.auth_blocking.check_auth_blocking(res.user_id)
-        return res
+        raise AuthError(403, "Invalid login token", errcode=Codes.FORBIDDEN)
 
     async def delete_access_token(self, access_token: str) -> None:
         """Invalidate a single access token
@@ -1700,7 +1755,7 @@ class AuthHandler:
             )
 
         # Create a login token
-        login_token = self.macaroon_gen.generate_short_term_login_token(
+        login_token = await self.create_login_token_for_user_id(
             registered_user_id,
             auth_provider_id=auth_provider_id,
             auth_provider_session_id=auth_provider_session_id,
diff --git a/synapse/handlers/cas.py b/synapse/handlers/cas.py
index 7163af8004..fc467bc7c1 100644
--- a/synapse/handlers/cas.py
+++ b/synapse/handlers/cas.py
@@ -130,6 +130,9 @@ class CasHandler:
         except PartialDownloadError as pde:
             # Twisted raises this error if the connection is closed,
             # even if that's being used old-http style to signal end-of-data
+            # Assertion is for mypy's benefit. Error.response is Optional[bytes],
+            # but a PartialDownloadError should always have a non-None response.
+            assert pde.response is not None
             body = pde.response
         except HttpResponseException as e:
             description = (
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 961f8eb186..2567954679 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -273,11 +273,9 @@ class DeviceWorkerHandler:
             possibly_left = possibly_changed | possibly_left
 
             # Double check if we still share rooms with the given user.
-            users_rooms = await self.store.get_rooms_for_users_with_stream_ordering(
-                possibly_left
-            )
+            users_rooms = await self.store.get_rooms_for_users(possibly_left)
             for changed_user_id, entries in users_rooms.items():
-                if any(e.room_id in room_ids for e in entries):
+                if any(rid in room_ids for rid in entries):
                     possibly_left.discard(changed_user_id)
                 else:
                     possibly_joined.discard(changed_user_id)
@@ -309,6 +307,17 @@ class DeviceWorkerHandler:
             "self_signing_key": self_signing_key,
         }
 
+    async def handle_room_un_partial_stated(self, room_id: str) -> None:
+        """Handles sending appropriate device list updates in a room that has
+        gone from partial to full state.
+        """
+
+        # TODO(faster_joins): worker mode support
+        #   https://github.com/matrix-org/synapse/issues/12994
+        logger.error(
+            "Trying handling device list state for partial join: not supported on workers."
+        )
+
 
 class DeviceHandler(DeviceWorkerHandler):
     def __init__(self, hs: "HomeServer"):
@@ -746,6 +755,95 @@ class DeviceHandler(DeviceWorkerHandler):
         finally:
             self._handle_new_device_update_is_processing = False
 
+    async def handle_room_un_partial_stated(self, room_id: str) -> None:
+        """Handles sending appropriate device list updates in a room that has
+        gone from partial to full state.
+        """
+
+        # We defer to the device list updater to handle pending remote device
+        # list updates.
+        await self.device_list_updater.handle_room_un_partial_stated(room_id)
+
+        # Replay local updates.
+        (
+            join_event_id,
+            device_lists_stream_id,
+        ) = await self.store.get_join_event_id_and_device_lists_stream_id_for_partial_state(
+            room_id
+        )
+
+        # Get the local device list changes that have happened in the room since
+        # we started joining. If there are no updates there's nothing left to do.
+        changes = await self.store.get_device_list_changes_in_room(
+            room_id, device_lists_stream_id
+        )
+        local_changes = {(u, d) for u, d in changes if self.hs.is_mine_id(u)}
+        if not local_changes:
+            return
+
+        # Note: We have persisted the full state at this point, we just haven't
+        # cleared the `partial_room` flag.
+        join_state_ids = await self._state_storage.get_state_ids_for_event(
+            join_event_id, await_full_state=False
+        )
+        current_state_ids = await self.store.get_partial_current_state_ids(room_id)
+
+        # Now we need to work out all servers that might have been in the room
+        # at any point during our join.
+
+        # First we look for any membership states that have changed between the
+        # initial join and now...
+        all_keys = set(join_state_ids)
+        all_keys.update(current_state_ids)
+
+        potentially_changed_hosts = set()
+        for etype, state_key in all_keys:
+            if etype != EventTypes.Member:
+                continue
+
+            prev = join_state_ids.get((etype, state_key))
+            current = current_state_ids.get((etype, state_key))
+
+            if prev != current:
+                potentially_changed_hosts.add(get_domain_from_id(state_key))
+
+        # ... then we add all the hosts that are currently joined to the room...
+        current_hosts_in_room = await self.store.get_current_hosts_in_room(room_id)
+        potentially_changed_hosts.update(current_hosts_in_room)
+
+        # ... and finally we remove any hosts that we were told about, as we
+        # will have sent device list updates to those hosts when they happened.
+        known_hosts_at_join = await self.store.get_partial_state_servers_at_join(
+            room_id
+        )
+        potentially_changed_hosts.difference_update(known_hosts_at_join)
+
+        potentially_changed_hosts.discard(self.server_name)
+
+        if not potentially_changed_hosts:
+            # Nothing to do.
+            return
+
+        logger.info(
+            "Found %d changed hosts to send device list updates to",
+            len(potentially_changed_hosts),
+        )
+
+        for user_id, device_id in local_changes:
+            await self.store.add_device_list_outbound_pokes(
+                user_id=user_id,
+                device_id=device_id,
+                room_id=room_id,
+                stream_id=None,
+                hosts=potentially_changed_hosts,
+                context=None,
+            )
+
+        # Notify things that device lists need to be sent out.
+        self.notifier.notify_replication()
+        for host in potentially_changed_hosts:
+            self.federation_sender.send_device_messages(host, immediate=False)
+
 
 def _update_device_from_client_ips(
     device: JsonDict, client_ips: Mapping[Tuple[str, str], Mapping[str, Any]]
@@ -836,6 +934,19 @@ class DeviceListUpdater:
             )
             return
 
+        # Check if we are partially joining any rooms. If so we need to store
+        # all device list updates so that we can handle them correctly once we
+        # know who is in the room.
+        # TODO(faster joins): this fetches and processes a bunch of data that we don't
+        # use. Could be replaced by a tighter query e.g.
+        #   SELECT EXISTS(SELECT 1 FROM partial_state_rooms)
+        partial_rooms = await self.store.get_partial_state_room_resync_info()
+        if partial_rooms:
+            await self.store.add_remote_device_list_to_pending(
+                user_id,
+                device_id,
+            )
+
         room_ids = await self.store.get_rooms_for_user(user_id)
         if not room_ids:
             # We don't share any rooms with this user. Ignore update, as we
@@ -1175,3 +1286,35 @@ class DeviceListUpdater:
             device_ids.append(verify_key.version)
 
         return device_ids
+
+    async def handle_room_un_partial_stated(self, room_id: str) -> None:
+        """Handles sending appropriate device list updates in a room that has
+        gone from partial to full state.
+        """
+
+        pending_updates = (
+            await self.store.get_pending_remote_device_list_updates_for_room(room_id)
+        )
+
+        for user_id, device_id in pending_updates:
+            logger.info(
+                "Got pending device list update in room %s: %s / %s",
+                room_id,
+                user_id,
+                device_id,
+            )
+            position = await self.store.add_device_change_to_streams(
+                user_id,
+                [device_id],
+                room_ids=[room_id],
+            )
+
+            if not position:
+                # This should only happen if there are no updates, which
+                # shouldn't happen when we've passed in a non-empty set of
+                # device IDs.
+                continue
+
+            self.device_handler.notifier.on_new_event(
+                StreamKeyType.DEVICE_LIST, position, rooms=[room_id]
+            )
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 7127d5aefc..2ea52257cb 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -16,6 +16,8 @@ import logging
 import string
 from typing import TYPE_CHECKING, Iterable, List, Optional
 
+from typing_extensions import Literal
+
 from synapse.api.constants import MAX_ALIAS_LENGTH, EventTypes
 from synapse.api.errors import (
     AuthError,
@@ -83,7 +85,7 @@ class DirectoryHandler:
         # TODO(erikj): Add transactions.
         # TODO(erikj): Check if there is a current association.
         if not servers:
-            servers = await self._storage_controllers.state.get_current_hosts_in_room(
+            servers = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
                 room_id
             )
 
@@ -288,7 +290,7 @@ class DirectoryHandler:
                 Codes.NOT_FOUND,
             )
 
-        extra_servers = await self._storage_controllers.state.get_current_hosts_in_room(
+        extra_servers = await self._storage_controllers.state.get_current_hosts_in_room_or_partial_state_approximation(
             room_id
         )
         servers_set = set(extra_servers) | set(servers)
@@ -429,7 +431,10 @@ class DirectoryHandler:
         return await self.auth.check_can_change_room_list(room_id, requester)
 
     async def edit_published_room_list(
-        self, requester: Requester, room_id: str, visibility: str
+        self,
+        requester: Requester,
+        room_id: str,
+        visibility: Literal["public", "private"],
     ) -> None:
         """Edit the entry of the room in the published room list.
 
@@ -451,9 +456,6 @@ class DirectoryHandler:
         if requester.is_guest:
             raise AuthError(403, "Guests cannot edit the published room list")
 
-        if visibility not in ["public", "private"]:
-            raise SynapseError(400, "Invalid visibility setting")
-
         if visibility == "public" and not self.enable_room_list_search:
             # The room list has been disabled.
             raise AuthError(
@@ -505,7 +507,11 @@ class DirectoryHandler:
         await self.store.set_room_is_public(room_id, making_public)
 
     async def edit_published_appservice_room_list(
-        self, appservice_id: str, network_id: str, room_id: str, visibility: str
+        self,
+        appservice_id: str,
+        network_id: str,
+        room_id: str,
+        visibility: Literal["public", "private"],
     ) -> None:
         """Add or remove a room from the appservice/network specific public
         room list.
@@ -516,9 +522,6 @@ class DirectoryHandler:
             room_id
             visibility: either "public" or "private"
         """
-        if visibility not in ["public", "private"]:
-            raise SynapseError(400, "Invalid visibility setting")
-
         await self.store.set_room_is_public_appservice(
             room_id, appservice_id, network_id, visibility == "public"
         )
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index cf788a4a86..5f84f1769b 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -49,6 +49,7 @@ logger = logging.getLogger(__name__)
 
 class E2eKeysHandler:
     def __init__(self, hs: "HomeServer"):
+        self.config = hs.config
         self.store = hs.get_datastores().main
         self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
@@ -431,13 +432,17 @@ class E2eKeysHandler:
     @trace
     @cancellable
     async def query_local_devices(
-        self, query: Mapping[str, Optional[List[str]]]
+        self,
+        query: Mapping[str, Optional[List[str]]],
+        include_displaynames: bool = True,
     ) -> Dict[str, Dict[str, dict]]:
         """Get E2E device keys for local users
 
         Args:
             query: map from user_id to a list
                  of devices to query (None for all devices)
+            include_displaynames: Whether to include device displaynames in the returned
+                device details.
 
         Returns:
             A map from user_id -> device_id -> device details
@@ -469,7 +474,9 @@ class E2eKeysHandler:
             # make sure that each queried user appears in the result dict
             result_dict[user_id] = {}
 
-        results = await self.store.get_e2e_device_keys_for_cs_api(local_query)
+        results = await self.store.get_e2e_device_keys_for_cs_api(
+            local_query, include_displaynames
+        )
 
         # Build the result structure
         for user_id, device_keys in results.items():
@@ -482,11 +489,33 @@ class E2eKeysHandler:
     async def on_federation_query_client_keys(
         self, query_body: Dict[str, Dict[str, Optional[List[str]]]]
     ) -> JsonDict:
-        """Handle a device key query from a federated server"""
+        """Handle a device key query from a federated server:
+
+        Handles the path: GET /_matrix/federation/v1/users/keys/query
+
+        Args:
+            query_body: The body of the query request. Should contain a key
+                "device_keys" that map to a dictionary of user ID's -> list of
+                device IDs. If the list of device IDs is empty, all devices of
+                that user will be queried.
+
+        Returns:
+            A json dictionary containing the following:
+                - device_keys: A dictionary containing the requested device information.
+                - master_keys: An optional dictionary of user ID -> master cross-signing
+                   key info.
+                - self_signing_key: An optional dictionary of user ID -> self-signing
+                    key info.
+        """
         device_keys_query: Dict[str, Optional[List[str]]] = query_body.get(
             "device_keys", {}
         )
-        res = await self.query_local_devices(device_keys_query)
+        res = await self.query_local_devices(
+            device_keys_query,
+            include_displaynames=(
+                self.config.federation.allow_device_name_lookup_over_federation
+            ),
+        )
         ret = {"device_keys": res}
 
         # add in the cross-signing keys
@@ -841,7 +870,7 @@ class E2eKeysHandler:
         - signatures of the user's master key by the user's devices.
 
         Args:
-            user_id (string): the user uploading the keys
+            user_id: the user uploading the keys
             signatures (dict[string, dict]): map of devices to signed keys
 
         Returns:
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 8786534e54..098288f058 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -377,8 +377,9 @@ class E2eRoomKeysHandler:
         """Deletes a given version of the user's e2e_room_keys backup
 
         Args:
-            user_id(str): the user whose current backup version we're deleting
-            version(str): the version id of the backup being deleted
+            user_id: the user whose current backup version we're deleting
+            version: Optional. the version ID of the backup version we're deleting
+                If missing, we delete the current backup version info.
         Raises:
             NotFoundError: if this backup version doesn't exist
         """
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 8249ca1ed2..3bbad0271b 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Collection, List, Optional, Union
+from typing import TYPE_CHECKING, Collection, List, Mapping, Optional, Union
 
 from synapse import event_auth
 from synapse.api.constants import (
@@ -29,7 +29,6 @@ from synapse.event_auth import (
 )
 from synapse.events import EventBase
 from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
 from synapse.types import StateMap, get_domain_from_id
 
 if TYPE_CHECKING:
@@ -51,12 +50,21 @@ class EventAuthHandler:
     async def check_auth_rules_from_context(
         self,
         event: EventBase,
-        context: EventContext,
+        batched_auth_events: Optional[Mapping[str, EventBase]] = None,
     ) -> None:
-        """Check an event passes the auth rules at its own auth events"""
-        await check_state_independent_auth_rules(self._store, event)
+        """Check an event passes the auth rules at its own auth events
+        Args:
+            event: event to be authed
+            batched_auth_events: if the event being authed is part of a batch, any events
+            from the same batch that may be necessary to auth the current event
+        """
+        await check_state_independent_auth_rules(
+            self._store, event, batched_auth_events
+        )
         auth_event_ids = event.auth_event_ids()
         auth_events_by_id = await self._store.get_events(auth_event_ids)
+        if batched_auth_events:
+            auth_events_by_id.update(batched_auth_events)
         check_state_dependent_auth_rules(event, auth_events_by_id.values())
 
     def compute_auth_events(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 73471fe041..79e792395f 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -38,13 +38,14 @@ from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
 
 from synapse import event_auth
-from synapse.api.constants import EventContentFields, EventTypes, Membership
+from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
     CodeMessageException,
     Codes,
     FederationDeniedError,
     FederationError,
+    FederationPullAttemptBackoffError,
     HttpResponseException,
     LimitExceededError,
     NotFoundError,
@@ -155,6 +156,8 @@ class FederationHandler:
         self.http_client = hs.get_proxied_blacklisted_http_client()
         self._replication = hs.get_replication_data_handler()
         self._federation_event_handler = hs.get_federation_event_handler()
+        self._device_handler = hs.get_device_handler()
+        self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
 
         self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
             hs
@@ -215,7 +218,7 @@ class FederationHandler:
         current_depth: int,
         limit: int,
         *,
-        processing_start_time: int,
+        processing_start_time: Optional[int],
     ) -> bool:
         """
         Checks whether the `current_depth` is at or approaching any backfill
@@ -227,12 +230,23 @@ class FederationHandler:
             room_id: The room to backfill in.
             current_depth: The depth to check at for any upcoming backfill points.
             limit: The max number of events to request from the remote federated server.
-            processing_start_time: The time when `maybe_backfill` started
-                processing. Only used for timing.
+            processing_start_time: The time when `maybe_backfill` started processing.
+                Only used for timing. If `None`, no timing observation will be made.
         """
         backwards_extremities = [
             _BackfillPoint(event_id, depth, _BackfillPointType.BACKWARDS_EXTREMITY)
-            for event_id, depth in await self.store.get_backfill_points_in_room(room_id)
+            for event_id, depth in await self.store.get_backfill_points_in_room(
+                room_id=room_id,
+                current_depth=current_depth,
+                # We only need to end up with 5 extremities combined with the
+                # insertion event extremities to make the `/backfill` request
+                # but fetch an order of magnitude more to make sure there is
+                # enough even after we filter them by whether visible in the
+                # history. This isn't fool-proof as all backfill points within
+                # our limit could be filtered out but seems like a good amount
+                # to try with at least.
+                limit=50,
+            )
         ]
 
         insertion_events_to_be_backfilled: List[_BackfillPoint] = []
@@ -240,7 +254,12 @@ class FederationHandler:
             insertion_events_to_be_backfilled = [
                 _BackfillPoint(event_id, depth, _BackfillPointType.INSERTION_PONT)
                 for event_id, depth in await self.store.get_insertion_event_backward_extremities_in_room(
-                    room_id
+                    room_id=room_id,
+                    current_depth=current_depth,
+                    # We only need to end up with 5 extremities combined with
+                    # the backfill points to make the `/backfill` request ...
+                    # (see the other comment above for more context).
+                    limit=50,
                 )
             ]
         logger.debug(
@@ -249,10 +268,6 @@ class FederationHandler:
             insertion_events_to_be_backfilled,
         )
 
-        if not backwards_extremities and not insertion_events_to_be_backfilled:
-            logger.debug("Not backfilling as no extremeties found.")
-            return False
-
         # we now have a list of potential places to backpaginate from. We prefer to
         # start with the most recent (ie, max depth), so let's sort the list.
         sorted_backfill_points: List[_BackfillPoint] = sorted(
@@ -273,6 +288,33 @@ class FederationHandler:
             sorted_backfill_points,
         )
 
+        # If we have no backfill points lower than the `current_depth` then
+        # either we can a) bail or b) still attempt to backfill. We opt to try
+        # backfilling anyway just in case we do get relevant events.
+        if not sorted_backfill_points and current_depth != MAX_DEPTH:
+            logger.debug(
+                "_maybe_backfill_inner: all backfill points are *after* current depth. Trying again with later backfill points."
+            )
+            return await self._maybe_backfill_inner(
+                room_id=room_id,
+                # We use `MAX_DEPTH` so that we find all backfill points next
+                # time (all events are below the `MAX_DEPTH`)
+                current_depth=MAX_DEPTH,
+                limit=limit,
+                # We don't want to start another timing observation from this
+                # nested recursive call. The top-most call can record the time
+                # overall otherwise the smaller one will throw off the results.
+                processing_start_time=None,
+            )
+
+        # Even after recursing with `MAX_DEPTH`, we didn't find any
+        # backward extremities to backfill from.
+        if not sorted_backfill_points:
+            logger.debug(
+                "_maybe_backfill_inner: Not backfilling as no backward extremeties found."
+            )
+            return False
+
         # If we're approaching an extremity we trigger a backfill, otherwise we
         # no-op.
         #
@@ -282,47 +324,16 @@ class FederationHandler:
         # chose more than one times the limit in case of failure, but choosing a
         # much larger factor will result in triggering a backfill request much
         # earlier than necessary.
-        #
-        # XXX: shouldn't we do this *after* the filter by depth below? Again, we don't
-        # care about events that have happened after our current position.
-        #
-        max_depth = sorted_backfill_points[0].depth
-        if current_depth - 2 * limit > max_depth:
+        max_depth_of_backfill_points = sorted_backfill_points[0].depth
+        if current_depth - 2 * limit > max_depth_of_backfill_points:
             logger.debug(
                 "Not backfilling as we don't need to. %d < %d - 2 * %d",
-                max_depth,
+                max_depth_of_backfill_points,
                 current_depth,
                 limit,
             )
             return False
 
-        # We ignore extremities that have a greater depth than our current depth
-        # as:
-        #    1. we don't really care about getting events that have happened
-        #       after our current position; and
-        #    2. we have likely previously tried and failed to backfill from that
-        #       extremity, so to avoid getting "stuck" requesting the same
-        #       backfill repeatedly we drop those extremities.
-        #
-        # However, we need to check that the filtered extremities are non-empty.
-        # If they are empty then either we can a) bail or b) still attempt to
-        # backfill. We opt to try backfilling anyway just in case we do get
-        # relevant events.
-        #
-        filtered_sorted_backfill_points = [
-            t for t in sorted_backfill_points if t.depth <= current_depth
-        ]
-        if filtered_sorted_backfill_points:
-            logger.debug(
-                "_maybe_backfill_inner: backfill points before current depth: %s",
-                filtered_sorted_backfill_points,
-            )
-            sorted_backfill_points = filtered_sorted_backfill_points
-        else:
-            logger.debug(
-                "_maybe_backfill_inner: all backfill points are *after* current depth. Backfilling anyway."
-            )
-
         # For performance's sake, we only want to paginate from a particular extremity
         # if we can actually see the events we'll get. Otherwise, we'd just spend a lot
         # of resources to get redacted events. We check each extremity in turn and
@@ -413,11 +424,22 @@ class FederationHandler:
             # First we try hosts that are already in the room.
             # TODO: HEURISTIC ALERT.
             likely_domains = (
-                await self._storage_controllers.state.get_current_hosts_in_room(room_id)
+                await self._storage_controllers.state.get_current_hosts_in_room_ordered(
+                    room_id
+                )
             )
 
         async def try_backfill(domains: Collection[str]) -> bool:
             # TODO: Should we try multiple of these at a time?
+
+            # Number of contacted remote homeservers that have denied our backfill
+            # request with a 4xx code.
+            denied_count = 0
+
+            # Maximum number of contacted remote homeservers that can deny our
+            # backfill request with 4xx codes before we give up.
+            max_denied_count = 5
+
             for dom in domains:
                 # We don't want to ask our own server for information we don't have
                 if dom == self.server_name:
@@ -431,40 +453,68 @@ class FederationHandler:
                     # appropriate stuff.
                     # TODO: We can probably do something more intelligent here.
                     return True
+                except NotRetryingDestination as e:
+                    logger.info("_maybe_backfill_inner: %s", e)
+                    continue
+                except FederationDeniedError:
+                    logger.info(
+                        "_maybe_backfill_inner: Not attempting to backfill from %s because the homeserver is not on our federation whitelist",
+                        dom,
+                    )
+                    continue
                 except (SynapseError, InvalidResponseError) as e:
                     logger.info("Failed to backfill from %s because %s", dom, e)
                     continue
                 except HttpResponseException as e:
                     if 400 <= e.code < 500:
-                        raise e.to_synapse_error()
+                        logger.warning(
+                            "Backfill denied from %s because %s [%d/%d]",
+                            dom,
+                            e,
+                            denied_count,
+                            max_denied_count,
+                        )
+                        denied_count += 1
+                        if denied_count >= max_denied_count:
+                            return False
+                        continue
 
                     logger.info("Failed to backfill from %s because %s", dom, e)
                     continue
                 except CodeMessageException as e:
                     if 400 <= e.code < 500:
-                        raise
+                        logger.warning(
+                            "Backfill denied from %s because %s [%d/%d]",
+                            dom,
+                            e,
+                            denied_count,
+                            max_denied_count,
+                        )
+                        denied_count += 1
+                        if denied_count >= max_denied_count:
+                            return False
+                        continue
 
                     logger.info("Failed to backfill from %s because %s", dom, e)
                     continue
-                except NotRetryingDestination as e:
-                    logger.info(str(e))
-                    continue
                 except RequestSendFailed as e:
                     logger.info("Failed to get backfill from %s because %s", dom, e)
                     continue
-                except FederationDeniedError as e:
-                    logger.info(e)
-                    continue
                 except Exception as e:
                     logger.exception("Failed to backfill from %s because %s", dom, e)
                     continue
 
             return False
 
-        processing_end_time = self.clock.time_msec()
-        backfill_processing_before_timer.observe(
-            (processing_end_time - processing_start_time) / 1000
-        )
+        # If we have the `processing_start_time`, then we can make an
+        # observation. We wouldn't have the `processing_start_time` in the case
+        # where `_maybe_backfill_inner` is recursively called to find any
+        # backfill points regardless of `current_depth`.
+        if processing_start_time is not None:
+            processing_end_time = self.clock.time_msec()
+            backfill_processing_before_timer.observe(
+                (processing_end_time - processing_start_time) / 1000
+            )
 
         success = await try_backfill(likely_domains)
         if success:
@@ -592,7 +642,12 @@ class FederationHandler:
                 # Mark the room as having partial state.
                 # The background process is responsible for unmarking this flag,
                 # even if the join fails.
-                await self.store.store_partial_state_room(room_id, ret.servers_in_room)
+                await self.store.store_partial_state_room(
+                    room_id=room_id,
+                    servers=ret.servers_in_room,
+                    device_lists_stream_id=self.store.get_device_stream_token(),
+                    joined_via=origin,
+                )
 
             try:
                 max_stream_id = (
@@ -617,6 +672,14 @@ class FederationHandler:
                     room_id,
                 )
                 raise LimitExceededError(msg=e.msg, errcode=e.errcode, retry_after_ms=0)
+            else:
+                # Record the join event id for future use (when we finish the full
+                # join). We have to do this after persisting the event to keep foreign
+                # key constraints intact.
+                if ret.partial_state:
+                    await self.store.write_partial_state_rooms_join_event_id(
+                        room_id, event.event_id
+                    )
             finally:
                 # Always kick off the background process that asynchronously fetches
                 # state for the room.
@@ -734,15 +797,27 @@ class FederationHandler:
 
         # Send the signed event back to the room, and potentially receive some
         # further information about the room in the form of partial state events
-        stripped_room_state = await self.federation_client.send_knock(
-            target_hosts, event
-        )
+        knock_response = await self.federation_client.send_knock(target_hosts, event)
 
         # Store any stripped room state events in the "unsigned" key of the event.
         # This is a bit of a hack and is cribbing off of invites. Basically we
         # store the room state here and retrieve it again when this event appears
         # in the invitee's sync stream. It is stripped out for all other local users.
-        event.unsigned["knock_room_state"] = stripped_room_state["knock_state_events"]
+        stripped_room_state = (
+            knock_response.get("knock_room_state")
+            # Since v1.37, Synapse incorrectly used "knock_state_events" for this field.
+            # Thus, we also check for a 'knock_state_events' to support old instances.
+            # See https://github.com/matrix-org/synapse/issues/14088.
+            or knock_response.get("knock_state_events")
+        )
+
+        if stripped_room_state is None:
+            raise KeyError(
+                "Missing 'knock_room_state' (or legacy 'knock_state_events') field in "
+                "send_knock response"
+            )
+
+        event.unsigned["knock_room_state"] = stripped_room_state
 
         context = EventContext.for_outlier(self._storage_controllers)
         stream_id = await self._federation_event_handler.persist_events_and_notify(
@@ -881,7 +956,7 @@ class FederationHandler:
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
-        await self._event_auth_handler.check_auth_rules_from_context(event, context)
+        await self._event_auth_handler.check_auth_rules_from_context(event)
         return event
 
     async def on_invite_request(
@@ -955,9 +1030,17 @@ class FederationHandler:
         )
 
         context = EventContext.for_outlier(self._storage_controllers)
-        await self._federation_event_handler.persist_events_and_notify(
-            event.room_id, [(event, context)]
+
+        await self._bulk_push_rule_evaluator.action_for_events_by_user(
+            [(event, context)]
         )
+        try:
+            await self._federation_event_handler.persist_events_and_notify(
+                event.room_id, [(event, context)]
+            )
+        except Exception:
+            await self.store.remove_push_actions_from_staging(event.event_id)
+            raise
 
         return event
 
@@ -1056,7 +1139,7 @@ class FederationHandler:
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
-            await self._event_auth_handler.check_auth_rules_from_context(event, context)
+            await self._event_auth_handler.check_auth_rules_from_context(event)
         except AuthError as e:
             logger.warning("Failed to create new leave %r because %s", event, e)
             raise e
@@ -1115,7 +1198,7 @@ class FederationHandler:
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_knock_request`
-            await self._event_auth_handler.check_auth_rules_from_context(event, context)
+            await self._event_auth_handler.check_auth_rules_from_context(event)
         except AuthError as e:
             logger.warning("Failed to create new knock %r because %s", event, e)
             raise e
@@ -1281,9 +1364,7 @@ class FederationHandler:
 
             try:
                 validate_event_for_room_version(event)
-                await self._event_auth_handler.check_auth_rules_from_context(
-                    event, context
-                )
+                await self._event_auth_handler.check_auth_rules_from_context(event)
             except AuthError as e:
                 logger.warning("Denying new third party invite %r because %s", event, e)
                 raise e
@@ -1333,7 +1414,7 @@ class FederationHandler:
 
         try:
             validate_event_for_room_version(event)
-            await self._event_auth_handler.check_auth_rules_from_context(event, context)
+            await self._event_auth_handler.check_auth_rules_from_context(event)
         except AuthError as e:
             logger.warning("Denying third party invite %r because %s", event, e)
             raise e
@@ -1526,8 +1607,8 @@ class FederationHandler:
         Fetch the complexity of a remote room over federation.
 
         Args:
-            remote_room_hosts (list[str]): The remote servers to ask.
-            room_id (str): The room ID to ask about.
+            remote_room_hosts: The remote servers to ask.
+            room_id: The room ID to ask about.
 
         Returns:
             Dict contains the complexity
@@ -1549,13 +1630,13 @@ class FederationHandler:
         """Resumes resyncing of all partial-state rooms after a restart."""
         assert not self.config.worker.worker_app
 
-        partial_state_rooms = await self.store.get_partial_state_rooms_and_servers()
-        for room_id, servers_in_room in partial_state_rooms.items():
+        partial_state_rooms = await self.store.get_partial_state_room_resync_info()
+        for room_id, resync_info in partial_state_rooms.items():
             run_as_background_process(
                 desc="sync_partial_state_room",
                 func=self._sync_partial_state_room,
-                initial_destination=None,
-                other_destinations=servers_in_room,
+                initial_destination=resync_info.joined_via,
+                other_destinations=resync_info.servers_in_room,
                 room_id=room_id,
             )
 
@@ -1584,28 +1665,12 @@ class FederationHandler:
         #   really leave, that might mean we have difficulty getting the room state over
         #   federation.
         #   https://github.com/matrix-org/synapse/issues/12802
-        #
-        # TODO(faster_joins): we need some way of prioritising which homeservers in
-        #   `other_destinations` to try first, otherwise we'll spend ages trying dead
-        #   homeservers for large rooms.
-        #   https://github.com/matrix-org/synapse/issues/12999
-
-        if initial_destination is None and len(other_destinations) == 0:
-            raise ValueError(
-                f"Cannot resync state of {room_id}: no destinations provided"
-            )
 
         # Make an infinite iterator of destinations to try. Once we find a working
         # destination, we'll stick with it until it flakes.
-        destinations: Collection[str]
-        if initial_destination is not None:
-            # Move `initial_destination` to the front of the list.
-            destinations = list(other_destinations)
-            if initial_destination in destinations:
-                destinations.remove(initial_destination)
-            destinations = [initial_destination] + destinations
-        else:
-            destinations = other_destinations
+        destinations = _prioritise_destinations_for_partial_state_resync(
+            initial_destination, other_destinations, room_id
+        )
         destination_iter = itertools.cycle(destinations)
 
         # `destination` is the current remote homeserver we're pulling from.
@@ -1623,6 +1688,9 @@ class FederationHandler:
                 #   https://github.com/matrix-org/synapse/issues/12994
                 await self.state_handler.update_current_state(room_id)
 
+                logger.info("Handling any pending device list updates")
+                await self._device_handler.handle_room_un_partial_stated(room_id)
+
                 logger.info("Clearing partial-state flag for %s", room_id)
                 success = await self.store.clear_partial_state_room(room_id)
                 if success:
@@ -1652,7 +1720,22 @@ class FederationHandler:
                             destination, event
                         )
                         break
+                    except FederationPullAttemptBackoffError as exc:
+                        # Log a warning about why we failed to process the event (the error message
+                        # for `FederationPullAttemptBackoffError` is pretty good)
+                        logger.warning("_sync_partial_state_room: %s", exc)
+                        # We do not record a failed pull attempt when we backoff fetching a missing
+                        # `prev_event` because not being able to fetch the `prev_events` just means
+                        # we won't be able to de-outlier the pulled event. But we can still use an
+                        # `outlier` in the state/auth chain for another event. So we shouldn't stop
+                        # a downstream event from trying to pull it.
+                        #
+                        # This avoids a cascade of backoff for all events in the DAG downstream from
+                        # one event backoff upstream.
                     except FederationError as e:
+                        # TODO: We should `record_event_failed_pull_attempt` here,
+                        #   see https://github.com/matrix-org/synapse/issues/13700
+
                         if attempt == len(destinations) - 1:
                             # We have tried every remote server for this event. Give up.
                             # TODO(faster_joins) giving up isn't the right thing to do
@@ -1685,3 +1768,29 @@ class FederationHandler:
                             room_id,
                             destination,
                         )
+
+
+def _prioritise_destinations_for_partial_state_resync(
+    initial_destination: Optional[str],
+    other_destinations: Collection[str],
+    room_id: str,
+) -> Collection[str]:
+    """Work out the order in which we should ask servers to resync events.
+
+    If an `initial_destination` is given, it takes top priority. Otherwise
+    all servers are treated equally.
+
+    :raises ValueError: if no destination is provided at all.
+    """
+    if initial_destination is None and len(other_destinations) == 0:
+        raise ValueError(f"Cannot resync state of {room_id}: no destinations provided")
+
+    if initial_destination is None:
+        return other_destinations
+
+    # Move `initial_destination` to the front of the list.
+    destinations = list(other_destinations)
+    if initial_destination in destinations:
+        destinations.remove(initial_destination)
+    destinations = [initial_destination] + destinations
+    return destinations
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 00a8860ff3..378b863c5f 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -44,6 +44,7 @@ from synapse.api.errors import (
     AuthError,
     Codes,
     FederationError,
+    FederationPullAttemptBackoffError,
     HttpResponseException,
     RequestSendFailed,
     SynapseError,
@@ -57,7 +58,7 @@ from synapse.event_auth import (
 )
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
-from synapse.federation.federation_client import InvalidResponseError
+from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
 from synapse.logging.context import nested_logging_context
 from synapse.logging.tracing import (
     SynapseTags,
@@ -414,7 +415,9 @@ class FederationEventHandler:
 
         # First, precalculate the joined hosts so that the federation sender doesn't
         # need to.
-        await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
+        await self._event_creation_handler.cache_joined_hosts_for_events(
+            [(event, context)]
+        )
 
         await self._check_for_soft_fail(event, context=context, origin=origin)
         await self._run_push_actions_and_persist_event(event, context)
@@ -565,6 +568,9 @@ class FederationEventHandler:
             event: partial-state event to be de-partial-stated
 
         Raises:
+            FederationPullAttemptBackoffError if we are are deliberately not attempting
+                to pull the given event over federation because we've already done so
+                recently and are backing off.
             FederationError if we fail to request state from the remote server.
         """
         logger.info("Updating state for %s", event.event_id)
@@ -792,9 +798,42 @@ class FederationEventHandler:
             ],
         )
 
+        # Check if we already any of these have these events.
+        # Note: we currently make a lookup in the database directly here rather than
+        # checking the event cache, due to:
+        # https://github.com/matrix-org/synapse/issues/13476
+        existing_events_map = await self._store._get_events_from_db(
+            [event.event_id for event in events]
+        )
+
+        new_events = []
+        for event in events:
+            event_id = event.event_id
+
+            # If we've already seen this event ID...
+            if event_id in existing_events_map:
+                existing_event = existing_events_map[event_id]
+
+                # ...and the event itself was not previously stored as an outlier...
+                if not existing_event.event.internal_metadata.is_outlier():
+                    # ...then there's no need to persist it. We have it already.
+                    logger.info(
+                        "_process_pulled_event: Ignoring received event %s which we "
+                        "have already seen",
+                        event.event_id,
+                    )
+                    continue
+
+                # While we have seen this event before, it was stored as an outlier.
+                # We'll now persist it as a non-outlier.
+                logger.info("De-outliering event %s", event_id)
+
+            # Continue on with the events that are new to us.
+            new_events.append(event)
+
         # We want to sort these by depth so we process them and
         # tell clients about them in order.
-        sorted_events = sorted(events, key=lambda x: x.depth)
+        sorted_events = sorted(new_events, key=lambda x: x.depth)
         for ev in sorted_events:
             with nested_logging_context(ev.event_id):
                 await self._process_pulled_event(origin, ev, backfilled=backfilled)
@@ -846,18 +885,6 @@ class FederationEventHandler:
 
         event_id = event.event_id
 
-        existing = await self._store.get_event(
-            event_id, allow_none=True, allow_rejected=True
-        )
-        if existing:
-            if not existing.internal_metadata.is_outlier():
-                logger.info(
-                    "_process_pulled_event: Ignoring received event %s which we have already seen",
-                    event_id,
-                )
-                return
-            logger.info("De-outliering event %s", event_id)
-
         try:
             self._sanity_check_event(event)
         except SynapseError as err:
@@ -866,11 +893,6 @@ class FederationEventHandler:
                 event.room_id, event_id, str(err)
             )
             return
-        except Exception as exc:
-            await self._store.record_event_failed_pull_attempt(
-                event.room_id, event_id, str(exc)
-            )
-            raise exc
 
         try:
             try:
@@ -904,6 +926,18 @@ class FederationEventHandler:
                     context,
                     backfilled=backfilled,
                 )
+        except FederationPullAttemptBackoffError as exc:
+            # Log a warning about why we failed to process the event (the error message
+            # for `FederationPullAttemptBackoffError` is pretty good)
+            logger.warning("_process_pulled_event: %s", exc)
+            # We do not record a failed pull attempt when we backoff fetching a missing
+            # `prev_event` because not being able to fetch the `prev_events` just means
+            # we won't be able to de-outlier the pulled event. But we can still use an
+            # `outlier` in the state/auth chain for another event. So we shouldn't stop
+            # a downstream event from trying to pull it.
+            #
+            # This avoids a cascade of backoff for all events in the DAG downstream from
+            # one event backoff upstream.
         except FederationError as e:
             await self._store.record_event_failed_pull_attempt(
                 event.room_id, event_id, str(e)
@@ -913,11 +947,6 @@ class FederationEventHandler:
                 logger.warning("Pulled event %s failed history check.", event_id)
             else:
                 raise
-        except Exception as exc:
-            await self._store.record_event_failed_pull_attempt(
-                event.room_id, event_id, str(exc)
-            )
-            raise exc
 
     @trace
     async def _compute_event_context_with_maybe_missing_prevs(
@@ -955,6 +984,9 @@ class FederationEventHandler:
             The event context.
 
         Raises:
+            FederationPullAttemptBackoffError if we are are deliberately not attempting
+                to pull the given event over federation because we've already done so
+                recently and are backing off.
             FederationError if we fail to get the state from the remote server after any
                 missing `prev_event`s.
         """
@@ -965,6 +997,18 @@ class FederationEventHandler:
         seen = await self._store.have_events_in_timeline(prevs)
         missing_prevs = prevs - seen
 
+        # If we've already recently attempted to pull this missing event, don't
+        # try it again so soon. Since we have to fetch all of the prev_events, we can
+        # bail early here if we find any to ignore.
+        prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff(
+            room_id, missing_prevs
+        )
+        if len(prevs_to_ignore) > 0:
+            raise FederationPullAttemptBackoffError(
+                event_ids=prevs_to_ignore,
+                message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).",
+            )
+
         if not missing_prevs:
             return await self._state_handler.compute_event_context(event)
 
@@ -1021,10 +1065,9 @@ class FederationEventHandler:
                 state_res_store=StateResolutionStore(self._store),
             )
 
-        except Exception:
+        except Exception as e:
             logger.warning(
-                "Error attempting to resolve state at missing prev_events",
-                exc_info=True,
+                "Error attempting to resolve state at missing prev_events: %s", e
             )
             raise FederationError(
                 "ERROR",
@@ -1478,8 +1521,8 @@ class FederationEventHandler:
         )
 
     async def backfill_event_id(
-        self, destination: str, room_id: str, event_id: str
-    ) -> EventBase:
+        self, destinations: List[str], room_id: str, event_id: str
+    ) -> PulledPduInfo:
         """Backfill a single event and persist it as a non-outlier which means
         we also pull in all of the state and auth events necessary for it.
 
@@ -1491,24 +1534,21 @@ class FederationEventHandler:
         Raises:
             FederationError if we are unable to find the event from the destination
         """
-        logger.info(
-            "backfill_event_id: event_id=%s from destination=%s", event_id, destination
-        )
+        logger.info("backfill_event_id: event_id=%s", event_id)
 
         room_version = await self._store.get_room_version(room_id)
 
-        event_from_response = await self._federation_client.get_pdu(
-            [destination],
+        pulled_pdu_info = await self._federation_client.get_pdu(
+            destinations,
             event_id,
             room_version,
         )
 
-        if not event_from_response:
+        if not pulled_pdu_info:
             raise FederationError(
                 "ERROR",
                 404,
-                "Unable to find event_id=%s from destination=%s to backfill."
-                % (event_id, destination),
+                f"Unable to find event_id={event_id} from remote servers to backfill.",
                 affected=event_id,
             )
 
@@ -1516,13 +1556,13 @@ class FederationEventHandler:
         # and auth events to de-outlier it. This also sets up the necessary
         # `state_groups` for the event.
         await self._process_pulled_events(
-            destination,
-            [event_from_response],
+            pulled_pdu_info.pull_origin,
+            [pulled_pdu_info.pdu],
             # Prevent notifications going to clients
             backfilled=True,
         )
 
-        return event_from_response
+        return pulled_pdu_info
 
     @trace
     @tag_args
@@ -1545,19 +1585,19 @@ class FederationEventHandler:
         async def get_event(event_id: str) -> None:
             with nested_logging_context(event_id):
                 try:
-                    event = await self._federation_client.get_pdu(
+                    pulled_pdu_info = await self._federation_client.get_pdu(
                         [destination],
                         event_id,
                         room_version,
                     )
-                    if event is None:
+                    if pulled_pdu_info is None:
                         logger.warning(
                             "Server %s didn't return event %s",
                             destination,
                             event_id,
                         )
                         return
-                    events.append(event)
+                    events.append(pulled_pdu_info.pdu)
 
                 except Exception as e:
                     logger.warning(
@@ -2132,8 +2172,8 @@ class FederationEventHandler:
                     min_depth,
                 )
             else:
-                await self._bulk_push_rule_evaluator.action_for_event_by_user(
-                    event, context
+                await self._bulk_push_rule_evaluator.action_for_events_by_user(
+                    [(event, context)]
                 )
 
         try:
@@ -2175,6 +2215,7 @@ class FederationEventHandler:
         if instance != self._instance_name:
             # Limit the number of events sent over replication. We choose 200
             # here as that is what we default to in `max_request_body_size(..)`
+            result = {}
             try:
                 for batch in batch_iter(event_and_contexts, 200):
                     result = await self._send_events(
@@ -2254,8 +2295,8 @@ class FederationEventHandler:
         event_pos = PersistedEventPosition(
             self._instance_name, event.internal_metadata.stream_ordering
         )
-        await self._notifier.on_new_room_event(
-            event, event_pos, max_stream_token, extra_users=extra_users
+        await self._notifier.on_new_room_events(
+            [(event, event_pos)], max_stream_token, extra_users=extra_users
         )
 
         if event.type == EventTypes.Member and event.membership == Membership.JOIN:
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 93d09e9939..848e46eb9b 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -711,7 +711,7 @@ class IdentityHandler:
             inviter_display_name: The current display name of the
                 inviter.
             inviter_avatar_url: The URL of the inviter's avatar.
-            id_access_token (str): The access token to authenticate to the identity
+            id_access_token: The access token to authenticate to the identity
                 server with
 
         Returns:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 860c82c110..9c335e6863 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -57,13 +57,7 @@ class InitialSyncHandler:
         self.validator = EventValidator()
         self.snapshot_cache: ResponseCache[
             Tuple[
-                str,
-                Optional[StreamToken],
-                Optional[StreamToken],
-                str,
-                Optional[int],
-                bool,
-                bool,
+                str, Optional[StreamToken], Optional[StreamToken], str, int, bool, bool
             ]
         ] = ResponseCache(hs.get_clock(), "initial_sync_cache")
         self._event_serializer = hs.get_event_client_serializer()
@@ -154,11 +148,6 @@ class InitialSyncHandler:
 
         public_room_ids = await self.store.get_public_room_ids()
 
-        if pagin_config.limit is not None:
-            limit = pagin_config.limit
-        else:
-            limit = 10
-
         serializer_options = SerializeEventConfig(as_client_event=as_client_event)
 
         async def handle_room(event: RoomsForUser) -> None:
@@ -210,7 +199,7 @@ class InitialSyncHandler:
                             run_in_background(
                                 self.store.get_recent_events_for_room,
                                 event.room_id,
-                                limit=limit,
+                                limit=pagin_config.limit,
                                 end_token=room_end_token,
                             ),
                             deferred_room_state,
@@ -360,15 +349,11 @@ class InitialSyncHandler:
             member_event_id
         )
 
-        limit = pagin_config.limit if pagin_config else None
-        if limit is None:
-            limit = 10
-
         leave_position = await self.store.get_position_for_event(member_event_id)
         stream_token = leave_position.to_room_stream_token()
 
         messages, token = await self.store.get_recent_events_for_room(
-            room_id, limit=limit, end_token=stream_token
+            room_id, limit=pagin_config.limit, end_token=stream_token
         )
 
         messages = await filter_events_for_client(
@@ -420,10 +405,6 @@ class InitialSyncHandler:
 
         now_token = self.hs.get_event_sources().get_current_token()
 
-        limit = pagin_config.limit if pagin_config else None
-        if limit is None:
-            limit = 10
-
         room_members = [
             m
             for m in current_state.values()
@@ -467,7 +448,7 @@ class InitialSyncHandler:
                     run_in_background(
                         self.store.get_recent_events_for_room,
                         room_id,
-                        limit=limit,
+                        limit=pagin_config.limit,
                         end_token=now_token.room_key,
                     ),
                 ),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 10b5dad030..f2a0101733 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -56,13 +56,16 @@ from synapse.logging import tracing
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
+from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
 from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
 from synapse.types import (
     MutableStateMap,
+    PersistedEventPosition,
     Requester,
     RoomAlias,
+    StateMap,
     StreamToken,
     UserID,
     create_requester,
@@ -492,6 +495,7 @@ class EventCreationHandler:
             self.membership_types_to_include_profile_data_in.add(Membership.INVITE)
 
         self.send_event = ReplicationSendEventRestServlet.make_client(hs)
+        self.send_events = ReplicationSendEventsRestServlet.make_client(hs)
 
         self.request_ratelimiter = hs.get_request_ratelimiter()
 
@@ -567,9 +571,17 @@ class EventCreationHandler:
         outlier: bool = False,
         historical: bool = False,
         depth: Optional[int] = None,
+        state_map: Optional[StateMap[str]] = None,
+        for_batch: bool = False,
+        current_state_group: Optional[int] = None,
     ) -> Tuple[EventBase, EventContext]:
         """
-        Given a dict from a client, create a new event.
+        Given a dict from a client, create a new event. If bool for_batch is true, will
+        create an event using the prev_event_ids, and will create an event context for
+        the event using the parameters state_map and current_state_group, thus these parameters
+        must be provided in this case if for_batch is True. The subsequently created event
+        and context are suitable for being batched up and bulk persisted to the database
+        with other similarly created events.
 
         Creates an FrozenEvent object, filling out auth_events, prev_events,
         etc.
@@ -612,16 +624,27 @@ class EventCreationHandler:
             outlier: Indicates whether the event is an `outlier`, i.e. if
                 it's from an arbitrary point and floating in the DAG as
                 opposed to being inline with the current DAG.
+
             historical: Indicates whether the message is being inserted
                 back in time around some existing events. This is used to skip
                 a few checks and mark the event as backfilled.
+
             depth: Override the depth used to order the event in the DAG.
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
 
+            state_map: A state map of previously created events, used only when creating events
+                for batch persisting
+
+            for_batch: whether the event is being created for batch persisting to the db
+
+            current_state_group: the current state group, used only for creating events for
+                batch persisting
+
         Raises:
             ResourceLimitError if server is blocked to some resource being
             exceeded
+
         Returns:
             Tuple of created event, Context
         """
@@ -693,6 +716,9 @@ class EventCreationHandler:
             auth_event_ids=auth_event_ids,
             state_event_ids=state_event_ids,
             depth=depth,
+            state_map=state_map,
+            for_batch=for_batch,
+            current_state_group=current_state_group,
         )
 
         # In an ideal world we wouldn't need the second part of this condition. However,
@@ -707,10 +733,14 @@ class EventCreationHandler:
             # federation as well as those created locally. As of room v3, aliases events
             # can be created by users that are not in the room, therefore we have to
             # tolerate them in event_auth.check().
-            prev_state_ids = await context.get_prev_state_ids(
-                StateFilter.from_types([(EventTypes.Member, None)])
-            )
-            prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
+            if for_batch:
+                assert state_map is not None
+                prev_event_id = state_map.get((EventTypes.Member, event.sender))
+            else:
+                prev_state_ids = await context.get_prev_state_ids(
+                    StateFilter.from_types([(EventTypes.Member, None)])
+                )
+                prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
             prev_event = (
                 await self.store.get_event(prev_event_id, allow_none=True)
                 if prev_event_id
@@ -847,6 +877,36 @@ class EventCreationHandler:
                 return prev_event
         return None
 
+    async def get_event_from_transaction(
+        self,
+        requester: Requester,
+        txn_id: str,
+        room_id: str,
+    ) -> Optional[EventBase]:
+        """For the given transaction ID and room ID, check if there is a matching event.
+        If so, fetch it and return it.
+
+        Args:
+            requester: The requester making the request in the context of which we want
+                to fetch the event.
+            txn_id: The transaction ID.
+            room_id: The room ID.
+
+        Returns:
+            An event if one could be found, None otherwise.
+        """
+        if requester.access_token_id:
+            existing_event_id = await self.store.get_event_id_from_transaction_id(
+                room_id,
+                requester.user.to_string(),
+                requester.access_token_id,
+                txn_id,
+            )
+            if existing_event_id:
+                return await self.store.get_event(existing_event_id)
+
+        return None
+
     async def create_and_send_nonmember_event(
         self,
         requester: Requester,
@@ -926,18 +986,17 @@ class EventCreationHandler:
         # extremities to pile up, which in turn leads to state resolution
         # taking longer.
         async with self.limiter.queue(event_dict["room_id"]):
-            if txn_id and requester.access_token_id:
-                existing_event_id = await self.store.get_event_id_from_transaction_id(
-                    event_dict["room_id"],
-                    requester.user.to_string(),
-                    requester.access_token_id,
-                    txn_id,
+            if txn_id:
+                event = await self.get_event_from_transaction(
+                    requester, txn_id, event_dict["room_id"]
                 )
-                if existing_event_id:
-                    event = await self.store.get_event(existing_event_id)
+                if event:
                     # we know it was persisted, so must have a stream ordering
                     assert event.internal_metadata.stream_ordering
-                    return event, event.internal_metadata.stream_ordering
+                    return (
+                        event,
+                        event.internal_metadata.stream_ordering,
+                    )
 
             event, context = await self.create_event(
                 requester,
@@ -989,8 +1048,7 @@ class EventCreationHandler:
 
             ev = await self.handle_new_client_event(
                 requester=requester,
-                event=event,
-                context=context,
+                events_and_context=[(event, context)],
                 ratelimit=ratelimit,
                 ignore_shadow_ban=ignore_shadow_ban,
             )
@@ -1009,8 +1067,16 @@ class EventCreationHandler:
         auth_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
         depth: Optional[int] = None,
+        state_map: Optional[StateMap[str]] = None,
+        for_batch: bool = False,
+        current_state_group: Optional[int] = None,
     ) -> Tuple[EventBase, EventContext]:
-        """Create a new event for a local client
+        """Create a new event for a local client. If bool for_batch is true, will
+        create an event using the prev_event_ids, and will create an event context for
+        the event using the parameters state_map and current_state_group, thus these parameters
+        must be provided in this case if for_batch is True. The subsequently created event
+        and context are suitable for being batched up and bulk persisted to the database
+        with other similarly created events.
 
         Args:
             builder:
@@ -1043,6 +1109,14 @@ class EventCreationHandler:
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
 
+            state_map: A state map of previously created events, used only when creating events
+                for batch persisting
+
+            for_batch: whether the event is being created for batch persisting to the db
+
+            current_state_group: the current state group, used only for creating events for
+                batch persisting
+
         Returns:
             Tuple of created event, context
         """
@@ -1095,64 +1169,76 @@ class EventCreationHandler:
                 builder.type == EventTypes.Create or prev_event_ids
             ), "Attempting to create a non-m.room.create event with no prev_events"
 
-        event = await builder.build(
-            prev_event_ids=prev_event_ids,
-            auth_event_ids=auth_event_ids,
-            depth=depth,
-        )
+        if for_batch:
+            assert prev_event_ids is not None
+            assert state_map is not None
+            assert current_state_group is not None
+            auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
+            event = await builder.build(
+                prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
+            )
+            context = await self.state.compute_event_context_for_batched(
+                event, state_map, current_state_group
+            )
+        else:
+            event = await builder.build(
+                prev_event_ids=prev_event_ids,
+                auth_event_ids=auth_event_ids,
+                depth=depth,
+            )
 
-        # Pass on the outlier property from the builder to the event
-        # after it is created
-        if builder.internal_metadata.outlier:
-            event.internal_metadata.outlier = True
-            context = EventContext.for_outlier(self._storage_controllers)
-        elif (
-            event.type == EventTypes.MSC2716_INSERTION
-            and state_event_ids
-            and builder.internal_metadata.is_historical()
-        ):
-            # Add explicit state to the insertion event so it has state to derive
-            # from even though it's floating with no `prev_events`. The rest of
-            # the batch can derive from this state and state_group.
-            #
-            # TODO(faster_joins): figure out how this works, and make sure that the
-            #   old state is complete.
-            #   https://github.com/matrix-org/synapse/issues/13003
-            metadata = await self.store.get_metadata_for_events(state_event_ids)
-
-            state_map_for_event: MutableStateMap[str] = {}
-            for state_id in state_event_ids:
-                data = metadata.get(state_id)
-                if data is None:
-                    # We're trying to persist a new historical batch of events
-                    # with the given state, e.g. via
-                    # `RoomBatchSendEventRestServlet`. The state can be inferred
-                    # by Synapse or set directly by the client.
-                    #
-                    # Either way, we should have persisted all the state before
-                    # getting here.
-                    raise Exception(
-                        f"State event {state_id} not found in DB,"
-                        " Synapse should have persisted it before using it."
-                    )
+            # Pass on the outlier property from the builder to the event
+            # after it is created
+            if builder.internal_metadata.outlier:
+                event.internal_metadata.outlier = True
+                context = EventContext.for_outlier(self._storage_controllers)
+            elif (
+                event.type == EventTypes.MSC2716_INSERTION
+                and state_event_ids
+                and builder.internal_metadata.is_historical()
+            ):
+                # Add explicit state to the insertion event so it has state to derive
+                # from even though it's floating with no `prev_events`. The rest of
+                # the batch can derive from this state and state_group.
+                #
+                # TODO(faster_joins): figure out how this works, and make sure that the
+                #   old state is complete.
+                #   https://github.com/matrix-org/synapse/issues/13003
+                metadata = await self.store.get_metadata_for_events(state_event_ids)
+
+                state_map_for_event: MutableStateMap[str] = {}
+                for state_id in state_event_ids:
+                    data = metadata.get(state_id)
+                    if data is None:
+                        # We're trying to persist a new historical batch of events
+                        # with the given state, e.g. via
+                        # `RoomBatchSendEventRestServlet`. The state can be inferred
+                        # by Synapse or set directly by the client.
+                        #
+                        # Either way, we should have persisted all the state before
+                        # getting here.
+                        raise Exception(
+                            f"State event {state_id} not found in DB,"
+                            " Synapse should have persisted it before using it."
+                        )
 
-                if data.state_key is None:
-                    raise Exception(
-                        f"Trying to set non-state event {state_id} as state"
-                    )
+                    if data.state_key is None:
+                        raise Exception(
+                            f"Trying to set non-state event {state_id} as state"
+                        )
 
-                state_map_for_event[(data.event_type, data.state_key)] = state_id
+                    state_map_for_event[(data.event_type, data.state_key)] = state_id
 
-            context = await self.state.compute_event_context(
-                event,
-                state_ids_before_event=state_map_for_event,
-                # TODO(faster_joins): check how MSC2716 works and whether we can have
-                #   partial state here
-                #   https://github.com/matrix-org/synapse/issues/13003
-                partial_state=False,
-            )
-        else:
-            context = await self.state.compute_event_context(event)
+                context = await self.state.compute_event_context(
+                    event,
+                    state_ids_before_event=state_map_for_event,
+                    # TODO(faster_joins): check how MSC2716 works and whether we can have
+                    #   partial state here
+                    #   https://github.com/matrix-org/synapse/issues/13003
+                    partial_state=False,
+                )
+            else:
+                context = await self.state.compute_event_context(event)
 
         if requester:
             context.app_service = requester.app_service
@@ -1238,13 +1324,13 @@ class EventCreationHandler:
     async def handle_new_client_event(
         self,
         requester: Requester,
-        event: EventBase,
-        context: EventContext,
+        events_and_context: List[Tuple[EventBase, EventContext]],
         ratelimit: bool = True,
         extra_users: Optional[List[UserID]] = None,
         ignore_shadow_ban: bool = False,
     ) -> EventBase:
-        """Processes a new event.
+        """Processes new events. Please note that if batch persisting events, an error in
+        handling any one of these events will result in all of the events being dropped.
 
         This includes deduplicating, checking auth, persisting,
         notifying users, sending to remote servers, etc.
@@ -1254,8 +1340,7 @@ class EventCreationHandler:
 
         Args:
             requester
-            event
-            context
+            events_and_context: A list of one or more tuples of event, context to be persisted
             ratelimit
             extra_users: Any extra users to notify about event
 
@@ -1273,67 +1358,76 @@ class EventCreationHandler:
         """
         extra_users = extra_users or []
 
-        # we don't apply shadow-banning to membership events here. Invites are blocked
-        # higher up the stack, and we allow shadow-banned users to send join and leave
-        # events as normal.
-        if (
-            event.type != EventTypes.Member
-            and not ignore_shadow_ban
-            and requester.shadow_banned
-        ):
-            # We randomly sleep a bit just to annoy the requester.
-            await self.clock.sleep(random.randint(1, 10))
-            raise ShadowBanError()
+        for event, context in events_and_context:
+            # we don't apply shadow-banning to membership events here. Invites are blocked
+            # higher up the stack, and we allow shadow-banned users to send join and leave
+            # events as normal.
+            if (
+                event.type != EventTypes.Member
+                and not ignore_shadow_ban
+                and requester.shadow_banned
+            ):
+                # We randomly sleep a bit just to annoy the requester.
+                await self.clock.sleep(random.randint(1, 10))
+                raise ShadowBanError()
 
-        if event.is_state():
-            prev_event = await self.deduplicate_state_event(event, context)
-            if prev_event is not None:
-                logger.info(
-                    "Not bothering to persist state event %s duplicated by %s",
-                    event.event_id,
-                    prev_event.event_id,
-                )
-                return prev_event
+            if event.is_state():
+                prev_event = await self.deduplicate_state_event(event, context)
+                if prev_event is not None:
+                    logger.info(
+                        "Not bothering to persist state event %s duplicated by %s",
+                        event.event_id,
+                        prev_event.event_id,
+                    )
+                    return prev_event
 
-        if event.internal_metadata.is_out_of_band_membership():
-            # the only sort of out-of-band-membership events we expect to see here are
-            # invite rejections and rescinded knocks that we have generated ourselves.
-            assert event.type == EventTypes.Member
-            assert event.content["membership"] == Membership.LEAVE
-        else:
-            try:
-                validate_event_for_room_version(event)
-                await self._event_auth_handler.check_auth_rules_from_context(
-                    event, context
-                )
-            except AuthError as err:
-                logger.warning("Denying new event %r because %s", event, err)
-                raise err
+            if event.internal_metadata.is_out_of_band_membership():
+                # the only sort of out-of-band-membership events we expect to see here are
+                # invite rejections and rescinded knocks that we have generated ourselves.
+                assert event.type == EventTypes.Member
+                assert event.content["membership"] == Membership.LEAVE
+            else:
+                try:
+                    validate_event_for_room_version(event)
+                    # If we are persisting a batch of events the event(s) needed to auth the
+                    # current event may be part of the batch and will not be in the DB yet
+                    event_id_to_event = {e.event_id: e for e, _ in events_and_context}
+                    batched_auth_events = {}
+                    for event_id in event.auth_event_ids():
+                        auth_event = event_id_to_event.get(event_id)
+                        if auth_event:
+                            batched_auth_events[event_id] = auth_event
+                    await self._event_auth_handler.check_auth_rules_from_context(
+                        event, batched_auth_events
+                    )
+                except AuthError as err:
+                    logger.warning("Denying new event %r because %s", event, err)
+                    raise err
 
-        # Ensure that we can round trip before trying to persist in db
-        try:
-            dump = json_encoder.encode(event.content)
-            json_decoder.decode(dump)
-        except Exception:
-            logger.exception("Failed to encode content: %r", event.content)
-            raise
+            # Ensure that we can round trip before trying to persist in db
+            try:
+                dump = json_encoder.encode(event.content)
+                json_decoder.decode(dump)
+            except Exception:
+                logger.exception("Failed to encode content: %r", event.content)
+                raise
 
         # We now persist the event (and update the cache in parallel, since we
         # don't want to block on it).
+        event, context = events_and_context[0]
         try:
             result, _ = await make_deferred_yieldable(
                 gather_results(
                     (
                         run_in_background(
-                            self._persist_event,
+                            self._persist_events,
                             requester=requester,
-                            event=event,
-                            context=context,
+                            events_and_context=events_and_context,
                             ratelimit=ratelimit,
                             extra_users=extra_users,
                         ),
                         run_in_background(
-                            self.cache_joined_hosts_for_event, event, context
+                            self.cache_joined_hosts_for_events, events_and_context
                         ).addErrback(
                             log_failure, "cache_joined_hosts_for_event failed"
                         ),
@@ -1352,45 +1446,40 @@ class EventCreationHandler:
 
         return result
 
-    async def _persist_event(
+    async def _persist_events(
         self,
         requester: Requester,
-        event: EventBase,
-        context: EventContext,
+        events_and_context: List[Tuple[EventBase, EventContext]],
         ratelimit: bool = True,
         extra_users: Optional[List[UserID]] = None,
     ) -> EventBase:
-        """Actually persists the event. Should only be called by
+        """Actually persists new events. Should only be called by
         `handle_new_client_event`, and see its docstring for documentation of
-        the arguments.
+        the arguments. Please note that if batch persisting events, an error in
+        handling any one of these events will result in all of the events being dropped.
 
         PartialStateConflictError: if attempting to persist a partial state event in
             a room that has been un-partial stated.
         """
 
-        # Skip push notification actions for historical messages
-        # because we don't want to notify people about old history back in time.
-        # The historical messages also do not have the proper `context.current_state_ids`
-        # and `state_groups` because they have `prev_events` that aren't persisted yet
-        # (historical messages persisted in reverse-chronological order).
-        if not event.internal_metadata.is_historical():
-            with tracing.start_active_span("calculate_push_actions"):
-                await self._bulk_push_rule_evaluator.action_for_event_by_user(
-                    event, context
-                )
+        with tracing.start_active_span("calculate_push_actions"):
+            await self._bulk_push_rule_evaluator.action_for_events_by_user(
+                events_and_context
+            )
 
         try:
             # If we're a worker we need to hit out to the master.
-            writer_instance = self._events_shard_config.get_instance(event.room_id)
+            first_event, _ = events_and_context[0]
+            writer_instance = self._events_shard_config.get_instance(
+                first_event.room_id
+            )
             if writer_instance != self._instance_name:
                 try:
-                    result = await self.send_event(
+                    result = await self.send_events(
                         instance_name=writer_instance,
-                        event_id=event.event_id,
+                        events_and_context=events_and_context,
                         store=self.store,
                         requester=requester,
-                        event=event,
-                        context=context,
                         ratelimit=ratelimit,
                         extra_users=extra_users,
                     )
@@ -1400,6 +1489,11 @@ class EventCreationHandler:
                     raise
                 stream_id = result["stream_id"]
                 event_id = result["event_id"]
+
+                # If we batch persisted events we return the last persisted event, otherwise
+                # we return the one event that was persisted
+                event, _ = events_and_context[-1]
+
                 if event_id != event.event_id:
                     # If we get a different event back then it means that its
                     # been de-duplicated, so we replace the given event with the
@@ -1412,73 +1506,80 @@ class EventCreationHandler:
                     event.internal_metadata.stream_ordering = stream_id
                 return event
 
-            event = await self.persist_and_notify_client_event(
-                requester, event, context, ratelimit=ratelimit, extra_users=extra_users
+            event = await self.persist_and_notify_client_events(
+                requester,
+                events_and_context,
+                ratelimit=ratelimit,
+                extra_users=extra_users,
             )
 
             return event
         except Exception:
-            # Ensure that we actually remove the entries in the push actions
-            # staging area, if we calculated them.
-            await self.store.remove_push_actions_from_staging(event.event_id)
+            for event, _ in events_and_context:
+                # Ensure that we actually remove the entries in the push actions
+                # staging area, if we calculated them.
+                await self.store.remove_push_actions_from_staging(event.event_id)
             raise
 
-    async def cache_joined_hosts_for_event(
-        self, event: EventBase, context: EventContext
+    async def cache_joined_hosts_for_events(
+        self, events_and_context: List[Tuple[EventBase, EventContext]]
     ) -> None:
-        """Precalculate the joined hosts at the event, when using Redis, so that
+        """Precalculate the joined hosts at each of the given events, when using Redis, so that
         external federation senders don't have to recalculate it themselves.
         """
 
-        if not self._external_cache.is_enabled():
-            return
-
-        # If external cache is enabled we should always have this.
-        assert self._external_cache_joined_hosts_updates is not None
+        for event, _ in events_and_context:
+            if not self._external_cache.is_enabled():
+                return
 
-        # We actually store two mappings, event ID -> prev state group,
-        # state group -> joined hosts, which is much more space efficient
-        # than event ID -> joined hosts.
-        #
-        # Note: We have to cache event ID -> prev state group, as we don't
-        # store that in the DB.
-        #
-        # Note: We set the state group -> joined hosts cache if it hasn't been
-        # set for a while, so that the expiry time is reset.
+            # If external cache is enabled we should always have this.
+            assert self._external_cache_joined_hosts_updates is not None
 
-        state_entry = await self.state.resolve_state_groups_for_events(
-            event.room_id, event_ids=event.prev_event_ids()
-        )
+            # We actually store two mappings, event ID -> prev state group,
+            # state group -> joined hosts, which is much more space efficient
+            # than event ID -> joined hosts.
+            #
+            # Note: We have to cache event ID -> prev state group, as we don't
+            # store that in the DB.
+            #
+            # Note: We set the state group -> joined hosts cache if it hasn't been
+            # set for a while, so that the expiry time is reset.
 
-        if state_entry.state_group:
-            await self._external_cache.set(
-                "event_to_prev_state_group",
-                event.event_id,
-                state_entry.state_group,
-                expiry_ms=60 * 60 * 1000,
+            state_entry = await self.state.resolve_state_groups_for_events(
+                event.room_id, event_ids=event.prev_event_ids()
             )
 
-            if state_entry.state_group in self._external_cache_joined_hosts_updates:
-                return
+            if state_entry.state_group:
+                await self._external_cache.set(
+                    "event_to_prev_state_group",
+                    event.event_id,
+                    state_entry.state_group,
+                    expiry_ms=60 * 60 * 1000,
+                )
 
-            state = await state_entry.get_state(
-                self._storage_controllers.state, StateFilter.all()
-            )
-            with tracing.start_active_span("get_joined_hosts"):
-                joined_hosts = await self.store.get_joined_hosts(
-                    event.room_id, state, state_entry
+                if state_entry.state_group in self._external_cache_joined_hosts_updates:
+                    return
+
+                state = await state_entry.get_state(
+                    self._storage_controllers.state, StateFilter.all()
                 )
+                with tracing.start_active_span("get_joined_hosts"):
+                    joined_hosts = await self.store.get_joined_hosts(
+                        event.room_id, state, state_entry
+                    )
 
-            # Note that the expiry times must be larger than the expiry time in
-            # _external_cache_joined_hosts_updates.
-            await self._external_cache.set(
-                "get_joined_hosts",
-                str(state_entry.state_group),
-                list(joined_hosts),
-                expiry_ms=60 * 60 * 1000,
-            )
+                # Note that the expiry times must be larger than the expiry time in
+                # _external_cache_joined_hosts_updates.
+                await self._external_cache.set(
+                    "get_joined_hosts",
+                    str(state_entry.state_group),
+                    list(joined_hosts),
+                    expiry_ms=60 * 60 * 1000,
+                )
 
-            self._external_cache_joined_hosts_updates[state_entry.state_group] = None
+                self._external_cache_joined_hosts_updates[
+                    state_entry.state_group
+                ] = None
 
     async def _validate_canonical_alias(
         self,
@@ -1514,23 +1615,26 @@ class EventCreationHandler:
                 Codes.BAD_ALIAS,
             )
 
-    async def persist_and_notify_client_event(
+    async def persist_and_notify_client_events(
         self,
         requester: Requester,
-        event: EventBase,
-        context: EventContext,
+        events_and_context: List[Tuple[EventBase, EventContext]],
         ratelimit: bool = True,
         extra_users: Optional[List[UserID]] = None,
     ) -> EventBase:
-        """Called when we have fully built the event, have already
-        calculated the push actions for the event, and checked auth.
+        """Called when we have fully built the events, have already
+        calculated the push actions for the events, and checked auth.
 
         This should only be run on the instance in charge of persisting events.
 
+        Please note that if batch persisting events, an error in
+        handling any one of these events will result in all of the events being dropped.
+
         Returns:
-            The persisted event. This may be different than the given event if
-            it was de-duplicated (e.g. because we had already persisted an
-            event with the same transaction ID.)
+            The persisted event, if one event is passed in, or the last event in the
+            list in the case of batch persisting. If only one event was persisted, the
+            returned event may be different than the given event if it was de-duplicated
+            (e.g. because we had already persisted an event with the same transaction ID.)
 
         Raises:
             PartialStateConflictError: if attempting to persist a partial state event in
@@ -1538,277 +1642,296 @@ class EventCreationHandler:
         """
         extra_users = extra_users or []
 
-        assert self._storage_controllers.persistence is not None
-        assert self._events_shard_config.should_handle(
-            self._instance_name, event.room_id
-        )
+        for event, context in events_and_context:
+            assert self._events_shard_config.should_handle(
+                self._instance_name, event.room_id
+            )
 
-        if ratelimit:
-            # We check if this is a room admin redacting an event so that we
-            # can apply different ratelimiting. We do this by simply checking
-            # it's not a self-redaction (to avoid having to look up whether the
-            # user is actually admin or not).
-            is_admin_redaction = False
-            if event.type == EventTypes.Redaction:
-                assert event.redacts is not None
+            if ratelimit:
+                # We check if this is a room admin redacting an event so that we
+                # can apply different ratelimiting. We do this by simply checking
+                # it's not a self-redaction (to avoid having to look up whether the
+                # user is actually admin or not).
+                is_admin_redaction = False
+                if event.type == EventTypes.Redaction:
+                    assert event.redacts is not None
+
+                    original_event = await self.store.get_event(
+                        event.redacts,
+                        redact_behaviour=EventRedactBehaviour.as_is,
+                        get_prev_content=False,
+                        allow_rejected=False,
+                        allow_none=True,
+                    )
 
-                original_event = await self.store.get_event(
-                    event.redacts,
-                    redact_behaviour=EventRedactBehaviour.as_is,
-                    get_prev_content=False,
-                    allow_rejected=False,
-                    allow_none=True,
+                    is_admin_redaction = bool(
+                        original_event and event.sender != original_event.sender
+                    )
+
+                await self.request_ratelimiter.ratelimit(
+                    requester, is_admin_redaction=is_admin_redaction
                 )
 
-                is_admin_redaction = bool(
-                    original_event and event.sender != original_event.sender
+            # run checks/actions on event based on type
+            if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+                (
+                    current_membership,
+                    _,
+                ) = await self.store.get_local_current_membership_for_user_in_room(
+                    event.state_key, event.room_id
                 )
+                if current_membership != Membership.JOIN:
+                    self._notifier.notify_user_joined_room(
+                        event.event_id, event.room_id
+                    )
 
-            await self.request_ratelimiter.ratelimit(
-                requester, is_admin_redaction=is_admin_redaction
-            )
+            await self._maybe_kick_guest_users(event, context)
 
-        if event.type == EventTypes.Member and event.membership == Membership.JOIN:
-            (
-                current_membership,
-                _,
-            ) = await self.store.get_local_current_membership_for_user_in_room(
-                event.state_key, event.room_id
-            )
-            if current_membership != Membership.JOIN:
-                self._notifier.notify_user_joined_room(event.event_id, event.room_id)
+            if event.type == EventTypes.CanonicalAlias:
+                # Validate a newly added alias or newly added alt_aliases.
 
-        await self._maybe_kick_guest_users(event, context)
+                original_alias = None
+                original_alt_aliases: object = []
 
-        if event.type == EventTypes.CanonicalAlias:
-            # Validate a newly added alias or newly added alt_aliases.
+                original_event_id = event.unsigned.get("replaces_state")
+                if original_event_id:
+                    original_alias_event = await self.store.get_event(original_event_id)
 
-            original_alias = None
-            original_alt_aliases: object = []
+                    if original_alias_event:
+                        original_alias = original_alias_event.content.get("alias", None)
+                        original_alt_aliases = original_alias_event.content.get(
+                            "alt_aliases", []
+                        )
 
-            original_event_id = event.unsigned.get("replaces_state")
-            if original_event_id:
-                original_event = await self.store.get_event(original_event_id)
+                # Check the alias is currently valid (if it has changed).
+                room_alias_str = event.content.get("alias", None)
+                directory_handler = self.hs.get_directory_handler()
+                if room_alias_str and room_alias_str != original_alias:
+                    await self._validate_canonical_alias(
+                        directory_handler, room_alias_str, event.room_id
+                    )
 
-                if original_event:
-                    original_alias = original_event.content.get("alias", None)
-                    original_alt_aliases = original_event.content.get("alt_aliases", [])
-
-            # Check the alias is currently valid (if it has changed).
-            room_alias_str = event.content.get("alias", None)
-            directory_handler = self.hs.get_directory_handler()
-            if room_alias_str and room_alias_str != original_alias:
-                await self._validate_canonical_alias(
-                    directory_handler, room_alias_str, event.room_id
-                )
+                # Check that alt_aliases is the proper form.
+                alt_aliases = event.content.get("alt_aliases", [])
+                if not isinstance(alt_aliases, (list, tuple)):
+                    raise SynapseError(
+                        400,
+                        "The alt_aliases property must be a list.",
+                        Codes.INVALID_PARAM,
+                    )
 
-            # Check that alt_aliases is the proper form.
-            alt_aliases = event.content.get("alt_aliases", [])
-            if not isinstance(alt_aliases, (list, tuple)):
-                raise SynapseError(
-                    400, "The alt_aliases property must be a list.", Codes.INVALID_PARAM
-                )
+                # If the old version of alt_aliases is of an unknown form,
+                # completely replace it.
+                if not isinstance(original_alt_aliases, (list, tuple)):
+                    # TODO: check that the original_alt_aliases' entries are all strings
+                    original_alt_aliases = []
+
+                # Check that each alias is currently valid.
+                new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
+                if new_alt_aliases:
+                    for alias_str in new_alt_aliases:
+                        await self._validate_canonical_alias(
+                            directory_handler, alias_str, event.room_id
+                        )
 
-            # If the old version of alt_aliases is of an unknown form,
-            # completely replace it.
-            if not isinstance(original_alt_aliases, (list, tuple)):
-                # TODO: check that the original_alt_aliases' entries are all strings
-                original_alt_aliases = []
+            federation_handler = self.hs.get_federation_handler()
 
-            # Check that each alias is currently valid.
-            new_alt_aliases = set(alt_aliases) - set(original_alt_aliases)
-            if new_alt_aliases:
-                for alias_str in new_alt_aliases:
-                    await self._validate_canonical_alias(
-                        directory_handler, alias_str, event.room_id
+            if event.type == EventTypes.Member:
+                if event.content["membership"] == Membership.INVITE:
+                    event.unsigned[
+                        "invite_room_state"
+                    ] = await self.store.get_stripped_room_state_from_event_context(
+                        context,
+                        self.room_prejoin_state_types,
+                        membership_user_id=event.sender,
                     )
 
-        federation_handler = self.hs.get_federation_handler()
+                    invitee = UserID.from_string(event.state_key)
+                    if not self.hs.is_mine(invitee):
+                        # TODO: Can we add signature from remote server in a nicer
+                        # way? If we have been invited by a remote server, we need
+                        # to get them to sign the event.
 
-        if event.type == EventTypes.Member:
-            if event.content["membership"] == Membership.INVITE:
-                event.unsigned[
-                    "invite_room_state"
-                ] = await self.store.get_stripped_room_state_from_event_context(
-                    context,
-                    self.room_prejoin_state_types,
-                    membership_user_id=event.sender,
-                )
+                        returned_invite = await federation_handler.send_invite(
+                            invitee.domain, event
+                        )
+                        event.unsigned.pop("room_state", None)
 
-                invitee = UserID.from_string(event.state_key)
-                if not self.hs.is_mine(invitee):
-                    # TODO: Can we add signature from remote server in a nicer
-                    # way? If we have been invited by a remote server, we need
-                    # to get them to sign the event.
+                        # TODO: Make sure the signatures actually are correct.
+                        event.signatures.update(returned_invite.signatures)
 
-                    returned_invite = await federation_handler.send_invite(
-                        invitee.domain, event
+                if event.content["membership"] == Membership.KNOCK:
+                    event.unsigned[
+                        "knock_room_state"
+                    ] = await self.store.get_stripped_room_state_from_event_context(
+                        context,
+                        self.room_prejoin_state_types,
                     )
-                    event.unsigned.pop("room_state", None)
 
-                    # TODO: Make sure the signatures actually are correct.
-                    event.signatures.update(returned_invite.signatures)
+            if event.type == EventTypes.Redaction:
+                assert event.redacts is not None
 
-            if event.content["membership"] == Membership.KNOCK:
-                event.unsigned[
-                    "knock_room_state"
-                ] = await self.store.get_stripped_room_state_from_event_context(
-                    context,
-                    self.room_prejoin_state_types,
+                original_event = await self.store.get_event(
+                    event.redacts,
+                    redact_behaviour=EventRedactBehaviour.as_is,
+                    get_prev_content=False,
+                    allow_rejected=False,
+                    allow_none=True,
                 )
 
-        if event.type == EventTypes.Redaction:
-            assert event.redacts is not None
+                room_version = await self.store.get_room_version_id(event.room_id)
+                room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
-            original_event = await self.store.get_event(
-                event.redacts,
-                redact_behaviour=EventRedactBehaviour.as_is,
-                get_prev_content=False,
-                allow_rejected=False,
-                allow_none=True,
-            )
+                # we can make some additional checks now if we have the original event.
+                if original_event:
+                    if original_event.type == EventTypes.Create:
+                        raise AuthError(403, "Redacting create events is not permitted")
 
-            room_version = await self.store.get_room_version_id(event.room_id)
-            room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
-            # we can make some additional checks now if we have the original event.
-            if original_event:
-                if original_event.type == EventTypes.Create:
-                    raise AuthError(403, "Redacting create events is not permitted")
-
-                if original_event.room_id != event.room_id:
-                    raise SynapseError(400, "Cannot redact event from a different room")
-
-                if original_event.type == EventTypes.ServerACL:
-                    raise AuthError(403, "Redacting server ACL events is not permitted")
-
-                # Add a little safety stop-gap to prevent people from trying to
-                # redact MSC2716 related events when they're in a room version
-                # which does not support it yet. We allow people to use MSC2716
-                # events in existing room versions but only from the room
-                # creator since it does not require any changes to the auth
-                # rules and in effect, the redaction algorithm . In the
-                # supported room version, we add the `historical` power level to
-                # auth the MSC2716 related events and adjust the redaction
-                # algorthim to keep the `historical` field around (redacting an
-                # event should only strip fields which don't affect the
-                # structural protocol level).
-                is_msc2716_event = (
-                    original_event.type == EventTypes.MSC2716_INSERTION
-                    or original_event.type == EventTypes.MSC2716_BATCH
-                    or original_event.type == EventTypes.MSC2716_MARKER
-                )
-                if not room_version_obj.msc2716_historical and is_msc2716_event:
-                    raise AuthError(
-                        403,
-                        "Redacting MSC2716 events is not supported in this room version",
-                    )
+                    if original_event.room_id != event.room_id:
+                        raise SynapseError(
+                            400, "Cannot redact event from a different room"
+                        )
 
-            event_types = event_auth.auth_types_for_event(event.room_version, event)
-            prev_state_ids = await context.get_prev_state_ids(
-                StateFilter.from_types(event_types)
-            )
+                    if original_event.type == EventTypes.ServerACL:
+                        raise AuthError(
+                            403, "Redacting server ACL events is not permitted"
+                        )
 
-            auth_events_ids = self._event_auth_handler.compute_auth_events(
-                event, prev_state_ids, for_verification=True
-            )
-            auth_events_map = await self.store.get_events(auth_events_ids)
-            auth_events = {(e.type, e.state_key): e for e in auth_events_map.values()}
+                    # Add a little safety stop-gap to prevent people from trying to
+                    # redact MSC2716 related events when they're in a room version
+                    # which does not support it yet. We allow people to use MSC2716
+                    # events in existing room versions but only from the room
+                    # creator since it does not require any changes to the auth
+                    # rules and in effect, the redaction algorithm . In the
+                    # supported room version, we add the `historical` power level to
+                    # auth the MSC2716 related events and adjust the redaction
+                    # algorthim to keep the `historical` field around (redacting an
+                    # event should only strip fields which don't affect the
+                    # structural protocol level).
+                    is_msc2716_event = (
+                        original_event.type == EventTypes.MSC2716_INSERTION
+                        or original_event.type == EventTypes.MSC2716_BATCH
+                        or original_event.type == EventTypes.MSC2716_MARKER
+                    )
+                    if not room_version_obj.msc2716_historical and is_msc2716_event:
+                        raise AuthError(
+                            403,
+                            "Redacting MSC2716 events is not supported in this room version",
+                        )
 
-            if event_auth.check_redaction(
-                room_version_obj, event, auth_events=auth_events
-            ):
-                # this user doesn't have 'redact' rights, so we need to do some more
-                # checks on the original event. Let's start by checking the original
-                # event exists.
-                if not original_event:
-                    raise NotFoundError("Could not find event %s" % (event.redacts,))
-
-                if event.user_id != original_event.user_id:
-                    raise AuthError(403, "You don't have permission to redact events")
-
-                # all the checks are done.
-                event.internal_metadata.recheck_redaction = False
-
-        if event.type == EventTypes.Create:
-            prev_state_ids = await context.get_prev_state_ids()
-            if prev_state_ids:
-                raise AuthError(403, "Changing the room create event is forbidden")
-
-        if event.type == EventTypes.MSC2716_INSERTION:
-            room_version = await self.store.get_room_version_id(event.room_id)
-            room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
-
-            create_event = await self.store.get_create_event_for_room(event.room_id)
-            room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
-
-            # Only check an insertion event if the room version
-            # supports it or the event is from the room creator.
-            if room_version_obj.msc2716_historical or (
-                self.config.experimental.msc2716_enabled
-                and event.sender == room_creator
-            ):
-                next_batch_id = event.content.get(
-                    EventContentFields.MSC2716_NEXT_BATCH_ID
+                event_types = event_auth.auth_types_for_event(event.room_version, event)
+                prev_state_ids = await context.get_prev_state_ids(
+                    StateFilter.from_types(event_types)
+                )
+
+                auth_events_ids = self._event_auth_handler.compute_auth_events(
+                    event, prev_state_ids, for_verification=True
                 )
-                conflicting_insertion_event_id = None
-                if next_batch_id:
-                    conflicting_insertion_event_id = (
-                        await self.store.get_insertion_event_id_by_batch_id(
-                            event.room_id, next_batch_id
+                auth_events_map = await self.store.get_events(auth_events_ids)
+                auth_events = {
+                    (e.type, e.state_key): e for e in auth_events_map.values()
+                }
+
+                if event_auth.check_redaction(
+                    room_version_obj, event, auth_events=auth_events
+                ):
+                    # this user doesn't have 'redact' rights, so we need to do some more
+                    # checks on the original event. Let's start by checking the original
+                    # event exists.
+                    if not original_event:
+                        raise NotFoundError(
+                            "Could not find event %s" % (event.redacts,)
                         )
+
+                    if event.user_id != original_event.user_id:
+                        raise AuthError(
+                            403, "You don't have permission to redact events"
+                        )
+
+                    # all the checks are done.
+                    event.internal_metadata.recheck_redaction = False
+
+            if event.type == EventTypes.Create:
+                prev_state_ids = await context.get_prev_state_ids()
+                if prev_state_ids:
+                    raise AuthError(403, "Changing the room create event is forbidden")
+
+            if event.type == EventTypes.MSC2716_INSERTION:
+                room_version = await self.store.get_room_version_id(event.room_id)
+                room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
+
+                create_event = await self.store.get_create_event_for_room(event.room_id)
+                room_creator = create_event.content.get(EventContentFields.ROOM_CREATOR)
+
+                # Only check an insertion event if the room version
+                # supports it or the event is from the room creator.
+                if room_version_obj.msc2716_historical or (
+                    self.config.experimental.msc2716_enabled
+                    and event.sender == room_creator
+                ):
+                    next_batch_id = event.content.get(
+                        EventContentFields.MSC2716_NEXT_BATCH_ID
                     )
-                if conflicting_insertion_event_id is not None:
-                    # The current insertion event that we're processing is invalid
-                    # because an insertion event already exists in the room with the
-                    # same next_batch_id. We can't allow multiple because the batch
-                    # pointing will get weird, e.g. we can't determine which insertion
-                    # event the batch event is pointing to.
-                    raise SynapseError(
-                        HTTPStatus.BAD_REQUEST,
-                        "Another insertion event already exists with the same next_batch_id",
-                        errcode=Codes.INVALID_PARAM,
-                    )
+                    conflicting_insertion_event_id = None
+                    if next_batch_id:
+                        conflicting_insertion_event_id = (
+                            await self.store.get_insertion_event_id_by_batch_id(
+                                event.room_id, next_batch_id
+                            )
+                        )
+                    if conflicting_insertion_event_id is not None:
+                        # The current insertion event that we're processing is invalid
+                        # because an insertion event already exists in the room with the
+                        # same next_batch_id. We can't allow multiple because the batch
+                        # pointing will get weird, e.g. we can't determine which insertion
+                        # event the batch event is pointing to.
+                        raise SynapseError(
+                            HTTPStatus.BAD_REQUEST,
+                            "Another insertion event already exists with the same next_batch_id",
+                            errcode=Codes.INVALID_PARAM,
+                        )
 
-        # Mark any `m.historical` messages as backfilled so they don't appear
-        # in `/sync` and have the proper decrementing `stream_ordering` as we import
-        backfilled = False
-        if event.internal_metadata.is_historical():
-            backfilled = True
+            # Mark any `m.historical` messages as backfilled so they don't appear
+            # in `/sync` and have the proper decrementing `stream_ordering` as we import
+            backfilled = False
+            if event.internal_metadata.is_historical():
+                backfilled = True
 
-        # Note that this returns the event that was persisted, which may not be
-        # the same as we passed in if it was deduplicated due transaction IDs.
+        assert self._storage_controllers.persistence is not None
         (
-            event,
-            event_pos,
+            persisted_events,
             max_stream_token,
-        ) = await self._storage_controllers.persistence.persist_event(
-            event, context=context, backfilled=backfilled
+        ) = await self._storage_controllers.persistence.persist_events(
+            events_and_context, backfilled=backfilled
         )
 
-        if self._ephemeral_events_enabled:
-            # If there's an expiry timestamp on the event, schedule its expiry.
-            self._message_handler.maybe_schedule_expiry(event)
+        events_and_pos = []
+        for event in persisted_events:
+            if self._ephemeral_events_enabled:
+                # If there's an expiry timestamp on the event, schedule its expiry.
+                self._message_handler.maybe_schedule_expiry(event)
+
+            stream_ordering = event.internal_metadata.stream_ordering
+            assert stream_ordering is not None
+            pos = PersistedEventPosition(self._instance_name, stream_ordering)
+            events_and_pos.append((event, pos))
+
+            if event.type == EventTypes.Message:
+                # We don't want to block sending messages on any presence code. This
+                # matters as sometimes presence code can take a while.
+                run_in_background(self._bump_active_time, requester.user)
 
         async def _notify() -> None:
             try:
-                await self.notifier.on_new_room_event(
-                    event, event_pos, max_stream_token, extra_users=extra_users
+                await self.notifier.on_new_room_events(
+                    events_and_pos, max_stream_token, extra_users=extra_users
                 )
             except Exception:
-                logger.exception(
-                    "Error notifying about new room event %s",
-                    event.event_id,
-                )
+                logger.exception("Error notifying about new room events")
 
         run_in_background(_notify)
 
-        if event.type == EventTypes.Message:
-            # We don't want to block sending messages on any presence code. This
-            # matters as sometimes presence code can take a while.
-            run_in_background(self._bump_active_time, requester.user)
-
-        return event
+        return persisted_events[-1]
 
     async def _maybe_kick_guest_users(
         self, event: EventBase, context: EventContext
@@ -1897,8 +2020,7 @@ class EventCreationHandler:
                 # shadow-banned user.
                 await self.handle_new_client_event(
                     requester,
-                    event,
-                    context,
+                    events_and_context=[(event, context)],
                     ratelimit=False,
                     ignore_shadow_ban=True,
                 )
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index d7a8226900..41c675f408 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -12,14 +12,28 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import binascii
 import inspect
+import json
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, TypeVar, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Generic,
+    List,
+    Optional,
+    Type,
+    TypeVar,
+    Union,
+)
 from urllib.parse import urlencode, urlparse
 
 import attr
+import unpaddedbase64
 from authlib.common.security import generate_token
-from authlib.jose import JsonWebToken, jwt
+from authlib.jose import JsonWebToken, JWTClaims
+from authlib.jose.errors import InvalidClaimError, JoseError, MissingClaimError
 from authlib.oauth2.auth import ClientAuth
 from authlib.oauth2.rfc6749.parameters import prepare_grant_uri
 from authlib.oidc.core import CodeIDToken, UserInfo
@@ -35,9 +49,12 @@ from typing_extensions import TypedDict
 from twisted.web.client import readBody
 from twisted.web.http_headers import Headers
 
+from synapse.api.errors import SynapseError
 from synapse.config import ConfigError
 from synapse.config.oidc import OidcProviderClientSecretJwtKey, OidcProviderConfig
 from synapse.handlers.sso import MappingException, UserAttributes
+from synapse.http.server import finish_request
+from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
@@ -88,6 +105,8 @@ class Token(TypedDict):
 #: there is no real point of doing this in our case.
 JWK = Dict[str, str]
 
+C = TypeVar("C")
+
 
 #: A JWK Set, as per RFC7517 sec 5.
 class JWKS(TypedDict):
@@ -247,6 +266,80 @@ class OidcHandler:
 
         await oidc_provider.handle_oidc_callback(request, session_data, code)
 
+    async def handle_backchannel_logout(self, request: SynapseRequest) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
+
+        This extracts the logout_token from the request and tries to figure out
+        which OpenID Provider it is comming from. This works by matching the iss claim
+        with the issuer and the aud claim with the client_id.
+
+        Since at this point we don't know who signed the JWT, we can't just
+        decode it using authlib since it will always verifies the signature. We
+        have to decode it manually without validating the signature. The actual JWT
+        verification is done in the `OidcProvider.handler_backchannel_logout` method,
+        once we figured out which provider sent the request.
+
+        Args:
+            request: the incoming request from the browser.
+        """
+        logout_token = parse_string(request, "logout_token")
+        if logout_token is None:
+            raise SynapseError(400, "Missing logout_token in request")
+
+        # A JWT looks like this:
+        #    header.payload.signature
+        # where all parts are encoded with urlsafe base64.
+        # The aud and iss claims we care about are in the payload part, which
+        # is a JSON object.
+        try:
+            # By destructuring the list after splitting, we ensure that we have
+            # exactly 3 segments
+            _, payload, _ = logout_token.split(".")
+        except ValueError:
+            raise SynapseError(400, "Invalid logout_token in request")
+
+        try:
+            payload_bytes = unpaddedbase64.decode_base64(payload)
+            claims = json_decoder.decode(payload_bytes.decode("utf-8"))
+        except (json.JSONDecodeError, binascii.Error, UnicodeError):
+            raise SynapseError(400, "Invalid logout_token payload in request")
+
+        try:
+            # Let's extract the iss and aud claims
+            iss = claims["iss"]
+            aud = claims["aud"]
+            # The aud claim can be either a string or a list of string. Here we
+            # normalize it as a list of strings.
+            if isinstance(aud, str):
+                aud = [aud]
+
+            # Check that we have the right types for the aud and the iss claims
+            if not isinstance(iss, str) or not isinstance(aud, list):
+                raise TypeError()
+            for a in aud:
+                if not isinstance(a, str):
+                    raise TypeError()
+
+            # At this point we properly checked both claims types
+            issuer: str = iss
+            audience: List[str] = aud
+        except (TypeError, KeyError):
+            raise SynapseError(400, "Invalid issuer/audience in logout_token")
+
+        # Now that we know the audience and the issuer, we can figure out from
+        # what provider it is coming from
+        oidc_provider: Optional[OidcProvider] = None
+        for provider in self._providers.values():
+            if provider.issuer == issuer and provider.client_id in audience:
+                oidc_provider = provider
+                break
+
+        if oidc_provider is None:
+            raise SynapseError(400, "Could not find the OP that issued this event")
+
+        # Ask the provider to handle the logout request.
+        await oidc_provider.handle_backchannel_logout(request, logout_token)
+
 
 class OidcError(Exception):
     """Used to catch errors when calling the token_endpoint"""
@@ -275,6 +368,7 @@ class OidcProvider:
         provider: OidcProviderConfig,
     ):
         self._store = hs.get_datastores().main
+        self._clock = hs.get_clock()
 
         self._macaroon_generaton = macaroon_generator
 
@@ -341,6 +435,7 @@ class OidcProvider:
         self.idp_brand = provider.idp_brand
 
         self._sso_handler = hs.get_sso_handler()
+        self._device_handler = hs.get_device_handler()
 
         self._sso_handler.register_identity_provider(self)
 
@@ -399,6 +494,41 @@ class OidcProvider:
             # If we're not using userinfo, we need a valid jwks to validate the ID token
             m.validate_jwks_uri()
 
+        if self._config.backchannel_logout_enabled:
+            if not m.get("backchannel_logout_supported", False):
+                logger.warning(
+                    "OIDC Back-Channel Logout is enabled for issuer %r"
+                    "but it does not advertise support for it",
+                    self.issuer,
+                )
+
+            elif not m.get("backchannel_logout_session_supported", False):
+                logger.warning(
+                    "OIDC Back-Channel Logout is enabled and supported "
+                    "by issuer %r but it might not send a session ID with "
+                    "logout tokens, which is required for the logouts to work",
+                    self.issuer,
+                )
+
+            if not self._config.backchannel_logout_ignore_sub:
+                # If OIDC backchannel logouts are enabled, the provider mapping provider
+                # should use the `sub` claim. We verify that by mapping a dumb user and
+                # see if we get back the sub claim
+                user = UserInfo({"sub": "thisisasubject"})
+                try:
+                    subject = self._user_mapping_provider.get_remote_user_id(user)
+                    if subject != user["sub"]:
+                        raise ValueError("Unexpected subject")
+                except Exception:
+                    logger.warning(
+                        f"OIDC Back-Channel Logout is enabled for issuer {self.issuer!r} "
+                        "but it looks like the configured `user_mapping_provider` "
+                        "does not use the `sub` claim as subject. If it is the case, "
+                        "and you want Synapse to ignore the `sub` claim in OIDC "
+                        "Back-Channel Logouts, set `backchannel_logout_ignore_sub` "
+                        "to `true` in the issuer config."
+                    )
+
     @property
     def _uses_userinfo(self) -> bool:
         """Returns True if the ``userinfo_endpoint`` should be used.
@@ -414,6 +544,16 @@ class OidcProvider:
             or self._user_profile_method == "userinfo_endpoint"
         )
 
+    @property
+    def issuer(self) -> str:
+        """The issuer identifying this provider."""
+        return self._config.issuer
+
+    @property
+    def client_id(self) -> str:
+        """The client_id used when interacting with this provider."""
+        return self._config.client_id
+
     async def load_metadata(self, force: bool = False) -> OpenIDProviderMetadata:
         """Return the provider metadata.
 
@@ -647,7 +787,7 @@ class OidcProvider:
                 Must include an ``access_token`` field.
 
         Returns:
-            UserInfo: an object representing the user.
+            an object representing the user.
         """
         logger.debug("Using the OAuth2 access_token to request userinfo")
         metadata = await self.load_metadata()
@@ -661,61 +801,99 @@ class OidcProvider:
 
         return UserInfo(resp)
 
-    async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
-        """Return an instance of UserInfo from token's ``id_token``.
+    async def _verify_jwt(
+        self,
+        alg_values: List[str],
+        token: str,
+        claims_cls: Type[C],
+        claims_options: Optional[dict] = None,
+        claims_params: Optional[dict] = None,
+    ) -> C:
+        """Decode and validate a JWT, re-fetching the JWKS as needed.
 
         Args:
-            token: the token given by the ``token_endpoint``.
-                Must include an ``id_token`` field.
-            nonce: the nonce value originally sent in the initial authorization
-                request. This value should match the one inside the token.
+            alg_values: list of `alg` values allowed when verifying the JWT.
+            token: the JWT.
+            claims_cls: the JWTClaims class to use to validate the claims.
+            claims_options: dict of options passed to the `claims_cls` constructor.
+            claims_params: dict of params passed to the `claims_cls` constructor.
 
         Returns:
-            The decoded claims in the ID token.
+            The decoded claims in the JWT.
         """
-        metadata = await self.load_metadata()
-        claims_params = {
-            "nonce": nonce,
-            "client_id": self._client_auth.client_id,
-        }
-        if "access_token" in token:
-            # If we got an `access_token`, there should be an `at_hash` claim
-            # in the `id_token` that we can check against.
-            claims_params["access_token"] = token["access_token"]
-
-        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
         jwt = JsonWebToken(alg_values)
 
-        claim_options = {"iss": {"values": [metadata["issuer"]]}}
-
-        id_token = token["id_token"]
-        logger.debug("Attempting to decode JWT id_token %r", id_token)
+        logger.debug("Attempting to decode JWT (%s) %r", claims_cls.__name__, token)
 
         # Try to decode the keys in cache first, then retry by forcing the keys
         # to be reloaded
         jwk_set = await self.load_jwks()
         try:
             claims = jwt.decode(
-                id_token,
+                token,
                 key=jwk_set,
-                claims_cls=CodeIDToken,
-                claims_options=claim_options,
+                claims_cls=claims_cls,
+                claims_options=claims_options,
                 claims_params=claims_params,
             )
         except ValueError:
             logger.info("Reloading JWKS after decode error")
             jwk_set = await self.load_jwks(force=True)  # try reloading the jwks
             claims = jwt.decode(
-                id_token,
+                token,
                 key=jwk_set,
-                claims_cls=CodeIDToken,
-                claims_options=claim_options,
+                claims_cls=claims_cls,
+                claims_options=claims_options,
                 claims_params=claims_params,
             )
 
-        logger.debug("Decoded id_token JWT %r; validating", claims)
+        logger.debug("Decoded JWT (%s) %r; validating", claims_cls.__name__, claims)
+
+        claims.validate(
+            now=self._clock.time(), leeway=120
+        )  # allows 2 min of clock skew
+        return claims
+
+    async def _parse_id_token(self, token: Token, nonce: str) -> CodeIDToken:
+        """Return an instance of UserInfo from token's ``id_token``.
+
+        Args:
+            token: the token given by the ``token_endpoint``.
+                Must include an ``id_token`` field.
+            nonce: the nonce value originally sent in the initial authorization
+                request. This value should match the one inside the token.
+
+        Returns:
+            The decoded claims in the ID token.
+        """
+        id_token = token.get("id_token")
+
+        # That has been theoritically been checked by the caller, so even though
+        # assertion are not enabled in production, it is mainly here to appease mypy
+        assert id_token is not None
+
+        metadata = await self.load_metadata()
+
+        claims_params = {
+            "nonce": nonce,
+            "client_id": self._client_auth.client_id,
+        }
+        if "access_token" in token:
+            # If we got an `access_token`, there should be an `at_hash` claim
+            # in the `id_token` that we can check against.
+            claims_params["access_token"] = token["access_token"]
+
+        claims_options = {"iss": {"values": [metadata["issuer"]]}}
+
+        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
 
-        claims.validate(leeway=120)  # allows 2 min of clock skew
+        claims = await self._verify_jwt(
+            alg_values=alg_values,
+            token=id_token,
+            claims_cls=CodeIDToken,
+            claims_options=claims_options,
+            claims_params=claims_params,
+        )
 
         return claims
 
@@ -1036,6 +1214,146 @@ class OidcProvider:
         # to be strings.
         return str(remote_user_id)
 
+    async def handle_backchannel_logout(
+        self, request: SynapseRequest, logout_token: str
+    ) -> None:
+        """Handle an incoming request to /_synapse/client/oidc/backchannel_logout
+
+        The OIDC Provider posts a logout token to this endpoint when a user
+        session ends. That token is a JWT signed with the same keys as
+        ID tokens. The OpenID Connect Back-Channel Logout draft explains how to
+        validate the JWT and figure out what session to end.
+
+        Args:
+            request: The request to respond to
+            logout_token: The logout token (a JWT) extracted from the request body
+        """
+        # Back-Channel Logout can be disabled in the config, hence this check.
+        # This is not that important for now since Synapse is registered
+        # manually to the OP, so not specifying the backchannel-logout URI is
+        # as effective than disabling it here. It might make more sense if we
+        # support dynamic registration in Synapse at some point.
+        if not self._config.backchannel_logout_enabled:
+            logger.warning(
+                f"Received an OIDC Back-Channel Logout request from issuer {self.issuer!r} but it is disabled in config"
+            )
+
+            # TODO: this responds with a 400 status code, which is what the OIDC
+            # Back-Channel Logout spec expects, but spec also suggests answering with
+            # a JSON object, with the `error` and `error_description` fields set, which
+            # we are not doing here.
+            # See https://openid.net/specs/openid-connect-backchannel-1_0.html#BCResponse
+            raise SynapseError(
+                400, "OpenID Connect Back-Channel Logout is disabled for this provider"
+            )
+
+        metadata = await self.load_metadata()
+
+        # As per OIDC Back-Channel Logout 1.0 sec. 2.4:
+        #   A Logout Token MUST be signed and MAY also be encrypted. The same
+        #   keys are used to sign and encrypt Logout Tokens as are used for ID
+        #   Tokens. If the Logout Token is encrypted, it SHOULD replicate the
+        #   iss (issuer) claim in the JWT Header Parameters, as specified in
+        #   Section 5.3 of [JWT].
+        alg_values = metadata.get("id_token_signing_alg_values_supported", ["RS256"])
+
+        # As per sec. 2.6:
+        #    3. Validate the iss, aud, and iat Claims in the same way they are
+        #       validated in ID Tokens.
+        # Which means the audience should contain Synapse's client_id and the
+        # issuer should be the IdP issuer
+        claims_options = {
+            "iss": {"values": [metadata["issuer"]]},
+            "aud": {"values": [self.client_id]},
+        }
+
+        try:
+            claims = await self._verify_jwt(
+                alg_values=alg_values,
+                token=logout_token,
+                claims_cls=LogoutToken,
+                claims_options=claims_options,
+            )
+        except JoseError:
+            logger.exception("Invalid logout_token")
+            raise SynapseError(400, "Invalid logout_token")
+
+        # As per sec. 2.6:
+        #    4. Verify that the Logout Token contains a sub Claim, a sid Claim,
+        #       or both.
+        #    5. Verify that the Logout Token contains an events Claim whose
+        #       value is JSON object containing the member name
+        #       http://schemas.openid.net/event/backchannel-logout.
+        #    6. Verify that the Logout Token does not contain a nonce Claim.
+        # This is all verified by the LogoutToken claims class, so at this
+        # point the `sid` claim exists and is a string.
+        sid: str = claims.get("sid")
+
+        # If the `sub` claim was included in the logout token, we check that it matches
+        # that it matches the right user. We can have cases where the `sub` claim is not
+        # the ID saved in database, so we let admins disable this check in config.
+        sub: Optional[str] = claims.get("sub")
+        expected_user_id: Optional[str] = None
+        if sub is not None and not self._config.backchannel_logout_ignore_sub:
+            expected_user_id = await self._store.get_user_by_external_id(
+                self.idp_id, sub
+            )
+
+        # Invalidate any running user-mapping sessions, in-flight login tokens and
+        # active devices
+        await self._sso_handler.revoke_sessions_for_provider_session_id(
+            auth_provider_id=self.idp_id,
+            auth_provider_session_id=sid,
+            expected_user_id=expected_user_id,
+        )
+
+        request.setResponseCode(200)
+        request.setHeader(b"Cache-Control", b"no-cache, no-store")
+        request.setHeader(b"Pragma", b"no-cache")
+        finish_request(request)
+
+
+class LogoutToken(JWTClaims):
+    """
+    Holds and verify claims of a logout token, as per
+    https://openid.net/specs/openid-connect-backchannel-1_0.html#LogoutToken
+    """
+
+    REGISTERED_CLAIMS = ["iss", "sub", "aud", "iat", "jti", "events", "sid"]
+
+    def validate(self, now: Optional[int] = None, leeway: int = 0) -> None:
+        """Validate everything in claims payload."""
+        super().validate(now, leeway)
+        self.validate_sid()
+        self.validate_events()
+        self.validate_nonce()
+
+    def validate_sid(self) -> None:
+        """Ensure the sid claim is present"""
+        sid = self.get("sid")
+        if not sid:
+            raise MissingClaimError("sid")
+
+        if not isinstance(sid, str):
+            raise InvalidClaimError("sid")
+
+    def validate_nonce(self) -> None:
+        """Ensure the nonce claim is absent"""
+        if "nonce" in self:
+            raise InvalidClaimError("nonce")
+
+    def validate_events(self) -> None:
+        """Ensure the events claim is present and with the right value"""
+        events = self.get("events")
+        if not events:
+            raise MissingClaimError("events")
+
+        if not isinstance(events, dict):
+            raise InvalidClaimError("events")
+
+        if "http://schemas.openid.net/event/backchannel-logout" not in events:
+            raise InvalidClaimError("events")
+
 
 # number of seconds a newly-generated client secret should be valid for
 CLIENT_SECRET_VALIDITY_SECONDS = 3600
@@ -1105,6 +1423,7 @@ class JwtClientSecret:
         logger.info(
             "Generating new JWT for %s: %s %s", self._oauth_issuer, header, payload
         )
+        jwt = JsonWebToken(header["alg"])
         self._cached_secret = jwt.encode(header, payload, self._key.key)
         self._cached_secret_replacement_time = (
             expires_at - CLIENT_SECRET_MIN_VALIDITY_SECONDS
@@ -1119,9 +1438,6 @@ class UserAttributeDict(TypedDict):
     emails: List[str]
 
 
-C = TypeVar("C")
-
-
 class OidcMappingProvider(Generic[C]):
     """A mapping provider maps a UserInfo object to user attributes.
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index d865ee6e73..fcb8572348 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -458,11 +458,6 @@ class PaginationHandler:
             # `/messages` should still works with live tokens when manually provided.
             assert from_token.room_key.topological is not None
 
-        if pagin_config.limit is None:
-            # This shouldn't happen as we've set a default limit before this
-            # gets called.
-            raise Exception("limit not set")
-
         room_token = from_token.room_key
 
         async with self.pagination_lock.read(room_id):
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 4e575ffbaa..cf08737d11 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -201,7 +201,7 @@ class BasePresenceHandler(abc.ABC):
         """Get the current presence state for multiple users.
 
         Returns:
-            dict: `user_id` -> `UserPresenceState`
+            A mapping of `user_id` -> `UserPresenceState`
         """
         states = {}
         missing = []
@@ -256,7 +256,7 @@ class BasePresenceHandler(abc.ABC):
         with the app.
         """
 
-    async def update_external_syncs_row(
+    async def update_external_syncs_row(  # noqa: B027 (no-op by design)
         self, process_id: str, user_id: str, is_syncing: bool, sync_time_msec: int
     ) -> None:
         """Update the syncing users for an external process as a delta.
@@ -272,7 +272,9 @@ class BasePresenceHandler(abc.ABC):
             sync_time_msec: Time in ms when the user was last syncing
         """
 
-    async def update_external_syncs_clear(self, process_id: str) -> None:
+    async def update_external_syncs_clear(  # noqa: B027 (no-op by design)
+        self, process_id: str
+    ) -> None:
         """Marks all users that had been marked as syncing by a given process
         as offline.
 
@@ -476,7 +478,7 @@ class WorkerPresenceHandler(BasePresenceHandler):
             return _NullContextManager()
 
         prev_state = await self.current_state_for_user(user_id)
-        if prev_state != PresenceState.BUSY:
+        if prev_state.state != PresenceState.BUSY:
             # We set state here but pass ignore_status_msg = True as we don't want to
             # cause the status message to be cleared.
             # Note that this causes last_active_ts to be incremented which is not
@@ -1596,7 +1598,9 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
         self,
         user: UserID,
         from_key: Optional[int],
-        limit: Optional[int] = None,
+        # Having a default limit doesn't match the EventSource API, but some
+        # callers do not provide it. It is unused in this class.
+        limit: int = 0,
         room_ids: Optional[Collection[str]] = None,
         is_guest: bool = False,
         explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index d8ff5289b5..4bf9a047a3 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -307,7 +307,11 @@ class ProfileHandler:
         if not self.max_avatar_size and not self.allowed_avatar_mimetypes:
             return True
 
-        server_name, _, media_id = parse_and_validate_mxc_uri(mxc)
+        host, port, media_id = parse_and_validate_mxc_uri(mxc)
+        if port is not None:
+            server_name = host + ":" + str(port)
+        else:
+            server_name = host
 
         if server_name == self.server_name:
             media_info = await self.store.get_local_media(media_id)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 4768a34c07..ac01582442 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -63,8 +63,6 @@ class ReceiptsHandler:
         self.clock = self.hs.get_clock()
         self.state = hs.get_state_handler()
 
-        self._msc3771_enabled = hs.config.experimental.msc3771_enabled
-
     async def _received_remote_receipt(self, origin: str, content: JsonDict) -> None:
         """Called when we receive an EDU of type m.receipt from a remote HS."""
         receipts = []
@@ -96,11 +94,10 @@ class ReceiptsHandler:
                     # Check if these receipts apply to a thread.
                     thread_id = None
                     data = user_values.get("data", {})
-                    if self._msc3771_enabled and isinstance(data, dict):
-                        thread_id = data.get("thread_id")
-                        # If the thread ID is invalid, consider it missing.
-                        if not isinstance(thread_id, str):
-                            thread_id = None
+                    thread_id = data.get("thread_id")
+                    # If the thread ID is invalid, consider it missing.
+                    if not isinstance(thread_id, str):
+                        thread_id = None
 
                     receipts.append(
                         ReadReceipt(
@@ -260,7 +257,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
         self,
         user: UserID,
         from_key: int,
-        limit: Optional[int],
+        limit: int,
         room_ids: Iterable[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index cfcadb34db..ca1c7a1866 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -220,6 +220,7 @@ class RegistrationHandler:
         by_admin: bool = False,
         user_agent_ips: Optional[List[Tuple[str, str]]] = None,
         auth_provider_id: Optional[str] = None,
+        approved: bool = False,
     ) -> str:
         """Registers a new client on the server.
 
@@ -246,6 +247,8 @@ class RegistrationHandler:
             user_agent_ips: Tuples of user-agents and IP addresses used
                 during the registration process.
             auth_provider_id: The SSO IdP the user used, if any.
+            approved: True if the new user should be considered already
+                approved by an administrator.
         Returns:
             The registered user_id.
         Raises:
@@ -307,6 +310,7 @@ class RegistrationHandler:
                 user_type=user_type,
                 address=address,
                 shadow_banned=shadow_banned,
+                approved=approved,
             )
 
             profile = await self.store.get_profileinfo(localpart)
@@ -695,6 +699,7 @@ class RegistrationHandler:
         user_type: Optional[str] = None,
         address: Optional[str] = None,
         shadow_banned: bool = False,
+        approved: bool = False,
     ) -> None:
         """Register user in the datastore.
 
@@ -713,6 +718,7 @@ class RegistrationHandler:
                 api.constants.UserTypes, or None for a normal user.
             address: the IP address used to perform the registration.
             shadow_banned: Whether to shadow-ban the user
+            approved: Whether to mark the user as approved by an administrator
         """
         if self.hs.config.worker.worker_app:
             await self._register_client(
@@ -726,6 +732,7 @@ class RegistrationHandler:
                 user_type=user_type,
                 address=address,
                 shadow_banned=shadow_banned,
+                approved=approved,
             )
         else:
             await self.store.register_user(
@@ -738,6 +745,7 @@ class RegistrationHandler:
                 admin=admin,
                 user_type=user_type,
                 shadow_banned=shadow_banned,
+                approved=approved,
             )
 
             # Only call the account validity module(s) on the main process, to avoid
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 48cc9c1ac5..64d373e9d7 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -11,16 +11,18 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import enum
 import logging
 from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
 
 import attr
 
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
 from synapse.logging.tracing import SynapseTags, set_attribute, trace
-from synapse.storage.databases.main.relations import _RelatedEvent
+from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
+from synapse.streams.config import PaginationConfig
 from synapse.types import JsonDict, Requester, StreamToken, UserID
 from synapse.visibility import filter_events_for_client
 
@@ -31,6 +33,13 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class ThreadsListInclude(str, enum.Enum):
+    """Valid values for the 'include' flag of /threads."""
+
+    all = "all"
+    participated = "participated"
+
+
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _ThreadAggregation:
     # The latest event in the thread.
@@ -66,18 +75,17 @@ class RelationsHandler:
         self._clock = hs.get_clock()
         self._event_handler = hs.get_event_handler()
         self._event_serializer = hs.get_event_client_serializer()
+        self._event_creation_handler = hs.get_event_creation_handler()
 
     async def get_relations(
         self,
         requester: Requester,
         event_id: str,
         room_id: str,
+        pagin_config: PaginationConfig,
+        include_original_event: bool,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
-        limit: int = 5,
-        direction: str = "b",
-        from_token: Optional[StreamToken] = None,
-        to_token: Optional[StreamToken] = None,
     ) -> JsonDict:
         """Get related events of a event, ordered by topological ordering.
 
@@ -87,13 +95,10 @@ class RelationsHandler:
             requester: The user requesting the relations.
             event_id: Fetch events that relate to this event ID.
             room_id: The room the event belongs to.
+            pagin_config: The pagination config rules to apply, if any.
+            include_original_event: Whether to include the parent event.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
-            limit: Only fetch the most recent `limit` events.
-            direction: Whether to fetch the most recent first (`"b"`) or the
-                oldest first (`"f"`).
-            from_token: Fetch rows from the given token, or from the start if None.
-            to_token: Fetch rows up to the given token, or up to the end if None.
 
         Returns:
             The pagination chunk.
@@ -121,10 +126,10 @@ class RelationsHandler:
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
-            limit=limit,
-            direction=direction,
-            from_token=from_token,
-            to_token=to_token,
+            limit=pagin_config.limit,
+            direction=pagin_config.direction,
+            from_token=pagin_config.from_token,
+            to_token=pagin_config.to_token,
         )
 
         events = await self._main_store.get_events_as_list(
@@ -138,31 +143,32 @@ class RelationsHandler:
             is_peeking=(member_event_id is None),
         )
 
-        now = self._clock.time_msec()
-        # Do not bundle aggregations when retrieving the original event because
-        # we want the content before relations are applied to it.
-        original_event = self._event_serializer.serialize_event(
-            event, now, bundle_aggregations=None
-        )
         # The relations returned for the requested event do include their
         # bundled aggregations.
         aggregations = await self.get_bundled_aggregations(
             events, requester.user.to_string()
         )
-        serialized_events = self._event_serializer.serialize_events(
-            events, now, bundle_aggregations=aggregations
-        )
 
-        return_value = {
-            "chunk": serialized_events,
-            "original_event": original_event,
+        now = self._clock.time_msec()
+        return_value: JsonDict = {
+            "chunk": self._event_serializer.serialize_events(
+                events, now, bundle_aggregations=aggregations
+            ),
         }
+        if include_original_event:
+            # Do not bundle aggregations when retrieving the original event because
+            # we want the content before relations are applied to it.
+            return_value["original_event"] = self._event_serializer.serialize_event(
+                event, now, bundle_aggregations=None
+            )
 
         if next_token:
             return_value["next_batch"] = await next_token.to_string(self._main_store)
 
-        if from_token:
-            return_value["prev_batch"] = await from_token.to_string(self._main_store)
+        if pagin_config.from_token:
+            return_value["prev_batch"] = await pagin_config.from_token.to_string(
+                self._main_store
+            )
 
         return return_value
 
@@ -201,6 +207,59 @@ class RelationsHandler:
 
         return related_events, next_token
 
+    async def redact_events_related_to(
+        self,
+        requester: Requester,
+        event_id: str,
+        initial_redaction_event: EventBase,
+        relation_types: List[str],
+    ) -> None:
+        """Redacts all events related to the given event ID with one of the given
+        relation types.
+
+        This method is expected to be called when redacting the event referred to by
+        the given event ID.
+
+        If an event cannot be redacted (e.g. because of insufficient permissions), log
+        the error and try to redact the next one.
+
+        Args:
+            requester: The requester to redact events on behalf of.
+            event_id: The event IDs to look and redact relations of.
+            initial_redaction_event: The redaction for the event referred to by
+                event_id.
+            relation_types: The types of relations to look for.
+
+        Raises:
+            ShadowBanError if the requester is shadow-banned
+        """
+        related_event_ids = (
+            await self._main_store.get_all_relations_for_event_with_types(
+                event_id, relation_types
+            )
+        )
+
+        for related_event_id in related_event_ids:
+            try:
+                await self._event_creation_handler.create_and_send_nonmember_event(
+                    requester,
+                    {
+                        "type": EventTypes.Redaction,
+                        "content": initial_redaction_event.content,
+                        "room_id": initial_redaction_event.room_id,
+                        "sender": requester.user.to_string(),
+                        "redacts": related_event_id,
+                    },
+                    ratelimit=False,
+                )
+            except SynapseError as e:
+                logger.warning(
+                    "Failed to redact event %s (related to event %s): %s",
+                    related_event_id,
+                    event_id,
+                    e.msg,
+                )
+
     @trace
     async def get_annotations_for_event(
         self,
@@ -490,3 +549,79 @@ class RelationsHandler:
             results.setdefault(event_id, BundledAggregations()).replace = edit
 
         return results
+
+    async def get_threads(
+        self,
+        requester: Requester,
+        room_id: str,
+        include: ThreadsListInclude,
+        limit: int = 5,
+        from_token: Optional[ThreadsNextBatch] = None,
+    ) -> JsonDict:
+        """Get related events of a event, ordered by topological ordering.
+
+        Args:
+            requester: The user requesting the relations.
+            room_id: The room the event belongs to.
+            include: One of "all" or "participated" to indicate which threads should
+                be returned.
+            limit: Only fetch the most recent `limit` events.
+            from_token: Fetch rows from the given token, or from the start if None.
+
+        Returns:
+            The pagination chunk.
+        """
+
+        user_id = requester.user.to_string()
+
+        # TODO Properly handle a user leaving a room.
+        (_, member_event_id) = await self._auth.check_user_in_room_or_world_readable(
+            room_id, requester, allow_departed_users=True
+        )
+
+        # Note that ignored users are not passed into get_relations_for_event
+        # below. Ignored users are handled in filter_events_for_client (and by
+        # not passing them in here we should get a better cache hit rate).
+        thread_roots, next_batch = await self._main_store.get_threads(
+            room_id=room_id, limit=limit, from_token=from_token
+        )
+
+        events = await self._main_store.get_events_as_list(thread_roots)
+
+        if include == ThreadsListInclude.participated:
+            # Pre-seed thread participation with whether the requester sent the event.
+            participated = {event.event_id: event.sender == user_id for event in events}
+            # For events the requester did not send, check the database for whether
+            # the requester sent a threaded reply.
+            participated.update(
+                await self._main_store.get_threads_participated(
+                    [eid for eid, p in participated.items() if not p],
+                    user_id,
+                )
+            )
+
+            # Limit the returned threads to those the user has participated in.
+            events = [event for event in events if participated[event.event_id]]
+
+        events = await filter_events_for_client(
+            self._storage_controllers,
+            user_id,
+            events,
+            is_peeking=(member_event_id is None),
+        )
+
+        aggregations = await self.get_bundled_aggregations(
+            events, requester.user.to_string()
+        )
+
+        now = self._clock.time_msec()
+        serialized_events = self._event_serializer.serialize_events(
+            events, now, bundle_aggregations=aggregations
+        )
+
+        return_value: JsonDict = {"chunk": serialized_events}
+
+        if next_batch:
+            return_value["next_batch"] = str(next_batch)
+
+        return return_value
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 33e9a87002..6dcfd86fdf 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -49,7 +49,6 @@ from synapse.api.constants import (
 from synapse.api.errors import (
     AuthError,
     Codes,
-    HttpResponseException,
     LimitExceededError,
     NotFoundError,
     StoreError,
@@ -60,7 +59,6 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
 from synapse.events.utils import copy_and_fixup_power_levels_contents
-from synapse.federation.federation_client import InvalidResponseError
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
 from synapse.rest.admin._base import assert_user_is_admin
@@ -229,9 +227,7 @@ class RoomCreationHandler:
             },
         )
         validate_event_for_room_version(tombstone_event)
-        await self._event_auth_handler.check_auth_rules_from_context(
-            tombstone_event, tombstone_context
-        )
+        await self._event_auth_handler.check_auth_rules_from_context(tombstone_event)
 
         # Upgrade the room
         #
@@ -301,8 +297,7 @@ class RoomCreationHandler:
         # now send the tombstone
         await self.event_creation_handler.handle_new_client_event(
             requester=requester,
-            event=tombstone_event,
-            context=tombstone_context,
+            events_and_context=[(tombstone_event, tombstone_context)],
         )
 
         state_filter = StateFilter.from_types(
@@ -562,7 +557,6 @@ class RoomCreationHandler:
             invite_list=[],
             initial_state=initial_state,
             creation_content=creation_content,
-            ratelimit=False,
         )
 
         # Transfer membership events
@@ -716,7 +710,7 @@ class RoomCreationHandler:
 
         if (
             self._server_notices_mxid is not None
-            and requester.user.to_string() == self._server_notices_mxid
+            and user_id == self._server_notices_mxid
         ):
             # allow the server notices mxid to create rooms
             is_requester_admin = True
@@ -756,6 +750,10 @@ class RoomCreationHandler:
                 )
 
         if ratelimit:
+            # Rate limit once in advance, but don't rate limit the individual
+            # events in the room — room creation isn't atomic and it's very
+            # janky if half the events in the initial state don't make it because
+            # of rate limiting.
             await self.request_ratelimiter.ratelimit(requester)
 
         room_version_id = config.get(
@@ -916,7 +914,6 @@ class RoomCreationHandler:
             room_alias=room_alias,
             power_level_content_override=power_level_content_override,
             creator_join_profile=creator_join_profile,
-            ratelimit=ratelimit,
         )
 
         if "name" in config:
@@ -1040,26 +1037,36 @@ class RoomCreationHandler:
         room_alias: Optional[RoomAlias] = None,
         power_level_content_override: Optional[JsonDict] = None,
         creator_join_profile: Optional[JsonDict] = None,
-        ratelimit: bool = True,
     ) -> Tuple[int, str, int]:
-        """Sends the initial events into a new room.
+        """Sends the initial events into a new room. Sends the room creation, membership,
+        and power level events into the room sequentially, then creates and batches up the
+        rest of the events to persist as a batch to the DB.
 
         `power_level_content_override` doesn't apply when initial state has
         power level state event content.
 
+        Rate limiting should already have been applied by this point.
+
         Returns:
             A tuple containing the stream ID, event ID and depth of the last
             event sent to the room.
         """
 
         creator_id = creator.user.to_string()
-
         event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
-
         depth = 1
-        last_sent_event_id: Optional[str] = None
 
-        def create(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
+        # the most recently created event
+        prev_event: List[str] = []
+        # a map of event types, state keys -> event_ids. We collect these mappings this as events are
+        # created (but not persisted to the db) to determine state for future created events
+        # (as this info can't be pulled from the db)
+        state_map: MutableStateMap[str] = {}
+        # current_state_group of last event created. Used for computing event context of
+        # events to be batched
+        current_state_group = None
+
+        def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
             e = {"type": etype, "content": content}
 
             e.update(event_keys)
@@ -1067,32 +1074,44 @@ class RoomCreationHandler:
 
             return e
 
-        async def send(etype: str, content: JsonDict, **kwargs: Any) -> int:
-            nonlocal last_sent_event_id
+        async def create_event(
+            etype: str,
+            content: JsonDict,
+            for_batch: bool,
+            **kwargs: Any,
+        ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
+            """
+            Creates an event and associated event context.
+            Args:
+                etype: the type of event to be created
+                content: content of the event
+                for_batch: whether the event is being created for batch persisting. If
+                bool for_batch is true, this will create an event using the prev_event_ids,
+                and will create an event context for the event using the parameters state_map
+                and current_state_group, thus these parameters must be provided in this
+                case if for_batch is True. The subsequently created event and context
+                are suitable for being batched up and bulk persisted to the database
+                with other similarly created events.
+            """
             nonlocal depth
+            nonlocal prev_event
 
-            event = create(etype, content, **kwargs)
-            logger.debug("Sending %s in new room", etype)
-            # Allow these events to be sent even if the user is shadow-banned to
-            # allow the room creation to complete.
-            (
-                sent_event,
-                last_stream_id,
-            ) = await self.event_creation_handler.create_and_send_nonmember_event(
+            event_dict = create_event_dict(etype, content, **kwargs)
+
+            new_event, new_context = await self.event_creation_handler.create_event(
                 creator,
-                event,
-                ratelimit=False,
-                ignore_shadow_ban=True,
-                # Note: we don't pass state_event_ids here because this triggers
-                # an additional query per event to look them up from the events table.
-                prev_event_ids=[last_sent_event_id] if last_sent_event_id else [],
+                event_dict,
+                prev_event_ids=prev_event,
                 depth=depth,
+                state_map=state_map,
+                for_batch=for_batch,
+                current_state_group=current_state_group,
             )
-
-            last_sent_event_id = sent_event.event_id
             depth += 1
+            prev_event = [new_event.event_id]
+            state_map[(new_event.type, new_event.state_key)] = new_event.event_id
 
-            return last_stream_id
+            return new_event, new_context
 
         try:
             config = self._presets_dict[preset_config]
@@ -1102,31 +1121,55 @@ class RoomCreationHandler:
             )
 
         creation_content.update({"creator": creator_id})
-        await send(etype=EventTypes.Create, content=creation_content)
+        creation_event, creation_context = await create_event(
+            EventTypes.Create, creation_content, False
+        )
 
         logger.debug("Sending %s in new room", EventTypes.Member)
-        # Room create event must exist at this point
-        assert last_sent_event_id is not None
+        ev = await self.event_creation_handler.handle_new_client_event(
+            requester=creator,
+            events_and_context=[(creation_event, creation_context)],
+            ratelimit=False,
+            ignore_shadow_ban=True,
+        )
+        last_sent_event_id = ev.event_id
+
         member_event_id, _ = await self.room_member_handler.update_membership(
             creator,
             creator.user,
             room_id,
             "join",
-            ratelimit=ratelimit,
+            ratelimit=False,
             content=creator_join_profile,
             new_room=True,
             prev_event_ids=[last_sent_event_id],
             depth=depth,
         )
-        last_sent_event_id = member_event_id
+        prev_event = [member_event_id]
+
+        # update the depth and state map here as the membership event has been created
+        # through a different code path
+        depth += 1
+        state_map[(EventTypes.Member, creator.user.to_string())] = member_event_id
 
+        # we need the state group of the membership event as it is the current state group
+        event_to_state = (
+            await self._storage_controllers.state.get_state_group_for_events(
+                [member_event_id]
+            )
+        )
+        current_state_group = event_to_state[member_event_id]
+
+        events_to_send = []
         # We treat the power levels override specially as this needs to be one
         # of the first events that get sent into a room.
         pl_content = initial_state.pop((EventTypes.PowerLevels, ""), None)
         if pl_content is not None:
-            last_sent_stream_id = await send(
-                etype=EventTypes.PowerLevels, content=pl_content
+            power_event, power_context = await create_event(
+                EventTypes.PowerLevels, pl_content, True
             )
+            current_state_group = power_context._state_group
+            events_to_send.append((power_event, power_context))
         else:
             power_level_content: JsonDict = {
                 "users": {creator_id: 100},
@@ -1169,48 +1212,73 @@ class RoomCreationHandler:
             # apply those.
             if power_level_content_override:
                 power_level_content.update(power_level_content_override)
-
-            last_sent_stream_id = await send(
-                etype=EventTypes.PowerLevels, content=power_level_content
+            pl_event, pl_context = await create_event(
+                EventTypes.PowerLevels,
+                power_level_content,
+                True,
             )
+            current_state_group = pl_context._state_group
+            events_to_send.append((pl_event, pl_context))
 
         if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
-            last_sent_stream_id = await send(
-                etype=EventTypes.CanonicalAlias,
-                content={"alias": room_alias.to_string()},
+            room_alias_event, room_alias_context = await create_event(
+                EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
             )
+            current_state_group = room_alias_context._state_group
+            events_to_send.append((room_alias_event, room_alias_context))
 
         if (EventTypes.JoinRules, "") not in initial_state:
-            last_sent_stream_id = await send(
-                etype=EventTypes.JoinRules, content={"join_rule": config["join_rules"]}
+            join_rules_event, join_rules_context = await create_event(
+                EventTypes.JoinRules,
+                {"join_rule": config["join_rules"]},
+                True,
             )
+            current_state_group = join_rules_context._state_group
+            events_to_send.append((join_rules_event, join_rules_context))
 
         if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
-            last_sent_stream_id = await send(
-                etype=EventTypes.RoomHistoryVisibility,
-                content={"history_visibility": config["history_visibility"]},
+            visibility_event, visibility_context = await create_event(
+                EventTypes.RoomHistoryVisibility,
+                {"history_visibility": config["history_visibility"]},
+                True,
             )
+            current_state_group = visibility_context._state_group
+            events_to_send.append((visibility_event, visibility_context))
 
         if config["guest_can_join"]:
             if (EventTypes.GuestAccess, "") not in initial_state:
-                last_sent_stream_id = await send(
-                    etype=EventTypes.GuestAccess,
-                    content={EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
+                guest_access_event, guest_access_context = await create_event(
+                    EventTypes.GuestAccess,
+                    {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
+                    True,
                 )
+                current_state_group = guest_access_context._state_group
+                events_to_send.append((guest_access_event, guest_access_context))
 
         for (etype, state_key), content in initial_state.items():
-            last_sent_stream_id = await send(
-                etype=etype, state_key=state_key, content=content
+            event, context = await create_event(
+                etype, content, True, state_key=state_key
             )
+            current_state_group = context._state_group
+            events_to_send.append((event, context))
 
         if config["encrypted"]:
-            last_sent_stream_id = await send(
-                etype=EventTypes.RoomEncryption,
+            encryption_event, encryption_context = await create_event(
+                EventTypes.RoomEncryption,
+                {"algorithm": RoomEncryptionAlgorithms.DEFAULT},
+                True,
                 state_key="",
-                content={"algorithm": RoomEncryptionAlgorithms.DEFAULT},
             )
+            events_to_send.append((encryption_event, encryption_context))
 
-        return last_sent_stream_id, last_sent_event_id, depth
+        last_event = await self.event_creation_handler.handle_new_client_event(
+            creator,
+            events_to_send,
+            ignore_shadow_ban=True,
+            ratelimit=False,
+        )
+        assert last_event.internal_metadata.stream_ordering is not None
+        return last_event.internal_metadata.stream_ordering, last_event.event_id, depth
 
     def _generate_room_id(self) -> str:
         """Generates a random room ID.
@@ -1383,7 +1451,7 @@ class RoomContextHandler:
             events_before=events_before,
             event=event,
             events_after=events_after,
-            state=await filter_evts(state_events),
+            state=state_events,
             aggregations=aggregations,
             start=await token.copy_and_replace(
                 StreamKeyType.ROOM, results.start
@@ -1429,7 +1497,12 @@ class TimestampLookupHandler:
         Raises:
             SynapseError if unable to find any event locally in the given direction
         """
-
+        logger.debug(
+            "get_event_for_timestamp(room_id=%s, timestamp=%s, direction=%s) Finding closest event...",
+            room_id,
+            timestamp,
+            direction,
+        )
         local_event_id = await self.store.get_event_id_for_timestamp(
             room_id, timestamp, direction
         )
@@ -1476,88 +1549,59 @@ class TimestampLookupHandler:
             )
 
             likely_domains = (
-                await self._storage_controllers.state.get_current_hosts_in_room(room_id)
+                await self._storage_controllers.state.get_current_hosts_in_room_ordered(
+                    room_id
+                )
             )
 
-            # Loop through each homeserver candidate until we get a succesful response
-            for domain in likely_domains:
-                # We don't want to ask our own server for information we don't have
-                if domain == self.server_name:
-                    continue
+            remote_response = await self.federation_client.timestamp_to_event(
+                destinations=likely_domains,
+                room_id=room_id,
+                timestamp=timestamp,
+                direction=direction,
+            )
+            if remote_response is not None:
+                logger.debug(
+                    "get_event_for_timestamp: remote_response=%s",
+                    remote_response,
+                )
 
-                try:
-                    remote_response = await self.federation_client.timestamp_to_event(
-                        domain, room_id, timestamp, direction
-                    )
-                    logger.debug(
-                        "get_event_for_timestamp: response from domain(%s)=%s",
-                        domain,
-                        remote_response,
-                    )
+                remote_event_id = remote_response.event_id
+                remote_origin_server_ts = remote_response.origin_server_ts
 
-                    remote_event_id = remote_response.event_id
-                    remote_origin_server_ts = remote_response.origin_server_ts
-
-                    # Backfill this event so we can get a pagination token for
-                    # it with `/context` and paginate `/messages` from this
-                    # point.
-                    #
-                    # TODO: The requested timestamp may lie in a part of the
-                    #   event graph that the remote server *also* didn't have,
-                    #   in which case they will have returned another event
-                    #   which may be nowhere near the requested timestamp. In
-                    #   the future, we may need to reconcile that gap and ask
-                    #   other homeservers, and/or extend `/timestamp_to_event`
-                    #   to return events on *both* sides of the timestamp to
-                    #   help reconcile the gap faster.
-                    remote_event = (
-                        await self.federation_event_handler.backfill_event_id(
-                            domain, room_id, remote_event_id
-                        )
-                    )
+                # Backfill this event so we can get a pagination token for
+                # it with `/context` and paginate `/messages` from this
+                # point.
+                pulled_pdu_info = await self.federation_event_handler.backfill_event_id(
+                    likely_domains, room_id, remote_event_id
+                )
+                remote_event = pulled_pdu_info.pdu
 
-                    # XXX: When we see that the remote server is not trustworthy,
-                    # maybe we should not ask them first in the future.
-                    if remote_origin_server_ts != remote_event.origin_server_ts:
-                        logger.info(
-                            "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
-                            domain,
-                            remote_event_id,
-                            remote_origin_server_ts,
-                            remote_event.origin_server_ts,
-                        )
-
-                    # Only return the remote event if it's closer than the local event
-                    if not local_event or (
-                        abs(remote_event.origin_server_ts - timestamp)
-                        < abs(local_event.origin_server_ts - timestamp)
-                    ):
-                        logger.info(
-                            "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
-                            remote_event_id,
-                            remote_event.origin_server_ts,
-                            timestamp,
-                            local_event.event_id if local_event else None,
-                            local_event.origin_server_ts if local_event else None,
-                        )
-                        return remote_event_id, remote_origin_server_ts
-                except (HttpResponseException, InvalidResponseError) as ex:
-                    # Let's not put a high priority on some other homeserver
-                    # failing to respond or giving a random response
-                    logger.debug(
-                        "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception(%s) %s args=%s",
-                        domain,
-                        type(ex).__name__,
-                        ex,
-                        ex.args,
+                # XXX: When we see that the remote server is not trustworthy,
+                # maybe we should not ask them first in the future.
+                if remote_origin_server_ts != remote_event.origin_server_ts:
+                    logger.info(
+                        "get_event_for_timestamp: Remote server (%s) claimed that remote_event_id=%s occured at remote_origin_server_ts=%s but that isn't true (actually occured at %s). Their claims are dubious and we should consider not trusting them.",
+                        pulled_pdu_info.pull_origin,
+                        remote_event_id,
+                        remote_origin_server_ts,
+                        remote_event.origin_server_ts,
                     )
-                except Exception:
-                    # But we do want to see some exceptions in our code
-                    logger.warning(
-                        "get_event_for_timestamp: Failed to fetch /timestamp_to_event from %s because of exception",
-                        domain,
-                        exc_info=True,
+
+                # Only return the remote event if it's closer than the local event
+                if not local_event or (
+                    abs(remote_event.origin_server_ts - timestamp)
+                    < abs(local_event.origin_server_ts - timestamp)
+                ):
+                    logger.info(
+                        "get_event_for_timestamp: returning remote_event_id=%s (%s) since it's closer to timestamp=%s than local_event=%s (%s)",
+                        remote_event_id,
+                        remote_event.origin_server_ts,
+                        timestamp,
+                        local_event.event_id if local_event else None,
+                        local_event.origin_server_ts if local_event else None,
                     )
+                    return remote_event_id, remote_origin_server_ts
 
         # To appease mypy, we have to add both of these conditions to check for
         # `None`. We only expect `local_event` to be `None` when
@@ -1580,7 +1624,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
         self,
         user: UserID,
         from_key: RoomStreamToken,
-        limit: Optional[int],
+        limit: int,
         room_ids: Collection[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 1414e575d6..411a6fb22f 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -379,8 +379,7 @@ class RoomBatchHandler:
                 await self.create_requester_for_user_id_from_app_service(
                     event.sender, app_service_requester.app_service
                 ),
-                event=event,
-                context=context,
+                events_and_context=[(event, context)],
             )
 
         return event_ids
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index e0d0a8941c..2ebd2e6eb7 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -322,6 +322,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         require_consent: bool = True,
         outlier: bool = False,
         historical: bool = False,
+        origin_server_ts: Optional[int] = None,
     ) -> Tuple[str, int]:
         """
         Internal membership update function to get an existing event or create
@@ -361,6 +362,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             historical: Indicates whether the message is being inserted
                 back in time around some existing events. This is used to skip
                 a few checks and mark the event as backfilled.
+            origin_server_ts: The origin_server_ts to use if a new event is created. Uses
+                the current timestamp if set to None.
 
         Returns:
             Tuple of event ID and stream ordering position
@@ -399,6 +402,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 "state_key": user_id,
                 # For backwards compatibility:
                 "membership": membership,
+                "origin_server_ts": origin_server_ts,
             },
             txn_id=txn_id,
             allow_no_prev_events=allow_no_prev_events,
@@ -432,8 +436,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         with tracing.start_active_span("handle_new_client_event"):
             result_event = await self.event_creation_handler.handle_new_client_event(
                 requester,
-                event,
-                context,
+                events_and_context=[(event, context)],
                 extra_users=[target],
                 ratelimit=ratelimit,
             )
@@ -505,6 +508,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
         depth: Optional[int] = None,
+        origin_server_ts: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Update a user's membership in a room.
 
@@ -543,6 +547,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             depth: Override the depth used to order the event in the DAG.
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
+            origin_server_ts: The origin_server_ts to use if a new event is created. Uses
+                the current timestamp if set to None.
 
         Returns:
             A tuple of the new event ID and stream ID.
@@ -584,6 +590,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                         prev_event_ids=prev_event_ids,
                         state_event_ids=state_event_ids,
                         depth=depth,
+                        origin_server_ts=origin_server_ts,
                     )
 
         return result
@@ -607,6 +614,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         prev_event_ids: Optional[List[str]] = None,
         state_event_ids: Optional[List[str]] = None,
         depth: Optional[int] = None,
+        origin_server_ts: Optional[int] = None,
     ) -> Tuple[str, int]:
         """Helper for update_membership.
 
@@ -647,6 +655,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             depth: Override the depth used to order the event in the DAG.
                 Should normally be set to None, which will cause the depth to be calculated
                 based on the prev_events.
+            origin_server_ts: The origin_server_ts to use if a new event is created. Uses
+                the current timestamp if set to None.
 
         Returns:
             A tuple of the new event ID and stream ID.
@@ -786,6 +796,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 require_consent=require_consent,
                 outlier=outlier,
                 historical=historical,
+                origin_server_ts=origin_server_ts,
             )
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
@@ -1031,6 +1042,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             content=content,
             require_consent=require_consent,
             outlier=outlier,
+            origin_server_ts=origin_server_ts,
         )
 
     async def _should_perform_remote_join(
@@ -1151,8 +1163,8 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         logger.info("Transferring room state from %s to %s", old_room_id, room_id)
 
         # Find all local users that were in the old room and copy over each user's state
-        users = await self.store.get_users_in_room(old_room_id)
-        await self.copy_user_state_on_room_upgrade(old_room_id, room_id, users)
+        local_users = await self.store.get_local_users_in_room(old_room_id)
+        await self.copy_user_state_on_room_upgrade(old_room_id, room_id, local_users)
 
         # Add new room to the room directory if the old room was there
         # Remove old room from the room directory
@@ -1252,7 +1264,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 raise SynapseError(403, "This room has been blocked on this server")
 
         event = await self.event_creation_handler.handle_new_client_event(
-            requester, event, context, extra_users=[target_user], ratelimit=ratelimit
+            requester,
+            events_and_context=[(event, context)],
+            extra_users=[target_user],
+            ratelimit=ratelimit,
         )
 
         prev_member_event_id = prev_state_ids.get(
@@ -1860,8 +1875,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
         result_event = await self.event_creation_handler.handle_new_client_event(
             requester,
-            event,
-            context,
+            events_and_context=[(event, context)],
             extra_users=[UserID.from_string(target_user)],
         )
         # we know it was persisted, so must have a stream ordering
diff --git a/synapse/handlers/saml.py b/synapse/handlers/saml.py
index 9602f0d0bb..874860d461 100644
--- a/synapse/handlers/saml.py
+++ b/synapse/handlers/saml.py
@@ -441,7 +441,7 @@ class DefaultSamlMappingProvider:
             client_redirect_url: where the client wants to redirect to
 
         Returns:
-            dict: A dict containing new user attributes. Possible keys:
+            A dict containing new user attributes. Possible keys:
                 * mxid_localpart (str): Required. The localpart of the user's mxid
                 * displayname (str): The displayname of the user
                 * emails (list[str]): Any emails for the user
@@ -483,7 +483,7 @@ class DefaultSamlMappingProvider:
         Args:
             config: A dictionary containing configuration options for this provider
         Returns:
-            SamlConfig: A custom config object for this module
+            A custom config object for this module
         """
         # Parse config options and use defaults where necessary
         mxid_source_attribute = config.get("mxid_source_attribute", "uid")
diff --git a/synapse/handlers/send_email.py b/synapse/handlers/send_email.py
index e2844799e8..804cc6e81e 100644
--- a/synapse/handlers/send_email.py
+++ b/synapse/handlers/send_email.py
@@ -187,6 +187,19 @@ class SendEmailHandler:
         multipart_msg["To"] = email_address
         multipart_msg["Date"] = email.utils.formatdate()
         multipart_msg["Message-ID"] = email.utils.make_msgid()
+        # Discourage automatic responses to Synapse's emails.
+        # Per RFC 3834, automatic responses should not be sent if the "Auto-Submitted"
+        # header is present with any value other than "no". See
+        #     https://www.rfc-editor.org/rfc/rfc3834.html#section-5.1
+        multipart_msg["Auto-Submitted"] = "auto-generated"
+        # Also include a Microsoft-Exchange specific header:
+        #    https://learn.microsoft.com/en-us/openspecs/exchange_server_protocols/ms-oxcmail/ced68690-498a-4567-9d14-5c01f974d8b1
+        # which suggests it can take the value "All" to "suppress all auto-replies",
+        # or a comma separated list of auto-reply classes to suppress.
+        # The following stack overflow question has a little more context:
+        #    https://stackoverflow.com/a/25324691/5252017
+        #    https://stackoverflow.com/a/61646381/5252017
+        multipart_msg["X-Auto-Response-Suppress"] = "All"
         multipart_msg.attach(text_part)
         multipart_msg.attach(html_part)
 
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 6bc1cbd787..749d7e93b0 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -147,6 +147,9 @@ class UsernameMappingSession:
     # A unique identifier for this SSO provider, e.g.  "oidc" or "saml".
     auth_provider_id: str
 
+    # An optional session ID from the IdP.
+    auth_provider_session_id: Optional[str]
+
     # user ID on the IdP server
     remote_user_id: str
 
@@ -188,6 +191,7 @@ class SsoHandler:
         self._server_name = hs.hostname
         self._registration_handler = hs.get_registration_handler()
         self._auth_handler = hs.get_auth_handler()
+        self._device_handler = hs.get_device_handler()
         self._error_template = hs.config.sso.sso_error_template
         self._bad_user_template = hs.config.sso.sso_auth_bad_user_template
         self._profile_handler = hs.get_profile_handler()
@@ -464,6 +468,7 @@ class SsoHandler:
                         client_redirect_url,
                         next_step_url,
                         extra_login_attributes,
+                        auth_provider_session_id,
                     )
 
                 user_id = await self._register_mapped_user(
@@ -585,6 +590,7 @@ class SsoHandler:
         client_redirect_url: str,
         next_step_url: bytes,
         extra_login_attributes: Optional[JsonDict],
+        auth_provider_session_id: Optional[str],
     ) -> NoReturn:
         """Creates a UsernameMappingSession and redirects the browser
 
@@ -607,6 +613,8 @@ class SsoHandler:
             extra_login_attributes: An optional dictionary of extra
                 attributes to be provided to the client in the login response.
 
+            auth_provider_session_id: An optional session ID from the IdP.
+
         Raises:
             RedirectException
         """
@@ -615,6 +623,7 @@ class SsoHandler:
         now = self._clock.time_msec()
         session = UsernameMappingSession(
             auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
             remote_user_id=remote_user_id,
             display_name=attributes.display_name,
             emails=attributes.emails,
@@ -866,7 +875,7 @@ class SsoHandler:
         )
 
     async def handle_terms_accepted(
-        self, request: Request, session_id: str, terms_version: str
+        self, request: SynapseRequest, session_id: str, terms_version: str
     ) -> None:
         """Handle a request to the new-user 'consent' endpoint
 
@@ -968,6 +977,7 @@ class SsoHandler:
             session.client_redirect_url,
             session.extra_login_attributes,
             new_user=True,
+            auth_provider_session_id=session.auth_provider_session_id,
         )
 
     def _expire_old_sessions(self) -> None:
@@ -1017,6 +1027,76 @@ class SsoHandler:
 
         return True
 
+    async def revoke_sessions_for_provider_session_id(
+        self,
+        auth_provider_id: str,
+        auth_provider_session_id: str,
+        expected_user_id: Optional[str] = None,
+    ) -> None:
+        """Revoke any devices and in-flight logins tied to a provider session.
+
+        Args:
+            auth_provider_id: A unique identifier for this SSO provider, e.g.
+                "oidc" or "saml".
+            auth_provider_session_id: The session ID from the provider to logout
+            expected_user_id: The user we're expecting to logout. If set, it will ignore
+                sessions belonging to other users and log an error.
+        """
+        # Invalidate any running user-mapping sessions
+        to_delete = []
+        for session_id, session in self._username_mapping_sessions.items():
+            if (
+                session.auth_provider_id == auth_provider_id
+                and session.auth_provider_session_id == auth_provider_session_id
+            ):
+                to_delete.append(session_id)
+
+        for session_id in to_delete:
+            logger.info("Revoking mapping session %s", session_id)
+            del self._username_mapping_sessions[session_id]
+
+        # Invalidate any in-flight login tokens
+        await self._store.invalidate_login_tokens_by_session_id(
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+
+        # Fetch any device(s) in the store associated with the session ID.
+        devices = await self._store.get_devices_by_auth_provider_session_id(
+            auth_provider_id=auth_provider_id,
+            auth_provider_session_id=auth_provider_session_id,
+        )
+
+        # We have no guarantee that all the devices of that session are for the same
+        # `user_id`. Hence, we have to iterate over the list of devices and log them out
+        # one by one.
+        for device in devices:
+            user_id = device["user_id"]
+            device_id = device["device_id"]
+
+            # If the user_id associated with that device/session is not the one we got
+            # out of the `sub` claim, skip that device and show log an error.
+            if expected_user_id is not None and user_id != expected_user_id:
+                logger.error(
+                    "Received a logout notification from SSO provider "
+                    f"{auth_provider_id!r} for the user {expected_user_id!r}, but with "
+                    f"a session ID ({auth_provider_session_id!r}) which belongs to "
+                    f"{user_id!r}. This may happen when the SSO provider user mapper "
+                    "uses something else than the standard attribute as mapping ID. "
+                    "For OIDC providers, set `backchannel_logout_ignore_sub` to `true` "
+                    "in the provider config if that is the case."
+                )
+                continue
+
+            logger.info(
+                "Logging out %r (device %r) via SSO (%r) logout notification (session %r).",
+                user_id,
+                device_id,
+                auth_provider_id,
+                auth_provider_session_id,
+            )
+            await self._device_handler.delete_devices(user_id, [device_id])
+
 
 def get_username_mapping_session_cookie_from_request(request: IRequest) -> str:
     """Extract the session ID from the cookie
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index f7fd6d7933..c00f30518a 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -45,7 +45,8 @@ from synapse.logging.tracing import (
     start_active_span,
 )
 from synapse.push.clientformat import format_push_rules_for_user
-from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.storage.databases.main.event_push_actions import RoomNotifCounts
+from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -133,6 +134,7 @@ class JoinedSyncResult:
     ephemeral: List[JsonDict]
     account_data: List[JsonDict]
     unread_notifications: JsonDict
+    unread_thread_notifications: JsonDict
     summary: Optional[JsonDict]
     unread_count: int
 
@@ -809,18 +811,6 @@ class SyncHandler:
             if canonical_alias and canonical_alias.content.get("alias"):
                 return summary
 
-        me = sync_config.user.to_string()
-
-        joined_user_ids = [
-            r[0] for r in details.get(Membership.JOIN, empty_ms).members if r[0] != me
-        ]
-        invited_user_ids = [
-            r[0] for r in details.get(Membership.INVITE, empty_ms).members if r[0] != me
-        ]
-        gone_user_ids = [
-            r[0] for r in details.get(Membership.LEAVE, empty_ms).members if r[0] != me
-        ] + [r[0] for r in details.get(Membership.BAN, empty_ms).members if r[0] != me]
-
         # FIXME: only build up a member_ids list for our heroes
         member_ids = {}
         for membership in (
@@ -832,11 +822,8 @@ class SyncHandler:
             for user_id, event_id in details.get(membership, empty_ms).members:
                 member_ids[user_id] = event_id
 
-        # FIXME: order by stream ordering rather than as returned by SQL
-        if joined_user_ids or invited_user_ids:
-            summary["m.heroes"] = sorted(joined_user_ids + invited_user_ids)[0:5]
-        else:
-            summary["m.heroes"] = sorted(gone_user_ids)[0:5]
+        me = sync_config.user.to_string()
+        summary["m.heroes"] = extract_heroes_from_room_summary(details, me)
 
         if not sync_config.filter_collection.lazy_load_members():
             return summary
@@ -1196,7 +1183,9 @@ class SyncHandler:
             room_id: The partial state room to find the remaining memberships for.
             members_to_fetch: The memberships to find.
             events_with_membership_auth: A mapping from user IDs to events whose auth
-                events are known to contain their membership.
+                events would contain their prior membership, if one exists.
+                Note that join events will not cite a prior membership if a user has
+                never been in a room before.
             found_state_ids: A dict from (type, state_key) -> state_event_id, containing
                 memberships that have been previously found. Entries in
                 `members_to_fetch` that have a membership in `found_state_ids` are
@@ -1206,6 +1195,10 @@ class SyncHandler:
             A dict from ("m.room.member", state_key) -> state_event_id, containing the
             memberships missing from `found_state_ids`.
 
+            When `events_with_membership_auth` contains a join event for a given user
+            which does not cite a prior membership, no membership is returned for that
+            user.
+
         Raises:
             KeyError: if `events_with_membership_auth` does not have an entry for a
                 missing membership. Memberships in `found_state_ids` do not need an
@@ -1223,8 +1216,18 @@ class SyncHandler:
             if (EventTypes.Member, member) in found_state_ids:
                 continue
 
-            missing_members.add(member)
             event_with_membership_auth = events_with_membership_auth[member]
+            is_join = (
+                event_with_membership_auth.is_state()
+                and event_with_membership_auth.type == EventTypes.Member
+                and event_with_membership_auth.state_key == member
+                and event_with_membership_auth.content.get("membership")
+                == Membership.JOIN
+            )
+            if not is_join:
+                # The event must include the desired membership as an auth event, unless
+                # it's the first join event for a given user.
+                missing_members.add(member)
             auth_event_ids.update(event_with_membership_auth.auth_event_ids())
 
         auth_events = await self.store.get_events(auth_event_ids)
@@ -1248,7 +1251,7 @@ class SyncHandler:
                     auth_event.type == EventTypes.Member
                     and auth_event.state_key == member
                 ):
-                    missing_members.remove(member)
+                    missing_members.discard(member)
                     additional_state_ids[
                         (EventTypes.Member, member)
                     ] = auth_event.event_id
@@ -1277,7 +1280,7 @@ class SyncHandler:
 
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
-    ) -> NotifCounts:
+    ) -> RoomNotifCounts:
         with Measure(self.clock, "unread_notifs_for_room_id"):
 
             return await self.store.get_unread_event_push_actions_by_room_for_user(
@@ -1303,6 +1306,19 @@ class SyncHandler:
         At the end, we transfer data from the `sync_result_builder` to a new `SyncResult`
         instance to signify that the sync calculation is complete.
         """
+
+        user_id = sync_config.user.to_string()
+        app_service = self.store.get_app_service_by_user_id(user_id)
+        if app_service:
+            # We no longer support AS users using /sync directly.
+            # See https://github.com/matrix-org/matrix-doc/issues/1144
+            raise NotImplementedError()
+
+        # Note: we get the users room list *before* we get the current token, this
+        # avoids checking back in history if rooms are joined after the token is fetched.
+        token_before_rooms = self.event_sources.get_current_token()
+        mutable_joined_room_ids = set(await self.store.get_rooms_for_user(user_id))
+
         # NB: The now_token gets changed by some of the generate_sync_* methods,
         # this is due to some of the underlying streams not supporting the ability
         # to query up to a given point.
@@ -1310,6 +1326,57 @@ class SyncHandler:
         now_token = self.event_sources.get_current_token()
         log_kv({"now_token": str(now_token)})
 
+        # Since we fetched the users room list before the token, there's a small window
+        # during which membership events may have been persisted, so we fetch these now
+        # and modify the joined room list for any changes between the get_rooms_for_user
+        # call and the get_current_token call.
+        membership_change_events = []
+        if since_token:
+            membership_change_events = await self.store.get_membership_changes_for_user(
+                user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
+            )
+
+            mem_last_change_by_room_id: Dict[str, EventBase] = {}
+            for event in membership_change_events:
+                mem_last_change_by_room_id[event.room_id] = event
+
+            # For the latest membership event in each room found, add/remove the room ID
+            # from the joined room list accordingly. In this case we only care if the
+            # latest change is JOIN.
+
+            for room_id, event in mem_last_change_by_room_id.items():
+                assert event.internal_metadata.stream_ordering
+                if (
+                    event.internal_metadata.stream_ordering
+                    < token_before_rooms.room_key.stream
+                ):
+                    continue
+
+                logger.info(
+                    "User membership change between getting rooms and current token: %s %s %s",
+                    user_id,
+                    event.membership,
+                    room_id,
+                )
+                # User joined a room - we have to then check the room state to ensure we
+                # respect any bans if there's a race between the join and ban events.
+                if event.membership == Membership.JOIN:
+                    user_ids_in_room = await self.store.get_users_in_room(room_id)
+                    if user_id in user_ids_in_room:
+                        mutable_joined_room_ids.add(room_id)
+                # The user left the room, or left and was re-invited but not joined yet
+                else:
+                    mutable_joined_room_ids.discard(room_id)
+
+        # Now we have our list of joined room IDs, exclude as configured and freeze
+        joined_room_ids = frozenset(
+            (
+                room_id
+                for room_id in mutable_joined_room_ids
+                if room_id not in self.rooms_to_exclude
+            )
+        )
+
         logger.debug(
             "Calculating sync response for %r between %s and %s",
             sync_config.user,
@@ -1317,22 +1384,13 @@ class SyncHandler:
             now_token,
         )
 
-        user_id = sync_config.user.to_string()
-        app_service = self.store.get_app_service_by_user_id(user_id)
-        if app_service:
-            # We no longer support AS users using /sync directly.
-            # See https://github.com/matrix-org/matrix-doc/issues/1144
-            raise NotImplementedError()
-        else:
-            joined_room_ids = await self.get_rooms_for_user_at(
-                user_id, now_token.room_key
-            )
         sync_result_builder = SyncResultBuilder(
             sync_config,
             full_state,
             since_token=since_token,
             now_token=now_token,
             joined_room_ids=joined_room_ids,
+            membership_change_events=membership_change_events,
         )
 
         logger.debug("Fetching account data")
@@ -1479,16 +1537,14 @@ class SyncHandler:
                 since_token.device_list_key
             )
             if changed_users is not None:
-                result = await self.store.get_rooms_for_users_with_stream_ordering(
-                    changed_users
-                )
+                result = await self.store.get_rooms_for_users(changed_users)
 
                 for changed_user_id, entries in result.items():
                     # Check if the changed user shares any rooms with the user,
                     # or if the changed user is the syncing user (as we always
                     # want to include device list updates of their own devices).
                     if user_id == changed_user_id or any(
-                        e.room_id in joined_rooms for e in entries
+                        rid in joined_rooms for rid in entries
                     ):
                         users_that_have_changed.add(changed_user_id)
             else:
@@ -1522,13 +1578,9 @@ class SyncHandler:
                 newly_left_users.update(left_users)
 
             # Remove any users that we still share a room with.
-            left_users_rooms = (
-                await self.store.get_rooms_for_users_with_stream_ordering(
-                    newly_left_users
-                )
-            )
+            left_users_rooms = await self.store.get_rooms_for_users(newly_left_users)
             for user_id, entries in left_users_rooms.items():
-                if any(e.room_id in joined_rooms for e in entries):
+                if any(rid in joined_rooms for rid in entries):
                     newly_left_users.discard(user_id)
 
             return DeviceListUpdates(
@@ -1819,19 +1871,12 @@ class SyncHandler:
 
         Does not modify the `sync_result_builder`.
         """
-        user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
-        now_token = sync_result_builder.now_token
+        membership_change_events = sync_result_builder.membership_change_events
 
         assert since_token
 
-        # Get a list of membership change events that have happened to the user
-        # requesting the sync.
-        membership_changes = await self.store.get_membership_changes_for_user(
-            user_id, since_token.room_key, now_token.room_key
-        )
-
-        if membership_changes:
+        if membership_change_events:
             return True
 
         stream_id = since_token.room_key.stream
@@ -1870,16 +1915,10 @@ class SyncHandler:
         since_token = sync_result_builder.since_token
         now_token = sync_result_builder.now_token
         sync_config = sync_result_builder.sync_config
+        membership_change_events = sync_result_builder.membership_change_events
 
         assert since_token
 
-        # TODO: we've already called this function and ran this query in
-        #       _have_rooms_changed. We could keep the results in memory to avoid a
-        #       second query, at the cost of more complicated source code.
-        membership_change_events = await self.store.get_membership_changes_for_user(
-            user_id, since_token.room_key, now_token.room_key, self.rooms_to_exclude
-        )
-
         mem_change_events_by_room_id: Dict[str, List[EventBase]] = {}
         for event in membership_change_events:
             mem_change_events_by_room_id.setdefault(event.room_id, []).append(event)
@@ -2348,6 +2387,7 @@ class SyncHandler:
                     ephemeral=ephemeral,
                     account_data=account_data_events,
                     unread_notifications=unread_notifications,
+                    unread_thread_notifications={},
                     summary=summary,
                     unread_count=0,
                 )
@@ -2355,10 +2395,33 @@ class SyncHandler:
                 if room_sync or always_include:
                     notifs = await self.unread_notifs_for_room_id(room_id, sync_config)
 
-                    unread_notifications["notification_count"] = notifs.notify_count
-                    unread_notifications["highlight_count"] = notifs.highlight_count
-
-                    room_sync.unread_count = notifs.unread_count
+                    # Notifications for the main timeline.
+                    notify_count = notifs.main_timeline.notify_count
+                    highlight_count = notifs.main_timeline.highlight_count
+                    unread_count = notifs.main_timeline.unread_count
+
+                    # Check the sync configuration.
+                    if sync_config.filter_collection.unread_thread_notifications():
+                        # And add info for each thread.
+                        room_sync.unread_thread_notifications = {
+                            thread_id: {
+                                "notification_count": thread_notifs.notify_count,
+                                "highlight_count": thread_notifs.highlight_count,
+                            }
+                            for thread_id, thread_notifs in notifs.threads.items()
+                            if thread_id is not None
+                        }
+
+                    else:
+                        # Combine the unread counts for all threads and main timeline.
+                        for thread_notifs in notifs.threads.values():
+                            notify_count += thread_notifs.notify_count
+                            highlight_count += thread_notifs.highlight_count
+                            unread_count += thread_notifs.unread_count
+
+                    unread_notifications["notification_count"] = notify_count
+                    unread_notifications["highlight_count"] = highlight_count
+                    room_sync.unread_count = unread_count
 
                     sync_result_builder.joined.append(room_sync)
 
@@ -2380,60 +2443,6 @@ class SyncHandler:
             else:
                 raise Exception("Unrecognized rtype: %r", room_builder.rtype)
 
-    async def get_rooms_for_user_at(
-        self,
-        user_id: str,
-        room_key: RoomStreamToken,
-    ) -> FrozenSet[str]:
-        """Get set of joined rooms for a user at the given stream ordering.
-
-        The stream ordering *must* be recent, otherwise this may throw an
-        exception if older than a month. (This function is called with the
-        current token, which should be perfectly fine).
-
-        Args:
-            user_id
-            stream_ordering
-
-        ReturnValue:
-            Set of room_ids the user is in at given stream_ordering.
-        """
-        joined_rooms = await self.store.get_rooms_for_user_with_stream_ordering(user_id)
-
-        joined_room_ids = set()
-
-        # We need to check that the stream ordering of the join for each room
-        # is before the stream_ordering asked for. This might not be the case
-        # if the user joins a room between us getting the current token and
-        # calling `get_rooms_for_user_with_stream_ordering`.
-        # If the membership's stream ordering is after the given stream
-        # ordering, we need to go and work out if the user was in the room
-        # before.
-        # We also need to check whether the room should be excluded from sync
-        # responses as per the homeserver config.
-        for joined_room in joined_rooms:
-            if joined_room.room_id in self.rooms_to_exclude:
-                continue
-
-            if not joined_room.event_pos.persisted_after(room_key):
-                joined_room_ids.add(joined_room.room_id)
-                continue
-
-            logger.info("User joined room after current token: %s", joined_room.room_id)
-
-            extrems = (
-                await self.store.get_forward_extremities_for_room_at_stream_ordering(
-                    joined_room.room_id, joined_room.event_pos.stream
-                )
-            )
-            user_ids_in_room = await self.state.get_current_user_ids_in_room(
-                joined_room.room_id, extrems
-            )
-            if user_id in user_ids_in_room:
-                joined_room_ids.add(joined_room.room_id)
-
-        return frozenset(joined_room_ids)
-
 
 def _action_has_highlight(actions: List[JsonDict]) -> bool:
     for action in actions:
@@ -2530,6 +2539,7 @@ class SyncResultBuilder:
     since_token: Optional[StreamToken]
     now_token: StreamToken
     joined_room_ids: FrozenSet[str]
+    membership_change_events: List[EventBase]
 
     presence: List[UserPresenceState] = attr.Factory(list)
     account_data: List[JsonDict] = attr.Factory(list)
diff --git a/synapse/handlers/typing.py b/synapse/handlers/typing.py
index f953691669..a0ea719430 100644
--- a/synapse/handlers/typing.py
+++ b/synapse/handlers/typing.py
@@ -513,7 +513,7 @@ class TypingNotificationEventSource(EventSource[int, JsonDict]):
         self,
         user: UserID,
         from_key: int,
-        limit: Optional[int],
+        limit: int,
         room_ids: Iterable[str],
         is_guest: bool,
         explicit_room_id: Optional[str] = None,
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index a744d68c64..332edcca24 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -119,6 +119,9 @@ class RecaptchaAuthChecker(UserInteractiveAuthChecker):
         except PartialDownloadError as pde:
             # Twisted is silly
             data = pde.response
+            # For mypy's benefit. A general Error.response is Optional[bytes], but
+            # a PartialDownloadError.response should be bytes AFAICS.
+            assert data is not None
             resp_body = json_decoder.decode(data.decode("utf-8"))
 
         if "success" in resp_body:
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 8c3c52e1ca..3610b6bf78 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
 
 import synapse.metrics
 from synapse.api.constants import EventTypes, HistoryVisibility, JoinRules, Membership
@@ -379,7 +379,7 @@ class UserDirectoryHandler(StateDeltasHandler):
             user_id, event.content.get("displayname"), event.content.get("avatar_url")
         )
 
-    async def _track_user_joined_room(self, room_id: str, user_id: str) -> None:
+    async def _track_user_joined_room(self, room_id: str, joining_user_id: str) -> None:
         """Someone's just joined a room. Update `users_in_public_rooms` or
         `users_who_share_private_rooms` as appropriate.
 
@@ -390,32 +390,44 @@ class UserDirectoryHandler(StateDeltasHandler):
             room_id
         )
         if is_public:
-            await self.store.add_users_in_public_rooms(room_id, (user_id,))
+            await self.store.add_users_in_public_rooms(room_id, (joining_user_id,))
         else:
             users_in_room = await self.store.get_users_in_room(room_id)
             other_users_in_room = [
                 other
                 for other in users_in_room
-                if other != user_id
+                if other != joining_user_id
                 and (
+                    # We can't apply any special rules to remote users so
+                    # they're always included
                     not self.is_mine_id(other)
+                    # Check the special rules whether the local user should be
+                    # included in the user directory
                     or await self.store.should_include_local_user_in_dir(other)
                 )
             ]
-            to_insert = set()
+            updates_to_users_who_share_rooms: Set[Tuple[str, str]] = set()
 
-            # First, if they're our user then we need to update for every user
-            if self.is_mine_id(user_id):
+            # First, if the joining user is our local user then we need an
+            # update for every other user in the room.
+            if self.is_mine_id(joining_user_id):
                 for other_user_id in other_users_in_room:
-                    to_insert.add((user_id, other_user_id))
+                    updates_to_users_who_share_rooms.add(
+                        (joining_user_id, other_user_id)
+                    )
 
-            # Next we need to update for every local user in the room
+            # Next, we need an update for every other local user in the room
+            # that they now share a room with the joining user.
             for other_user_id in other_users_in_room:
                 if self.is_mine_id(other_user_id):
-                    to_insert.add((other_user_id, user_id))
+                    updates_to_users_who_share_rooms.add(
+                        (other_user_id, joining_user_id)
+                    )
 
-            if to_insert:
-                await self.store.add_users_who_share_private_room(room_id, to_insert)
+            if updates_to_users_who_share_rooms:
+                await self.store.add_users_who_share_private_room(
+                    room_id, updates_to_users_who_share_rooms
+                )
 
     async def _handle_remove_user(self, room_id: str, user_id: str) -> None:
         """Called when when someone leaves a room. The user may be local or remote.