summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/account_data.py17
-rw-r--r--synapse/handlers/admin.py87
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/auth.py54
-rw-r--r--synapse/handlers/deactivate_account.py27
-rw-r--r--synapse/handlers/directory.py15
-rw-r--r--synapse/handlers/e2e_keys.py25
-rw-r--r--synapse/handlers/e2e_room_keys.py1
-rw-r--r--synapse/handlers/event_auth.py30
-rw-r--r--synapse/handlers/events.py20
-rw-r--r--synapse/handlers/federation.py82
-rw-r--r--synapse/handlers/federation_event.py65
-rw-r--r--synapse/handlers/identity.py47
-rw-r--r--synapse/handlers/initial_sync.py57
-rw-r--r--synapse/handlers/message.py77
-rw-r--r--synapse/handlers/pagination.py6
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/receipts.py4
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/handlers/relations.py88
-rw-r--r--synapse/handlers/room.py199
-rw-r--r--synapse/handlers/room_batch.py6
-rw-r--r--synapse/handlers/room_member.py137
-rw-r--r--synapse/handlers/room_member_worker.py9
-rw-r--r--synapse/handlers/room_summary.py4
-rw-r--r--synapse/handlers/search.py43
-rw-r--r--synapse/handlers/sync.py115
-rw-r--r--synapse/handlers/ui_auth/checkers.py18
28 files changed, 751 insertions, 490 deletions
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 67e789eef7..7e01c18c6c 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -155,9 +155,6 @@ class AccountDataHandler:
             max_stream_id = await self._store.remove_account_data_for_room(
                 user_id, room_id, account_data_type
             )
-            if max_stream_id is None:
-                # The referenced account data did not exist, so no delete occurred.
-                return None
 
             self._notifier.on_new_event(
                 StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
@@ -230,9 +227,6 @@ class AccountDataHandler:
             max_stream_id = await self._store.remove_account_data_for_user(
                 user_id, account_data_type
             )
-            if max_stream_id is None:
-                # The referenced account data did not exist, so no delete occurred.
-                return None
 
             self._notifier.on_new_event(
                 StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
@@ -248,7 +242,6 @@ class AccountDataHandler:
                 instance_name=random.choice(self._account_data_writers),
                 user_id=user_id,
                 account_data_type=account_data_type,
-                content={},
             )
             return response["max_stream_id"]
 
@@ -343,10 +336,12 @@ class AccountDataEventSource(EventSource[int, JsonDict]):
                 }
             )
 
-        (
-            account_data,
-            room_account_data,
-        ) = await self.store.get_updated_account_data_for_user(user_id, last_stream_id)
+        account_data = await self.store.get_updated_global_account_data_for_user(
+            user_id, last_stream_id
+        )
+        room_account_data = await self.store.get_updated_room_account_data_for_user(
+            user_id, last_stream_id
+        )
 
         for account_data_type, content in account_data.items():
             results.append({"type": account_data_type, "content": content})
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index b03c214b14..b06f25b03c 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,7 +14,7 @@
 
 import abc
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
 
 from synapse.api.constants import Direction, Membership
 from synapse.events import EventBase
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 class AdminHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastores().main
+        self._store = hs.get_datastores().main
         self._device_handler = hs.get_device_handler()
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
@@ -38,7 +38,7 @@ class AdminHandler:
     async def get_whois(self, user: UserID) -> JsonDict:
         connections = []
 
-        sessions = await self.store.get_user_ip_and_agents(user)
+        sessions = await self._store.get_user_ip_and_agents(user)
         for session in sessions:
             connections.append(
                 {
@@ -57,7 +57,7 @@ class AdminHandler:
 
     async def get_user(self, user: UserID) -> Optional[JsonDict]:
         """Function to get user details"""
-        user_info_dict = await self.store.get_user_by_id(user.to_string())
+        user_info_dict = await self._store.get_user_by_id(user.to_string())
         if user_info_dict is None:
             return None
 
@@ -89,11 +89,11 @@ class AdminHandler:
         }
 
         # Add additional user metadata
-        profile = await self.store.get_profileinfo(user.localpart)
-        threepids = await self.store.user_get_threepids(user.to_string())
+        profile = await self._store.get_profileinfo(user.localpart)
+        threepids = await self._store.user_get_threepids(user.to_string())
         external_ids = [
             ({"auth_provider": auth_provider, "external_id": external_id})
-            for auth_provider, external_id in await self.store.get_external_ids_by_user(
+            for auth_provider, external_id in await self._store.get_external_ids_by_user(
                 user.to_string()
             )
         ]
@@ -101,7 +101,7 @@ class AdminHandler:
         user_info_dict["avatar_url"] = profile.avatar_url
         user_info_dict["threepids"] = threepids
         user_info_dict["external_ids"] = external_ids
-        user_info_dict["erased"] = await self.store.is_user_erased(user.to_string())
+        user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
 
         return user_info_dict
 
@@ -117,7 +117,7 @@ class AdminHandler:
             The returned value is that returned by `writer.finished()`.
         """
         # Get all rooms the user is in or has been in
-        rooms = await self.store.get_rooms_for_local_user_where_membership_is(
+        rooms = await self._store.get_rooms_for_local_user_where_membership_is(
             user_id,
             membership_list=(
                 Membership.JOIN,
@@ -131,7 +131,7 @@ class AdminHandler:
         # We only try and fetch events for rooms the user has been in. If
         # they've been e.g. invited to a room without joining then we handle
         # those separately.
-        rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
+        rooms_user_has_been_in = await self._store.get_rooms_user_has_been_in(user_id)
 
         for index, room in enumerate(rooms):
             room_id = room.room_id
@@ -140,7 +140,7 @@ class AdminHandler:
                 "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
             )
 
-            forgotten = await self.store.did_forget(user_id, room_id)
+            forgotten = await self._store.did_forget(user_id, room_id)
             if forgotten:
                 logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
                 continue
@@ -152,14 +152,14 @@ class AdminHandler:
 
                 if room.membership == Membership.INVITE:
                     event_id = room.event_id
-                    invite = await self.store.get_event(event_id, allow_none=True)
+                    invite = await self._store.get_event(event_id, allow_none=True)
                     if invite:
                         invited_state = invite.unsigned["invite_room_state"]
                         writer.write_invite(room_id, invite, invited_state)
 
                 if room.membership == Membership.KNOCK:
                     event_id = room.event_id
-                    knock = await self.store.get_event(event_id, allow_none=True)
+                    knock = await self._store.get_event(event_id, allow_none=True)
                     if knock:
                         knock_state = knock.unsigned["knock_room_state"]
                         writer.write_knock(room_id, knock, knock_state)
@@ -170,7 +170,7 @@ class AdminHandler:
             # were joined. We estimate that point by looking at the
             # stream_ordering of the last membership if it wasn't a join.
             if room.membership == Membership.JOIN:
-                stream_ordering = self.store.get_room_max_stream_ordering()
+                stream_ordering = self._store.get_room_max_stream_ordering()
             else:
                 stream_ordering = room.stream_ordering
 
@@ -197,7 +197,7 @@ class AdminHandler:
             # events that we have and then filtering, this isn't the most
             # efficient method perhaps but it does guarantee we get everything.
             while True:
-                events, _ = await self.store.paginate_room_events(
+                events, _ = await self._store.paginate_room_events(
                     room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
                 )
                 if not events:
@@ -252,16 +252,49 @@ class AdminHandler:
         profile = await self.get_user(UserID.from_string(user_id))
         if profile is not None:
             writer.write_profile(profile)
+            logger.info("[%s] Written profile", user_id)
 
         # Get all devices the user has
         devices = await self._device_handler.get_devices_by_user(user_id)
         writer.write_devices(devices)
+        logger.info("[%s] Written %s devices", user_id, len(devices))
 
         # Get all connections the user has
         connections = await self.get_whois(UserID.from_string(user_id))
         writer.write_connections(
             connections["devices"][""]["sessions"][0]["connections"]
         )
+        logger.info("[%s] Written %s connections", user_id, len(connections))
+
+        # Get all account data the user has global and in rooms
+        global_data = await self._store.get_global_account_data_for_user(user_id)
+        by_room_data = await self._store.get_room_account_data_for_user(user_id)
+        writer.write_account_data("global", global_data)
+        for room_id in by_room_data:
+            writer.write_account_data(room_id, by_room_data[room_id])
+        logger.info(
+            "[%s] Written account data for %s rooms", user_id, len(by_room_data)
+        )
+
+        # Get all media ids the user has
+        limit = 100
+        start = 0
+        while True:
+            media_ids, total = await self._store.get_local_media_by_user_paginate(
+                start, limit, user_id
+            )
+            for media in media_ids:
+                writer.write_media_id(media["media_id"], media)
+
+            logger.info(
+                "[%s] Written %d media_ids of %s",
+                user_id,
+                (start + len(media_ids)),
+                total,
+            )
+            if (start + limit) >= total:
+                break
+            start += limit
 
         return writer.finished()
 
@@ -341,6 +374,30 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
+    def write_account_data(
+        self, file_name: str, account_data: Mapping[str, JsonDict]
+    ) -> None:
+        """Write the account data of a user.
+
+        Args:
+            file_name: file name to write data
+            account_data: mapping of global or room account_data
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
+        """Write the media's metadata of a user.
+        Exports only the metadata, as this can be fetched from the database via
+        read only. In order to access the files, a connection to the correct
+        media repository would be required.
+
+        Args:
+            media_id: ID of the media.
+            media_metadata: Metadata of one media file.
+        """
+
+    @abc.abstractmethod
     def finished(self) -> Any:
         """Called when all data has successfully been exported and written.
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index b4a3ad217a..455d4005df 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -737,7 +737,7 @@ class ApplicationServicesHandler:
         )
 
         ret = []
-        for (success, result) in results:
+        for success, result in results:
             if success:
                 ret.extend(result)
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 30f2d46c3c..308e38edea 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -201,7 +201,7 @@ class AuthHandler:
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
             inst = auth_checker_class(hs)
             if inst.is_enabled():
-                self.checkers[inst.AUTH_TYPE] = inst  # type: ignore
+                self.checkers[inst.AUTH_TYPE] = inst
 
         self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
 
@@ -815,7 +815,6 @@ class AuthHandler:
         now_ms = self._clock.time_msec()
 
         if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
-
             raise SynapseError(
                 HTTPStatus.FORBIDDEN,
                 "The supplied refresh token has expired",
@@ -1543,6 +1542,17 @@ class AuthHandler:
     async def add_threepid(
         self, user_id: str, medium: str, address: str, validated_at: int
     ) -> None:
+        """
+        Adds an association between a user's Matrix ID and a third-party ID (email,
+        phone number).
+
+        Args:
+            user_id: The ID of the user to associate.
+            medium: The medium of the third-party ID (email, msisdn).
+            address: The address of the third-party ID (i.e. an email address).
+            validated_at: The timestamp in ms of when the validation that the user owns
+                this third-party ID occurred.
+        """
         # check if medium has a valid value
         if medium not in ["email", "msisdn"]:
             raise SynapseError(
@@ -1567,43 +1577,44 @@ class AuthHandler:
             user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
         )
 
+        # Inform Synapse modules that a 3PID association has been created.
+        await self._third_party_rules.on_add_user_third_party_identifier(
+            user_id, medium, address
+        )
+
+        # Deprecated method for informing Synapse modules that a 3PID association
+        # has successfully been created.
         await self._third_party_rules.on_threepid_bind(user_id, medium, address)
 
-    async def delete_threepid(
-        self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
-    ) -> bool:
-        """Attempts to unbind the 3pid on the identity servers and deletes it
-        from the local database.
+    async def delete_local_threepid(
+        self, user_id: str, medium: str, address: str
+    ) -> None:
+        """Deletes an association between a third-party ID and a user ID from the local
+        database. This method does not unbind the association from any identity servers.
+
+        If `medium` is 'email' and a pusher is associated with this third-party ID, the
+        pusher will also be deleted.
 
         Args:
             user_id: ID of user to remove the 3pid from.
             medium: The medium of the 3pid being removed: "email" or "msisdn".
             address: The 3pid address to remove.
-            id_server: Use the given identity server when unbinding
-                any threepids. If None then will attempt to unbind using the
-                identity server specified when binding (if known).
-
-        Returns:
-            Returns True if successfully unbound the 3pid on
-            the identity server, False if identity server doesn't support the
-            unbind API.
         """
-
         # 'Canonicalise' email addresses as per above
         if medium == "email":
             address = canonicalise_email(address)
 
-        identity_handler = self.hs.get_identity_handler()
-        result = await identity_handler.try_unbind_threepid(
-            user_id, {"medium": medium, "address": address, "id_server": id_server}
+        await self.store.user_delete_threepid(user_id, medium, address)
+
+        # Inform Synapse modules that a 3PID association has been deleted.
+        await self._third_party_rules.on_remove_user_third_party_identifier(
+            user_id, medium, address
         )
 
-        await self.store.user_delete_threepid(user_id, medium, address)
         if medium == "email":
             await self.store.delete_pusher_by_app_id_pushkey_user_id(
                 app_id="m.email", pushkey=address, user_id=user_id
             )
-        return result
 
     async def hash(self, password: str) -> str:
         """Computes a secure hash of password.
@@ -2260,7 +2271,6 @@ class PasswordAuthProvider:
     async def on_logged_out(
         self, user_id: str, device_id: Optional[str], access_token: str
     ) -> None:
-
         # call all of the on_logged_out callbacks
         for callback in self.on_logged_out_callbacks:
             try:
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index ba58f150d1..c1850521e4 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -100,31 +100,28 @@ class DeactivateAccountHandler:
         # unbinding
         identity_server_supports_unbinding = True
 
-        # Retrieve the 3PIDs this user has bound to an identity server
-        threepids = await self.store.user_get_bound_threepids(user_id)
-
-        for threepid in threepids:
+        # Attempt to unbind any known bound threepids to this account from identity
+        # server(s).
+        bound_threepids = await self.store.user_get_bound_threepids(user_id)
+        for threepid in bound_threepids:
             try:
                 result = await self._identity_handler.try_unbind_threepid(
-                    user_id,
-                    {
-                        "medium": threepid["medium"],
-                        "address": threepid["address"],
-                        "id_server": id_server,
-                    },
+                    user_id, threepid["medium"], threepid["address"], id_server
                 )
-                identity_server_supports_unbinding &= result
             except Exception:
                 # Do we want this to be a fatal error or should we carry on?
                 logger.exception("Failed to remove threepid from ID server")
                 raise SynapseError(400, "Failed to remove threepid from ID server")
-            await self.store.user_delete_threepid(
+
+            identity_server_supports_unbinding &= result
+
+        # Remove any local threepid associations for this account.
+        local_threepids = await self.store.user_get_threepids(user_id)
+        for threepid in local_threepids:
+            await self._auth_handler.delete_local_threepid(
                 user_id, threepid["medium"], threepid["address"]
             )
 
-        # Remove all 3PIDs this user has bound to the homeserver
-        await self.store.user_delete_threepids(user_id)
-
         # delete any devices belonging to the user, which will also
         # delete corresponding access tokens.
         await self._device_handler.delete_all_devices_for_user(user_id)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index 2ea52257cb..1fb23cc9bf 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -14,7 +14,7 @@
 
 import logging
 import string
-from typing import TYPE_CHECKING, Iterable, List, Optional
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence
 
 from typing_extensions import Literal
 
@@ -485,7 +485,8 @@ class DirectoryHandler:
                 )
             )
             if canonical_alias:
-                room_aliases.append(canonical_alias)
+                # Ensure we do not mutate room_aliases.
+                room_aliases = list(room_aliases) + [canonical_alias]
 
             if not self.config.roomdirectory.is_publishing_room_allowed(
                 user_id, room_id, room_aliases
@@ -496,9 +497,11 @@ class DirectoryHandler:
                 raise SynapseError(403, "Not allowed to publish room")
 
             # Check if publishing is blocked by a third party module
-            allowed_by_third_party_rules = await (
-                self.third_party_event_rules.check_visibility_can_be_modified(
-                    room_id, visibility
+            allowed_by_third_party_rules = (
+                await (
+                    self.third_party_event_rules.check_visibility_can_be_modified(
+                        room_id, visibility
+                    )
                 )
             )
             if not allowed_by_third_party_rules:
@@ -528,7 +531,7 @@ class DirectoryHandler:
 
     async def get_aliases_for_room(
         self, requester: Requester, room_id: str
-    ) -> List[str]:
+    ) -> Sequence[str]:
         """
         Get a list of the aliases that currently point to this room on this server
         """
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d2188ca08f..4e9c8d8db0 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -159,19 +159,22 @@ class E2eKeysHandler:
             # A map of destination -> user ID -> device IDs.
             remote_queries_not_in_cache: Dict[str, Dict[str, Iterable[str]]] = {}
             if remote_queries:
-                query_list: List[Tuple[str, Optional[str]]] = []
+                user_ids = set()
+                user_and_device_ids: List[Tuple[str, str]] = []
                 for user_id, device_ids in remote_queries.items():
                     if device_ids:
-                        query_list.extend(
+                        user_and_device_ids.extend(
                             (user_id, device_id) for device_id in device_ids
                         )
                     else:
-                        query_list.append((user_id, None))
+                        user_ids.add(user_id)
 
                 (
                     user_ids_not_in_cache,
                     remote_results,
-                ) = await self.store.get_user_devices_from_cache(query_list)
+                ) = await self.store.get_user_devices_from_cache(
+                    user_ids, user_and_device_ids
+                )
 
                 # Check that the homeserver still shares a room with all cached users.
                 # Note that this check may be slightly racy when a remote user leaves a
@@ -1298,6 +1301,20 @@ class E2eKeysHandler:
 
         return desired_key_data
 
+    async def is_cross_signing_set_up_for_user(self, user_id: str) -> bool:
+        """Checks if the user has cross-signing set up
+
+        Args:
+            user_id: The user to check
+
+        Returns:
+            True if the user has cross-signing set up, False otherwise
+        """
+        existing_master_key = await self.store.get_e2e_cross_signing_key(
+            user_id, "master"
+        )
+        return existing_master_key is not None
+
 
 def _check_cross_signing_key(
     key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 83f53ceb88..50317ec753 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -188,7 +188,6 @@ class E2eRoomKeysHandler:
 
         # XXX: perhaps we should use a finer grained lock here?
         async with self._upload_linearizer.queue(user_id):
-
             # Check that the version we're trying to upload is the current version
             try:
                 version_info = await self.store.get_e2e_room_keys_version_info(user_id)
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index a23a8ce2a1..0db0bd7304 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -63,9 +63,18 @@ class EventAuthHandler:
             self._store, event, batched_auth_events
         )
         auth_event_ids = event.auth_event_ids()
-        auth_events_by_id = await self._store.get_events(auth_event_ids)
+
         if batched_auth_events:
-            auth_events_by_id.update(batched_auth_events)
+            # Copy the batched auth events to avoid mutating them.
+            auth_events_by_id = dict(batched_auth_events)
+            needed_auth_event_ids = set(auth_event_ids) - set(batched_auth_events)
+            if needed_auth_event_ids:
+                auth_events_by_id.update(
+                    await self._store.get_events(needed_auth_event_ids)
+                )
+        else:
+            auth_events_by_id = await self._store.get_events(auth_event_ids)
+
         check_state_dependent_auth_rules(event, auth_events_by_id.values())
 
     def compute_auth_events(
@@ -202,7 +211,7 @@ class EventAuthHandler:
         state_ids: StateMap[str],
         room_version: RoomVersion,
         user_id: str,
-        prev_member_event: Optional[EventBase],
+        prev_membership: Optional[str],
     ) -> None:
         """
         Check whether a user can join a room without an invite due to restricted join rules.
@@ -214,15 +223,14 @@ class EventAuthHandler:
             state_ids: The state of the room as it currently is.
             room_version: The room version of the room being joined.
             user_id: The user joining the room.
-            prev_member_event: The current membership event for this user.
+            prev_membership: The current membership state for this user. `None` if the
+                user has never joined the room (equivalent to "leave").
 
         Raises:
             AuthError if the user cannot join the room.
         """
         # If the member is invited or currently joined, then nothing to do.
-        if prev_member_event and (
-            prev_member_event.membership in (Membership.JOIN, Membership.INVITE)
-        ):
+        if prev_membership in (Membership.JOIN, Membership.INVITE):
             return
 
         # This is not a room with a restricted join rule, so we don't need to do the
@@ -237,7 +245,6 @@ class EventAuthHandler:
         # in any of them.
         allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
         if not await self.is_user_in_rooms(allowed_rooms, user_id):
-
             # If this is a remote request, the user might be in an allowed room
             # that we do not know about.
             if get_domain_from_id(user_id) != self._server_name:
@@ -255,13 +262,14 @@ class EventAuthHandler:
             )
 
     async def has_restricted_join_rules(
-        self, state_ids: StateMap[str], room_version: RoomVersion
+        self, partial_state_ids: StateMap[str], room_version: RoomVersion
     ) -> bool:
         """
         Return if the room has the proper join rules set for access via rooms.
 
         Args:
-            state_ids: The state of the room as it currently is.
+            state_ids: The state of the room as it currently is. May be full or partial
+                state.
             room_version: The room version of the room to query.
 
         Returns:
@@ -272,7 +280,7 @@ class EventAuthHandler:
             return False
 
         # If there's no join rule, then it defaults to invite (so this doesn't apply).
-        join_rules_event_id = state_ids.get((EventTypes.JoinRules, ""), None)
+        join_rules_event_id = partial_state_ids.get((EventTypes.JoinRules, ""), None)
         if not join_rules_event_id:
             return False
 
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 949b69cb41..68c07f0265 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -23,7 +23,7 @@ from synapse.events.utils import SerializeEventConfig
 from synapse.handlers.presence import format_user_presence_state
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, Requester, UserID
 from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
@@ -46,13 +46,12 @@ class EventStreamHandler:
 
     async def get_stream(
         self,
-        auth_user_id: str,
+        requester: Requester,
         pagin_config: PaginationConfig,
         timeout: int = 0,
         as_client_event: bool = True,
         affect_presence: bool = True,
         room_id: Optional[str] = None,
-        is_guest: bool = False,
     ) -> JsonDict:
         """Fetches the events stream for a given user."""
 
@@ -62,13 +61,12 @@ class EventStreamHandler:
                 raise SynapseError(403, "This room has been blocked on this server")
 
         # send any outstanding server notices to the user.
-        await self._server_notices_sender.on_user_syncing(auth_user_id)
+        await self._server_notices_sender.on_user_syncing(requester.user.to_string())
 
-        auth_user = UserID.from_string(auth_user_id)
         presence_handler = self.hs.get_presence_handler()
 
         context = await presence_handler.user_syncing(
-            auth_user_id,
+            requester.user.to_string(),
             affect_presence=affect_presence,
             presence_state=PresenceState.ONLINE,
         )
@@ -82,10 +80,10 @@ class EventStreamHandler:
                 timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
 
             stream_result = await self.notifier.get_events_for(
-                auth_user,
+                requester.user,
                 pagin_config,
                 timeout,
-                is_guest=is_guest,
+                is_guest=requester.is_guest,
                 explicit_room_id=room_id,
             )
             events = stream_result.events
@@ -102,7 +100,7 @@ class EventStreamHandler:
                     if event.membership != Membership.JOIN:
                         continue
                     # Send down presence.
-                    if event.state_key == auth_user_id:
+                    if event.state_key == requester.user.to_string():
                         # Send down presence for everyone in the room.
                         users: Iterable[str] = await self.store.get_users_in_room(
                             event.room_id
@@ -124,7 +122,9 @@ class EventStreamHandler:
             chunks = self._event_serializer.serialize_events(
                 events,
                 time_now,
-                config=SerializeEventConfig(as_client_event=as_client_event),
+                config=SerializeEventConfig(
+                    as_client_event=as_client_event, requester=requester
+                ),
             )
 
             chunk = {
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 4e77bfa55e..2944deaa00 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -49,6 +49,7 @@ from synapse.api.errors import (
     FederationPullAttemptBackoffError,
     HttpResponseException,
     NotFoundError,
+    PartialStateConflictError,
     RequestSendFailed,
     SynapseError,
 )
@@ -56,7 +57,7 @@ from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.crypto.event_signing import compute_event_signature
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.events.validator import EventValidator
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.http.servlet import assert_params_in_dict
@@ -68,7 +69,6 @@ from synapse.replication.http.federation import (
     ReplicationCleanRoomRestServlet,
     ReplicationStoreRoomOnOutlierMembershipRestServlet,
 )
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import JsonDict, StrCollection, get_domain_from_id
 from synapse.types.state import StateFilter
@@ -952,7 +952,20 @@ class FederationHandler:
         #
         # Note that this requires the /send_join request to come back to the
         # same server.
+        prev_event_ids = None
         if room_version.msc3083_join_rules:
+            # Note that the room's state can change out from under us and render our
+            # nice join rules-conformant event non-conformant by the time we build the
+            # event. When this happens, our validation at the end fails and we respond
+            # to the requesting server with a 403, which is misleading — it indicates
+            # that the user is not allowed to join the room and the joining server
+            # should not bother retrying via this homeserver or any others, when
+            # in fact we've just messed up with building the event.
+            #
+            # To reduce the likelihood of this race, we capture the forward extremities
+            # of the room (prev_event_ids) just before fetching the current state, and
+            # hope that the state we fetch corresponds to the prev events we chose.
+            prev_event_ids = await self.store.get_prev_events_for_room(room_id)
             state_ids = await self._state_storage_controller.get_current_state_ids(
                 room_id
             )
@@ -990,15 +1003,21 @@ class FederationHandler:
         )
 
         try:
-            event, context = await self.event_creation_handler.create_new_client_event(
-                builder=builder
+            (
+                event,
+                unpersisted_context,
+            ) = await self.event_creation_handler.create_new_client_event(
+                builder=builder,
+                prev_event_ids=prev_event_ids,
             )
         except SynapseError as e:
             logger.warning("Failed to create join to %s because %s", room_id, e)
             raise
 
         # Ensure the user can even join the room.
-        await self._federation_event_handler.check_join_restrictions(context, event)
+        await self._federation_event_handler.check_join_restrictions(
+            unpersisted_context, event
+        )
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
@@ -1178,7 +1197,7 @@ class FederationHandler:
             },
         )
 
-        event, context = await self.event_creation_handler.create_new_client_event(
+        event, _ = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
 
@@ -1228,12 +1247,13 @@ class FederationHandler:
             },
         )
 
-        event, context = await self.event_creation_handler.create_new_client_event(
-            builder=builder
-        )
+        (
+            event,
+            unpersisted_context,
+        ) = await self.event_creation_handler.create_new_client_event(builder=builder)
 
         event_allowed, _ = await self.third_party_event_rules.check_event_allowed(
-            event, context
+            event, unpersisted_context
         )
         if not event_allowed:
             logger.warning("Creation of knock %s forbidden by third-party rules", event)
@@ -1406,15 +1426,20 @@ class FederationHandler:
                 try:
                     (
                         event,
-                        context,
+                        unpersisted_context,
                     ) = await self.event_creation_handler.create_new_client_event(
                         builder=builder
                     )
 
-                    event, context = await self.add_display_name_to_third_party_invite(
-                        room_version_obj, event_dict, event, context
+                    (
+                        event,
+                        unpersisted_context,
+                    ) = await self.add_display_name_to_third_party_invite(
+                        room_version_obj, event_dict, event, unpersisted_context
                     )
 
+                    context = await unpersisted_context.persist(event)
+
                     EventValidator().validate_new(event, self.config)
 
                     # We need to tell the transaction queue to send this out, even
@@ -1483,14 +1508,19 @@ class FederationHandler:
             try:
                 (
                     event,
-                    context,
+                    unpersisted_context,
                 ) = await self.event_creation_handler.create_new_client_event(
                     builder=builder
                 )
-                event, context = await self.add_display_name_to_third_party_invite(
-                    room_version_obj, event_dict, event, context
+                (
+                    event,
+                    unpersisted_context,
+                ) = await self.add_display_name_to_third_party_invite(
+                    room_version_obj, event_dict, event, unpersisted_context
                 )
 
+                context = await unpersisted_context.persist(event)
+
                 try:
                     validate_event_for_room_version(event)
                     await self._event_auth_handler.check_auth_rules_from_context(event)
@@ -1522,8 +1552,8 @@ class FederationHandler:
         room_version_obj: RoomVersion,
         event_dict: JsonDict,
         event: EventBase,
-        context: EventContext,
-    ) -> Tuple[EventBase, EventContext]:
+        context: UnpersistedEventContextBase,
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         key = (
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"],
@@ -1557,11 +1587,14 @@ class FederationHandler:
             room_version_obj, event_dict
         )
         EventValidator().validate_builder(builder)
-        event, context = await self.event_creation_handler.create_new_client_event(
-            builder=builder
-        )
+
+        (
+            event,
+            unpersisted_context,
+        ) = await self.event_creation_handler.create_new_client_event(builder=builder)
+
         EventValidator().validate_new(event, self.config)
-        return event, context
+        return event, unpersisted_context
 
     async def _check_signature(self, event: EventBase, context: EventContext) -> None:
         """
@@ -1861,6 +1894,11 @@ class FederationHandler:
                 logger.info("Updating current state for %s", room_id)
                 # TODO(faster_joins): notify workers in notify_room_un_partial_stated
                 #   https://github.com/matrix-org/synapse/issues/12994
+                #
+                # NB: there's a potential race here. If room is purged just before we
+                # call this, we _might_ end up inserting rows into current_state_events.
+                # (The logic is hard to chase through.) We think this is fine, but if
+                # not the HS admin should purge the room again.
                 await self.state_handler.update_current_state(room_id)
 
                 logger.info("Handling any pending device list updates")
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 2e19df0976..3a65ccbb55 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -47,6 +47,7 @@ from synapse.api.errors import (
     FederationError,
     FederationPullAttemptBackoffError,
     HttpResponseException,
+    PartialStateConflictError,
     RequestSendFailed,
     SynapseError,
 )
@@ -58,7 +59,7 @@ from synapse.event_auth import (
     validate_event_for_room_version,
 )
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.federation.federation_client import InvalidResponseError, PulledPduInfo
 from synapse.logging.context import nested_logging_context
 from synapse.logging.opentracing import (
@@ -74,7 +75,6 @@ from synapse.replication.http.federation import (
     ReplicationFederationSendEventsRestServlet,
 )
 from synapse.state import StateResolutionStore
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     PersistedEventPosition,
@@ -426,7 +426,9 @@ class FederationEventHandler:
         return event, context
 
     async def check_join_restrictions(
-        self, context: EventContext, event: EventBase
+        self,
+        context: UnpersistedEventContextBase,
+        event: EventBase,
     ) -> None:
         """Check that restrictions in restricted join rules are matched
 
@@ -439,16 +441,17 @@ class FederationEventHandler:
         # Check if the user is already in the room or invited to the room.
         user_id = event.state_key
         prev_member_event_id = prev_state_ids.get((EventTypes.Member, user_id), None)
-        prev_member_event = None
+        prev_membership = None
         if prev_member_event_id:
             prev_member_event = await self._store.get_event(prev_member_event_id)
+            prev_membership = prev_member_event.membership
 
         # Check if the member should be allowed access via membership in a space.
         await self._event_auth_handler.check_restricted_join_rules(
             prev_state_ids,
             event.room_version,
             user_id,
-            prev_member_event,
+            prev_membership,
         )
 
     @trace
@@ -524,11 +527,57 @@ class FederationEventHandler:
             "Peristing join-via-remote %s (partial_state: %s)", event, partial_state
         )
         with nested_logging_context(suffix=event.event_id):
+            if partial_state:
+                # When handling a second partial state join into a partial state room,
+                # the returned state will exclude the membership from the first join. To
+                # preserve prior memberships, we try to compute the partial state before
+                # the event ourselves if we know about any of the prev events.
+                #
+                # When we don't know about any of the prev events, it's fine to just use
+                # the returned state, since the new join will create a new forward
+                # extremity, and leave the forward extremity containing our prior
+                # memberships alone.
+                prev_event_ids = set(event.prev_event_ids())
+                seen_event_ids = await self._store.have_events_in_timeline(
+                    prev_event_ids
+                )
+                missing_event_ids = prev_event_ids - seen_event_ids
+
+                state_maps_to_resolve: List[StateMap[str]] = []
+
+                # Fetch the state after the prev events that we know about.
+                state_maps_to_resolve.extend(
+                    (
+                        await self._state_storage_controller.get_state_groups_ids(
+                            room_id, seen_event_ids, await_full_state=False
+                        )
+                    ).values()
+                )
+
+                # When there are prev events we do not have the state for, we state
+                # resolve with the state returned by the remote homeserver.
+                if missing_event_ids or len(state_maps_to_resolve) == 0:
+                    state_maps_to_resolve.append(
+                        {(e.type, e.state_key): e.event_id for e in state}
+                    )
+
+                state_ids_before_event = (
+                    await self._state_resolution_handler.resolve_events_with_store(
+                        event.room_id,
+                        room_version.identifier,
+                        state_maps_to_resolve,
+                        event_map=None,
+                        state_res_store=StateResolutionStore(self._store),
+                    )
+                )
+            else:
+                state_ids_before_event = {
+                    (e.type, e.state_key): e.event_id for e in state
+                }
+
             context = await self._state_handler.compute_event_context(
                 event,
-                state_ids_before_event={
-                    (e.type, e.state_key): e.event_id for e in state
-                },
+                state_ids_before_event=state_ids_before_event,
                 partial_state=partial_state,
             )
 
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 848e46eb9b..bf0f7acf80 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -219,28 +219,31 @@ class IdentityHandler:
             data = json_decoder.decode(e.msg)  # XXX WAT?
             return data
 
-    async def try_unbind_threepid(self, mxid: str, threepid: dict) -> bool:
-        """Attempt to remove a 3PID from an identity server, or if one is not provided, all
-        identity servers we're aware the binding is present on
+    async def try_unbind_threepid(
+        self, mxid: str, medium: str, address: str, id_server: Optional[str]
+    ) -> bool:
+        """Attempt to remove a 3PID from one or more identity servers.
 
         Args:
             mxid: Matrix user ID of binding to be removed
-            threepid: Dict with medium & address of binding to be
-                removed, and an optional id_server.
+            medium: The medium of the third-party ID.
+            address: The address of the third-party ID.
+            id_server: An identity server to attempt to unbind from. If None,
+                attempt to remove the association from all identity servers
+                known to potentially have it.
 
         Raises:
-            SynapseError: If we failed to contact the identity server
+            SynapseError: If we failed to contact one or more identity servers.
 
         Returns:
-            True on success, otherwise False if the identity
-            server doesn't support unbinding (or no identity server found to
-            contact).
+            True on success, otherwise False if the identity server doesn't
+            support unbinding (or no identity server to contact was found).
         """
-        if threepid.get("id_server"):
-            id_servers = [threepid["id_server"]]
+        if id_server:
+            id_servers = [id_server]
         else:
             id_servers = await self.store.get_id_servers_user_bound(
-                user_id=mxid, medium=threepid["medium"], address=threepid["address"]
+                mxid, medium, address
             )
 
         # We don't know where to unbind, so we don't have a choice but to return
@@ -249,20 +252,21 @@ class IdentityHandler:
 
         changed = True
         for id_server in id_servers:
-            changed &= await self.try_unbind_threepid_with_id_server(
-                mxid, threepid, id_server
+            changed &= await self._try_unbind_threepid_with_id_server(
+                mxid, medium, address, id_server
             )
 
         return changed
 
-    async def try_unbind_threepid_with_id_server(
-        self, mxid: str, threepid: dict, id_server: str
+    async def _try_unbind_threepid_with_id_server(
+        self, mxid: str, medium: str, address: str, id_server: str
     ) -> bool:
         """Removes a binding from an identity server
 
         Args:
             mxid: Matrix user ID of binding to be removed
-            threepid: Dict with medium & address of binding to be removed
+            medium: The medium of the third-party ID
+            address: The address of the third-party ID
             id_server: Identity server to unbind from
 
         Raises:
@@ -286,7 +290,7 @@ class IdentityHandler:
 
         content = {
             "mxid": mxid,
-            "threepid": {"medium": threepid["medium"], "address": threepid["address"]},
+            "threepid": {"medium": medium, "address": address},
         }
 
         # we abuse the federation http client to sign the request, but we have to send it
@@ -319,12 +323,7 @@ class IdentityHandler:
         except RequestTimedOutError:
             raise SynapseError(500, "Timed out contacting identity server")
 
-        await self.store.remove_user_bound_threepid(
-            user_id=mxid,
-            medium=threepid["medium"],
-            address=threepid["address"],
-            id_server=id_server,
-        )
+        await self.store.remove_user_bound_threepid(mxid, medium, address, id_server)
 
         return changed
 
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 191529bd8e..b3be7a86f0 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -124,7 +124,6 @@ class InitialSyncHandler:
         as_client_event: bool = True,
         include_archived: bool = False,
     ) -> JsonDict:
-
         memberships = [Membership.INVITE, Membership.JOIN]
         if include_archived:
             memberships.append(Membership.LEAVE)
@@ -154,9 +153,8 @@ class InitialSyncHandler:
 
         tags_by_room = await self.store.get_tags_for_user(user_id)
 
-        account_data, account_data_by_room = await self.store.get_account_data_for_user(
-            user_id
-        )
+        account_data = await self.store.get_global_account_data_for_user(user_id)
+        account_data_by_room = await self.store.get_room_account_data_for_user(user_id)
 
         public_room_ids = await self.store.get_public_room_ids()
 
@@ -320,11 +318,9 @@ class InitialSyncHandler:
         )
         is_peeking = member_event_id is None
 
-        user_id = requester.user.to_string()
-
         if membership == Membership.JOIN:
             result = await self._room_initial_sync_joined(
-                user_id, room_id, pagin_config, membership, is_peeking
+                requester, room_id, pagin_config, membership, is_peeking
             )
         elif membership == Membership.LEAVE:
             # The member_event_id will always be available if membership is set
@@ -332,10 +328,16 @@ class InitialSyncHandler:
             assert member_event_id
 
             result = await self._room_initial_sync_parted(
-                user_id, room_id, pagin_config, membership, member_event_id, is_peeking
+                requester,
+                room_id,
+                pagin_config,
+                membership,
+                member_event_id,
+                is_peeking,
             )
 
         account_data_events = []
+        user_id = requester.user.to_string()
         tags = await self.store.get_tags_for_room(user_id, room_id)
         if tags:
             account_data_events.append(
@@ -352,7 +354,7 @@ class InitialSyncHandler:
 
     async def _room_initial_sync_parted(
         self,
-        user_id: str,
+        requester: Requester,
         room_id: str,
         pagin_config: PaginationConfig,
         membership: str,
@@ -371,13 +373,17 @@ class InitialSyncHandler:
         )
 
         messages = await filter_events_for_client(
-            self._storage_controllers, user_id, messages, is_peeking=is_peeking
+            self._storage_controllers,
+            requester.user.to_string(),
+            messages,
+            is_peeking=is_peeking,
         )
 
         start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
         end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token)
 
         time_now = self.clock.time_msec()
+        serialize_options = SerializeEventConfig(requester=requester)
 
         return {
             "membership": membership,
@@ -385,14 +391,18 @@ class InitialSyncHandler:
             "messages": {
                 "chunk": (
                     # Don't bundle aggregations as this is a deprecated API.
-                    self._event_serializer.serialize_events(messages, time_now)
+                    self._event_serializer.serialize_events(
+                        messages, time_now, config=serialize_options
+                    )
                 ),
                 "start": await start_token.to_string(self.store),
                 "end": await end_token.to_string(self.store),
             },
             "state": (
                 # Don't bundle aggregations as this is a deprecated API.
-                self._event_serializer.serialize_events(room_state.values(), time_now)
+                self._event_serializer.serialize_events(
+                    room_state.values(), time_now, config=serialize_options
+                )
             ),
             "presence": [],
             "receipts": [],
@@ -400,7 +410,7 @@ class InitialSyncHandler:
 
     async def _room_initial_sync_joined(
         self,
-        user_id: str,
+        requester: Requester,
         room_id: str,
         pagin_config: PaginationConfig,
         membership: str,
@@ -412,9 +422,12 @@ class InitialSyncHandler:
 
         # TODO: These concurrently
         time_now = self.clock.time_msec()
+        serialize_options = SerializeEventConfig(requester=requester)
         # Don't bundle aggregations as this is a deprecated API.
         state = self._event_serializer.serialize_events(
-            current_state.values(), time_now
+            current_state.values(),
+            time_now,
+            config=serialize_options,
         )
 
         now_token = self.hs.get_event_sources().get_current_token()
@@ -452,7 +465,10 @@ class InitialSyncHandler:
             if not receipts:
                 return []
 
-            return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
+            return ReceiptEventSource.filter_out_private_receipts(
+                receipts,
+                requester.user.to_string(),
+            )
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
             gather_results(
@@ -471,20 +487,23 @@ class InitialSyncHandler:
         )
 
         messages = await filter_events_for_client(
-            self._storage_controllers, user_id, messages, is_peeking=is_peeking
+            self._storage_controllers,
+            requester.user.to_string(),
+            messages,
+            is_peeking=is_peeking,
         )
 
         start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
         end_token = now_token
 
-        time_now = self.clock.time_msec()
-
         ret = {
             "room_id": room_id,
             "messages": {
                 "chunk": (
                     # Don't bundle aggregations as this is a deprecated API.
-                    self._event_serializer.serialize_events(messages, time_now)
+                    self._event_serializer.serialize_events(
+                        messages, time_now, config=serialize_options
+                    )
                 ),
                 "start": await start_token.to_string(self.store),
                 "end": await end_token.to_string(self.store),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 1c5fdca12a..29ec7e3544 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -38,6 +38,7 @@ from synapse.api.errors import (
     Codes,
     ConsentNotGivenError,
     NotFoundError,
+    PartialStateConflictError,
     ShadowBanError,
     SynapseError,
     UnstableSpecAuthError,
@@ -48,8 +49,8 @@ from synapse.api.urls import ConsentURIBuilder
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase, relation_from_event
 from synapse.events.builder import EventBuilder
-from synapse.events.snapshot import EventContext
-from synapse.events.utils import maybe_upsert_event_field
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
+from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
 from synapse.logging import opentracing
@@ -57,7 +58,6 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.replication.http.send_events import ReplicationSendEventsRestServlet
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.types import (
     MutableStateMap,
@@ -245,8 +245,11 @@ class MessageHandler:
                 )
                 room_state = room_state_events[membership_event_id]
 
-        now = self.clock.time_msec()
-        events = self._event_serializer.serialize_events(room_state.values(), now)
+        events = self._event_serializer.serialize_events(
+            room_state.values(),
+            self.clock.time_msec(),
+            config=SerializeEventConfig(requester=requester),
+        )
         return events
 
     async def _user_can_see_state_at_event(
@@ -499,9 +502,9 @@ class EventCreationHandler:
 
         self.request_ratelimiter = hs.get_request_ratelimiter()
 
-        # We arbitrarily limit concurrent event creation for a room to 5.
-        # This is to stop us from diverging history *too* much.
-        self.limiter = Linearizer(max_count=5, name="room_event_creation_limit")
+        # We limit concurrent event creation for a room to 1. This prevents state resolution
+        # from occurring when sending bursts of events to a local room
+        self.limiter = Linearizer(max_count=1, name="room_event_creation_limit")
 
         self._bulk_push_rule_evaluator = hs.get_bulk_push_rule_evaluator()
 
@@ -574,7 +577,7 @@ class EventCreationHandler:
         state_map: Optional[StateMap[str]] = None,
         for_batch: bool = False,
         current_state_group: Optional[int] = None,
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         """
         Given a dict from a client, create a new event. If bool for_batch is true, will
         create an event using the prev_event_ids, and will create an event context for
@@ -708,7 +711,7 @@ class EventCreationHandler:
 
         builder.internal_metadata.historical = historical
 
-        event, context = await self.create_new_client_event(
+        event, unpersisted_context = await self.create_new_client_event(
             builder=builder,
             requester=requester,
             allow_no_prev_events=allow_no_prev_events,
@@ -737,7 +740,7 @@ class EventCreationHandler:
                 assert state_map is not None
                 prev_event_id = state_map.get((EventTypes.Member, event.sender))
             else:
-                prev_state_ids = await context.get_prev_state_ids(
+                prev_state_ids = await unpersisted_context.get_prev_state_ids(
                     StateFilter.from_types([(EventTypes.Member, None)])
                 )
                 prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
@@ -762,8 +765,7 @@ class EventCreationHandler:
                 )
 
         self.validator.validate_new(event, self.config)
-
-        return event, context
+        return event, unpersisted_context
 
     async def _is_exempt_from_privacy_policy(
         self, builder: EventBuilder, requester: Requester
@@ -1003,7 +1005,7 @@ class EventCreationHandler:
         max_retries = 5
         for i in range(max_retries):
             try:
-                event, context = await self.create_event(
+                event, unpersisted_context = await self.create_event(
                     requester,
                     event_dict,
                     txn_id=txn_id,
@@ -1014,6 +1016,7 @@ class EventCreationHandler:
                     historical=historical,
                     depth=depth,
                 )
+                context = await unpersisted_context.persist(event)
 
                 assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
                     event.sender,
@@ -1083,13 +1086,14 @@ class EventCreationHandler:
         state_map: Optional[StateMap[str]] = None,
         for_batch: bool = False,
         current_state_group: Optional[int] = None,
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         """Create a new event for a local client. If bool for_batch is true, will
         create an event using the prev_event_ids, and will create an event context for
         the event using the parameters state_map and current_state_group, thus these parameters
         must be provided in this case if for_batch is True. The subsequently created event
         and context are suitable for being batched up and bulk persisted to the database
-        with other similarly created events.
+        with other similarly created events. Note that this returns an UnpersistedEventContext,
+        which must be converted to an EventContext before it can be sent to the DB.
 
         Args:
             builder:
@@ -1131,7 +1135,7 @@ class EventCreationHandler:
                 batch persisting
 
         Returns:
-            Tuple of created event, context
+            Tuple of created event, UnpersistedEventContext
         """
         # Strip down the state_event_ids to only what we need to auth the event.
         # For example, we don't need extra m.room.member that don't match event.sender
@@ -1187,14 +1191,20 @@ class EventCreationHandler:
         if for_batch:
             assert prev_event_ids is not None
             assert state_map is not None
-            assert current_state_group is not None
             auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
             event = await builder.build(
                 prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
             )
-            context = await self.state.compute_event_context_for_batched(
-                event, state_map, current_state_group
+
+            context: UnpersistedEventContextBase = (
+                await self.state.calculate_context_info(
+                    event,
+                    state_ids_before_event=state_map,
+                    partial_state=False,
+                    state_group_before_event=current_state_group,
+                )
             )
+
         else:
             event = await builder.build(
                 prev_event_ids=prev_event_ids,
@@ -1244,16 +1254,17 @@ class EventCreationHandler:
 
                     state_map_for_event[(data.event_type, data.state_key)] = state_id
 
-                context = await self.state.compute_event_context(
+                # TODO(faster_joins): check how MSC2716 works and whether we can have
+                #   partial state here
+                #   https://github.com/matrix-org/synapse/issues/13003
+                context = await self.state.calculate_context_info(
                     event,
                     state_ids_before_event=state_map_for_event,
-                    # TODO(faster_joins): check how MSC2716 works and whether we can have
-                    #   partial state here
-                    #   https://github.com/matrix-org/synapse/issues/13003
                     partial_state=False,
                 )
+
             else:
-                context = await self.state.compute_event_context(event)
+                context = await self.state.calculate_context_info(event)
 
         if requester:
             context.app_service = requester.app_service
@@ -1326,7 +1337,11 @@ class EventCreationHandler:
                 relation.parent_id, event.type, aggregation_key, event.sender
             )
             if already_exists:
-                raise SynapseError(400, "Can't send same reaction twice")
+                raise SynapseError(
+                    400,
+                    "Can't send same reaction twice",
+                    errcode=Codes.DUPLICATE_ANNOTATION,
+                )
 
         # Don't attempt to start a thread if the parent event is a relation.
         elif relation.rel_type == RelationTypes.THREAD:
@@ -2031,7 +2046,7 @@ class EventCreationHandler:
                 max_retries = 5
                 for i in range(max_retries):
                     try:
-                        event, context = await self.create_event(
+                        event, unpersisted_context = await self.create_event(
                             requester,
                             {
                                 "type": EventTypes.Dummy,
@@ -2040,6 +2055,7 @@ class EventCreationHandler:
                                 "sender": user_id,
                             },
                         )
+                        context = await unpersisted_context.persist(event)
 
                         event.internal_metadata.proactively_send = False
 
@@ -2082,9 +2098,9 @@ class EventCreationHandler:
 
     async def _rebuild_event_after_third_party_rules(
         self, third_party_result: dict, original_event: EventBase
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         # the third_party_event_rules want to replace the event.
-        # we do some basic checks, and then return the replacement event and context.
+        # we do some basic checks, and then return the replacement event.
 
         # Construct a new EventBuilder and validate it, which helps with the
         # rest of these checks.
@@ -2138,5 +2154,6 @@ class EventCreationHandler:
 
         # we rebuild the event context, to be on the safe side. If nothing else,
         # delta_ids might need an update.
-        context = await self.state.compute_event_context(event)
+        context = await self.state.calculate_context_info(event)
+
         return event, context
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index f2095ce164..d7085c001d 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -579,7 +579,9 @@ class PaginationHandler:
 
         time_now = self.clock.time_msec()
 
-        serialize_options = SerializeEventConfig(as_client_event=as_client_event)
+        serialize_options = SerializeEventConfig(
+            as_client_event=as_client_event, requester=requester
+        )
 
         chunk = {
             "chunk": (
@@ -681,7 +683,7 @@ class PaginationHandler:
 
                     await self._storage_controllers.purge_events.purge_room(room_id)
 
-            logger.info("complete")
+            logger.info("purge complete for room_id %s", room_id)
             self._delete_by_id[delete_id].status = DeleteStatus.STATUS_COMPLETE
         except Exception:
             f = Failure()
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index b4c0577e4d..b289b6cb23 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -777,7 +777,6 @@ class PresenceHandler(BasePresenceHandler):
         )
 
         if self.unpersisted_users_changes:
-
             await self.store.update_presence(
                 [
                     self.user_to_current_state[user_id]
@@ -823,7 +822,6 @@ class PresenceHandler(BasePresenceHandler):
         now = self.clock.time_msec()
 
         with Measure(self.clock, "presence_update_states"):
-
             # NOTE: We purposefully don't await between now and when we've
             # calculated what we want to do with the new states, to avoid races.
 
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 04c61ae3dd..2bacdebfb5 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple
+from typing import TYPE_CHECKING, Iterable, List, Optional, Sequence, Tuple
 
 from synapse.api.constants import EduTypes, ReceiptTypes
 from synapse.appservice import ApplicationService
@@ -189,7 +189,7 @@ class ReceiptEventSource(EventSource[int, JsonDict]):
 
     @staticmethod
     def filter_out_private_receipts(
-        rooms: List[JsonDict], user_id: str
+        rooms: Sequence[JsonDict], user_id: str
     ) -> List[JsonDict]:
         """
         Filters a list of serialized receipts (as returned by /sync and /initialSync)
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c611efb760..e4e506e62c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -476,7 +476,7 @@ class RegistrationHandler:
                     # create room expects the localpart of the room alias
                     config["room_alias_name"] = room_alias.localpart
 
-                    info, _ = await room_creation_handler.create_room(
+                    room_id, _, _ = await room_creation_handler.create_room(
                         fake_requester,
                         config=config,
                         ratelimit=False,
@@ -490,7 +490,7 @@ class RegistrationHandler:
                                 user_id, authenticated_entity=self._server_name
                             ),
                             target=UserID.from_string(user_id),
-                            room_id=info["room_id"],
+                            room_id=room_id,
                             # Since it was just created, there are no remote hosts.
                             remote_room_hosts=[],
                             action="join",
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 0fb15391e0..1d09fdf135 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -20,6 +20,7 @@ import attr
 from synapse.api.constants import Direction, EventTypes, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
+from synapse.events.utils import SerializeEventConfig
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import trace
 from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
@@ -60,13 +61,12 @@ class BundledAggregations:
     Some values require additional processing during serialization.
     """
 
-    annotations: Optional[JsonDict] = None
     references: Optional[JsonDict] = None
     replace: Optional[EventBase] = None
     thread: Optional[_ThreadAggregation] = None
 
     def __bool__(self) -> bool:
-        return bool(self.annotations or self.references or self.replace or self.thread)
+        return bool(self.references or self.replace or self.thread)
 
 
 class RelationsHandler:
@@ -152,16 +152,23 @@ class RelationsHandler:
         )
 
         now = self._clock.time_msec()
+        serialize_options = SerializeEventConfig(requester=requester)
         return_value: JsonDict = {
             "chunk": self._event_serializer.serialize_events(
-                events, now, bundle_aggregations=aggregations
+                events,
+                now,
+                bundle_aggregations=aggregations,
+                config=serialize_options,
             ),
         }
         if include_original_event:
             # Do not bundle aggregations when retrieving the original event because
             # we want the content before relations are applied to it.
             return_value["original_event"] = self._event_serializer.serialize_event(
-                event, now, bundle_aggregations=None
+                event,
+                now,
+                bundle_aggregations=None,
+                config=serialize_options,
             )
 
         if next_token:
@@ -227,67 +234,6 @@ class RelationsHandler:
                     e.msg,
                 )
 
-    async def get_annotations_for_events(
-        self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
-    ) -> Dict[str, List[JsonDict]]:
-        """Get a list of annotations to the given events, grouped by event type and
-        aggregation key, sorted by count.
-
-        This is used e.g. to get the what and how many reactions have happened
-        on an event.
-
-        Args:
-            event_ids: Fetch events that relate to these event IDs.
-            ignored_users: The users ignored by the requesting user.
-
-        Returns:
-            A map of event IDs to a list of groups of annotations that match.
-            Each entry is a dict with `type`, `key` and `count` fields.
-        """
-        # Get the base results for all users.
-        full_results = await self._main_store.get_aggregation_groups_for_events(
-            event_ids
-        )
-
-        # Avoid additional logic if there are no ignored users.
-        if not ignored_users:
-            return {
-                event_id: results
-                for event_id, results in full_results.items()
-                if results
-            }
-
-        # Then subtract off the results for any ignored users.
-        ignored_results = await self._main_store.get_aggregation_groups_for_users(
-            [event_id for event_id, results in full_results.items() if results],
-            ignored_users,
-        )
-
-        filtered_results = {}
-        for event_id, results in full_results.items():
-            # If no annotations, skip.
-            if not results:
-                continue
-
-            # If there are not ignored results for this event, copy verbatim.
-            if event_id not in ignored_results:
-                filtered_results[event_id] = results
-                continue
-
-            # Otherwise, subtract out the ignored results.
-            event_ignored_results = ignored_results[event_id]
-            for result in results:
-                key = (result["type"], result["key"])
-                if key in event_ignored_results:
-                    # Ensure to not modify the cache.
-                    result = result.copy()
-                    result["count"] -= event_ignored_results[key]
-                    if result["count"] <= 0:
-                        continue
-                filtered_results.setdefault(event_id, []).append(result)
-
-        return filtered_results
-
     async def get_references_for_events(
         self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
     ) -> Dict[str, List[_RelatedEvent]]:
@@ -531,17 +477,6 @@ class RelationsHandler:
                 # (as that is what makes it part of the thread).
                 relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
 
-        async def _fetch_annotations() -> None:
-            """Fetch any annotations (ie, reactions) to bundle with this event."""
-            annotations_by_event_id = await self.get_annotations_for_events(
-                events_by_id.keys(), ignored_users=ignored_users
-            )
-            for event_id, annotations in annotations_by_event_id.items():
-                if annotations:
-                    results.setdefault(event_id, BundledAggregations()).annotations = {
-                        "chunk": annotations
-                    }
-
         async def _fetch_references() -> None:
             """Fetch any references to bundle with this event."""
             references_by_event_id = await self.get_references_for_events(
@@ -575,7 +510,6 @@ class RelationsHandler:
         await make_deferred_yieldable(
             gather_results(
                 (
-                    run_in_background(_fetch_annotations),
                     run_in_background(_fetch_references),
                     run_in_background(_fetch_edits),
                 )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 7ba7c4ff07..be120cb12f 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -43,6 +43,7 @@ from synapse.api.errors import (
     Codes,
     LimitExceededError,
     NotFoundError,
+    PartialStateConflictError,
     StoreError,
     SynapseError,
 )
@@ -50,11 +51,11 @@ from synapse.api.filtering import Filter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext
 from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
 from synapse.rest.admin._base import assert_user_is_admin
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.streams import EventSource
 from synapse.types import (
     JsonDict,
@@ -211,7 +212,7 @@ class RoomCreationHandler:
                 # the required power level to send the tombstone event.
                 (
                     tombstone_event,
-                    tombstone_context,
+                    tombstone_unpersisted_context,
                 ) = await self.event_creation_handler.create_event(
                     requester,
                     {
@@ -225,6 +226,9 @@ class RoomCreationHandler:
                         },
                     },
                 )
+                tombstone_context = await tombstone_unpersisted_context.persist(
+                    tombstone_event
+                )
                 validate_event_for_room_version(tombstone_event)
                 await self._event_auth_handler.check_auth_rules_from_context(
                     tombstone_event
@@ -565,7 +569,7 @@ class RoomCreationHandler:
             new_room_id,
             # we expect to override all the presets with initial_state, so this is
             # somewhat arbitrary.
-            preset_config=RoomCreationPreset.PRIVATE_CHAT,
+            room_config={"preset": RoomCreationPreset.PRIVATE_CHAT},
             invite_list=[],
             initial_state=initial_state,
             creation_content=creation_content,
@@ -690,13 +694,14 @@ class RoomCreationHandler:
         config: JsonDict,
         ratelimit: bool = True,
         creator_join_profile: Optional[JsonDict] = None,
-    ) -> Tuple[dict, int]:
+    ) -> Tuple[str, Optional[RoomAlias], int]:
         """Creates a new room.
 
         Args:
-            requester:
-                The user who requested the room creation.
-            config : A dict of configuration options.
+            requester: The user who requested the room creation.
+            config: A dict of configuration options. This will be the body of
+                a /createRoom request; see
+                https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
             ratelimit: set to False to disable the rate limiter
 
             creator_join_profile:
@@ -707,14 +712,17 @@ class RoomCreationHandler:
                 `avatar_url` and/or `displayname`.
 
         Returns:
-                First, a dict containing the keys `room_id` and, if an alias
-                was, requested, `room_alias`. Secondly, the stream_id of the
-                last persisted event.
+            A 3-tuple containing:
+                - the room ID;
+                - if requested, the room alias, otherwise None; and
+                - the `stream_id` of the last persisted event.
         Raises:
-            SynapseError if the room ID couldn't be stored, 3pid invitation config
-            validation failed, or something went horribly wrong.
-            ResourceLimitError if server is blocked to some resource being
-            exceeded
+            SynapseError:
+                if the room ID couldn't be stored, 3pid invitation config
+                validation failed, or something went horribly wrong.
+            ResourceLimitError:
+                if server is blocked to some resource being
+                exceeded
         """
         user_id = requester.user.to_string()
 
@@ -864,9 +872,11 @@ class RoomCreationHandler:
         )
 
         # Check whether this visibility value is blocked by a third party module
-        allowed_by_third_party_rules = await (
-            self.third_party_event_rules.check_visibility_can_be_modified(
-                room_id, visibility
+        allowed_by_third_party_rules = (
+            await (
+                self.third_party_event_rules.check_visibility_can_be_modified(
+                    room_id, visibility
+                )
             )
         )
         if not allowed_by_third_party_rules:
@@ -894,13 +904,6 @@ class RoomCreationHandler:
                 check_membership=False,
             )
 
-        preset_config = config.get(
-            "preset",
-            RoomCreationPreset.PRIVATE_CHAT
-            if visibility == "private"
-            else RoomCreationPreset.PUBLIC_CHAT,
-        )
-
         raw_initial_state = config.get("initial_state", [])
 
         initial_state = OrderedDict()
@@ -919,7 +922,7 @@ class RoomCreationHandler:
         ) = await self._send_events_for_new_room(
             requester,
             room_id,
-            preset_config=preset_config,
+            room_config=config,
             invite_list=invite_list,
             initial_state=initial_state,
             creation_content=creation_content,
@@ -928,48 +931,6 @@ class RoomCreationHandler:
             creator_join_profile=creator_join_profile,
         )
 
-        if "name" in config:
-            name = config["name"]
-            (
-                name_event,
-                last_stream_id,
-            ) = await self.event_creation_handler.create_and_send_nonmember_event(
-                requester,
-                {
-                    "type": EventTypes.Name,
-                    "room_id": room_id,
-                    "sender": user_id,
-                    "state_key": "",
-                    "content": {"name": name},
-                },
-                ratelimit=False,
-                prev_event_ids=[last_sent_event_id],
-                depth=depth,
-            )
-            last_sent_event_id = name_event.event_id
-            depth += 1
-
-        if "topic" in config:
-            topic = config["topic"]
-            (
-                topic_event,
-                last_stream_id,
-            ) = await self.event_creation_handler.create_and_send_nonmember_event(
-                requester,
-                {
-                    "type": EventTypes.Topic,
-                    "room_id": room_id,
-                    "sender": user_id,
-                    "state_key": "",
-                    "content": {"topic": topic},
-                },
-                ratelimit=False,
-                prev_event_ids=[last_sent_event_id],
-                depth=depth,
-            )
-            last_sent_event_id = topic_event.event_id
-            depth += 1
-
         # we avoid dropping the lock between invites, as otherwise joins can
         # start coming in and making the createRoom slow.
         #
@@ -1024,11 +985,6 @@ class RoomCreationHandler:
             last_sent_event_id = member_event_id
             depth += 1
 
-        result = {"room_id": room_id}
-
-        if room_alias:
-            result["room_alias"] = room_alias.to_string()
-
         # Always wait for room creation to propagate before returning
         await self._replication.wait_for_stream_position(
             self.hs.config.worker.events_shard_config.get_instance(room_id),
@@ -1036,13 +992,13 @@ class RoomCreationHandler:
             last_stream_id,
         )
 
-        return result, last_stream_id
+        return room_id, room_alias, last_stream_id
 
     async def _send_events_for_new_room(
         self,
         creator: Requester,
         room_id: str,
-        preset_config: str,
+        room_config: JsonDict,
         invite_list: List[str],
         initial_state: MutableStateMap,
         creation_content: JsonDict,
@@ -1059,11 +1015,33 @@ class RoomCreationHandler:
 
         Rate limiting should already have been applied by this point.
 
+        Args:
+            creator:
+                the user requesting the room creation
+            room_id:
+                room id for the room being created
+            room_config:
+                A dict of configuration options. This will be the body of
+                a /createRoom request; see
+                https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
+            invite_list:
+                a list of user ids to invite to the room
+            initial_state:
+                A list of state events to set in the new room.
+            creation_content:
+                Extra keys, such as m.federate, to be added to the content of the m.room.create event.
+            room_alias:
+                alias for the room
+            power_level_content_override:
+                The power level content to override in the default power level event.
+            creator_join_profile:
+                Set to override the displayname and avatar for the creating
+                user in this room.
+
         Returns:
             A tuple containing the stream ID, event ID and depth of the last
             event sent to the room.
         """
-
         creator_id = creator.user.to_string()
         event_keys = {"room_id": room_id, "sender": creator_id, "state_key": ""}
         depth = 1
@@ -1074,9 +1052,6 @@ class RoomCreationHandler:
         # created (but not persisted to the db) to determine state for future created events
         # (as this info can't be pulled from the db)
         state_map: MutableStateMap[str] = {}
-        # current_state_group of last event created. Used for computing event context of
-        # events to be batched
-        current_state_group = None
 
         def create_event_dict(etype: str, content: JsonDict, **kwargs: Any) -> JsonDict:
             e = {"type": etype, "content": content}
@@ -1091,7 +1066,7 @@ class RoomCreationHandler:
             content: JsonDict,
             for_batch: bool,
             **kwargs: Any,
-        ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
+        ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
             """
             Creates an event and associated event context.
             Args:
@@ -1110,20 +1085,33 @@ class RoomCreationHandler:
 
             event_dict = create_event_dict(etype, content, **kwargs)
 
-            new_event, new_context = await self.event_creation_handler.create_event(
+            (
+                new_event,
+                new_unpersisted_context,
+            ) = await self.event_creation_handler.create_event(
                 creator,
                 event_dict,
                 prev_event_ids=prev_event,
                 depth=depth,
-                state_map=state_map,
+                # Take a copy to ensure each event gets a unique copy of
+                # state_map since it is modified below.
+                state_map=dict(state_map),
                 for_batch=for_batch,
-                current_state_group=current_state_group,
             )
+
             depth += 1
             prev_event = [new_event.event_id]
             state_map[(new_event.type, new_event.state_key)] = new_event.event_id
 
-            return new_event, new_context
+            return new_event, new_unpersisted_context
+
+        visibility = room_config.get("visibility", "private")
+        preset_config = room_config.get(
+            "preset",
+            RoomCreationPreset.PRIVATE_CHAT
+            if visibility == "private"
+            else RoomCreationPreset.PUBLIC_CHAT,
+        )
 
         try:
             config = self._presets_dict[preset_config]
@@ -1133,10 +1121,10 @@ class RoomCreationHandler:
             )
 
         creation_content.update({"creator": creator_id})
-        creation_event, creation_context = await create_event(
+        creation_event, unpersisted_creation_context = await create_event(
             EventTypes.Create, creation_content, False
         )
-
+        creation_context = await unpersisted_creation_context.persist(creation_event)
         logger.debug("Sending %s in new room", EventTypes.Member)
         ev = await self.event_creation_handler.handle_new_client_event(
             requester=creator,
@@ -1180,7 +1168,6 @@ class RoomCreationHandler:
             power_event, power_context = await create_event(
                 EventTypes.PowerLevels, pl_content, True
             )
-            current_state_group = power_context._state_group
             events_to_send.append((power_event, power_context))
         else:
             power_level_content: JsonDict = {
@@ -1229,14 +1216,12 @@ class RoomCreationHandler:
                 power_level_content,
                 True,
             )
-            current_state_group = pl_context._state_group
             events_to_send.append((pl_event, pl_context))
 
         if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
             room_alias_event, room_alias_context = await create_event(
                 EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
             )
-            current_state_group = room_alias_context._state_group
             events_to_send.append((room_alias_event, room_alias_context))
 
         if (EventTypes.JoinRules, "") not in initial_state:
@@ -1245,7 +1230,6 @@ class RoomCreationHandler:
                 {"join_rule": config["join_rules"]},
                 True,
             )
-            current_state_group = join_rules_context._state_group
             events_to_send.append((join_rules_event, join_rules_context))
 
         if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
@@ -1254,7 +1238,6 @@ class RoomCreationHandler:
                 {"history_visibility": config["history_visibility"]},
                 True,
             )
-            current_state_group = visibility_context._state_group
             events_to_send.append((visibility_event, visibility_context))
 
         if config["guest_can_join"]:
@@ -1264,14 +1247,12 @@ class RoomCreationHandler:
                     {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
                     True,
                 )
-                current_state_group = guest_access_context._state_group
                 events_to_send.append((guest_access_event, guest_access_context))
 
         for (etype, state_key), content in initial_state.items():
             event, context = await create_event(
                 etype, content, True, state_key=state_key
             )
-            current_state_group = context._state_group
             events_to_send.append((event, context))
 
         if config["encrypted"]:
@@ -1283,9 +1264,34 @@ class RoomCreationHandler:
             )
             events_to_send.append((encryption_event, encryption_context))
 
+        if "name" in room_config:
+            name = room_config["name"]
+            name_event, name_context = await create_event(
+                EventTypes.Name,
+                {"name": name},
+                True,
+            )
+            events_to_send.append((name_event, name_context))
+
+        if "topic" in room_config:
+            topic = room_config["topic"]
+            topic_event, topic_context = await create_event(
+                EventTypes.Topic,
+                {"topic": topic},
+                True,
+            )
+            events_to_send.append((topic_event, topic_context))
+
+        datastore = self.hs.get_datastores().state
+        events_and_context = (
+            await UnpersistedEventContext.batch_persist_unpersisted_contexts(
+                events_to_send, room_id, current_state_group, datastore
+            )
+        )
+
         last_event = await self.event_creation_handler.handle_new_client_event(
             creator,
-            events_to_send,
+            events_and_context,
             ignore_shadow_ban=True,
             ratelimit=False,
         )
@@ -1825,7 +1831,7 @@ class RoomShutdownHandler:
                 new_room_user_id, authenticated_entity=requester_user_id
             )
 
-            info, stream_id = await self._room_creation_handler.create_room(
+            new_room_id, _, stream_id = await self._room_creation_handler.create_room(
                 room_creator_requester,
                 config={
                     "preset": RoomCreationPreset.PUBLIC_CHAT,
@@ -1834,7 +1840,6 @@ class RoomShutdownHandler:
                 },
                 ratelimit=False,
             )
-            new_room_id = info["room_id"]
 
             logger.info(
                 "Shutting down room %r, joining to new room: %r", room_id, new_room_id
@@ -1887,6 +1892,7 @@ class RoomShutdownHandler:
 
                 # Join users to new room
                 if new_room_user_id:
+                    assert new_room_id is not None
                     await self.room_member_handler.update_membership(
                         requester=target_requester,
                         target=target_requester.user,
@@ -1919,6 +1925,7 @@ class RoomShutdownHandler:
 
             aliases_for_room = await self.store.get_aliases_for_room(room_id)
 
+            assert new_room_id is not None
             await self.store.update_aliases_for_room(
                 room_id, new_room_id, requester_user_id
             )
@@ -1928,6 +1935,6 @@ class RoomShutdownHandler:
         return {
             "kicked_users": kicked_users,
             "failed_to_kick_users": failed_to_kick_users,
-            "local_aliases": aliases_for_room,
+            "local_aliases": list(aliases_for_room),
             "new_room_id": new_room_id,
         }
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index c73d2adaad..bf9df60218 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -327,7 +327,7 @@ class RoomBatchHandler:
             # Mark all events as historical
             event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
 
-            event, context = await self.event_creation_handler.create_event(
+            event, unpersisted_context = await self.event_creation_handler.create_event(
                 await self.create_requester_for_user_id_from_app_service(
                     ev["sender"], app_service_requester.app_service
                 ),
@@ -345,7 +345,7 @@ class RoomBatchHandler:
                 historical=True,
                 depth=inherited_depth,
             )
-
+            context = await unpersisted_context.persist(event)
             assert context._state_group
 
             # Normally this is done when persisting the event but we have to
@@ -374,7 +374,7 @@ class RoomBatchHandler:
         # correct stream_ordering as they are backfilled (which decrements).
         # Events are sorted by (topological_ordering, stream_ordering)
         # where topological_ordering is just depth.
-        for (event, context) in reversed(events_to_persist):
+        for event, context in reversed(events_to_persist):
             # This call can't raise `PartialStateConflictError` since we forbid
             # use of the historical batch API during partial state
             await self.event_creation_handler.handle_new_client_event(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index d236cc09b5..509c557889 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -26,7 +26,13 @@ from synapse.api.constants import (
     GuestAccess,
     Membership,
 )
-from synapse.api.errors import AuthError, Codes, ShadowBanError, SynapseError
+from synapse.api.errors import (
+    AuthError,
+    Codes,
+    PartialStateConflictError,
+    ShadowBanError,
+    SynapseError,
+)
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.event_auth import get_named_level, get_power_level_event
 from synapse.events import EventBase
@@ -34,7 +40,6 @@ from synapse.events.snapshot import EventContext
 from synapse.handlers.profile import MAX_AVATAR_URL_LEN, MAX_DISPLAYNAME_LEN
 from synapse.logging import opentracing
 from synapse.module_api import NOT_SPAM
-from synapse.storage.databases.main.events import PartialStateConflictError
 from synapse.types import (
     JsonDict,
     Requester,
@@ -56,6 +61,13 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+class NoKnownServersError(SynapseError):
+    """No server already resident to the room was provided to the join/knock operation."""
+
+    def __init__(self, msg: str = "No known servers"):
+        super().__init__(404, msg)
+
+
 class RoomMemberHandler(metaclass=abc.ABCMeta):
     # TODO(paul): This handler currently contains a messy conflation of
     #   low-level API that works on UserID objects and so on, and REST-level
@@ -185,12 +197,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             room_id: Room that we are trying to join
             user: User who is trying to join
             content: A dict that should be used as the content of the join event.
+
+        Raises:
+            NoKnownServersError: if remote_room_hosts does not contain a server joined to
+                the room.
         """
         raise NotImplementedError()
 
     @abc.abstractmethod
     async def remote_knock(
         self,
+        requester: Requester,
         remote_room_hosts: List[str],
         room_id: str,
         user: UserID,
@@ -398,7 +415,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         max_retries = 5
         for i in range(max_retries):
             try:
-                event, context = await self.event_creation_handler.create_event(
+                (
+                    event,
+                    unpersisted_context,
+                ) = await self.event_creation_handler.create_event(
                     requester,
                     {
                         "type": EventTypes.Member,
@@ -419,7 +439,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     outlier=outlier,
                     historical=historical,
                 )
-
+                context = await unpersisted_context.persist(event)
                 prev_state_ids = await context.get_prev_state_ids(
                     StateFilter.from_types([(EventTypes.Member, None)])
                 )
@@ -484,7 +504,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             user_id: The user's ID.
         """
         # Retrieve user account data for predecessor room
-        user_account_data, _ = await self.store.get_account_data_for_user(user_id)
+        user_account_data = await self.store.get_global_account_data_for_user(user_id)
 
         # Copy direct message state if applicable
         direct_rooms = user_account_data.get(AccountDataTypes.DIRECT, {})
@@ -823,14 +843,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         latest_event_ids = await self.store.get_prev_events_for_room(room_id)
 
-        state_before_join = await self.state_handler.compute_state_after_events(
-            room_id, latest_event_ids
+        is_partial_state_room = await self.store.is_partial_state_room(room_id)
+        partial_state_before_join = await self.state_handler.compute_state_after_events(
+            room_id, latest_event_ids, await_full_state=False
         )
+        # `is_partial_state_room` also indicates whether `partial_state_before_join` is
+        # partial.
 
         # TODO: Refactor into dictionary of explicitly allowed transitions
         # between old and new state, with specific error messages for some
         # transitions and generic otherwise
-        old_state_id = state_before_join.get((EventTypes.Member, target.to_string()))
+        old_state_id = partial_state_before_join.get(
+            (EventTypes.Member, target.to_string())
+        )
         if old_state_id:
             old_state = await self.store.get_event(old_state_id, allow_none=True)
             old_membership = old_state.content.get("membership") if old_state else None
@@ -881,11 +906,11 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             if action == "kick":
                 raise AuthError(403, "The target user is not in the room")
 
-        is_host_in_room = await self._is_host_in_room(state_before_join)
+        is_host_in_room = await self._is_host_in_room(partial_state_before_join)
 
         if effective_membership_state == Membership.JOIN:
             if requester.is_guest:
-                guest_can_join = await self._can_guest_join(state_before_join)
+                guest_can_join = await self._can_guest_join(partial_state_before_join)
                 if not guest_can_join:
                     # This should be an auth check, but guests are a local concept,
                     # so don't really fit into the general auth process.
@@ -927,8 +952,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 room_id,
                 remote_room_hosts,
                 content,
+                is_partial_state_room,
                 is_host_in_room,
-                state_before_join,
+                partial_state_before_join,
             )
             if remote_join:
                 if ratelimit:
@@ -1048,7 +1074,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     )
 
                 return await self.remote_knock(
-                    remote_room_hosts, room_id, target, content
+                    requester, remote_room_hosts, room_id, target, content
                 )
 
         return await self._local_membership_update(
@@ -1073,8 +1099,9 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         room_id: str,
         remote_room_hosts: List[str],
         content: JsonDict,
+        is_partial_state_room: bool,
         is_host_in_room: bool,
-        state_before_join: StateMap[str],
+        partial_state_before_join: StateMap[str],
     ) -> Tuple[bool, List[str]]:
         """
         Check whether the server should do a remote join (as opposed to a local
@@ -1093,9 +1120,12 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             remote_room_hosts: A list of remote room hosts.
             content: The content to use as the event body of the join. This may
                 be modified.
-            is_host_in_room: True if the host is in the room.
-            state_before_join: The state before the join event (i.e. the resolution of
-                the states after its parent events).
+            is_partial_state_room: `True` if the server currently doesn't hold the full
+                state of the room.
+            is_host_in_room: `True` if the host is in the room.
+            partial_state_before_join: The state before the join event (i.e. the
+                resolution of the states after its parent events). May be full or
+                partial state, depending on `is_partial_state_room`.
 
         Returns:
             A tuple of:
@@ -1109,6 +1139,23 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         if not is_host_in_room:
             return True, remote_room_hosts
 
+        prev_member_event_id = partial_state_before_join.get(
+            (EventTypes.Member, user_id), None
+        )
+        previous_membership = None
+        if prev_member_event_id:
+            prev_member_event = await self.store.get_event(prev_member_event_id)
+            previous_membership = prev_member_event.membership
+
+        # If we are not fully joined yet, and the target is not already in the room,
+        # let's do a remote join so another server with the full state can validate
+        # that the user has not been banned for example.
+        # We could just accept the join and wait for state res to resolve that later on
+        # but we would then leak room history to this person until then, which is pretty
+        # bad.
+        if is_partial_state_room and previous_membership != Membership.JOIN:
+            return True, remote_room_hosts
+
         # If the host is in the room, but not one of the authorised hosts
         # for restricted join rules, a remote join must be used.
         room_version = await self.store.get_room_version(room_id)
@@ -1116,21 +1163,19 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         # If restricted join rules are not being used, a local join can always
         # be used.
         if not await self.event_auth_handler.has_restricted_join_rules(
-            state_before_join, room_version
+            partial_state_before_join, room_version
         ):
             return False, []
 
         # If the user is invited to the room or already joined, the join
         # event can always be issued locally.
-        prev_member_event_id = state_before_join.get((EventTypes.Member, user_id), None)
-        prev_member_event = None
-        if prev_member_event_id:
-            prev_member_event = await self.store.get_event(prev_member_event_id)
-            if prev_member_event.membership in (
-                Membership.JOIN,
-                Membership.INVITE,
-            ):
-                return False, []
+        if previous_membership in (Membership.JOIN, Membership.INVITE):
+            return False, []
+
+        # All the partial state cases are covered above. We have been given the full
+        # state of the room.
+        assert not is_partial_state_room
+        state_before_join = partial_state_before_join
 
         # If the local host has a user who can issue invites, then a local
         # join can be done.
@@ -1154,7 +1199,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
 
         # Ensure the member should be allowed access via membership in a room.
         await self.event_auth_handler.check_restricted_join_rules(
-            state_before_join, room_version, user_id, prev_member_event
+            state_before_join, room_version, user_id, previous_membership
         )
 
         # If this is going to be a local join, additional information must
@@ -1304,11 +1349,17 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                 if prev_member_event.membership == Membership.JOIN:
                     await self._user_left_room(target_user, room_id)
 
-    async def _can_guest_join(self, current_state_ids: StateMap[str]) -> bool:
+    async def _can_guest_join(self, partial_current_state_ids: StateMap[str]) -> bool:
         """
         Returns whether a guest can join a room based on its current state.
+
+        Args:
+            partial_current_state_ids: The current state of the room. May be full or
+                partial state.
         """
-        guest_access_id = current_state_ids.get((EventTypes.GuestAccess, ""), None)
+        guest_access_id = partial_current_state_ids.get(
+            (EventTypes.GuestAccess, ""), None
+        )
         if not guest_access_id:
             return False
 
@@ -1634,19 +1685,25 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         )
         return event, stream_id
 
-    async def _is_host_in_room(self, current_state_ids: StateMap[str]) -> bool:
+    async def _is_host_in_room(self, partial_current_state_ids: StateMap[str]) -> bool:
+        """Returns whether the homeserver is in the room based on its current state.
+
+        Args:
+            partial_current_state_ids: The current state of the room. May be full or
+                partial state.
+        """
         # Have we just created the room, and is this about to be the very
         # first member event?
-        create_event_id = current_state_ids.get(("m.room.create", ""))
-        if len(current_state_ids) == 1 and create_event_id:
+        create_event_id = partial_current_state_ids.get(("m.room.create", ""))
+        if len(partial_current_state_ids) == 1 and create_event_id:
             # We can only get here if we're in the process of creating the room
             return True
 
-        for etype, state_key in current_state_ids:
+        for etype, state_key in partial_current_state_ids:
             if etype != EventTypes.Member or not self.hs.is_mine_id(state_key):
                 continue
 
-            event_id = current_state_ids[(etype, state_key)]
+            event_id = partial_current_state_ids[(etype, state_key)]
             event = await self.store.get_event(event_id, allow_none=True)
             if not event:
                 continue
@@ -1715,8 +1772,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]
 
         if len(remote_room_hosts) == 0:
-            raise SynapseError(
-                404,
+            raise NoKnownServersError(
                 "Can't join remote room because no servers "
                 "that are in the room have been provided.",
             )
@@ -1892,7 +1948,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         max_retries = 5
         for i in range(max_retries):
             try:
-                event, context = await self.event_creation_handler.create_event(
+                (
+                    event,
+                    unpersisted_context,
+                ) = await self.event_creation_handler.create_event(
                     requester,
                     event_dict,
                     txn_id=txn_id,
@@ -1900,6 +1959,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
                     auth_event_ids=auth_event_ids,
                     outlier=True,
                 )
+                context = await unpersisted_context.persist(event)
                 event.internal_metadata.out_of_band_membership = True
 
                 result_event = (
@@ -1925,6 +1985,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
     async def remote_knock(
         self,
+        requester: Requester,
         remote_room_hosts: List[str],
         room_id: str,
         user: UserID,
@@ -1947,7 +2008,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         ]
 
         if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
+            raise NoKnownServersError()
 
         return await self.federation_handler.do_knock(
             remote_room_hosts, room_id, user.to_string(), content=content
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index 221552a2a6..76e36b8a6d 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -15,8 +15,7 @@
 import logging
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
-from synapse.api.errors import SynapseError
-from synapse.handlers.room_member import RoomMemberHandler
+from synapse.handlers.room_member import NoKnownServersError, RoomMemberHandler
 from synapse.replication.http.membership import (
     ReplicationRemoteJoinRestServlet as ReplRemoteJoin,
     ReplicationRemoteKnockRestServlet as ReplRemoteKnock,
@@ -52,7 +51,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
     ) -> Tuple[str, int]:
         """Implements RoomMemberHandler._remote_join"""
         if len(remote_room_hosts) == 0:
-            raise SynapseError(404, "No known servers")
+            raise NoKnownServersError()
 
         ret = await self._remote_join_client(
             requester=requester,
@@ -114,6 +113,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
 
     async def remote_knock(
         self,
+        requester: Requester,
         remote_room_hosts: List[str],
         room_id: str,
         user: UserID,
@@ -124,9 +124,10 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         Implements RoomMemberHandler.remote_knock
         """
         ret = await self._remote_knock_client(
+            requester=requester,
             remote_room_hosts=remote_room_hosts,
             room_id=room_id,
-            user=user,
+            user_id=user.to_string(),
             content=content,
         )
         return ret["event_id"], ret["stream_id"]
diff --git a/synapse/handlers/room_summary.py b/synapse/handlers/room_summary.py
index 4472019fbc..807245160d 100644
--- a/synapse/handlers/room_summary.py
+++ b/synapse/handlers/room_summary.py
@@ -521,8 +521,8 @@ class RoomSummaryHandler:
 
         It should return true if:
 
-        * The requester is joined or can join the room (per MSC3173).
-        * The origin server has any user that is joined or can join the room.
+        * The requesting user is joined or can join the room (per MSC3173); or
+        * The origin server has any user that is joined or can join the room; or
         * The history visibility is set to world readable.
 
         Args:
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 9bbf83047d..aad4706f14 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -23,7 +23,8 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
-from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
+from synapse.events.utils import SerializeEventConfig
+from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID
 from synapse.types.state import StateFilter
 from synapse.visibility import filter_events_for_client
 
@@ -109,12 +110,12 @@ class SearchHandler:
         return historical_room_ids
 
     async def search(
-        self, user: UserID, content: JsonDict, batch: Optional[str] = None
+        self, requester: Requester, content: JsonDict, batch: Optional[str] = None
     ) -> JsonDict:
         """Performs a full text search for a user.
 
         Args:
-            user: The user performing the search.
+            requester: The user performing the search.
             content: Search parameters
             batch: The next_batch parameter. Used for pagination.
 
@@ -199,7 +200,7 @@ class SearchHandler:
             )
 
         return await self._search(
-            user,
+            requester,
             batch_group,
             batch_group_key,
             batch_token,
@@ -217,7 +218,7 @@ class SearchHandler:
 
     async def _search(
         self,
-        user: UserID,
+        requester: Requester,
         batch_group: Optional[str],
         batch_group_key: Optional[str],
         batch_token: Optional[str],
@@ -235,7 +236,7 @@ class SearchHandler:
         """Performs a full text search for a user.
 
         Args:
-            user: The user performing the search.
+            requester: The user performing the search.
             batch_group: Pagination information.
             batch_group_key: Pagination information.
             batch_token: Pagination information.
@@ -269,7 +270,7 @@ class SearchHandler:
 
         # TODO: Search through left rooms too
         rooms = await self.store.get_rooms_for_local_user_where_membership_is(
-            user.to_string(),
+            requester.user.to_string(),
             membership_list=[Membership.JOIN],
             # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
         )
@@ -303,13 +304,13 @@ class SearchHandler:
 
         if order_by == "rank":
             search_result, sender_group = await self._search_by_rank(
-                user, room_ids, search_term, keys, search_filter
+                requester.user, room_ids, search_term, keys, search_filter
             )
             # Unused return values for rank search.
             global_next_batch = None
         elif order_by == "recent":
             search_result, global_next_batch = await self._search_by_recent(
-                user,
+                requester.user,
                 room_ids,
                 search_term,
                 keys,
@@ -334,7 +335,7 @@ class SearchHandler:
             assert after_limit is not None
 
             contexts = await self._calculate_event_contexts(
-                user,
+                requester.user,
                 search_result.allowed_events,
                 before_limit,
                 after_limit,
@@ -363,27 +364,37 @@ class SearchHandler:
                 # The returned events.
                 search_result.allowed_events,
             ),
-            user.to_string(),
+            requester.user.to_string(),
         )
 
         # We're now about to serialize the events. We should not make any
         # blocking calls after this. Otherwise, the 'age' will be wrong.
 
         time_now = self.clock.time_msec()
+        serialize_options = SerializeEventConfig(requester=requester)
 
         for context in contexts.values():
             context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now, bundle_aggregations=aggregations
+                context["events_before"],
+                time_now,
+                bundle_aggregations=aggregations,
+                config=serialize_options,
             )
             context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now, bundle_aggregations=aggregations
+                context["events_after"],
+                time_now,
+                bundle_aggregations=aggregations,
+                config=serialize_options,
             )
 
         results = [
             {
                 "rank": search_result.rank_map[e.event_id],
                 "result": self._event_serializer.serialize_event(
-                    e, time_now, bundle_aggregations=aggregations
+                    e,
+                    time_now,
+                    bundle_aggregations=aggregations,
+                    config=serialize_options,
                 ),
                 "context": contexts.get(e.event_id, {}),
             }
@@ -398,7 +409,9 @@ class SearchHandler:
 
         if state_results:
             rooms_cat_res["state"] = {
-                room_id: self._event_serializer.serialize_events(state_events, time_now)
+                room_id: self._event_serializer.serialize_events(
+                    state_events, time_now, config=serialize_options
+                )
                 for room_id, state_events in state_results.items()
             }
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 3566537894..9f5b83ed54 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -269,6 +269,8 @@ class SyncHandler:
         self._state_storage_controller = self._storage_controllers.state
         self._device_handler = hs.get_device_handler()
 
+        self.should_calculate_push_rules = hs.config.push.enable_push
+
         # TODO: flush cache entries on subsequent sync request.
         #    Once we get the next /sync request (ie, one with the same access token
         #    that sets 'since' to 'next_batch'), we know that device won't need a
@@ -1224,6 +1226,10 @@ class SyncHandler:
                 continue
 
             event_with_membership_auth = events_with_membership_auth[member]
+            is_create = (
+                event_with_membership_auth.is_state()
+                and event_with_membership_auth.type == EventTypes.Create
+            )
             is_join = (
                 event_with_membership_auth.is_state()
                 and event_with_membership_auth.type == EventTypes.Member
@@ -1231,9 +1237,10 @@ class SyncHandler:
                 and event_with_membership_auth.content.get("membership")
                 == Membership.JOIN
             )
-            if not is_join:
+            if not is_create and not is_join:
                 # The event must include the desired membership as an auth event, unless
-                # it's the first join event for a given user.
+                # it's the `m.room.create` event for a room or the first join event for
+                # a given user.
                 missing_members.add(member)
             auth_event_ids.update(event_with_membership_auth.auth_event_ids())
 
@@ -1288,8 +1295,13 @@ class SyncHandler:
     async def unread_notifs_for_room_id(
         self, room_id: str, sync_config: SyncConfig
     ) -> RoomNotifCounts:
-        with Measure(self.clock, "unread_notifs_for_room_id"):
+        if not self.should_calculate_push_rules:
+            # If push rules have been universally disabled then we know we won't
+            # have any unread counts in the DB, so we may as well skip asking
+            # the DB.
+            return RoomNotifCounts.empty()
 
+        with Measure(self.clock, "unread_notifs_for_room_id"):
             return await self.store.get_unread_event_push_actions_by_room_for_user(
                 room_id,
                 sync_config.user.to_string(),
@@ -1391,6 +1403,11 @@ class SyncHandler:
                 for room_id, is_partial_state in results.items()
                 if is_partial_state
             )
+            membership_change_events = [
+                event
+                for event in membership_change_events
+                if not results.get(event.room_id, False)
+            ]
 
         # Incremental eager syncs should additionally include rooms that
         # - we are joined to
@@ -1444,9 +1461,9 @@ class SyncHandler:
 
         logger.debug("Fetching account data")
 
-        account_data_by_room = await self._generate_sync_entry_for_account_data(
-            sync_result_builder
-        )
+        # Global account data is included if it is not filtered out.
+        if not sync_config.filter_collection.blocks_all_global_account_data():
+            await self._generate_sync_entry_for_account_data(sync_result_builder)
 
         # Presence data is included if the server has it enabled and not filtered out.
         include_presence_data = bool(
@@ -1472,9 +1489,7 @@ class SyncHandler:
             (
                 newly_joined_rooms,
                 newly_left_rooms,
-            ) = await self._generate_sync_entry_for_rooms(
-                sync_result_builder, account_data_by_room
-            )
+            ) = await self._generate_sync_entry_for_rooms(sync_result_builder)
 
             # Work out which users have joined or left rooms we're in. We use this
             # to build the presence and device_list parts of the sync response in
@@ -1521,7 +1536,7 @@ class SyncHandler:
             one_time_keys_count = await self.store.count_e2e_one_time_keys(
                 user_id, device_id
             )
-            unused_fallback_key_types = (
+            unused_fallback_key_types = list(
                 await self.store.get_e2e_unused_fallback_key_types(user_id, device_id)
             )
 
@@ -1717,35 +1732,29 @@ class SyncHandler:
 
     async def _generate_sync_entry_for_account_data(
         self, sync_result_builder: "SyncResultBuilder"
-    ) -> Dict[str, Dict[str, JsonDict]]:
-        """Generates the account data portion of the sync response.
+    ) -> None:
+        """Generates the global account data portion of the sync response.
 
         Account data (called "Client Config" in the spec) can be set either globally
         or for a specific room. Account data consists of a list of events which
         accumulate state, much like a room.
 
-        This function retrieves global and per-room account data. The former is written
-        to the given `sync_result_builder`. The latter is returned directly, to be
-        later written to the `sync_result_builder` on a room-by-room basis.
+        This function retrieves global account data and writes it to the given
+        `sync_result_builder`. See `_generate_sync_entry_for_rooms` for handling
+         of per-room account data.
 
         Args:
             sync_result_builder
-
-        Returns:
-            A dictionary whose keys (room ids) map to the per room account data for that
-            room.
         """
         sync_config = sync_result_builder.sync_config
         user_id = sync_result_builder.sync_config.user.to_string()
         since_token = sync_result_builder.since_token
 
         if since_token and not sync_result_builder.full_state:
-            # TODO Do not fetch room account data if it will be unused.
-            (
-                global_account_data,
-                account_data_by_room,
-            ) = await self.store.get_updated_account_data_for_user(
-                user_id, since_token.account_data_key
+            global_account_data = (
+                await self.store.get_updated_global_account_data_for_user(
+                    user_id, since_token.account_data_key
+                )
             )
 
             push_rules_changed = await self.store.have_push_rules_changed_for_user(
@@ -1753,31 +1762,31 @@ class SyncHandler:
             )
 
             if push_rules_changed:
+                global_account_data = dict(global_account_data)
                 global_account_data["m.push_rules"] = await self.push_rules_for_user(
                     sync_config.user
                 )
         else:
-            # TODO Do not fetch room account data if it will be unused.
-            (
-                global_account_data,
-                account_data_by_room,
-            ) = await self.store.get_account_data_for_user(sync_config.user.to_string())
+            all_global_account_data = await self.store.get_global_account_data_for_user(
+                user_id
+            )
 
+            global_account_data = dict(all_global_account_data)
             global_account_data["m.push_rules"] = await self.push_rules_for_user(
                 sync_config.user
             )
 
-        account_data_for_user = await sync_config.filter_collection.filter_account_data(
-            [
-                {"type": account_data_type, "content": content}
-                for account_data_type, content in global_account_data.items()
-            ]
+        account_data_for_user = (
+            await sync_config.filter_collection.filter_global_account_data(
+                [
+                    {"type": account_data_type, "content": content}
+                    for account_data_type, content in global_account_data.items()
+                ]
+            )
         )
 
         sync_result_builder.account_data = account_data_for_user
 
-        return account_data_by_room
-
     async def _generate_sync_entry_for_presence(
         self,
         sync_result_builder: "SyncResultBuilder",
@@ -1837,9 +1846,7 @@ class SyncHandler:
         sync_result_builder.presence = presence
 
     async def _generate_sync_entry_for_rooms(
-        self,
-        sync_result_builder: "SyncResultBuilder",
-        account_data_by_room: Dict[str, Dict[str, JsonDict]],
+        self, sync_result_builder: "SyncResultBuilder"
     ) -> Tuple[AbstractSet[str], AbstractSet[str]]:
         """Generates the rooms portion of the sync response. Populates the
         `sync_result_builder` with the result.
@@ -1850,7 +1857,6 @@ class SyncHandler:
 
         Args:
             sync_result_builder
-            account_data_by_room: Dictionary of per room account data
 
         Returns:
             Returns a 2-tuple describing rooms the user has joined or left.
@@ -1863,9 +1869,30 @@ class SyncHandler:
         since_token = sync_result_builder.since_token
         user_id = sync_result_builder.sync_config.user.to_string()
 
+        blocks_all_rooms = (
+            sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+        )
+
+        # 0. Start by fetching room account data (if required).
+        if (
+            blocks_all_rooms
+            or sync_result_builder.sync_config.filter_collection.blocks_all_room_account_data()
+        ):
+            account_data_by_room: Mapping[str, Mapping[str, JsonDict]] = {}
+        elif since_token and not sync_result_builder.full_state:
+            account_data_by_room = (
+                await self.store.get_updated_room_account_data_for_user(
+                    user_id, since_token.account_data_key
+                )
+            )
+        else:
+            account_data_by_room = await self.store.get_room_account_data_for_user(
+                user_id
+            )
+
         # 1. Start by fetching all ephemeral events in rooms we've joined (if required).
         block_all_room_ephemeral = (
-            sync_result_builder.sync_config.filter_collection.blocks_all_rooms()
+            blocks_all_rooms
             or sync_result_builder.sync_config.filter_collection.blocks_all_room_ephemeral()
         )
         if block_all_room_ephemeral:
@@ -2291,8 +2318,8 @@ class SyncHandler:
         sync_result_builder: "SyncResultBuilder",
         room_builder: "RoomSyncResultBuilder",
         ephemeral: List[JsonDict],
-        tags: Optional[Dict[str, Dict[str, Any]]],
-        account_data: Dict[str, JsonDict],
+        tags: Optional[Mapping[str, Mapping[str, Any]]],
+        account_data: Mapping[str, JsonDict],
         always_include: bool = False,
     ) -> None:
         """Populates the `joined` and `archived` section of `sync_result_builder`
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 332edcca24..78a75bfed6 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -13,7 +13,8 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type
 
 from twisted.web.client import PartialDownloadError
 
@@ -27,19 +28,28 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class UserInteractiveAuthChecker:
+class UserInteractiveAuthChecker(ABC):
     """Abstract base class for an interactive auth checker"""
 
-    def __init__(self, hs: "HomeServer"):
+    # This should really be an "abstract class property", i.e. it should
+    # be an error to instantiate a subclass that doesn't specify an AUTH_TYPE.
+    # But calling this a `ClassVar` is simpler than a decorator stack of
+    # @property @abstractmethod and @classmethod (if that's even the right order).
+    AUTH_TYPE: ClassVar[str]
+
+    def __init__(self, hs: "HomeServer"):  # noqa: B027
         pass
 
+    @abstractmethod
     def is_enabled(self) -> bool:
         """Check if the configuration of the homeserver allows this checker to work
 
         Returns:
             True if this login type is enabled.
         """
+        raise NotImplementedError()
 
+    @abstractmethod
     async def check_auth(self, authdict: dict, clientip: str) -> Any:
         """Given the authentication dict from the client, attempt to check this step
 
@@ -304,7 +314,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
             )
 
 
-INTERACTIVE_AUTH_CHECKERS = [
+INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [
     DummyAuthChecker,
     TermsAuthChecker,
     RecaptchaAuthChecker,