summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/admin.py38
-rw-r--r--synapse/handlers/appservice.py2
-rw-r--r--synapse/handlers/auth.py51
-rw-r--r--synapse/handlers/deactivate_account.py20
-rw-r--r--synapse/handlers/directory.py8
-rw-r--r--synapse/handlers/e2e_room_keys.py1
-rw-r--r--synapse/handlers/event_auth.py1
-rw-r--r--synapse/handlers/initial_sync.py1
-rw-r--r--synapse/handlers/message.py16
-rw-r--r--synapse/handlers/presence.py2
-rw-r--r--synapse/handlers/register.py4
-rw-r--r--synapse/handlers/room.py83
-rw-r--r--synapse/handlers/room_batch.py6
-rw-r--r--synapse/handlers/room_member.py13
-rw-r--r--synapse/handlers/sync.py1
15 files changed, 154 insertions, 93 deletions
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index 8b7760b2cc..b06f25b03c 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -252,16 +252,19 @@ class AdminHandler:
         profile = await self.get_user(UserID.from_string(user_id))
         if profile is not None:
             writer.write_profile(profile)
+            logger.info("[%s] Written profile", user_id)
 
         # Get all devices the user has
         devices = await self._device_handler.get_devices_by_user(user_id)
         writer.write_devices(devices)
+        logger.info("[%s] Written %s devices", user_id, len(devices))
 
         # Get all connections the user has
         connections = await self.get_whois(UserID.from_string(user_id))
         writer.write_connections(
             connections["devices"][""]["sessions"][0]["connections"]
         )
+        logger.info("[%s] Written %s connections", user_id, len(connections))
 
         # Get all account data the user has global and in rooms
         global_data = await self._store.get_global_account_data_for_user(user_id)
@@ -269,6 +272,29 @@ class AdminHandler:
         writer.write_account_data("global", global_data)
         for room_id in by_room_data:
             writer.write_account_data(room_id, by_room_data[room_id])
+        logger.info(
+            "[%s] Written account data for %s rooms", user_id, len(by_room_data)
+        )
+
+        # Get all media ids the user has
+        limit = 100
+        start = 0
+        while True:
+            media_ids, total = await self._store.get_local_media_by_user_paginate(
+                start, limit, user_id
+            )
+            for media in media_ids:
+                writer.write_media_id(media["media_id"], media)
+
+            logger.info(
+                "[%s] Written %d media_ids of %s",
+                user_id,
+                (start + len(media_ids)),
+                total,
+            )
+            if (start + limit) >= total:
+                break
+            start += limit
 
         return writer.finished()
 
@@ -360,6 +386,18 @@ class ExfiltrationWriter(metaclass=abc.ABCMeta):
         raise NotImplementedError()
 
     @abc.abstractmethod
+    def write_media_id(self, media_id: str, media_metadata: JsonDict) -> None:
+        """Write the media's metadata of a user.
+        Exports only the metadata, as this can be fetched from the database via
+        read only. In order to access the files, a connection to the correct
+        media repository would be required.
+
+        Args:
+            media_id: ID of the media.
+            media_metadata: Metadata of one media file.
+        """
+
+    @abc.abstractmethod
     def finished(self) -> Any:
         """Called when all data has successfully been exported and written.
 
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 5d1d21cdc8..ec3ab968e9 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -737,7 +737,7 @@ class ApplicationServicesHandler:
         )
 
         ret = []
-        for (success, result) in results:
+        for success, result in results:
             if success:
                 ret.extend(result)
 
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index cf12b55d21..308e38edea 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -815,7 +815,6 @@ class AuthHandler:
         now_ms = self._clock.time_msec()
 
         if existing_token.expiry_ts is not None and existing_token.expiry_ts < now_ms:
-
             raise SynapseError(
                 HTTPStatus.FORBIDDEN,
                 "The supplied refresh token has expired",
@@ -1543,6 +1542,17 @@ class AuthHandler:
     async def add_threepid(
         self, user_id: str, medium: str, address: str, validated_at: int
     ) -> None:
+        """
+        Adds an association between a user's Matrix ID and a third-party ID (email,
+        phone number).
+
+        Args:
+            user_id: The ID of the user to associate.
+            medium: The medium of the third-party ID (email, msisdn).
+            address: The address of the third-party ID (i.e. an email address).
+            validated_at: The timestamp in ms of when the validation that the user owns
+                this third-party ID occurred.
+        """
         # check if medium has a valid value
         if medium not in ["email", "msisdn"]:
             raise SynapseError(
@@ -1567,42 +1577,44 @@ class AuthHandler:
             user_id, medium, address, validated_at, self.hs.get_clock().time_msec()
         )
 
+        # Inform Synapse modules that a 3PID association has been created.
+        await self._third_party_rules.on_add_user_third_party_identifier(
+            user_id, medium, address
+        )
+
+        # Deprecated method for informing Synapse modules that a 3PID association
+        # has successfully been created.
         await self._third_party_rules.on_threepid_bind(user_id, medium, address)
 
-    async def delete_threepid(
-        self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
-    ) -> bool:
-        """Attempts to unbind the 3pid on the identity servers and deletes it
-        from the local database.
+    async def delete_local_threepid(
+        self, user_id: str, medium: str, address: str
+    ) -> None:
+        """Deletes an association between a third-party ID and a user ID from the local
+        database. This method does not unbind the association from any identity servers.
+
+        If `medium` is 'email' and a pusher is associated with this third-party ID, the
+        pusher will also be deleted.
 
         Args:
             user_id: ID of user to remove the 3pid from.
             medium: The medium of the 3pid being removed: "email" or "msisdn".
             address: The 3pid address to remove.
-            id_server: Use the given identity server when unbinding
-                any threepids. If None then will attempt to unbind using the
-                identity server specified when binding (if known).
-
-        Returns:
-            Returns True if successfully unbound the 3pid on
-            the identity server, False if identity server doesn't support the
-            unbind API.
         """
-
         # 'Canonicalise' email addresses as per above
         if medium == "email":
             address = canonicalise_email(address)
 
-        result = await self.hs.get_identity_handler().try_unbind_threepid(
-            user_id, medium, address, id_server
+        await self.store.user_delete_threepid(user_id, medium, address)
+
+        # Inform Synapse modules that a 3PID association has been deleted.
+        await self._third_party_rules.on_remove_user_third_party_identifier(
+            user_id, medium, address
         )
 
-        await self.store.user_delete_threepid(user_id, medium, address)
         if medium == "email":
             await self.store.delete_pusher_by_app_id_pushkey_user_id(
                 app_id="m.email", pushkey=address, user_id=user_id
             )
-        return result
 
     async def hash(self, password: str) -> str:
         """Computes a secure hash of password.
@@ -2259,7 +2271,6 @@ class PasswordAuthProvider:
     async def on_logged_out(
         self, user_id: str, device_id: Optional[str], access_token: str
     ) -> None:
-
         # call all of the on_logged_out callbacks
         for callback in self.on_logged_out_callbacks:
             try:
diff --git a/synapse/handlers/deactivate_account.py b/synapse/handlers/deactivate_account.py
index d24f649382..d31263c717 100644
--- a/synapse/handlers/deactivate_account.py
+++ b/synapse/handlers/deactivate_account.py
@@ -100,26 +100,28 @@ class DeactivateAccountHandler:
         # unbinding
         identity_server_supports_unbinding = True
 
-        # Retrieve the 3PIDs this user has bound to an identity server
-        threepids = await self.store.user_get_bound_threepids(user_id)
-
-        for threepid in threepids:
+        # Attempt to unbind any known bound threepids to this account from identity
+        # server(s).
+        bound_threepids = await self.store.user_get_bound_threepids(user_id)
+        for threepid in bound_threepids:
             try:
                 result = await self._identity_handler.try_unbind_threepid(
                     user_id, threepid["medium"], threepid["address"], id_server
                 )
-                identity_server_supports_unbinding &= result
             except Exception:
                 # Do we want this to be a fatal error or should we carry on?
                 logger.exception("Failed to remove threepid from ID server")
                 raise SynapseError(400, "Failed to remove threepid from ID server")
-            await self.store.user_delete_threepid(
+
+            identity_server_supports_unbinding &= result
+
+        # Remove any local threepid associations for this account.
+        local_threepids = await self.store.user_get_threepids(user_id)
+        for threepid in local_threepids:
+            await self._auth_handler.delete_local_threepid(
                 user_id, threepid["medium"], threepid["address"]
             )
 
-        # Remove all 3PIDs this user has bound to the homeserver
-        await self.store.user_delete_threepids(user_id)
-
         # delete any devices belonging to the user, which will also
         # delete corresponding access tokens.
         await self._device_handler.delete_all_devices_for_user(user_id)
diff --git a/synapse/handlers/directory.py b/synapse/handlers/directory.py
index a5798e9483..1fb23cc9bf 100644
--- a/synapse/handlers/directory.py
+++ b/synapse/handlers/directory.py
@@ -497,9 +497,11 @@ class DirectoryHandler:
                 raise SynapseError(403, "Not allowed to publish room")
 
             # Check if publishing is blocked by a third party module
-            allowed_by_third_party_rules = await (
-                self.third_party_event_rules.check_visibility_can_be_modified(
-                    room_id, visibility
+            allowed_by_third_party_rules = (
+                await (
+                    self.third_party_event_rules.check_visibility_can_be_modified(
+                        room_id, visibility
+                    )
                 )
             )
             if not allowed_by_third_party_rules:
diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py
index 83f53ceb88..50317ec753 100644
--- a/synapse/handlers/e2e_room_keys.py
+++ b/synapse/handlers/e2e_room_keys.py
@@ -188,7 +188,6 @@ class E2eRoomKeysHandler:
 
         # XXX: perhaps we should use a finer grained lock here?
         async with self._upload_linearizer.queue(user_id):
-
             # Check that the version we're trying to upload is the current version
             try:
                 version_info = await self.store.get_e2e_room_keys_version_info(user_id)
diff --git a/synapse/handlers/event_auth.py b/synapse/handlers/event_auth.py
index 46dd63c3f0..c508861b6a 100644
--- a/synapse/handlers/event_auth.py
+++ b/synapse/handlers/event_auth.py
@@ -236,7 +236,6 @@ class EventAuthHandler:
         # in any of them.
         allowed_rooms = await self.get_rooms_that_allow_join(state_ids)
         if not await self.is_user_in_rooms(allowed_rooms, user_id):
-
             # If this is a remote request, the user might be in an allowed room
             # that we do not know about.
             if get_domain_from_id(user_id) != self._server_name:
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 1a29abde98..aead0b44b9 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -124,7 +124,6 @@ class InitialSyncHandler:
         as_client_event: bool = True,
         include_archived: bool = False,
     ) -> JsonDict:
-
         memberships = [Membership.INVITE, Membership.JOIN]
         if include_archived:
             memberships.append(Membership.LEAVE)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index aa90d0000d..e433d6b01f 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -574,7 +574,7 @@ class EventCreationHandler:
         state_map: Optional[StateMap[str]] = None,
         for_batch: bool = False,
         current_state_group: Optional[int] = None,
-    ) -> Tuple[EventBase, EventContext]:
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         """
         Given a dict from a client, create a new event. If bool for_batch is true, will
         create an event using the prev_event_ids, and will create an event context for
@@ -721,8 +721,6 @@ class EventCreationHandler:
             current_state_group=current_state_group,
         )
 
-        context = await unpersisted_context.persist(event)
-
         # In an ideal world we wouldn't need the second part of this condition. However,
         # this behaviour isn't spec'd yet, meaning we should be able to deactivate this
         # behaviour. Another reason is that this code is also evaluated each time a new
@@ -739,7 +737,7 @@ class EventCreationHandler:
                 assert state_map is not None
                 prev_event_id = state_map.get((EventTypes.Member, event.sender))
             else:
-                prev_state_ids = await context.get_prev_state_ids(
+                prev_state_ids = await unpersisted_context.get_prev_state_ids(
                     StateFilter.from_types([(EventTypes.Member, None)])
                 )
                 prev_event_id = prev_state_ids.get((EventTypes.Member, event.sender))
@@ -764,8 +762,7 @@ class EventCreationHandler:
                 )
 
         self.validator.validate_new(event, self.config)
-
-        return event, context
+        return event, unpersisted_context
 
     async def _is_exempt_from_privacy_policy(
         self, builder: EventBuilder, requester: Requester
@@ -1005,7 +1002,7 @@ class EventCreationHandler:
         max_retries = 5
         for i in range(max_retries):
             try:
-                event, context = await self.create_event(
+                event, unpersisted_context = await self.create_event(
                     requester,
                     event_dict,
                     txn_id=txn_id,
@@ -1016,6 +1013,7 @@ class EventCreationHandler:
                     historical=historical,
                     depth=depth,
                 )
+                context = await unpersisted_context.persist(event)
 
                 assert self.hs.is_mine_id(event.sender), "User must be our own: %s" % (
                     event.sender,
@@ -1190,7 +1188,6 @@ class EventCreationHandler:
         if for_batch:
             assert prev_event_ids is not None
             assert state_map is not None
-            assert current_state_group is not None
             auth_ids = self._event_auth_handler.compute_auth_events(builder, state_map)
             event = await builder.build(
                 prev_event_ids=prev_event_ids, auth_event_ids=auth_ids, depth=depth
@@ -2046,7 +2043,7 @@ class EventCreationHandler:
                 max_retries = 5
                 for i in range(max_retries):
                     try:
-                        event, context = await self.create_event(
+                        event, unpersisted_context = await self.create_event(
                             requester,
                             {
                                 "type": EventTypes.Dummy,
@@ -2055,6 +2052,7 @@ class EventCreationHandler:
                                 "sender": user_id,
                             },
                         )
+                        context = await unpersisted_context.persist(event)
 
                         event.internal_metadata.proactively_send = False
 
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 87af31aa27..4ad2233573 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -777,7 +777,6 @@ class PresenceHandler(BasePresenceHandler):
         )
 
         if self.unpersisted_users_changes:
-
             await self.store.update_presence(
                 [
                     self.user_to_current_state[user_id]
@@ -823,7 +822,6 @@ class PresenceHandler(BasePresenceHandler):
         now = self.clock.time_msec()
 
         with Measure(self.clock, "presence_update_states"):
-
             # NOTE: We purposefully don't await between now and when we've
             # calculated what we want to do with the new states, to avoid races.
 
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index c611efb760..e4e506e62c 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -476,7 +476,7 @@ class RegistrationHandler:
                     # create room expects the localpart of the room alias
                     config["room_alias_name"] = room_alias.localpart
 
-                    info, _ = await room_creation_handler.create_room(
+                    room_id, _, _ = await room_creation_handler.create_room(
                         fake_requester,
                         config=config,
                         ratelimit=False,
@@ -490,7 +490,7 @@ class RegistrationHandler:
                                 user_id, authenticated_entity=self._server_name
                             ),
                             target=UserID.from_string(user_id),
-                            room_id=info["room_id"],
+                            room_id=room_id,
                             # Since it was just created, there are no remote hosts.
                             remote_room_hosts=[],
                             action="join",
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 837dabb3b7..b1784638f4 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -51,6 +51,7 @@ from synapse.api.filtering import Filter
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.event_auth import validate_event_for_room_version
 from synapse.events import EventBase
+from synapse.events.snapshot import UnpersistedEventContext
 from synapse.events.utils import copy_and_fixup_power_levels_contents
 from synapse.handlers.relations import BundledAggregations
 from synapse.module_api import NOT_SPAM
@@ -211,7 +212,7 @@ class RoomCreationHandler:
                 # the required power level to send the tombstone event.
                 (
                     tombstone_event,
-                    tombstone_context,
+                    tombstone_unpersisted_context,
                 ) = await self.event_creation_handler.create_event(
                     requester,
                     {
@@ -225,6 +226,9 @@ class RoomCreationHandler:
                         },
                     },
                 )
+                tombstone_context = await tombstone_unpersisted_context.persist(
+                    tombstone_event
+                )
                 validate_event_for_room_version(tombstone_event)
                 await self._event_auth_handler.check_auth_rules_from_context(
                     tombstone_event
@@ -690,13 +694,14 @@ class RoomCreationHandler:
         config: JsonDict,
         ratelimit: bool = True,
         creator_join_profile: Optional[JsonDict] = None,
-    ) -> Tuple[dict, int]:
+    ) -> Tuple[str, Optional[RoomAlias], int]:
         """Creates a new room.
 
         Args:
-            requester:
-                The user who requested the room creation.
-            config : A dict of configuration options.
+            requester: The user who requested the room creation.
+            config: A dict of configuration options. This will be the body of
+                a /createRoom request; see
+                https://spec.matrix.org/latest/client-server-api/#post_matrixclientv3createroom
             ratelimit: set to False to disable the rate limiter
 
             creator_join_profile:
@@ -707,14 +712,17 @@ class RoomCreationHandler:
                 `avatar_url` and/or `displayname`.
 
         Returns:
-                First, a dict containing the keys `room_id` and, if an alias
-                was, requested, `room_alias`. Secondly, the stream_id of the
-                last persisted event.
+            A 3-tuple containing:
+                - the room ID;
+                - if requested, the room alias, otherwise None; and
+                - the `stream_id` of the last persisted event.
         Raises:
-            SynapseError if the room ID couldn't be stored, 3pid invitation config
-            validation failed, or something went horribly wrong.
-            ResourceLimitError if server is blocked to some resource being
-            exceeded
+            SynapseError:
+                if the room ID couldn't be stored, 3pid invitation config
+                validation failed, or something went horribly wrong.
+            ResourceLimitError:
+                if server is blocked to some resource being
+                exceeded
         """
         user_id = requester.user.to_string()
 
@@ -864,9 +872,11 @@ class RoomCreationHandler:
         )
 
         # Check whether this visibility value is blocked by a third party module
-        allowed_by_third_party_rules = await (
-            self.third_party_event_rules.check_visibility_can_be_modified(
-                room_id, visibility
+        allowed_by_third_party_rules = (
+            await (
+                self.third_party_event_rules.check_visibility_can_be_modified(
+                    room_id, visibility
+                )
             )
         )
         if not allowed_by_third_party_rules:
@@ -1024,11 +1034,6 @@ class RoomCreationHandler:
             last_sent_event_id = member_event_id
             depth += 1
 
-        result = {"room_id": room_id}
-
-        if room_alias:
-            result["room_alias"] = room_alias.to_string()
-
         # Always wait for room creation to propagate before returning
         await self._replication.wait_for_stream_position(
             self.hs.config.worker.events_shard_config.get_instance(room_id),
@@ -1036,7 +1041,7 @@ class RoomCreationHandler:
             last_stream_id,
         )
 
-        return result, last_stream_id
+        return room_id, room_alias, last_stream_id
 
     async def _send_events_for_new_room(
         self,
@@ -1091,7 +1096,7 @@ class RoomCreationHandler:
             content: JsonDict,
             for_batch: bool,
             **kwargs: Any,
-        ) -> Tuple[EventBase, synapse.events.snapshot.EventContext]:
+        ) -> Tuple[EventBase, synapse.events.snapshot.UnpersistedEventContextBase]:
             """
             Creates an event and associated event context.
             Args:
@@ -1110,20 +1115,23 @@ class RoomCreationHandler:
 
             event_dict = create_event_dict(etype, content, **kwargs)
 
-            new_event, new_context = await self.event_creation_handler.create_event(
+            (
+                new_event,
+                new_unpersisted_context,
+            ) = await self.event_creation_handler.create_event(
                 creator,
                 event_dict,
                 prev_event_ids=prev_event,
                 depth=depth,
                 state_map=state_map,
                 for_batch=for_batch,
-                current_state_group=current_state_group,
             )
+
             depth += 1
             prev_event = [new_event.event_id]
             state_map[(new_event.type, new_event.state_key)] = new_event.event_id
 
-            return new_event, new_context
+            return new_event, new_unpersisted_context
 
         try:
             config = self._presets_dict[preset_config]
@@ -1133,10 +1141,10 @@ class RoomCreationHandler:
             )
 
         creation_content.update({"creator": creator_id})
-        creation_event, creation_context = await create_event(
+        creation_event, unpersisted_creation_context = await create_event(
             EventTypes.Create, creation_content, False
         )
-
+        creation_context = await unpersisted_creation_context.persist(creation_event)
         logger.debug("Sending %s in new room", EventTypes.Member)
         ev = await self.event_creation_handler.handle_new_client_event(
             requester=creator,
@@ -1180,7 +1188,6 @@ class RoomCreationHandler:
             power_event, power_context = await create_event(
                 EventTypes.PowerLevels, pl_content, True
             )
-            current_state_group = power_context._state_group
             events_to_send.append((power_event, power_context))
         else:
             power_level_content: JsonDict = {
@@ -1229,14 +1236,12 @@ class RoomCreationHandler:
                 power_level_content,
                 True,
             )
-            current_state_group = pl_context._state_group
             events_to_send.append((pl_event, pl_context))
 
         if room_alias and (EventTypes.CanonicalAlias, "") not in initial_state:
             room_alias_event, room_alias_context = await create_event(
                 EventTypes.CanonicalAlias, {"alias": room_alias.to_string()}, True
             )
-            current_state_group = room_alias_context._state_group
             events_to_send.append((room_alias_event, room_alias_context))
 
         if (EventTypes.JoinRules, "") not in initial_state:
@@ -1245,7 +1250,6 @@ class RoomCreationHandler:
                 {"join_rule": config["join_rules"]},
                 True,
             )
-            current_state_group = join_rules_context._state_group
             events_to_send.append((join_rules_event, join_rules_context))
 
         if (EventTypes.RoomHistoryVisibility, "") not in initial_state:
@@ -1254,7 +1258,6 @@ class RoomCreationHandler:
                 {"history_visibility": config["history_visibility"]},
                 True,
             )
-            current_state_group = visibility_context._state_group
             events_to_send.append((visibility_event, visibility_context))
 
         if config["guest_can_join"]:
@@ -1264,14 +1267,12 @@ class RoomCreationHandler:
                     {EventContentFields.GUEST_ACCESS: GuestAccess.CAN_JOIN},
                     True,
                 )
-                current_state_group = guest_access_context._state_group
                 events_to_send.append((guest_access_event, guest_access_context))
 
         for (etype, state_key), content in initial_state.items():
             event, context = await create_event(
                 etype, content, True, state_key=state_key
             )
-            current_state_group = context._state_group
             events_to_send.append((event, context))
 
         if config["encrypted"]:
@@ -1283,9 +1284,16 @@ class RoomCreationHandler:
             )
             events_to_send.append((encryption_event, encryption_context))
 
+        datastore = self.hs.get_datastores().state
+        events_and_context = (
+            await UnpersistedEventContext.batch_persist_unpersisted_contexts(
+                events_to_send, room_id, current_state_group, datastore
+            )
+        )
+
         last_event = await self.event_creation_handler.handle_new_client_event(
             creator,
-            events_to_send,
+            events_and_context,
             ignore_shadow_ban=True,
             ratelimit=False,
         )
@@ -1825,7 +1833,7 @@ class RoomShutdownHandler:
                 new_room_user_id, authenticated_entity=requester_user_id
             )
 
-            info, stream_id = await self._room_creation_handler.create_room(
+            new_room_id, _, stream_id = await self._room_creation_handler.create_room(
                 room_creator_requester,
                 config={
                     "preset": RoomCreationPreset.PUBLIC_CHAT,
@@ -1834,7 +1842,6 @@ class RoomShutdownHandler:
                 },
                 ratelimit=False,
             )
-            new_room_id = info["room_id"]
 
             logger.info(
                 "Shutting down room %r, joining to new room: %r", room_id, new_room_id
@@ -1887,6 +1894,7 @@ class RoomShutdownHandler:
 
                 # Join users to new room
                 if new_room_user_id:
+                    assert new_room_id is not None
                     await self.room_member_handler.update_membership(
                         requester=target_requester,
                         target=target_requester.user,
@@ -1919,6 +1927,7 @@ class RoomShutdownHandler:
 
             aliases_for_room = await self.store.get_aliases_for_room(room_id)
 
+            assert new_room_id is not None
             await self.store.update_aliases_for_room(
                 room_id, new_room_id, requester_user_id
             )
diff --git a/synapse/handlers/room_batch.py b/synapse/handlers/room_batch.py
index c73d2adaad..bf9df60218 100644
--- a/synapse/handlers/room_batch.py
+++ b/synapse/handlers/room_batch.py
@@ -327,7 +327,7 @@ class RoomBatchHandler:
             # Mark all events as historical
             event_dict["content"][EventContentFields.MSC2716_HISTORICAL] = True
 
-            event, context = await self.event_creation_handler.create_event(
+            event, unpersisted_context = await self.event_creation_handler.create_event(
                 await self.create_requester_for_user_id_from_app_service(
                     ev["sender"], app_service_requester.app_service
                 ),
@@ -345,7 +345,7 @@ class RoomBatchHandler:
                 historical=True,
                 depth=inherited_depth,
             )
-
+            context = await unpersisted_context.persist(event)
             assert context._state_group
 
             # Normally this is done when persisting the event but we have to
@@ -374,7 +374,7 @@ class RoomBatchHandler:
         # correct stream_ordering as they are backfilled (which decrements).
         # Events are sorted by (topological_ordering, stream_ordering)
         # where topological_ordering is just depth.
-        for (event, context) in reversed(events_to_persist):
+        for event, context in reversed(events_to_persist):
             # This call can't raise `PartialStateConflictError` since we forbid
             # use of the historical batch API during partial state
             await self.event_creation_handler.handle_new_client_event(
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index a965c7ec76..de7476f300 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -414,7 +414,10 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
         max_retries = 5
         for i in range(max_retries):
             try:
-                event, context = await self.event_creation_handler.create_event(
+                (
+                    event,
+                    unpersisted_context,
+                ) = await self.event_creation_handler.create_event(
                     requester,
                     {
                         "type": EventTypes.Member,
@@ -435,7 +438,7 @@ class RoomMemberHandler(metaclass=abc.ABCMeta):
                     outlier=outlier,
                     historical=historical,
                 )
-
+                context = await unpersisted_context.persist(event)
                 prev_state_ids = await context.get_prev_state_ids(
                     StateFilter.from_types([(EventTypes.Member, None)])
                 )
@@ -1944,7 +1947,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         max_retries = 5
         for i in range(max_retries):
             try:
-                event, context = await self.event_creation_handler.create_event(
+                (
+                    event,
+                    unpersisted_context,
+                ) = await self.event_creation_handler.create_event(
                     requester,
                     event_dict,
                     txn_id=txn_id,
@@ -1952,6 +1958,7 @@ class RoomMemberMasterHandler(RoomMemberHandler):
                     auth_event_ids=auth_event_ids,
                     outlier=True,
                 )
+                context = await unpersisted_context.persist(event)
                 event.internal_metadata.out_of_band_membership = True
 
                 result_event = (
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 4e4595312c..fd6d946c37 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -1297,7 +1297,6 @@ class SyncHandler:
             return RoomNotifCounts.empty()
 
         with Measure(self.clock, "unread_notifs_for_room_id"):
-
             return await self.store.get_unread_event_push_actions_by_room_for_user(
                 room_id,
                 sync_config.user.to_string(),