summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorH. Shay <hillerys@element.io>2023-03-06 12:48:59 -0800
committerH. Shay <hillerys@element.io>2023-03-06 12:48:59 -0800
commit7fc487421f19d8841c36481701655bc31bfcd79f (patch)
treee5c4e634fa957a561675156c868c561c36a71d66 /synapse/handlers
parentadd clearer return values (diff)
parentPass the requester during event serialization. (#15174) (diff)
downloadsynapse-7fc487421f19d8841c36481701655bc31bfcd79f.tar.xz
Merge branch 'develop' into shay/rework_module
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/account_data.py7
-rw-r--r--synapse/handlers/admin.py87
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/auth.py53
-rw-r--r--synapse/handlers/deactivate_account.py20
-rw-r--r--synapse/handlers/directory.py8
-rw-r--r--synapse/handlers/e2e_keys.py14
-rw-r--r--synapse/handlers/e2e_room_keys.py1
-rw-r--r--synapse/handlers/event_auth.py1
-rw-r--r--synapse/handlers/events.py20
-rw-r--r--synapse/handlers/federation.py16
-rw-r--r--synapse/handlers/initial_sync.py52
-rw-r--r--synapse/handlers/message.py30
-rw-r--r--synapse/handlers/pagination.py4
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/handlers/relations.py88
-rw-r--r--synapse/handlers/room.py84
-rw-r--r--synapse/handlers/room_batch.py10
-rw-r--r--synapse/handlers/room_member.py11
-rw-r--r--synapse/handlers/room_member_worker.py4
-rw-r--r--synapse/handlers/search.py43
-rw-r--r--synapse/handlers/sync.py1
-rw-r--r--synapse/handlers/ui_auth/checkers.py18
24 files changed, 336 insertions, 244 deletions
diff --git a/synapse/handlers/account_data.py b/synapse/handlers/account_data.py
index 797de46dbc..7e01c18c6c 100644
--- a/synapse/handlers/account_data.py
+++ b/synapse/handlers/account_data.py
@@ -155,9 +155,6 @@ class AccountDataHandler:
             max_stream_id = await self._store.remove_account_data_for_room(
                 user_id, room_id, account_data_type
             )
-            if max_stream_id is None:
-                # The referenced account data did not exist, so no delete occurred.
-                return None
 
             self._notifier.on_new_event(
                 StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
@@ -230,9 +227,6 @@ class AccountDataHandler:
             max_stream_id = await self._store.remove_account_data_for_user(
                 user_id, account_data_type
             )
-            if max_stream_id is None:
-                # The referenced account data did not exist, so no delete occurred.
-                return None
 
             self._notifier.on_new_event(
                 StreamKeyType.ACCOUNT_DATA, max_stream_id, users=[user_id]
@@ -248,7 +242,6 @@ class AccountDataHandler:
                 instance_name=random.choice(self._account_data_writers),
                 user_id=user_id,
                 account_data_type=account_data_type,
-                content={},
             )
             return response["max_stream_id"]
 
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index b03c214b14..b06f25b03c 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -14,7 +14,7 @@
 
 import abc
 import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
+from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set
 
 from synapse.api.constants import Direction, Membership
 from synapse.events import EventBase
@@ -29,7 +29,7 @@ logger = logging.getLogger(__name__)
 
 class AdminHandler:
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastores().main
+        self._store = hs.get_datastores().main
         self._device_handler = hs.get_device_handler()
         self._storage_controllers = hs.get_storage_controllers()
         self._state_storage_controller = self._storage_controllers.state
@@ -38,7 +38,7 @@ class AdminHandler:
     async def get_whois(self, user: UserID) -> JsonDict:
         connections = []
 
-        sessions = await self.store.get_user_ip_and_agents(user)
+        sessions = await self._store.get_user_ip_and_agents(user)
         for session in sessions:
             connections.append(
                 {
@@ -57,7 +57,7 @@ class AdminHandler:
 
     async def get_user(self, user: UserID) -> Optional[JsonDict]:
         """Function to get user details"""
-        user_info_dict = await self.store.get_user_by_id(user.to_string())
+        user_info_dict = await self._store.get_user_by_id(user.to_string())
         if user_info_dict is None:
             return None
 
@@ -89,11 +89,11 @@ class AdminHandler:
         }
 
         # Add additional user metadata
-        profile = await self.store.get_profileinfo(user.localpart)
-        threepids = await self.store.user_get_threepids(user.to_string())
+        profile = await self._store.get_profileinfo(user.localpart)
+        threepids = await self._store.user_get_threepids(user.to_string())
         external_ids = [
             ({"auth_provider": auth_provider, "external_id": external_id})
-            for auth_provider, external_id in await self.store.get_external_ids_by_user(
+            for auth_provider, external_id in await self._store.get_external_ids_by_user(
                 user.to_string()
             )
         ]
@@ -101,7 +101,7 @@ class AdminHandler:
         user_info_dict["avatar_url"] = profile.avatar_url
         user_info_dict["threepids"] = threepids
         user_info_dict["external_ids"] = external_ids
-        user_info_dict["erased"] = await self.store.is_user_erased(user.to_string())
+        user_info_dict["erased"] = await self._store.is_user_erased(user.to_string())
 
         return user_info_dict
 
@@ -117,7 +117,7 @@ class AdminHandler:
             The returned value is that returned by `writer.finished()`.
         """
         # Get all rooms the user is in or has been in
-        rooms = await self.store.get_rooms_for_local_user_where_membership_is(
+        rooms = await self._store.get_rooms_for_local_user_where_membership_is(
             user_id,
             membership_list=(
                 Membership.JOIN,
@@ -131,7 +131,7 @@ class AdminHandler:
         # We only try and fetch events for rooms the user has been in. If
         # they've been e.g. invited to a room without joining then we handle
         # those separately.
-        rooms_user_has_been_in = await self.store.get_rooms_user_has_been_in(user_id)
+        rooms_user_has_been_in = await self._store.get_rooms_user_has_been_in(user_id)
 
         for index, room in enumerate(rooms):
             room_id = room.room_id
@@ -140,7 +140,7 @@ class AdminHandler:
                 "[%s] Handling room %s, %d/%d", user_id, room_id, index + 1, len(rooms)
             )
 
-            forgotten = await self.store.did_forget(user_id, room_id)
+            forgotten = await self._store.did_forget(user_id, room_id)
             if forgotten:
                 logger.info("[%s] User forgot room %d, ignoring", user_id, room_id)
                 continue
@@ -152,14 +152,14 @@ class AdminHandler:
 
                 if room.membership == Membership.INVITE:
                     event_id = room.event_id
-                    invite = await self.store.get_event(event_id, allow_none=True)
+                    invite = await self._store.get_event(event_id, allow_none=True)
                     if invite:
                         invited_state = invite.unsigned["invite_room_state"]
                         writer.write_invite(room_id, invite, invited_state)
 
                 if room.membership == Membership.KNOCK:
                     event_id = room.event_id
-                    knock = await self.store.get_event(event_id, allow_none=True)
+                    knock = await self._store.get_event(event_id, allow_none=True)
                     if knock:
                         knock_state = knock.unsigned["knock_room_state"]
                         writer.write_knock(room_id, knock, knock_state)
@@ -170,7 +170,7 @@ class AdminHandler:
             # were joined. We estimate that point by looking at the
             # stream_ordering of the last membership if it wasn't a join.
             if room.membership == Membership.JOIN:
-                stream_ordering = self.store.get_room_max_stream_ordering()
+                stream_ordering = self._store.get_room_max_stream_ordering()
             else:
                 stream_ordering = room.stream_ordering
 
@@ -197,7 +197,7 @@ class AdminHandler:
             # events that we have and then filtering, this isn't the most
             # efficient method perhaps but it does guarantee we get everything.
             while True:
-                events, _ = await self.store.paginate_room_events(
+                events, _ = await self._store.paginate_room_events(
                     room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
                 )
                 if not events:
@@ -252,16 +252,49 @@ class AdminHandler:
         profile = await self.get_user(UserID.from_string(user_id))
         if profile is not None:
             writer.write_profile(profile)
+            logger.info("[%s] Written profile", user_id)
 
         # Get all devices the user has
         devices = await self._device_handler.get_devices_by_user(user_id)
         writer.write_devices(devices)
+        logger.info("[%s] Written %s devices", user_id, len(devices))
 
         # Get all connections the user has
         connections = await self.get_whois(UserID.from_string(user_id))
         writer.write_connections(
             connections["devices"][""]["sessions"][0]["connections"]
         )
+        logger.info("[%s] Written %s connections", user_id, len(connections))
+
+        # Get all account data the user has global and in rooms
+        global_data = await self._store.get_global_account_data_for_user(user_id)
+        by_room_data = await self._store.get_room_account_data_for_user(user_id)
+        writer.write_account_data("global", global_data)
+        for room_id in by_room_data:
+            writer.write_account_data(room_id, by_room_data[room_id])
+        logger.info(
+            "[%s] Written account data for %s rooms", user_id, len(by_room_data)
+        )
+
+        # Get all media ids the user has
+        limit = 100
+        start = 0
+        while True:
+            media_ids, total = await self._store.get_local_media_by_user_paginate(
+                start, limit, user_id
+            )
+            for media in media_ids:
+                writer.write_media_id(media["media_id"], media)
+
+            logger.info(
+                "[%s] Written %d media_ids of %s",
+                user_id,
+                (start + len(media_ids)),
+                total,
+            )
+            if (start + limit) >= total:
+                break
+            start += limit
 
         return writer.finished()
 
@@ -341,6 +374,30 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
+    def write_account_data(
+        self, file_name: str, account_data: Mapping[str, JsonDict]
+    ) -> None:
+        """Write the account data of a user.
+
+        Args:
+            file_name: file name to write data
+            account_data: mapping of global or room account_data
+        """
+        raise NotImplementedError()
+
+    @abc.abstractmethod
+    def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
+        """Write the media's metadata of a user.
+        Exports only the metadata, as this can be fetched from the database via
+        read only. In order to access the files, a connection to the correct
+        media repository would be required.
+
+        Args:
+            media_id: ID of the media.
+            media_metadata: Metadata of one media file.
+        """
+
+    @abc.abstractmethod
     def finished(self) -> Any:
         """Called when all data has successfully been exported and written.
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5d1d21cdc8..ec3ab968e9 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -737,7 +737,7 @@ class ApplicationServicesHandler:
         )
 
         ret = []
-        for (success, result) in results:
+        for success, result in results:
             if success:
                 ret.extend(result)
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 57a6854b1e..308e38edea 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -201,7 +201,7 @@ class AuthHandler:
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
             inst = auth_checker_class(hs)
             if inst.is_enabled():
-                self.checkers[inst.AUTH_TYPE] = inst  # type: ignore
+                self.checkers[inst.AUTH_TYPE] = inst
 
         self.bcrypt_rounds = hs.config.registration.bcrypt_rounds
 
@@ -815,7 +815,6 @@ class AuthHandler:
         now_ms = self._clock.time_msec()
 
         if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
-
             raise SynapseError(
                 HTTPStatus.FORBIDDEN,
                 "The supplied refresh token has expired",
@@ -1543,6 +1542,17 @@ class AuthHandler:
     async def add_threepid(
         self, user_id: str, medium: str, address: str, validated_at: int
     ) -> None:
+        """
+        Adds an association between a user's Matrix ID and a third-party ID (email,
+        phone number).
+
+        Args:
+            user_id: The ID of the user to associate.
+            medium: The medium of the third-party ID (email, msisdn).
+            address: The address of the third-party ID (i.e. an email address).
+            validated_at: The timestamp in ms of when the validation that the user owns
+                this third-party ID occurred.
+        """
         # check if medium has a valid value
         if medium not in ["email", "msisdn"]:
             raise SynapseError(
@@ -1567,42 +1577,44 @@ class AuthHandler:
             user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
         )
 
+        # Inform Synapse modules that a 3PID association has been created.
+        await self._third_party_rules.on_add_user_third_party_identifier(
+            user_id, medium, address
+        )
+
+        # Deprecated method for informing Synapse modules that a 3PID association
+        # has successfully been created.
         await self._third_party_rules.on_threepid_bind(user_id, medium, address)
 
-    async def delete_threepid(
-        self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
-    ) -> bool:
-        """Attempts to unbind the 3pid on the identity servers and deletes it
-        from the local database.
+    async def delete_local_threepid(
+        self, user_id: str, medium: str, address: str
+    ) -> None:
+        """Deletes an association between a third-party ID and a user ID from the local
+        database. This method does not unbind the association from any identity servers.
+
+        If `medium` is 'email' and a pusher is associated with this third-party ID, the
+        pusher will also be deleted.
 
         Args:
             user_id: ID of user to remove the 3pid from.
             medium: The medium of the 3pid being removed: "email" or "msisdn".
             address: The 3pid address to remove.
-            id_server: Use the given identity server when unbinding
-                any threepids. If None then will attempt to unbind using the
-                identity server specified when binding (if known).
-
-        Returns:
-            Returns True if successfully unbound the 3pid on
-            the identity server, False if identity server doesn't support the
-            unbind API.
         """
-
         # 'Canonicalise' email addresses as per above
         if medium == "email":
             address = canonicalise_email(address)
 
-        result = await self.hs.get_identity_handler().try_unbind_threepid(
-            user_id, medium, address, id_server
+        await self.store.user_delete_threepid(user_id, medium, address)
+
+        # Inform Synapse modules that a 3PID association has been deleted.
+        await self._third_party_rules.on_remove_user_third_party_identifier(
+            user_id, medium, address
         )
 
-        await self.store.user_delete_threepid(user_id, medium, address)
         if medium == "email":
             await self.store.delete_pusher_by_app_id_pushkey_user_id(
                 app_id="m.email", pushkey=address, user_id=user_id
             )
-        return result
 
     async def hash(self, password: str) -> str:
         """Computes a secure hash of password.
@@ -2259,7 +2271,6 @@ class PasswordAuthProvider:
     async def on_logged_out(
         self, user_id: str, device_id: Optional[str], access_token: str
     ) -> None:
-
         # call all of the on_logged_out callbacks
         for callback in self.on_logged_out_callbacks:
             try:
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index d24f649382..d31263c717 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -100,26 +100,28 @@ class DeactivateAccountHandler:
         # unbinding
         identity_server_supports_unbinding = True
 
-        # Retrieve the 3PIDs this user has bound to an identity server
-        threepids = await self.store.user_get_bound_threepids(user_id)
-
-        for threepid in threepids:
+        # Attempt to unbind any known bound threepids to this account from identity
+        # server(s).
+        bound_threepids = await self.store.user_get_bound_threepids(user_id)
+        for threepid in bound_threepids:
             try:
                 result = await self._identity_handler.try_unbind_threepid(
                     user_id, threepid["medium"], threepid["address"], id_server
                 )
-                identity_server_supports_unbinding &= result
             except Exception:
                 # Do we want this to be a fatal error or should we carry on?
                 logger.exception("Failed to remove threepid from ID server")
                 raise SynapseError(400, "Failed to remove threepid from ID server")
-            await self.store.user_delete_threepid(
+
+            identity_server_supports_unbinding &= result
+
+        # Remove any local threepid associations for this account.
+        local_threepids = await self.store.user_get_threepids(user_id)
+        for threepid in local_threepids:
+            await self._auth_handler.delete_local_threepid(
                 user_id, threepid["medium"], threepid["address"]
             )
 
-        # Remove all 3PIDs this user has bound to the homeserver
-        await self.store.user_delete_threepids(user_id)
-
         # delete any devices belonging to the user, which will also
         # delete corresponding access tokens.
         await self._device_handler.delete_all_devices_for_user(user_id)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a5798e9483..1fb23cc9bf 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -497,9 +497,11 @@ class DirectoryHandler:
                 raise SynapseError(403, "Not allowed to publish room")
 
             # Check if publishing is blocked by a third party module
-            allowed_by_third_party_rules = await (
-                self.third_party_event_rules.check_visibility_can_be_modified(
-                    room_id, visibility
+            allowed_by_third_party_rules = (
+                await (
+                    self.third_party_event_rules.check_visibility_can_be_modified(
+                        room_id, visibility
+                    )
                 )
             )
             if not allowed_by_third_party_rules:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 43cbece21b..4e9c8d8db0 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -1301,6 +1301,20 @@ class E2eKeysHandler:
 
         return desired_key_data
 
+    async def is_cross_signing_set_up_for_user(self, user_id: str) -> bool:
+        """Checks if the user has cross-signing set up
+
+        Args:
+            user_id: The user to check
+
+        Returns:
+            True if the user has cross-signing set up, False otherwise
+        """
+        existing_master_key = await self.store.get_e2e_cross_signing_key(
+            user_id, "master"
+        )
+        return existing_master_key is not None
+
 
 def _check_cross_signing_key(
     key: JsonDict, user_id: str, key_type: str, signing_key: Optional[VerifyKey] = None
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 83f53ceb88..50317ec753 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -188,7 +188,6 @@ class E2eRoomKeysHandler:
 
         # XXX: perhaps we should use a finer grained lock here?
         async with self._upload_linearizer.queue(user_id):
-
             # Check that the version we're trying to upload is the current version
             try:
                 version_info = await self.store.get_e2e_room_keys_version_info(user_id)
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 46dd63c3f0..c508861b6a 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -236,7 +236,6 @@ class EventAuthHandler:
         # in any of them.
         allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
         if not await self.is_user_in_rooms(allowed_rooms, user_id):
-
             # If this is a remote request, the user might be in an allowed room
             # that we do not know about.
             if get_domain_from_id(user_id) != self._server_name:
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index 949b69cb41..68c07f0265 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -23,7 +23,7 @@ from synapse.events.utils import SerializeEventConfig
 from synapse.handlers.presence import format_user_presence_state
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.streams.config import PaginationConfig
-from synapse.types import JsonDict, UserID
+from synapse.types import JsonDict, Requester, UserID
 from synapse.visibility import filter_events_for_client
 
 if TYPE_CHECKING:
@@ -46,13 +46,12 @@ class EventStreamHandler:
 
     async def get_stream(
         self,
-        auth_user_id: str,
+        requester: Requester,
         pagin_config: PaginationConfig,
         timeout: int = 0,
         as_client_event: bool = True,
         affect_presence: bool = True,
         room_id: Optional[str] = None,
-        is_guest: bool = False,
     ) -> JsonDict:
         """Fetches the events stream for a given user."""
 
@@ -62,13 +61,12 @@ class EventStreamHandler:
                 raise SynapseError(403, "This room has been blocked on this server")
 
         # send any outstanding server notices to the user.
-        await self._server_notices_sender.on_user_syncing(auth_user_id)
+        await self._server_notices_sender.on_user_syncing(requester.user.to_string())
 
-        auth_user = UserID.from_string(auth_user_id)
         presence_handler = self.hs.get_presence_handler()
 
         context = await presence_handler.user_syncing(
-            auth_user_id,
+            requester.user.to_string(),
             affect_presence=affect_presence,
             presence_state=PresenceState.ONLINE,
         )
@@ -82,10 +80,10 @@ class EventStreamHandler:
                 timeout = random.randint(int(timeout * 0.9), int(timeout * 1.1))
 
             stream_result = await self.notifier.get_events_for(
-                auth_user,
+                requester.user,
                 pagin_config,
                 timeout,
-                is_guest=is_guest,
+                is_guest=requester.is_guest,
                 explicit_room_id=room_id,
             )
             events = stream_result.events
@@ -102,7 +100,7 @@ class EventStreamHandler:
                     if event.membership != Membership.JOIN:
                         continue
                     # Send down presence.
-                    if event.state_key == auth_user_id:
+                    if event.state_key == requester.user.to_string():
                         # Send down presence for everyone in the room.
                         users: Iterable[str] = await self.store.get_users_in_room(
                             event.room_id
@@ -124,7 +122,9 @@ class EventStreamHandler:
             chunks = self._event_serializer.serialize_events(
                 events,
                 time_now,
-                config=SerializeEventConfig(as_client_event=as_client_event),
+                config=SerializeEventConfig(
+                    as_client_event=as_client_event, requester=requester
+                ),
             )
 
             chunk = {
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 5371336529..50f8041f17 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -952,7 +952,20 @@ class FederationHandler:
         #
         # Note that this requires the /send_join request to come back to the
         # same server.
+        prev_event_ids = None
         if room_version.msc3083_join_rules:
+            # Note that the room's state can change out from under us and render our
+            # nice join rules-conformant event non-conformant by the time we build the
+            # event. When this happens, our validation at the end fails and we respond
+            # to the requesting server with a 403, which is misleading — it indicates
+            # that the user is not allowed to join the room and the joining server
+            # should not bother retrying via this homeserver or any others, when
+            # in fact we've just messed up with building the event.
+            #
+            # To reduce the likelihood of this race, we capture the forward extremities
+            # of the room (prev_event_ids) just before fetching the current state, and
+            # hope that the state we fetch corresponds to the prev events we chose.
+            prev_event_ids = await self.store.get_prev_events_for_room(room_id)
             state_ids = await self._state_storage_controller.get_current_state_ids(
                 room_id
             )
@@ -995,7 +1008,8 @@ class FederationHandler:
                 unpersisted_context,
                 _,
             ) = await self.event_creation_handler.create_new_client_event(
-                builder=builder
+                builder=builder,
+                prev_event_ids=prev_event_ids,
             )
         except SynapseError as e:
             logger.warning("Failed to create join to %s because %s", room_id, e)
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 1a29abde98..b3be7a86f0 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -124,7 +124,6 @@ class InitialSyncHandler:
         as_client_event: bool = True,
         include_archived: bool = False,
     ) -> JsonDict:
-
         memberships = [Membership.INVITE, Membership.JOIN]
         if include_archived:
             memberships.append(Membership.LEAVE)
@@ -319,11 +318,9 @@ class InitialSyncHandler:
         )
         is_peeking = member_event_id is None
 
-        user_id = requester.user.to_string()
-
         if membership == Membership.JOIN:
             result = await self._room_initial_sync_joined(
-                user_id, room_id, pagin_config, membership, is_peeking
+                requester, room_id, pagin_config, membership, is_peeking
             )
         elif membership == Membership.LEAVE:
             # The member_event_id will always be available if membership is set
@@ -331,10 +328,16 @@ class InitialSyncHandler:
             assert member_event_id
 
             result = await self._room_initial_sync_parted(
-                user_id, room_id, pagin_config, membership, member_event_id, is_peeking
+                requester,
+                room_id,
+                pagin_config,
+                membership,
+                member_event_id,
+                is_peeking,
             )
 
         account_data_events = []
+        user_id = requester.user.to_string()
         tags = await self.store.get_tags_for_room(user_id, room_id)
         if tags:
             account_data_events.append(
@@ -351,7 +354,7 @@ class InitialSyncHandler:
 
     async def _room_initial_sync_parted(
         self,
-        user_id: str,
+        requester: Requester,
         room_id: str,
         pagin_config: PaginationConfig,
         membership: str,
@@ -370,13 +373,17 @@ class InitialSyncHandler:
         )
 
         messages = await filter_events_for_client(
-            self._storage_controllers, user_id, messages, is_peeking=is_peeking
+            self._storage_controllers,
+            requester.user.to_string(),
+            messages,
+            is_peeking=is_peeking,
         )
 
         start_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, token)
         end_token = StreamToken.START.copy_and_replace(StreamKeyType.ROOM, stream_token)
 
         time_now = self.clock.time_msec()
+        serialize_options = SerializeEventConfig(requester=requester)
 
         return {
             "membership": membership,
@@ -384,14 +391,18 @@ class InitialSyncHandler:
             "messages": {
                 "chunk": (
                     # Don't bundle aggregations as this is a deprecated API.
-                    self._event_serializer.serialize_events(messages, time_now)
+                    self._event_serializer.serialize_events(
+                        messages, time_now, config=serialize_options
+                    )
                 ),
                 "start": await start_token.to_string(self.store),
                 "end": await end_token.to_string(self.store),
             },
             "state": (
                 # Don't bundle aggregations as this is a deprecated API.
-                self._event_serializer.serialize_events(room_state.values(), time_now)
+                self._event_serializer.serialize_events(
+                    room_state.values(), time_now, config=serialize_options
+                )
             ),
             "presence": [],
             "receipts": [],
@@ -399,7 +410,7 @@ class InitialSyncHandler:
 
     async def _room_initial_sync_joined(
         self,
-        user_id: str,
+        requester: Requester,
         room_id: str,
         pagin_config: PaginationConfig,
         membership: str,
@@ -411,9 +422,12 @@ class InitialSyncHandler:
 
         # TODO: These concurrently
         time_now = self.clock.time_msec()
+        serialize_options = SerializeEventConfig(requester=requester)
         # Don't bundle aggregations as this is a deprecated API.
         state = self._event_serializer.serialize_events(
-            current_state.values(), time_now
+            current_state.values(),
+            time_now,
+            config=serialize_options,
         )
 
         now_token = self.hs.get_event_sources().get_current_token()
@@ -451,7 +465,10 @@ class InitialSyncHandler:
             if not receipts:
                 return []
 
-            return ReceiptEventSource.filter_out_private_receipts(receipts, user_id)
+            return ReceiptEventSource.filter_out_private_receipts(
+                receipts,
+                requester.user.to_string(),
+            )
 
         presence, receipts, (messages, token) = await make_deferred_yieldable(
             gather_results(
@@ -470,20 +487,23 @@ class InitialSyncHandler:
         )
 
         messages = await filter_events_for_client(
-            self._storage_controllers, user_id, messages, is_peeking=is_peeking
+            self._storage_controllers,
+            requester.user.to_string(),
+            messages,
+            is_peeking=is_peeking,
         )
 
         start_token = now_token.copy_and_replace(StreamKeyType.ROOM, token)
         end_token = now_token
 
-        time_now = self.clock.time_msec()
-
         ret = {
             "room_id": room_id,
             "messages": {
                 "chunk": (
                     # Don't bundle aggregations as this is a deprecated API.
-                    self._event_serializer.serialize_events(messages, time_now)
+                    self._event_serializer.serialize_events(
+                        messages, time_now, config=serialize_options
+                    )
                 ),
                 "start": await start_token.to_string(self.store),
                 "end": await end_token.to_string(self.store),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 77d92f1574..d283a938c0 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -16,6 +16,7 @@
 # limitations under the License.
 import logging
 import random
+from builtins import dict
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple
 
@@ -50,7 +51,7 @@ from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase, relation_from_event
 from synapse.events.builder import EventBuilder
 from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
-from synapse.events.utils import maybe_upsert_event_field
+from synapse.events.utils import SerializeEventConfig, maybe_upsert_event_field
 from synapse.events.validator import EventValidator
 from synapse.handlers.directory import DirectoryHandler
 from synapse.logging import opentracing
@@ -245,8 +246,11 @@ class MessageHandler:
                 )
                 room_state = room_state_events[membership_event_id]
 
-        now = self.clock.time_msec()
-        events = self._event_serializer.serialize_events(room_state.values(), now)
+        events = self._event_serializer.serialize_events(
+            room_state.values(),
+            self.clock.time_msec(),
+            config=SerializeEventConfig(requester=requester),
+        )
         return events
 
     async def _user_can_see_state_at_event(
@@ -574,7 +578,7 @@ class EventCreationHandler:
         state_map: Optional[StateMap[str]] = None,
         for_batch: bool = False,
         current_state_group: Optional[int] = None,
-    ) -> Tuple[EventBase, EventContext, Optional[dict]]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase, Optional[dict]]:
         """
         Given a dict from a client, create a new event. If bool for_batch is true, will
         create an event using the prev_event_ids, and will create an event context for
@@ -723,8 +727,6 @@ class EventCreationHandler:
             current_state_group=current_state_group,
         )
 
-        context = await unpersisted_context.persist(event)
-
         # In an ideal world we wouldn't need the second part of this condition. However,
         # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
         # behaviour. Another reason is that this code is also evaluated each time a new
@@ -741,7 +743,7 @@ class EventCreationHandler:
                 assert state_map is not None
                 prev_event_id = state_map.get((EventTypes.Member, event.sender))
             else:
-                prev_state_ids = await context.get_prev_state_ids(
+                prev_state_ids = await unpersisted_context.get_prev_state_ids(
                     StateFilter.from_types([(EventTypes.Member, None)])
                 )
                 prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
@@ -766,8 +768,7 @@ class EventCreationHandler:
                 )
 
         self.validator.validate_new(event, self.config)
-
-        return event, context, new_event
+        return event, unpersisted_context, new_event
 
     async def _is_exempt_from_privacy_policy(
         self, builder: EventBuilder, requester: Requester
@@ -1007,7 +1008,11 @@ class EventCreationHandler:
         max_retries = 5
         for i in range(max_retries):
             try:
-                event, context, third_party_event_dict = await self.create_event(
+                (
+                    event,
+                    unpersisted_context,
+                    third_party_event_dict,
+                ) = await self.create_event(
                     requester,
                     event_dict,
                     txn_id=txn_id,
@@ -1018,6 +1023,7 @@ class EventCreationHandler:
                     historical=historical,
                     depth=depth,
                 )
+                context = await unpersisted_context.persist(event)
 
                 assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
                     event.sender,
@@ -1209,7 +1215,6 @@ class EventCreationHandler:
         if for_batch:
             assert prev_event_ids is not None
             assert state_map is not None
-            assert current_state_group is not None
             auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
             event = await builder.build(
                 prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
@@ -2067,7 +2072,7 @@ class EventCreationHandler:
                 max_retries = 5
                 for i in range(max_retries):
                     try:
-                        event, context, _ = await self.create_event(
+                        event, unpersisted_context, _ = await self.create_event(
                             requester,
                             {
                                 "type": EventTypes.Dummy,
@@ -2076,6 +2081,7 @@ class EventCreationHandler:
                                 "sender": user_id,
                             },
                         )
+                        context = await unpersisted_context.persist(event)
 
                         event.internal_metadata.proactively_send = False
 
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index ceefa16b49..8c79c055ba 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -579,7 +579,9 @@ class PaginationHandler:
 
         time_now = self.clock.time_msec()
 
-        serialize_options = SerializeEventConfig(as_client_event=as_client_event)
+        serialize_options = SerializeEventConfig(
+            as_client_event=as_client_event, requester=requester
+        )
 
         chunk = {
             "chunk": (
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 87af31aa27..4ad2233573 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -777,7 +777,6 @@ class PresenceHandler(BasePresenceHandler):
         )
 
         if self.unpersisted_users_changes:
-
             await self.store.update_presence(
                 [
                     self.user_to_current_state[user_id]
@@ -823,7 +822,6 @@ class PresenceHandler(BasePresenceHandler):
         now = self.clock.time_msec()
 
         with Measure(self.clock, "presence_update_states"):
-
             # NOTE: We purposefully don't await between now and when we've
             # calculated what we want to do with the new states, to avoid races.
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c611efb760..e4e506e62c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -476,7 +476,7 @@ class RegistrationHandler:
                     # create room expects the localpart of the room alias
                     config["room_alias_name"] = room_alias.localpart
 
-                    info, _ = await room_creation_handler.create_room(
+                    room_id, _, _ = await room_creation_handler.create_room(
                         fake_requester,
                         config=config,
                         ratelimit=False,
@@ -490,7 +490,7 @@ class RegistrationHandler:
                                 user_id, authenticated_entity=self._server_name
                             ),
                             target=UserID.from_string(user_id),
-                            room_id=info["room_id"],
+                            room_id=room_id,
                             # Since it was just created, there are no remote hosts.
                             remote_room_hosts=[],
                             action="join",
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 0fb15391e0..1d09fdf135 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -20,6 +20,7 @@ import attr
 from synapse.api.constants import Direction, EventTypes, RelationTypes
 from synapse.api.errors import SynapseError
 from synapse.events import EventBase, relation_from_event
+from synapse.events.utils import SerializeEventConfig
 from synapse.logging.context import make_deferred_yieldable, run_in_background
 from synapse.logging.opentracing import trace
 from synapse.storage.databases.main.relations import ThreadsNextBatch, _RelatedEvent
@@ -60,13 +61,12 @@ class BundledAggregations:
     Some values require additional processing during serialization.
     """
 
-    annotations: Optional[JsonDict] = None
     references: Optional[JsonDict] = None
     replace: Optional[EventBase] = None
     thread: Optional[_ThreadAggregation] = None
 
     def __bool__(self) -> bool:
-        return bool(self.annotations or self.references or self.replace or self.thread)
+        return bool(self.references or self.replace or self.thread)
 
 
 class RelationsHandler:
@@ -152,16 +152,23 @@ class RelationsHandler:
         )
 
         now = self._clock.time_msec()
+        serialize_options = SerializeEventConfig(requester=requester)
         return_value: JsonDict = {
             "chunk": self._event_serializer.serialize_events(
-                events, now, bundle_aggregations=aggregations
+                events,
+                now,
+                bundle_aggregations=aggregations,
+                config=serialize_options,
             ),
         }
         if include_original_event:
             # Do not bundle aggregations when retrieving the original event because
             # we want the content before relations are applied to it.
             return_value["original_event"] = self._event_serializer.serialize_event(
-                event, now, bundle_aggregations=None
+                event,
+                now,
+                bundle_aggregations=None,
+                config=serialize_options,
             )
 
         if next_token:
@@ -227,67 +234,6 @@ class RelationsHandler:
                     e.msg,
                 )
 
-    async def get_annotations_for_events(
-        self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
-    ) -> Dict[str, List[JsonDict]]:
-        """Get a list of annotations to the given events, grouped by event type and
-        aggregation key, sorted by count.
-
-        This is used e.g. to get the what and how many reactions have happened
-        on an event.
-
-        Args:
-            event_ids: Fetch events that relate to these event IDs.
-            ignored_users: The users ignored by the requesting user.
-
-        Returns:
-            A map of event IDs to a list of groups of annotations that match.
-            Each entry is a dict with `type`, `key` and `count` fields.
-        """
-        # Get the base results for all users.
-        full_results = await self._main_store.get_aggregation_groups_for_events(
-            event_ids
-        )
-
-        # Avoid additional logic if there are no ignored users.
-        if not ignored_users:
-            return {
-                event_id: results
-                for event_id, results in full_results.items()
-                if results
-            }
-
-        # Then subtract off the results for any ignored users.
-        ignored_results = await self._main_store.get_aggregation_groups_for_users(
-            [event_id for event_id, results in full_results.items() if results],
-            ignored_users,
-        )
-
-        filtered_results = {}
-        for event_id, results in full_results.items():
-            # If no annotations, skip.
-            if not results:
-                continue
-
-            # If there are not ignored results for this event, copy verbatim.
-            if event_id not in ignored_results:
-                filtered_results[event_id] = results
-                continue
-
-            # Otherwise, subtract out the ignored results.
-            event_ignored_results = ignored_results[event_id]
-            for result in results:
-                key = (result["type"], result["key"])
-                if key in event_ignored_results:
-                    # Ensure to not modify the cache.
-                    result = result.copy()
-                    result["count"] -= event_ignored_results[key]
-                    if result["count"] <= 0:
-                        continue
-                filtered_results.setdefault(event_id, []).append(result)
-
-        return filtered_results
-
     async def get_references_for_events(
         self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
     ) -> Dict[str, List[_RelatedEvent]]:
@@ -531,17 +477,6 @@ class RelationsHandler:
                 # (as that is what makes it part of the thread).
                 relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
 
-        async def _fetch_annotations() -> None:
-            """Fetch any annotations (ie, reactions) to bundle with this event."""
-            annotations_by_event_id = await self.get_annotations_for_events(
-                events_by_id.keys(), ignored_users=ignored_users
-            )
-            for event_id, annotations in annotations_by_event_id.items():
-                if annotations:
-                    results.setdefault(event_id, BundledAggregations()).annotations = {
-                        "chunk": annotations
-                    }
-
         async def _fetch_references() -> None:
             """Fetch any references to bundle with this event."""
             references_by_event_id = await self.get_references_for_events(
@@ -575,7 +510,6 @@ class RelationsHandler:
         await make_deferred_yieldable(
             gather_results(
                 (
-                    run_in_background(_fetch_annotations),
                     run_in_background(_fetch_references),
                     run_in_background(_fetch_edits),
                 )
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 9fb7209549..c70afa3176 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -51,6 +51,7 @@ from synapse.api.filtering import Filter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext
 from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
@@ -211,7 +212,7 @@ class RoomCreationHandler:
                 # the required power level to send the tombstone event.
                 (
                     tombstone_event,
-                    tombstone_context,
+                    tombstone_unpersisted_context,
                     _,
                 ) = await self.event_creation_handler.create_event(
                     requester,
@@ -226,6 +227,9 @@ class RoomCreationHandler:
                         },
                     },
                 )
+                tombstone_context = await tombstone_unpersisted_context.persist(
+                    tombstone_event
+                )
                 validate_event_for_room_version(tombstone_event)
                 await self._event_auth_handler.check_auth_rules_from_context(
                     tombstone_event
@@ -691,13 +695,14 @@ class RoomCreationHandler:
         config: JsonDict,
         ratelimit: bool = True,
         creator_join_profile: Optional[JsonDict] = None,
-    ) -> Tuple[dict, int]:
+    ) -> Tuple[str, Optional[RoomAlias], int]:
         """Creates a new room.
 
         Args:
-            requester:
-                The user who requested the room creation.
-            config : A dict of configuration options.
+            requester: The user who requested the room creation.
+            config: A dict of configuration options. This will be the body of
+                a /createRoom request; see
+                https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
             ratelimit: set to False to disable the rate limiter
 
             creator_join_profile:
@@ -708,14 +713,17 @@ class RoomCreationHandler:
                 `avatar_url` and/or `displayname`.
 
         Returns:
-                First, a dict containing the keys `room_id` and, if an alias
-                was, requested, `room_alias`. Secondly, the stream_id of the
-                last persisted event.
+            A 3-tuple containing:
+                - the room ID;
+                - if requested, the room alias, otherwise None; and
+                - the `stream_id` of the last persisted event.
         Raises:
-            SynapseError if the room ID couldn't be stored, 3pid invitation config
-            validation failed, or something went horribly wrong.
-            ResourceLimitError if server is blocked to some resource being
-            exceeded
+            SynapseError:
+                if the room ID couldn't be stored, 3pid invitation config
+                validation failed, or something went horribly wrong.
+            ResourceLimitError:
+                if server is blocked to some resource being
+                exceeded
         """
         user_id = requester.user.to_string()
 
@@ -865,9 +873,11 @@ class RoomCreationHandler:
         )
 
         # Check whether this visibility value is blocked by a third party module
-        allowed_by_third_party_rules = await (
-            self.third_party_event_rules.check_visibility_can_be_modified(
-                room_id, visibility
+        allowed_by_third_party_rules = (
+            await (
+                self.third_party_event_rules.check_visibility_can_be_modified(
+                    room_id, visibility
+                )
             )
         )
         if not allowed_by_third_party_rules:
@@ -1025,11 +1035,6 @@ class RoomCreationHandler:
             last_sent_event_id = member_event_id
             depth += 1
 
-        result = {"room_id": room_id}
-
-        if room_alias:
-            result["room_alias"] = room_alias.to_string()
-
         # Always wait for room creation to propagate before returning
         await self._replication.wait_for_stream_position(
             self.hs.config.worker.events_shard_config.get_instance(room_id),
@@ -1037,7 +1042,7 @@ class RoomCreationHandler:
             last_stream_id,
         )
 
-        return result, last_stream_id
+        return room_id, room_alias, last_stream_id
 
     async def _send_events_for_new_room(
         self,
@@ -1092,7 +1097,11 @@ class RoomCreationHandler:
             content: JsonDict,
             for_batch: bool,
             **kwargs: Any,
-        ) -> Tuple[EventBase, synapse.events.snapshot.EventContext, Optional[dict]]:
+        ) -> Tuple[
+            EventBase,
+            synapse.events.snapshot.UnpersistedEventContextBase,
+            Optional[dict],
+        ]:
             """
             Creates an event and associated event context.
             Args:
@@ -1113,7 +1122,7 @@ class RoomCreationHandler:
 
             (
                 new_event,
-                new_context,
+                new_unpersisted_context,
                 third_party_event,
             ) = await self.event_creation_handler.create_event(
                 creator,
@@ -1122,13 +1131,13 @@ class RoomCreationHandler:
                 depth=depth,
                 state_map=state_map,
                 for_batch=for_batch,
-                current_state_group=current_state_group,
             )
+
             depth += 1
             prev_event = [new_event.event_id]
             state_map[(new_event.type, new_event.state_key)] = new_event.event_id
 
-            return new_event, new_context, third_party_event
+            return new_event, new_unpersisted_context, third_party_event
 
         try:
             config = self._presets_dict[preset_config]
@@ -1138,10 +1147,10 @@ class RoomCreationHandler:
             )
 
         creation_content.update({"creator": creator_id})
-        creation_event, creation_context, _ = await create_event(
+        creation_event, unpersisted_creation_context, _ = await create_event(
             EventTypes.Create, creation_content, False
         )
-
+        creation_context = await unpersisted_creation_context.persist(creation_event)
         logger.debug("Sending %s in new room", EventTypes.Member)
         ev = await self.event_creation_handler.handle_new_client_event(
             requester=creator,
@@ -1186,7 +1195,6 @@ class RoomCreationHandler:
             power_event, power_context, power_tp_event = await create_event(
                 EventTypes.PowerLevels, pl_content, True
             )
-            current_state_group = power_context._state_group
             events_to_send.append((power_event, power_context))
             if power_tp_event:
                 third_party_events_to_append.append(power_tp_event)
@@ -1237,7 +1245,6 @@ class RoomCreationHandler:
                 power_level_content,
                 True,
             )
-            current_state_group = pl_context._state_group
             events_to_send.append((pl_event, pl_context))
             if pl_tp_event:
                 third_party_events_to_append.append(pl_tp_event)
@@ -1246,7 +1253,6 @@ class RoomCreationHandler:
             room_alias_event, room_alias_context, ra_tp_event = await create_event(
                 EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
             )
-            current_state_group = room_alias_context._state_group
             events_to_send.append((room_alias_event, room_alias_context))
             if ra_tp_event:
                 third_party_events_to_append.append(ra_tp_event)
@@ -1257,7 +1263,6 @@ class RoomCreationHandler:
                 {"join_rule": config["join_rules"]},
                 True,
             )
-            current_state_group = join_rules_context._state_group
             events_to_send.append((join_rules_event, join_rules_context))
             if jr_tp_event:
                 third_party_events_to_append.append(jr_tp_event)
@@ -1268,7 +1273,6 @@ class RoomCreationHandler:
                 {"history_visibility": config["history_visibility"]},
                 True,
             )
-            current_state_group = visibility_context._state_group
             events_to_send.append((visibility_event, visibility_context))
             if vis_tp_event:
                 third_party_events_to_append.append(vis_tp_event)
@@ -1284,7 +1288,6 @@ class RoomCreationHandler:
                     {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
                     True,
                 )
-                current_state_group = guest_access_context._state_group
                 events_to_send.append((guest_access_event, guest_access_context))
                 if ga_tp_event:
                     third_party_events_to_append.append(ga_tp_event)
@@ -1293,7 +1296,6 @@ class RoomCreationHandler:
             event, context, tp_event = await create_event(
                 etype, content, True, state_key=state_key
             )
-            current_state_group = context._state_group
             events_to_send.append((event, context))
             if tp_event:
                 third_party_events_to_append.append(tp_event)
@@ -1325,9 +1327,16 @@ class RoomCreationHandler:
             context = await unpersisted_context.persist(event)
             events_to_send.append((event, context))
 
+        datastore = self.hs.get_datastores().state
+        events_and_context = (
+            await UnpersistedEventContext.batch_persist_unpersisted_contexts(
+                events_to_send, room_id, current_state_group, datastore
+            )
+        )
+
         last_event = await self.event_creation_handler.handle_new_client_event(
             creator,
-            events_to_send,
+            events_and_context,
             ignore_shadow_ban=True,
             ratelimit=False,
         )
@@ -1867,7 +1876,7 @@ class RoomShutdownHandler:
                 new_room_user_id, authenticated_entity=requester_user_id
             )
 
-            info, stream_id = await self._room_creation_handler.create_room(
+            new_room_id, _, stream_id = await self._room_creation_handler.create_room(
                 room_creator_requester,
                 config={
                     "preset": RoomCreationPreset.PUBLIC_CHAT,
@@ -1876,7 +1885,6 @@ class RoomShutdownHandler:
                 },
                 ratelimit=False,
             )
-            new_room_id = info["room_id"]
 
             logger.info(
                 "Shutting down room %r, joining to new room: %r", room_id, new_room_id
@@ -1929,6 +1937,7 @@ class RoomShutdownHandler:
 
                 # Join users to new room
                 if new_room_user_id:
+                    assert new_room_id is not None
                     await self.room_member_handler.update_membership(
                         requester=target_requester,
                         target=target_requester.user,
@@ -1961,6 +1970,7 @@ class RoomShutdownHandler:
 
             aliases_for_room = await self.store.get_aliases_for_room(room_id)
 
+            assert new_room_id is not None
             await self.store.update_aliases_for_room(
                 room_id, new_room_id, requester_user_id
             )
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index 3d432ac295..8b5e02af17 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -327,7 +327,11 @@ class RoomBatchHandler:
             # Mark all events as historical
             event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
 
-            event, context, _ = await self.event_creation_handler.create_event(
+            (
+                event,
+                unpersisted_context,
+                _,
+            ) = await self.event_creation_handler.create_event(
                 await self.create_requester_for_user_id_from_app_service(
                     ev["sender"], app_service_requester.app_service
                 ),
@@ -345,7 +349,7 @@ class RoomBatchHandler:
                 historical=True,
                 depth=inherited_depth,
             )
-
+            context = await unpersisted_context.persist(event)
             assert context._state_group
 
             # Normally this is done when persisting the event but we have to
@@ -374,7 +378,7 @@ class RoomBatchHandler:
         # correct stream_ordering as they are backfilled (which decrements).
         # Events are sorted by (topological_ordering, stream_ordering)
         # where topological_ordering is just depth.
-        for (event, context) in reversed(events_to_persist):
+        for event, context in reversed(events_to_persist):
             # This call can't raise `PartialStateConflictError` since we forbid
             # use of the historical batch API during partial state
             await self.event_creation_handler.handle_new_client_event(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index f136ee0e97..9d3096df8d 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -207,6 +207,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
     @abc.abstractmethod
     async def remote_knock(
         self,
+        requester: Requester,
         remote_room_hosts: List[str],
         room_id: str,
         user: UserID,
@@ -416,7 +417,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
             try:
                 (
                     event,
-                    context,
+                    unpersisted_context,
                     third_party_event,
                 ) = await self.event_creation_handler.create_event(
                     requester,
@@ -439,7 +440,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     outlier=outlier,
                     historical=historical,
                 )
-
+                context = await unpersisted_context.persist(event)
                 prev_state_ids = await context.get_prev_state_ids(
                     StateFilter.from_types([(EventTypes.Member, None)])
                 )
@@ -1088,7 +1089,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     )
 
                 return await self.remote_knock(
-                    remote_room_hosts, room_id, target, content
+                    requester, remote_room_hosts, room_id, target, content
                 )
 
         return await self._local_membership_update(
@@ -1964,7 +1965,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             try:
                 (
                     event,
-                    context,
+                    unpersisted_context,
                     third_party_event_dict,
                 ) = await self.event_creation_handler.create_event(
                     requester,
@@ -1974,6 +1975,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
                     auth_event_ids=auth_event_ids,
                     outlier=True,
                 )
+                context = await unpersisted_context.persist(event)
                 event.internal_metadata.out_of_band_membership = True
 
                 events_and_context = [(event, context)]
@@ -2013,6 +2015,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
 
     async def remote_knock(
         self,
+        requester: Requester,
         remote_room_hosts: List[str],
         room_id: str,
         user: UserID,
diff --git a/synapse/handlers/room_member_worker.py b/synapse/handlers/room_member_worker.py
index ba261702d4..76e36b8a6d 100644
--- a/synapse/handlers/room_member_worker.py
+++ b/synapse/handlers/room_member_worker.py
@@ -113,6 +113,7 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
 
     async def remote_knock(
         self,
+        requester: Requester,
         remote_room_hosts: List[str],
         room_id: str,
         user: UserID,
@@ -123,9 +124,10 @@ class RoomMemberWorkerHandler(RoomMemberHandler):
         Implements RoomMemberHandler.remote_knock
         """
         ret = await self._remote_knock_client(
+            requester=requester,
             remote_room_hosts=remote_room_hosts,
             room_id=room_id,
-            user=user,
+            user_id=user.to_string(),
             content=content,
         )
         return ret["event_id"], ret["stream_id"]
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 9bbf83047d..aad4706f14 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -23,7 +23,8 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.api.filtering import Filter
 from synapse.events import EventBase
-from synapse.types import JsonDict, StrCollection, StreamKeyType, UserID
+from synapse.events.utils import SerializeEventConfig
+from synapse.types import JsonDict, Requester, StrCollection, StreamKeyType, UserID
 from synapse.types.state import StateFilter
 from synapse.visibility import filter_events_for_client
 
@@ -109,12 +110,12 @@ class SearchHandler:
         return historical_room_ids
 
     async def search(
-        self, user: UserID, content: JsonDict, batch: Optional[str] = None
+        self, requester: Requester, content: JsonDict, batch: Optional[str] = None
     ) -> JsonDict:
         """Performs a full text search for a user.
 
         Args:
-            user: The user performing the search.
+            requester: The user performing the search.
             content: Search parameters
             batch: The next_batch parameter. Used for pagination.
 
@@ -199,7 +200,7 @@ class SearchHandler:
             )
 
         return await self._search(
-            user,
+            requester,
             batch_group,
             batch_group_key,
             batch_token,
@@ -217,7 +218,7 @@ class SearchHandler:
 
     async def _search(
         self,
-        user: UserID,
+        requester: Requester,
         batch_group: Optional[str],
         batch_group_key: Optional[str],
         batch_token: Optional[str],
@@ -235,7 +236,7 @@ class SearchHandler:
         """Performs a full text search for a user.
 
         Args:
-            user: The user performing the search.
+            requester: The user performing the search.
             batch_group: Pagination information.
             batch_group_key: Pagination information.
             batch_token: Pagination information.
@@ -269,7 +270,7 @@ class SearchHandler:
 
         # TODO: Search through left rooms too
         rooms = await self.store.get_rooms_for_local_user_where_membership_is(
-            user.to_string(),
+            requester.user.to_string(),
             membership_list=[Membership.JOIN],
             # membership_list=[Membership.JOIN, Membership.LEAVE, Membership.Ban],
         )
@@ -303,13 +304,13 @@ class SearchHandler:
 
         if order_by == "rank":
             search_result, sender_group = await self._search_by_rank(
-                user, room_ids, search_term, keys, search_filter
+                requester.user, room_ids, search_term, keys, search_filter
             )
             # Unused return values for rank search.
             global_next_batch = None
         elif order_by == "recent":
             search_result, global_next_batch = await self._search_by_recent(
-                user,
+                requester.user,
                 room_ids,
                 search_term,
                 keys,
@@ -334,7 +335,7 @@ class SearchHandler:
             assert after_limit is not None
 
             contexts = await self._calculate_event_contexts(
-                user,
+                requester.user,
                 search_result.allowed_events,
                 before_limit,
                 after_limit,
@@ -363,27 +364,37 @@ class SearchHandler:
                 # The returned events.
                 search_result.allowed_events,
             ),
-            user.to_string(),
+            requester.user.to_string(),
         )
 
         # We're now about to serialize the events. We should not make any
         # blocking calls after this. Otherwise, the 'age' will be wrong.
 
         time_now = self.clock.time_msec()
+        serialize_options = SerializeEventConfig(requester=requester)
 
         for context in contexts.values():
             context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now, bundle_aggregations=aggregations
+                context["events_before"],
+                time_now,
+                bundle_aggregations=aggregations,
+                config=serialize_options,
             )
             context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now, bundle_aggregations=aggregations
+                context["events_after"],
+                time_now,
+                bundle_aggregations=aggregations,
+                config=serialize_options,
             )
 
         results = [
             {
                 "rank": search_result.rank_map[e.event_id],
                 "result": self._event_serializer.serialize_event(
-                    e, time_now, bundle_aggregations=aggregations
+                    e,
+                    time_now,
+                    bundle_aggregations=aggregations,
+                    config=serialize_options,
                 ),
                 "context": contexts.get(e.event_id, {}),
             }
@@ -398,7 +409,9 @@ class SearchHandler:
 
         if state_results:
             rooms_cat_res["state"] = {
-                room_id: self._event_serializer.serialize_events(state_events, time_now)
+                room_id: self._event_serializer.serialize_events(
+                    state_events, time_now, config=serialize_options
+                )
                 for room_id, state_events in state_results.items()
             }
 
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4e4595312c..fd6d946c37 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1297,7 +1297,6 @@ class SyncHandler:
             return RoomNotifCounts.empty()
 
         with Measure(self.clock, "unread_notifs_for_room_id"):
-
             return await self.store.get_unread_event_push_actions_by_room_for_user(
                 room_id,
                 sync_config.user.to_string(),
diff --git a/synapse/handlers/ui_auth/checkers.py b/synapse/handlers/ui_auth/checkers.py
index 332edcca24..78a75bfed6 100644
--- a/synapse/handlers/ui_auth/checkers.py
+++ b/synapse/handlers/ui_auth/checkers.py
@@ -13,7 +13,8 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any
+from abc import ABC, abstractmethod
+from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type
 
 from twisted.web.client import PartialDownloadError
 
@@ -27,19 +28,28 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class UserInteractiveAuthChecker:
+class UserInteractiveAuthChecker(ABC):
     """Abstract base class for an interactive auth checker"""
 
-    def __init__(self, hs: "HomeServer"):
+    # This should really be an "abstract class property", i.e. it should
+    # be an error to instantiate a subclass that doesn't specify an AUTH_TYPE.
+    # But calling this a `ClassVar` is simpler than a decorator stack of
+    # @property @abstractmethod and @classmethod (if that's even the right order).
+    AUTH_TYPE: ClassVar[str]
+
+    def __init__(self, hs: "HomeServer"):  # noqa: B027
         pass
 
+    @abstractmethod
     def is_enabled(self) -> bool:
         """Check if the configuration of the homeserver allows this checker to work
 
         Returns:
             True if this login type is enabled.
         """
+        raise NotImplementedError()
 
+    @abstractmethod
     async def check_auth(self, authdict: dict, clientip: str) -> Any:
         """Given the authentication dict from the client, attempt to check this step
 
@@ -304,7 +314,7 @@ class RegistrationTokenAuthChecker(UserInteractiveAuthChecker):
             )
 
 
-INTERACTIVE_AUTH_CHECKERS = [
+INTERACTIVE_AUTH_CHECKERS: Sequence[Type[UserInteractiveAuthChecker]] = [
     DummyAuthChecker,
     TermsAuthChecker,
     RecaptchaAuthChecker,