summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-11-03 07:30:31 -0400
committerGitHub <noreply@github.com>2023-11-03 07:30:31 -0400
commit92828a7f958b2cb1925e2a64ed08c2efb6293787 (patch)
treec942454d63d072476ce62c10d26080ce4b691d50 /synapse/storage
parentUse simple_select_many_txn in event persistance code. (#16585) (diff)
downloadsynapse-92828a7f958b2cb1925e2a64ed08c2efb6293787.tar.xz
Simplify event persistence code (#16584)
The event persistence code used to handle multiple rooms
at a time, but was simplified to only ever be called with a
single room at a time (different rooms are now handled in
parallel). The code is still generic to multiple rooms causing
a lot of work that is unnecessary (e.g. unnecessary loops, and
partitioning data by room).

This strips out the ability to handle multiple rooms at once, greatly
simplifying the code.
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/controllers/persist_events.py252
-rw-r--r--synapse/storage/databases/main/events.py384
2 files changed, 324 insertions, 312 deletions
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index f39ae2d635..1529c86cc5 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -542,13 +542,15 @@ class EventsPersistenceStorageController:
         return await res.get_state(self._state_controller, StateFilter.all())
 
     async def _persist_event_batch(
-        self, _room_id: str, task: _PersistEventsTask
+        self, room_id: str, task: _PersistEventsTask
     ) -> Dict[str, str]:
         """Callback for the _event_persist_queue
 
         Calculates the change to current state and forward extremities, and
         persists the given events and with those updates.
 
+        Assumes that we are only persisting events for one room at a time.
+
         Returns:
             A dictionary of event ID to event ID we didn't persist as we already
             had another event persisted with the same TXN ID.
@@ -594,140 +596,23 @@ class EventsPersistenceStorageController:
             # We can't easily parallelize these since different chunks
             # might contain the same event. :(
 
-            # NB: Assumes that we are only persisting events for one room
-            # at a time.
-
-            # map room_id->set[event_ids] giving the new forward
-            # extremities in each room
-            new_forward_extremities: Dict[str, Set[str]] = {}
-
-            # map room_id->(to_delete, to_insert) where to_delete is a list
-            # of type/state keys to remove from current state, and to_insert
-            # is a map (type,key)->event_id giving the state delta in each
-            # room
-            state_delta_for_room: Dict[str, DeltaState] = {}
+            new_forward_extremities = None
+            state_delta_for_room = None
 
             if not backfilled:
                 with Measure(self._clock, "_calculate_state_and_extrem"):
-                    # Work out the new "current state" for each room.
+                    # Work out the new "current state" for the room.
                     # We do this by working out what the new extremities are and then
                     # calculating the state from that.
-                    events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
-                    for event, context in chunk:
-                        events_by_room.setdefault(event.room_id, []).append(
-                            (event, context)
-                        )
-
-                    for room_id, ev_ctx_rm in events_by_room.items():
-                        latest_event_ids = (
-                            await self.main_store.get_latest_event_ids_in_room(room_id)
-                        )
-                        new_latest_event_ids = await self._calculate_new_extremities(
-                            room_id, ev_ctx_rm, latest_event_ids
-                        )
-
-                        if new_latest_event_ids == latest_event_ids:
-                            # No change in extremities, so no change in state
-                            continue
-
-                        # there should always be at least one forward extremity.
-                        # (except during the initial persistence of the send_join
-                        # results, in which case there will be no existing
-                        # extremities, so we'll `continue` above and skip this bit.)
-                        assert new_latest_event_ids, "No forward extremities left!"
-
-                        new_forward_extremities[room_id] = new_latest_event_ids
-
-                        len_1 = (
-                            len(latest_event_ids) == 1
-                            and len(new_latest_event_ids) == 1
-                        )
-                        if len_1:
-                            all_single_prev_not_state = all(
-                                len(event.prev_event_ids()) == 1
-                                and not event.is_state()
-                                for event, ctx in ev_ctx_rm
-                            )
-                            # Don't bother calculating state if they're just
-                            # a long chain of single ancestor non-state events.
-                            if all_single_prev_not_state:
-                                continue
-
-                        state_delta_counter.inc()
-                        if len(new_latest_event_ids) == 1:
-                            state_delta_single_event_counter.inc()
-
-                            # This is a fairly handwavey check to see if we could
-                            # have guessed what the delta would have been when
-                            # processing one of these events.
-                            # What we're interested in is if the latest extremities
-                            # were the same when we created the event as they are
-                            # now. When this server creates a new event (as opposed
-                            # to receiving it over federation) it will use the
-                            # forward extremities as the prev_events, so we can
-                            # guess this by looking at the prev_events and checking
-                            # if they match the current forward extremities.
-                            for ev, _ in ev_ctx_rm:
-                                prev_event_ids = set(ev.prev_event_ids())
-                                if latest_event_ids == prev_event_ids:
-                                    state_delta_reuse_delta_counter.inc()
-                                    break
-
-                        logger.debug("Calculating state delta for room %s", room_id)
-                        with Measure(
-                            self._clock, "persist_events.get_new_state_after_events"
-                        ):
-                            res = await self._get_new_state_after_events(
-                                room_id,
-                                ev_ctx_rm,
-                                latest_event_ids,
-                                new_latest_event_ids,
-                            )
-                            current_state, delta_ids, new_latest_event_ids = res
-
-                            # there should always be at least one forward extremity.
-                            # (except during the initial persistence of the send_join
-                            # results, in which case there will be no existing
-                            # extremities, so we'll `continue` above and skip this bit.)
-                            assert new_latest_event_ids, "No forward extremities left!"
-
-                            new_forward_extremities[room_id] = new_latest_event_ids
-
-                        # If either are not None then there has been a change,
-                        # and we need to work out the delta (or use that
-                        # given)
-                        delta = None
-                        if delta_ids is not None:
-                            # If there is a delta we know that we've
-                            # only added or replaced state, never
-                            # removed keys entirely.
-                            delta = DeltaState([], delta_ids)
-                        elif current_state is not None:
-                            with Measure(
-                                self._clock, "persist_events.calculate_state_delta"
-                            ):
-                                delta = await self._calculate_state_delta(
-                                    room_id, current_state
-                                )
-
-                        if delta:
-                            # If we have a change of state then lets check
-                            # whether we're actually still a member of the room,
-                            # or if our last user left. If we're no longer in
-                            # the room then we delete the current state and
-                            # extremities.
-                            is_still_joined = await self._is_server_still_joined(
-                                room_id,
-                                ev_ctx_rm,
-                                delta,
-                            )
-                            if not is_still_joined:
-                                logger.info("Server no longer in room %s", room_id)
-                                delta.no_longer_in_room = True
-
-                            state_delta_for_room[room_id] = delta
+                    (
+                        new_forward_extremities,
+                        state_delta_for_room,
+                    ) = await self._calculate_new_forward_extremities_and_state_delta(
+                        room_id, chunk
+                    )
 
             await self.persist_events_store._persist_events_and_state_updates(
+                room_id,
                 chunk,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
@@ -737,6 +622,117 @@ class EventsPersistenceStorageController:
 
         return replaced_events
 
+    async def _calculate_new_forward_extremities_and_state_delta(
+        self, room_id: str, ev_ctx_rm: List[Tuple[EventBase, EventContext]]
+    ) -> Tuple[Optional[Set[str]], Optional[DeltaState]]:
+        """Calculates the new forward extremities and state delta for a room
+        given events to persist.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            A tuple of:
+                A set of str giving the new forward extremities the room
+
+                The state delta for the room.
+        """
+
+        latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id)
+        new_latest_event_ids = await self._calculate_new_extremities(
+            room_id, ev_ctx_rm, latest_event_ids
+        )
+
+        if new_latest_event_ids == latest_event_ids:
+            # No change in extremities, so no change in state
+            return (None, None)
+
+        # there should always be at least one forward extremity.
+        # (except during the initial persistence of the send_join
+        # results, in which case there will be no existing
+        # extremities, so we'll `continue` above and skip this bit.)
+        assert new_latest_event_ids, "No forward extremities left!"
+
+        new_forward_extremities = new_latest_event_ids
+
+        len_1 = len(latest_event_ids) == 1 and len(new_latest_event_ids) == 1
+        if len_1:
+            all_single_prev_not_state = all(
+                len(event.prev_event_ids()) == 1 and not event.is_state()
+                for event, ctx in ev_ctx_rm
+            )
+            # Don't bother calculating state if they're just
+            # a long chain of single ancestor non-state events.
+            if all_single_prev_not_state:
+                return (new_forward_extremities, None)
+
+        state_delta_counter.inc()
+        if len(new_latest_event_ids) == 1:
+            state_delta_single_event_counter.inc()
+
+            # This is a fairly handwavey check to see if we could
+            # have guessed what the delta would have been when
+            # processing one of these events.
+            # What we're interested in is if the latest extremities
+            # were the same when we created the event as they are
+            # now. When this server creates a new event (as opposed
+            # to receiving it over federation) it will use the
+            # forward extremities as the prev_events, so we can
+            # guess this by looking at the prev_events and checking
+            # if they match the current forward extremities.
+            for ev, _ in ev_ctx_rm:
+                prev_event_ids = set(ev.prev_event_ids())
+                if latest_event_ids == prev_event_ids:
+                    state_delta_reuse_delta_counter.inc()
+                    break
+
+        logger.debug("Calculating state delta for room %s", room_id)
+        with Measure(self._clock, "persist_events.get_new_state_after_events"):
+            res = await self._get_new_state_after_events(
+                room_id,
+                ev_ctx_rm,
+                latest_event_ids,
+                new_latest_event_ids,
+            )
+            current_state, delta_ids, new_latest_event_ids = res
+
+            # there should always be at least one forward extremity.
+            # (except during the initial persistence of the send_join
+            # results, in which case there will be no existing
+            # extremities, so we'll `continue` above and skip this bit.)
+            assert new_latest_event_ids, "No forward extremities left!"
+
+            new_forward_extremities = new_latest_event_ids
+
+        # If either are not None then there has been a change,
+        # and we need to work out the delta (or use that
+        # given)
+        delta = None
+        if delta_ids is not None:
+            # If there is a delta we know that we've
+            # only added or replaced state, never
+            # removed keys entirely.
+            delta = DeltaState([], delta_ids)
+        elif current_state is not None:
+            with Measure(self._clock, "persist_events.calculate_state_delta"):
+                delta = await self._calculate_state_delta(room_id, current_state)
+
+        if delta:
+            # If we have a change of state then lets check
+            # whether we're actually still a member of the room,
+            # or if our last user left. If we're no longer in
+            # the room then we delete the current state and
+            # extremities.
+            is_still_joined = await self._is_server_still_joined(
+                room_id,
+                ev_ctx_rm,
+                delta,
+            )
+            if not is_still_joined:
+                logger.info("Server no longer in room %s", room_id)
+                delta.no_longer_in_room = True
+
+        return (new_forward_extremities, delta)
+
     async def _calculate_new_extremities(
         self,
         room_id: str,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b74ff1c498..647ba182f6 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -79,7 +79,7 @@ class DeltaState:
     Attributes:
         to_delete: List of type/state_keys to delete from current state
         to_insert: Map of state to upsert into current state
-        no_longer_in_room: The server is not longer in the room, so the room
+        no_longer_in_room: The server is no longer in the room, so the room
             should e.g. be removed from `current_state_events` table.
     """
 
@@ -131,22 +131,25 @@ class PersistEventsStore:
     @trace
     async def _persist_events_and_state_updates(
         self,
+        room_id: str,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         *,
-        state_delta_for_room: Dict[str, DeltaState],
-        new_forward_extremities: Dict[str, Set[str]],
+        state_delta_for_room: Optional[DeltaState],
+        new_forward_extremities: Optional[Set[str]],
         use_negative_stream_ordering: bool = False,
         inhibit_local_membership_updates: bool = False,
     ) -> None:
         """Persist a set of events alongside updates to the current state and
-        forward extremities tables.
+                forward extremities tables.
+
+        Assumes that we are only persisting events for one room at a time.
 
         Args:
+            room_id:
             events_and_contexts:
-            state_delta_for_room: Map from room_id to the delta to apply to
-                room state
-            new_forward_extremities: Map from room_id to set of event IDs
-                that are the new forward extremities of the room.
+            state_delta_for_room: The delta to apply to the room state
+            new_forward_extremities: A set of event IDs that are the new forward
+                extremities of the room.
             use_negative_stream_ordering: Whether to start stream_ordering on
                 the negative side and decrement. This should be set as True
                 for backfilled events because backfilled events get a negative
@@ -196,6 +199,7 @@ class PersistEventsStore:
             await self.db_pool.runInteraction(
                 "persist_events",
                 self._persist_events_txn,
+                room_id=room_id,
                 events_and_contexts=events_and_contexts,
                 inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
@@ -221,9 +225,9 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-            for room_id, latest_event_ids in new_forward_extremities.items():
+            if new_forward_extremities:
                 self.store.get_latest_event_ids_in_room.prefill(
-                    (room_id,), frozenset(latest_event_ids)
+                    (room_id,), frozenset(new_forward_extremities)
                 )
 
     async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
@@ -336,10 +340,11 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         *,
+        room_id: str,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool,
-        state_delta_for_room: Dict[str, DeltaState],
-        new_forward_extremities: Dict[str, Set[str]],
+        state_delta_for_room: Optional[DeltaState],
+        new_forward_extremities: Optional[Set[str]],
     ) -> None:
         """Insert some number of room events into the necessary database tables.
 
@@ -347,8 +352,11 @@ class PersistEventsStore:
         and the rejections table. Things reading from those table will need to check
         whether the event was rejected.
 
+        Assumes that we are only persisting events for one room at a time.
+
         Args:
             txn
+            room_id: The room the events are from
             events_and_contexts: events to persist
             inhibit_local_membership_updates: Stop the local_current_membership
                 from being updated by these events. This should be set to True
@@ -357,10 +365,9 @@ class PersistEventsStore:
             delete_existing True to purge existing table rows for the events
                 from the database. This is useful when retrying due to
                 IntegrityError.
-            state_delta_for_room: The current-state delta for each room.
-            new_forward_extremities: The new forward extremities for each room.
-                For each room, a list of the event ids which are the forward
-                extremities.
+            state_delta_for_room: The current-state delta for the room.
+            new_forward_extremities: The new forward extremities for the room:
+                a set of the event ids which are the forward extremities.
 
         Raises:
             PartialStateConflictError: if attempting to persist a partial state event in
@@ -376,14 +383,13 @@ class PersistEventsStore:
         #
         # Annoyingly SQLite doesn't support row level locking.
         if isinstance(self.database_engine, PostgresEngine):
-            for room_id in {e.room_id for e, _ in events_and_contexts}:
-                txn.execute(
-                    "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
-                    (room_id,),
-                )
-                row = txn.fetchone()
-                if row is None:
-                    raise Exception(f"Room does not exist {room_id}")
+            txn.execute(
+                "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
+                (room_id,),
+            )
+            row = txn.fetchone()
+            if row is None:
+                raise Exception(f"Room does not exist {room_id}")
 
         # stream orderings should have been assigned by now
         assert min_stream_order
@@ -419,7 +425,9 @@ class PersistEventsStore:
             events_and_contexts
         )
 
-        self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
+        self._update_room_depths_txn(
+            txn, room_id, events_and_contexts=events_and_contexts
+        )
 
         # _update_outliers_txn filters out any events which have already been
         # persisted, and returns the filtered list.
@@ -432,11 +440,13 @@ class PersistEventsStore:
 
         self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
-        self._update_forward_extremities_txn(
-            txn,
-            new_forward_extremities=new_forward_extremities,
-            max_stream_order=max_stream_order,
-        )
+        if new_forward_extremities:
+            self._update_forward_extremities_txn(
+                txn,
+                room_id,
+                new_forward_extremities=new_forward_extremities,
+                max_stream_order=max_stream_order,
+            )
 
         self._persist_transaction_ids_txn(txn, events_and_contexts)
 
@@ -464,7 +474,10 @@ class PersistEventsStore:
         # We call this last as it assumes we've inserted the events into
         # room_memberships, where applicable.
         # NB: This function invalidates all state related caches
-        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+        if state_delta_for_room:
+            self._update_current_state_txn(
+                txn, room_id, state_delta_for_room, min_stream_order
+            )
 
     def _persist_event_auth_chain_txn(
         self,
@@ -1026,74 +1039,75 @@ class PersistEventsStore:
             await self.db_pool.runInteraction(
                 "update_current_state",
                 self._update_current_state_txn,
-                state_delta_by_room={room_id: state_delta},
+                room_id,
+                delta_state=state_delta,
                 stream_id=stream_ordering,
             )
 
     def _update_current_state_txn(
         self,
         txn: LoggingTransaction,
-        state_delta_by_room: Dict[str, DeltaState],
+        room_id: str,
+        delta_state: DeltaState,
         stream_id: int,
     ) -> None:
-        for room_id, delta_state in state_delta_by_room.items():
-            to_delete = delta_state.to_delete
-            to_insert = delta_state.to_insert
-
-            # Figure out the changes of membership to invalidate the
-            # `get_rooms_for_user` cache.
-            # We find out which membership events we may have deleted
-            # and which we have added, then we invalidate the caches for all
-            # those users.
-            members_changed = {
-                state_key
-                for ev_type, state_key in itertools.chain(to_delete, to_insert)
-                if ev_type == EventTypes.Member
-            }
+        to_delete = delta_state.to_delete
+        to_insert = delta_state.to_insert
+
+        # Figure out the changes of membership to invalidate the
+        # `get_rooms_for_user` cache.
+        # We find out which membership events we may have deleted
+        # and which we have added, then we invalidate the caches for all
+        # those users.
+        members_changed = {
+            state_key
+            for ev_type, state_key in itertools.chain(to_delete, to_insert)
+            if ev_type == EventTypes.Member
+        }
 
-            if delta_state.no_longer_in_room:
-                # Server is no longer in the room so we delete the room from
-                # current_state_events, being careful we've already updated the
-                # rooms.room_version column (which gets populated in a
-                # background task).
-                self._upsert_room_version_txn(txn, room_id)
+        if delta_state.no_longer_in_room:
+            # Server is no longer in the room so we delete the room from
+            # current_state_events, being careful we've already updated the
+            # rooms.room_version column (which gets populated in a
+            # background task).
+            self._upsert_room_version_txn(txn, room_id)
 
-                # Before deleting we populate the current_state_delta_stream
-                # so that async background tasks get told what happened.
-                sql = """
+            # Before deleting we populate the current_state_delta_stream
+            # so that async background tasks get told what happened.
+            sql = """
                     INSERT INTO current_state_delta_stream
                         (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
                     SELECT ?, ?, room_id, type, state_key, null, event_id
                         FROM current_state_events
                         WHERE room_id = ?
                 """
-                txn.execute(sql, (stream_id, self._instance_name, room_id))
+            txn.execute(sql, (stream_id, self._instance_name, room_id))
 
-                # We also want to invalidate the membership caches for users
-                # that were in the room.
-                users_in_room = self.store.get_users_in_room_txn(txn, room_id)
-                members_changed.update(users_in_room)
+            # We also want to invalidate the membership caches for users
+            # that were in the room.
+            users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+            members_changed.update(users_in_room)
 
-                self.db_pool.simple_delete_txn(
-                    txn,
-                    table="current_state_events",
-                    keyvalues={"room_id": room_id},
-                )
-            else:
-                # We're still in the room, so we update the current state as normal.
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="current_state_events",
+                keyvalues={"room_id": room_id},
+            )
+        else:
+            # We're still in the room, so we update the current state as normal.
 
-                # First we add entries to the current_state_delta_stream. We
-                # do this before updating the current_state_events table so
-                # that we can use it to calculate the `prev_event_id`. (This
-                # allows us to not have to pull out the existing state
-                # unnecessarily).
-                #
-                # The stream_id for the update is chosen to be the minimum of the stream_ids
-                # for the batch of the events that we are persisting; that means we do not
-                # end up in a situation where workers see events before the
-                # current_state_delta updates.
-                #
-                sql = """
+            # First we add entries to the current_state_delta_stream. We
+            # do this before updating the current_state_events table so
+            # that we can use it to calculate the `prev_event_id`. (This
+            # allows us to not have to pull out the existing state
+            # unnecessarily).
+            #
+            # The stream_id for the update is chosen to be the minimum of the stream_ids
+            # for the batch of the events that we are persisting; that means we do not
+            # end up in a situation where workers see events before the
+            # current_state_delta updates.
+            #
+            sql = """
                     INSERT INTO current_state_delta_stream
                     (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
                     SELECT ?, ?, ?, ?, ?, ?, (
@@ -1101,39 +1115,39 @@ class PersistEventsStore:
                         WHERE room_id = ? AND type = ? AND state_key = ?
                     )
                 """
-                txn.execute_batch(
-                    sql,
+            txn.execute_batch(
+                sql,
+                (
                     (
-                        (
-                            stream_id,
-                            self._instance_name,
-                            room_id,
-                            etype,
-                            state_key,
-                            to_insert.get((etype, state_key)),
-                            room_id,
-                            etype,
-                            state_key,
-                        )
-                        for etype, state_key in itertools.chain(to_delete, to_insert)
-                    ),
-                )
-                # Now we actually update the current_state_events table
+                        stream_id,
+                        self._instance_name,
+                        room_id,
+                        etype,
+                        state_key,
+                        to_insert.get((etype, state_key)),
+                        room_id,
+                        etype,
+                        state_key,
+                    )
+                    for etype, state_key in itertools.chain(to_delete, to_insert)
+                ),
+            )
+            # Now we actually update the current_state_events table
 
-                txn.execute_batch(
-                    "DELETE FROM current_state_events"
-                    " WHERE room_id = ? AND type = ? AND state_key = ?",
-                    (
-                        (room_id, etype, state_key)
-                        for etype, state_key in itertools.chain(to_delete, to_insert)
-                    ),
-                )
+            txn.execute_batch(
+                "DELETE FROM current_state_events"
+                " WHERE room_id = ? AND type = ? AND state_key = ?",
+                (
+                    (room_id, etype, state_key)
+                    for etype, state_key in itertools.chain(to_delete, to_insert)
+                ),
+            )
 
-                # We include the membership in the current state table, hence we do
-                # a lookup when we insert. This assumes that all events have already
-                # been inserted into room_memberships.
-                txn.execute_batch(
-                    """INSERT INTO current_state_events
+            # We include the membership in the current state table, hence we do
+            # a lookup when we insert. This assumes that all events have already
+            # been inserted into room_memberships.
+            txn.execute_batch(
+                """INSERT INTO current_state_events
                         (room_id, type, state_key, event_id, membership, event_stream_ordering)
                     VALUES (
                         ?, ?, ?, ?,
@@ -1141,34 +1155,34 @@ class PersistEventsStore:
                         (SELECT stream_ordering FROM events WHERE event_id = ?)
                     )
                     """,
-                    [
-                        (room_id, key[0], key[1], ev_id, ev_id, ev_id)
-                        for key, ev_id in to_insert.items()
-                    ],
-                )
+                [
+                    (room_id, key[0], key[1], ev_id, ev_id, ev_id)
+                    for key, ev_id in to_insert.items()
+                ],
+            )
 
-            # We now update `local_current_membership`. We do this regardless
-            # of whether we're still in the room or not to handle the case where
-            # e.g. we just got banned (where we need to record that fact here).
-
-            # Note: Do we really want to delete rows here (that we do not
-            # subsequently reinsert below)? While technically correct it means
-            # we have no record of the fact the user *was* a member of the
-            # room but got, say, state reset out of it.
-            if to_delete or to_insert:
-                txn.execute_batch(
-                    "DELETE FROM local_current_membership"
-                    " WHERE room_id = ? AND user_id = ?",
-                    (
-                        (room_id, state_key)
-                        for etype, state_key in itertools.chain(to_delete, to_insert)
-                        if etype == EventTypes.Member and self.is_mine_id(state_key)
-                    ),
-                )
+        # We now update `local_current_membership`. We do this regardless
+        # of whether we're still in the room or not to handle the case where
+        # e.g. we just got banned (where we need to record that fact here).
 
-            if to_insert:
-                txn.execute_batch(
-                    """INSERT INTO local_current_membership
+        # Note: Do we really want to delete rows here (that we do not
+        # subsequently reinsert below)? While technically correct it means
+        # we have no record of the fact the user *was* a member of the
+        # room but got, say, state reset out of it.
+        if to_delete or to_insert:
+            txn.execute_batch(
+                "DELETE FROM local_current_membership"
+                " WHERE room_id = ? AND user_id = ?",
+                (
+                    (room_id, state_key)
+                    for etype, state_key in itertools.chain(to_delete, to_insert)
+                    if etype == EventTypes.Member and self.is_mine_id(state_key)
+                ),
+            )
+
+        if to_insert:
+            txn.execute_batch(
+                """INSERT INTO local_current_membership
                         (room_id, user_id, event_id, membership, event_stream_ordering)
                     VALUES (
                         ?, ?, ?,
@@ -1176,29 +1190,27 @@ class PersistEventsStore:
                         (SELECT stream_ordering FROM events WHERE event_id = ?)
                     )
                     """,
-                    [
-                        (room_id, key[1], ev_id, ev_id, ev_id)
-                        for key, ev_id in to_insert.items()
-                        if key[0] == EventTypes.Member and self.is_mine_id(key[1])
-                    ],
-                )
-
-            txn.call_after(
-                self.store._curr_state_delta_stream_cache.entity_has_changed,
-                room_id,
-                stream_id,
+                [
+                    (room_id, key[1], ev_id, ev_id, ev_id)
+                    for key, ev_id in to_insert.items()
+                    if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+                ],
             )
 
-            # Invalidate the various caches
-            self.store._invalidate_state_caches_and_stream(
-                txn, room_id, members_changed
-            )
+        txn.call_after(
+            self.store._curr_state_delta_stream_cache.entity_has_changed,
+            room_id,
+            stream_id,
+        )
 
-            # Check if any of the remote membership changes requires us to
-            # unsubscribe from their device lists.
-            self.store.handle_potentially_left_users_txn(
-                txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
-            )
+        # Invalidate the various caches
+        self.store._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+
+        # Check if any of the remote membership changes requires us to
+        # unsubscribe from their device lists.
+        self.store.handle_potentially_left_users_txn(
+            txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
+        )
 
     def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
@@ -1232,23 +1244,19 @@ class PersistEventsStore:
     def _update_forward_extremities_txn(
         self,
         txn: LoggingTransaction,
-        new_forward_extremities: Dict[str, Set[str]],
+        room_id: str,
+        new_forward_extremities: Set[str],
         max_stream_order: int,
     ) -> None:
-        for room_id in new_forward_extremities.keys():
-            self.db_pool.simple_delete_txn(
-                txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
-            )
+        self.db_pool.simple_delete_txn(
+            txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
+        )
 
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_forward_extremities",
             keys=("event_id", "room_id"),
-            values=[
-                (ev_id, room_id)
-                for room_id, new_extrem in new_forward_extremities.items()
-                for ev_id in new_extrem
-            ],
+            values=[(ev_id, room_id) for ev_id in new_forward_extremities],
         )
         # We now insert into stream_ordering_to_exterm a mapping from room_id,
         # new stream_ordering to new forward extremeties in the room.
@@ -1260,8 +1268,7 @@ class PersistEventsStore:
             keys=("room_id", "event_id", "stream_ordering"),
             values=[
                 (room_id, event_id, max_stream_order)
-                for room_id, new_extrem in new_forward_extremities.items()
-                for event_id in new_extrem
+                for event_id in new_forward_extremities
             ],
         )
 
@@ -1298,36 +1305,45 @@ class PersistEventsStore:
     def _update_room_depths_txn(
         self,
         txn: LoggingTransaction,
+        room_id: str,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
     ) -> None:
         """Update min_depth for each room
 
         Args:
             txn: db connection
+            room_id: The room ID
             events_and_contexts: events we are persisting
         """
-        depth_updates: Dict[str, int] = {}
+        stream_ordering: Optional[int] = None
+        depth_update = 0
         for event, context in events_and_contexts:
-            # Then update the `stream_ordering` position to mark the latest
-            # event as the front of the room. This should not be done for
-            # backfilled events because backfilled events have negative
-            # stream_ordering and happened in the past so we know that we don't
-            # need to update the stream_ordering tip/front for the room.
+            # Don't update the stream ordering for backfilled events because
+            # backfilled events have negative stream_ordering and happened in the
+            # past, so we know that we don't need to update the stream_ordering
+            # tip/front for the room.
             assert event.internal_metadata.stream_ordering is not None
             if event.internal_metadata.stream_ordering >= 0:
-                txn.call_after(
-                    self.store._events_stream_cache.entity_has_changed,
-                    event.room_id,
-                    event.internal_metadata.stream_ordering,
-                )
+                if stream_ordering is None:
+                    stream_ordering = event.internal_metadata.stream_ordering
+                else:
+                    stream_ordering = max(
+                        stream_ordering, event.internal_metadata.stream_ordering
+                    )
 
             if not event.internal_metadata.is_outlier() and not context.rejected:
-                depth_updates[event.room_id] = max(
-                    event.depth, depth_updates.get(event.room_id, event.depth)
-                )
+                depth_update = max(event.depth, depth_update)
+
+        # Then update the `stream_ordering` position to mark the latest event as
+        # the front of the room.
+        if stream_ordering is not None:
+            txn.call_after(
+                self.store._events_stream_cache.entity_has_changed,
+                room_id,
+                stream_ordering,
+            )
 
-        for room_id, depth in depth_updates.items():
-            self._update_min_depth_for_room_txn(txn, room_id, depth)
+        self._update_min_depth_for_room_txn(txn, room_id, depth_update)
 
     def _update_outliers_txn(
         self,