summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/controllers/persist_events.py252
-rw-r--r--synapse/storage/databases/main/events.py384
2 files changed, 324 insertions, 312 deletions
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index f39ae2d635..1529c86cc5 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -542,13 +542,15 @@ class EventsPersistenceStorageController:
         return await res.get_state(self._state_controller, StateFilter.all())
 
     async def _persist_event_batch(
-        self, _room_id: str, task: _PersistEventsTask
+        self, room_id: str, task: _PersistEventsTask
     ) -> Dict[str, str]:
         """Callback for the _event_persist_queue
 
         Calculates the change to current state and forward extremities, and
         persists the given events and with those updates.
 
+        Assumes that we are only persisting events for one room at a time.
+
         Returns:
             A dictionary of event ID to event ID we didn't persist as we already
             had another event persisted with the same TXN ID.
@@ -594,140 +596,23 @@ class EventsPersistenceStorageController:
             # We can't easily parallelize these since different chunks
             # might contain the same event. :(
 
-            # NB: Assumes that we are only persisting events for one room
-            # at a time.
-
-            # map room_id->set[event_ids] giving the new forward
-            # extremities in each room
-            new_forward_extremities: Dict[str, Set[str]] = {}
-
-            # map room_id->(to_delete, to_insert) where to_delete is a list
-            # of type/state keys to remove from current state, and to_insert
-            # is a map (type,key)->event_id giving the state delta in each
-            # room
-            state_delta_for_room: Dict[str, DeltaState] = {}
+            new_forward_extremities = None
+            state_delta_for_room = None
 
             if not backfilled:
                 with Measure(self._clock, "_calculate_state_and_extrem"):
-                    # Work out the new "current state" for each room.
+                    # Work out the new "current state" for the room.
                     # We do this by working out what the new extremities are and then
                     # calculating the state from that.
-                    events_by_room: Dict[str, List[Tuple[EventBase, EventContext]]] = {}
-                    for event, context in chunk:
-                        events_by_room.setdefault(event.room_id, []).append(
-                            (event, context)
-                        )
-
-                    for room_id, ev_ctx_rm in events_by_room.items():
-                        latest_event_ids = (
-                            await self.main_store.get_latest_event_ids_in_room(room_id)
-                        )
-                        new_latest_event_ids = await self._calculate_new_extremities(
-                            room_id, ev_ctx_rm, latest_event_ids
-                        )
-
-                        if new_latest_event_ids == latest_event_ids:
-                            # No change in extremities, so no change in state
-                            continue
-
-                        # there should always be at least one forward extremity.
-                        # (except during the initial persistence of the send_join
-                        # results, in which case there will be no existing
-                        # extremities, so we'll `continue` above and skip this bit.)
-                        assert new_latest_event_ids, "No forward extremities left!"
-
-                        new_forward_extremities[room_id] = new_latest_event_ids
-
-                        len_1 = (
-                            len(latest_event_ids) == 1
-                            and len(new_latest_event_ids) == 1
-                        )
-                        if len_1:
-                            all_single_prev_not_state = all(
-                                len(event.prev_event_ids()) == 1
-                                and not event.is_state()
-                                for event, ctx in ev_ctx_rm
-                            )
-                            # Don't bother calculating state if they're just
-                            # a long chain of single ancestor non-state events.
-                            if all_single_prev_not_state:
-                                continue
-
-                        state_delta_counter.inc()
-                        if len(new_latest_event_ids) == 1:
-                            state_delta_single_event_counter.inc()
-
-                            # This is a fairly handwavey check to see if we could
-                            # have guessed what the delta would have been when
-                            # processing one of these events.
-                            # What we're interested in is if the latest extremities
-                            # were the same when we created the event as they are
-                            # now. When this server creates a new event (as opposed
-                            # to receiving it over federation) it will use the
-                            # forward extremities as the prev_events, so we can
-                            # guess this by looking at the prev_events and checking
-                            # if they match the current forward extremities.
-                            for ev, _ in ev_ctx_rm:
-                                prev_event_ids = set(ev.prev_event_ids())
-                                if latest_event_ids == prev_event_ids:
-                                    state_delta_reuse_delta_counter.inc()
-                                    break
-
-                        logger.debug("Calculating state delta for room %s", room_id)
-                        with Measure(
-                            self._clock, "persist_events.get_new_state_after_events"
-                        ):
-                            res = await self._get_new_state_after_events(
-                                room_id,
-                                ev_ctx_rm,
-                                latest_event_ids,
-                                new_latest_event_ids,
-                            )
-                            current_state, delta_ids, new_latest_event_ids = res
-
-                            # there should always be at least one forward extremity.
-                            # (except during the initial persistence of the send_join
-                            # results, in which case there will be no existing
-                            # extremities, so we'll `continue` above and skip this bit.)
-                            assert new_latest_event_ids, "No forward extremities left!"
-
-                            new_forward_extremities[room_id] = new_latest_event_ids
-
-                        # If either are not None then there has been a change,
-                        # and we need to work out the delta (or use that
-                        # given)
-                        delta = None
-                        if delta_ids is not None:
-                            # If there is a delta we know that we've
-                            # only added or replaced state, never
-                            # removed keys entirely.
-                            delta = DeltaState([], delta_ids)
-                        elif current_state is not None:
-                            with Measure(
-                                self._clock, "persist_events.calculate_state_delta"
-                            ):
-                                delta = await self._calculate_state_delta(
-                                    room_id, current_state
-                                )
-
-                        if delta:
-                            # If we have a change of state then lets check
-                            # whether we're actually still a member of the room,
-                            # or if our last user left. If we're no longer in
-                            # the room then we delete the current state and
-                            # extremities.
-                            is_still_joined = await self._is_server_still_joined(
-                                room_id,
-                                ev_ctx_rm,
-                                delta,
-                            )
-                            if not is_still_joined:
-                                logger.info("Server no longer in room %s", room_id)
-                                delta.no_longer_in_room = True
-
-                            state_delta_for_room[room_id] = delta
+                    (
+                        new_forward_extremities,
+                        state_delta_for_room,
+                    ) = await self._calculate_new_forward_extremities_and_state_delta(
+                        room_id, chunk
+                    )
 
             await self.persist_events_store._persist_events_and_state_updates(
+                room_id,
                 chunk,
                 state_delta_for_room=state_delta_for_room,
                 new_forward_extremities=new_forward_extremities,
@@ -737,6 +622,117 @@ class EventsPersistenceStorageController:
 
         return replaced_events
 
+    async def _calculate_new_forward_extremities_and_state_delta(
+        self, room_id: str, ev_ctx_rm: List[Tuple[EventBase, EventContext]]
+    ) -> Tuple[Optional[Set[str]], Optional[DeltaState]]:
+        """Calculates the new forward extremities and state delta for a room
+        given events to persist.
+
+        Assumes that we are only persisting events for one room at a time.
+
+        Returns:
+            A tuple of:
+                A set of str giving the new forward extremities the room
+
+                The state delta for the room.
+        """
+
+        latest_event_ids = await self.main_store.get_latest_event_ids_in_room(room_id)
+        new_latest_event_ids = await self._calculate_new_extremities(
+            room_id, ev_ctx_rm, latest_event_ids
+        )
+
+        if new_latest_event_ids == latest_event_ids:
+            # No change in extremities, so no change in state
+            return (None, None)
+
+        # there should always be at least one forward extremity.
+        # (except during the initial persistence of the send_join
+        # results, in which case there will be no existing
+        # extremities, so we'll `continue` above and skip this bit.)
+        assert new_latest_event_ids, "No forward extremities left!"
+
+        new_forward_extremities = new_latest_event_ids
+
+        len_1 = len(latest_event_ids) == 1 and len(new_latest_event_ids) == 1
+        if len_1:
+            all_single_prev_not_state = all(
+                len(event.prev_event_ids()) == 1 and not event.is_state()
+                for event, ctx in ev_ctx_rm
+            )
+            # Don't bother calculating state if they're just
+            # a long chain of single ancestor non-state events.
+            if all_single_prev_not_state:
+                return (new_forward_extremities, None)
+
+        state_delta_counter.inc()
+        if len(new_latest_event_ids) == 1:
+            state_delta_single_event_counter.inc()
+
+            # This is a fairly handwavey check to see if we could
+            # have guessed what the delta would have been when
+            # processing one of these events.
+            # What we're interested in is if the latest extremities
+            # were the same when we created the event as they are
+            # now. When this server creates a new event (as opposed
+            # to receiving it over federation) it will use the
+            # forward extremities as the prev_events, so we can
+            # guess this by looking at the prev_events and checking
+            # if they match the current forward extremities.
+            for ev, _ in ev_ctx_rm:
+                prev_event_ids = set(ev.prev_event_ids())
+                if latest_event_ids == prev_event_ids:
+                    state_delta_reuse_delta_counter.inc()
+                    break
+
+        logger.debug("Calculating state delta for room %s", room_id)
+        with Measure(self._clock, "persist_events.get_new_state_after_events"):
+            res = await self._get_new_state_after_events(
+                room_id,
+                ev_ctx_rm,
+                latest_event_ids,
+                new_latest_event_ids,
+            )
+            current_state, delta_ids, new_latest_event_ids = res
+
+            # there should always be at least one forward extremity.
+            # (except during the initial persistence of the send_join
+            # results, in which case there will be no existing
+            # extremities, so we'll `continue` above and skip this bit.)
+            assert new_latest_event_ids, "No forward extremities left!"
+
+            new_forward_extremities = new_latest_event_ids
+
+        # If either are not None then there has been a change,
+        # and we need to work out the delta (or use that
+        # given)
+        delta = None
+        if delta_ids is not None:
+            # If there is a delta we know that we've
+            # only added or replaced state, never
+            # removed keys entirely.
+            delta = DeltaState([], delta_ids)
+        elif current_state is not None:
+            with Measure(self._clock, "persist_events.calculate_state_delta"):
+                delta = await self._calculate_state_delta(room_id, current_state)
+
+        if delta:
+            # If we have a change of state then lets check
+            # whether we're actually still a member of the room,
+            # or if our last user left. If we're no longer in
+            # the room then we delete the current state and
+            # extremities.
+            is_still_joined = await self._is_server_still_joined(
+                room_id,
+                ev_ctx_rm,
+                delta,
+            )
+            if not is_still_joined:
+                logger.info("Server no longer in room %s", room_id)
+                delta.no_longer_in_room = True
+
+        return (new_forward_extremities, delta)
+
     async def _calculate_new_extremities(
         self,
         room_id: str,
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index b74ff1c498..647ba182f6 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -79,7 +79,7 @@ class DeltaState:
     Attributes:
         to_delete: List of type/state_keys to delete from current state
         to_insert: Map of state to upsert into current state
-        no_longer_in_room: The server is not longer in the room, so the room
+        no_longer_in_room: The server is no longer in the room, so the room
             should e.g. be removed from `current_state_events` table.
     """
 
@@ -131,22 +131,25 @@ class PersistEventsStore:
     @trace
     async def _persist_events_and_state_updates(
         self,
+        room_id: str,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         *,
-        state_delta_for_room: Dict[str, DeltaState],
-        new_forward_extremities: Dict[str, Set[str]],
+        state_delta_for_room: Optional[DeltaState],
+        new_forward_extremities: Optional[Set[str]],
         use_negative_stream_ordering: bool = False,
         inhibit_local_membership_updates: bool = False,
     ) -> None:
         """Persist a set of events alongside updates to the current state and
-        forward extremities tables.
+                forward extremities tables.
+
+        Assumes that we are only persisting events for one room at a time.
 
         Args:
+            room_id:
             events_and_contexts:
-            state_delta_for_room: Map from room_id to the delta to apply to
-                room state
-            new_forward_extremities: Map from room_id to set of event IDs
-                that are the new forward extremities of the room.
+            state_delta_for_room: The delta to apply to the room state
+            new_forward_extremities: A set of event IDs that are the new forward
+                extremities of the room.
             use_negative_stream_ordering: Whether to start stream_ordering on
                 the negative side and decrement. This should be set as True
                 for backfilled events because backfilled events get a negative
@@ -196,6 +199,7 @@ class PersistEventsStore:
             await self.db_pool.runInteraction(
                 "persist_events",
                 self._persist_events_txn,
+                room_id=room_id,
                 events_and_contexts=events_and_contexts,
                 inhibit_local_membership_updates=inhibit_local_membership_updates,
                 state_delta_for_room=state_delta_for_room,
@@ -221,9 +225,9 @@ class PersistEventsStore:
 
                 event_counter.labels(event.type, origin_type, origin_entity).inc()
 
-            for room_id, latest_event_ids in new_forward_extremities.items():
+            if new_forward_extremities:
                 self.store.get_latest_event_ids_in_room.prefill(
-                    (room_id,), frozenset(latest_event_ids)
+                    (room_id,), frozenset(new_forward_extremities)
                 )
 
     async def _get_events_which_are_prevs(self, event_ids: Iterable[str]) -> List[str]:
@@ -336,10 +340,11 @@ class PersistEventsStore:
         self,
         txn: LoggingTransaction,
         *,
+        room_id: str,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
         inhibit_local_membership_updates: bool,
-        state_delta_for_room: Dict[str, DeltaState],
-        new_forward_extremities: Dict[str, Set[str]],
+        state_delta_for_room: Optional[DeltaState],
+        new_forward_extremities: Optional[Set[str]],
     ) -> None:
         """Insert some number of room events into the necessary database tables.
 
@@ -347,8 +352,11 @@ class PersistEventsStore:
         and the rejections table. Things reading from those table will need to check
         whether the event was rejected.
 
+        Assumes that we are only persisting events for one room at a time.
+
         Args:
             txn
+            room_id: The room the events are from
             events_and_contexts: events to persist
             inhibit_local_membership_updates: Stop the local_current_membership
                 from being updated by these events. This should be set to True
@@ -357,10 +365,9 @@ class PersistEventsStore:
             delete_existing True to purge existing table rows for the events
                 from the database. This is useful when retrying due to
                 IntegrityError.
-            state_delta_for_room: The current-state delta for each room.
-            new_forward_extremities: The new forward extremities for each room.
-                For each room, a list of the event ids which are the forward
-                extremities.
+            state_delta_for_room: The current-state delta for the room.
+            new_forward_extremities: The new forward extremities for the room:
+                a set of the event ids which are the forward extremities.
 
         Raises:
             PartialStateConflictError: if attempting to persist a partial state event in
@@ -376,14 +383,13 @@ class PersistEventsStore:
         #
         # Annoyingly SQLite doesn't support row level locking.
         if isinstance(self.database_engine, PostgresEngine):
-            for room_id in {e.room_id for e, _ in events_and_contexts}:
-                txn.execute(
-                    "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
-                    (room_id,),
-                )
-                row = txn.fetchone()
-                if row is None:
-                    raise Exception(f"Room does not exist {room_id}")
+            txn.execute(
+                "SELECT room_version FROM rooms WHERE room_id = ? FOR SHARE",
+                (room_id,),
+            )
+            row = txn.fetchone()
+            if row is None:
+                raise Exception(f"Room does not exist {room_id}")
 
         # stream orderings should have been assigned by now
         assert min_stream_order
@@ -419,7 +425,9 @@ class PersistEventsStore:
             events_and_contexts
         )
 
-        self._update_room_depths_txn(txn, events_and_contexts=events_and_contexts)
+        self._update_room_depths_txn(
+            txn, room_id, events_and_contexts=events_and_contexts
+        )
 
         # _update_outliers_txn filters out any events which have already been
         # persisted, and returns the filtered list.
@@ -432,11 +440,13 @@ class PersistEventsStore:
 
         self._store_event_txn(txn, events_and_contexts=events_and_contexts)
 
-        self._update_forward_extremities_txn(
-            txn,
-            new_forward_extremities=new_forward_extremities,
-            max_stream_order=max_stream_order,
-        )
+        if new_forward_extremities:
+            self._update_forward_extremities_txn(
+                txn,
+                room_id,
+                new_forward_extremities=new_forward_extremities,
+                max_stream_order=max_stream_order,
+            )
 
         self._persist_transaction_ids_txn(txn, events_and_contexts)
 
@@ -464,7 +474,10 @@ class PersistEventsStore:
         # We call this last as it assumes we've inserted the events into
         # room_memberships, where applicable.
         # NB: This function invalidates all state related caches
-        self._update_current_state_txn(txn, state_delta_for_room, min_stream_order)
+        if state_delta_for_room:
+            self._update_current_state_txn(
+                txn, room_id, state_delta_for_room, min_stream_order
+            )
 
     def _persist_event_auth_chain_txn(
         self,
@@ -1026,74 +1039,75 @@ class PersistEventsStore:
             await self.db_pool.runInteraction(
                 "update_current_state",
                 self._update_current_state_txn,
-                state_delta_by_room={room_id: state_delta},
+                room_id,
+                delta_state=state_delta,
                 stream_id=stream_ordering,
             )
 
     def _update_current_state_txn(
         self,
         txn: LoggingTransaction,
-        state_delta_by_room: Dict[str, DeltaState],
+        room_id: str,
+        delta_state: DeltaState,
         stream_id: int,
     ) -> None:
-        for room_id, delta_state in state_delta_by_room.items():
-            to_delete = delta_state.to_delete
-            to_insert = delta_state.to_insert
-
-            # Figure out the changes of membership to invalidate the
-            # `get_rooms_for_user` cache.
-            # We find out which membership events we may have deleted
-            # and which we have added, then we invalidate the caches for all
-            # those users.
-            members_changed = {
-                state_key
-                for ev_type, state_key in itertools.chain(to_delete, to_insert)
-                if ev_type == EventTypes.Member
-            }
+        to_delete = delta_state.to_delete
+        to_insert = delta_state.to_insert
+
+        # Figure out the changes of membership to invalidate the
+        # `get_rooms_for_user` cache.
+        # We find out which membership events we may have deleted
+        # and which we have added, then we invalidate the caches for all
+        # those users.
+        members_changed = {
+            state_key
+            for ev_type, state_key in itertools.chain(to_delete, to_insert)
+            if ev_type == EventTypes.Member
+        }
 
-            if delta_state.no_longer_in_room:
-                # Server is no longer in the room so we delete the room from
-                # current_state_events, being careful we've already updated the
-                # rooms.room_version column (which gets populated in a
-                # background task).
-                self._upsert_room_version_txn(txn, room_id)
+        if delta_state.no_longer_in_room:
+            # Server is no longer in the room so we delete the room from
+            # current_state_events, being careful we've already updated the
+            # rooms.room_version column (which gets populated in a
+            # background task).
+            self._upsert_room_version_txn(txn, room_id)
 
-                # Before deleting we populate the current_state_delta_stream
-                # so that async background tasks get told what happened.
-                sql = """
+            # Before deleting we populate the current_state_delta_stream
+            # so that async background tasks get told what happened.
+            sql = """
                     INSERT INTO current_state_delta_stream
                         (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
                     SELECT ?, ?, room_id, type, state_key, null, event_id
                         FROM current_state_events
                         WHERE room_id = ?
                 """
-                txn.execute(sql, (stream_id, self._instance_name, room_id))
+            txn.execute(sql, (stream_id, self._instance_name, room_id))
 
-                # We also want to invalidate the membership caches for users
-                # that were in the room.
-                users_in_room = self.store.get_users_in_room_txn(txn, room_id)
-                members_changed.update(users_in_room)
+            # We also want to invalidate the membership caches for users
+            # that were in the room.
+            users_in_room = self.store.get_users_in_room_txn(txn, room_id)
+            members_changed.update(users_in_room)
 
-                self.db_pool.simple_delete_txn(
-                    txn,
-                    table="current_state_events",
-                    keyvalues={"room_id": room_id},
-                )
-            else:
-                # We're still in the room, so we update the current state as normal.
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="current_state_events",
+                keyvalues={"room_id": room_id},
+            )
+        else:
+            # We're still in the room, so we update the current state as normal.
 
-                # First we add entries to the current_state_delta_stream. We
-                # do this before updating the current_state_events table so
-                # that we can use it to calculate the `prev_event_id`. (This
-                # allows us to not have to pull out the existing state
-                # unnecessarily).
-                #
-                # The stream_id for the update is chosen to be the minimum of the stream_ids
-                # for the batch of the events that we are persisting; that means we do not
-                # end up in a situation where workers see events before the
-                # current_state_delta updates.
-                #
-                sql = """
+            # First we add entries to the current_state_delta_stream. We
+            # do this before updating the current_state_events table so
+            # that we can use it to calculate the `prev_event_id`. (This
+            # allows us to not have to pull out the existing state
+            # unnecessarily).
+            #
+            # The stream_id for the update is chosen to be the minimum of the stream_ids
+            # for the batch of the events that we are persisting; that means we do not
+            # end up in a situation where workers see events before the
+            # current_state_delta updates.
+            #
+            sql = """
                     INSERT INTO current_state_delta_stream
                     (stream_id, instance_name, room_id, type, state_key, event_id, prev_event_id)
                     SELECT ?, ?, ?, ?, ?, ?, (
@@ -1101,39 +1115,39 @@ class PersistEventsStore:
                         WHERE room_id = ? AND type = ? AND state_key = ?
                     )
                 """
-                txn.execute_batch(
-                    sql,
+            txn.execute_batch(
+                sql,
+                (
                     (
-                        (
-                            stream_id,
-                            self._instance_name,
-                            room_id,
-                            etype,
-                            state_key,
-                            to_insert.get((etype, state_key)),
-                            room_id,
-                            etype,
-                            state_key,
-                        )
-                        for etype, state_key in itertools.chain(to_delete, to_insert)
-                    ),
-                )
-                # Now we actually update the current_state_events table
+                        stream_id,
+                        self._instance_name,
+                        room_id,
+                        etype,
+                        state_key,
+                        to_insert.get((etype, state_key)),
+                        room_id,
+                        etype,
+                        state_key,
+                    )
+                    for etype, state_key in itertools.chain(to_delete, to_insert)
+                ),
+            )
+            # Now we actually update the current_state_events table
 
-                txn.execute_batch(
-                    "DELETE FROM current_state_events"
-                    " WHERE room_id = ? AND type = ? AND state_key = ?",
-                    (
-                        (room_id, etype, state_key)
-                        for etype, state_key in itertools.chain(to_delete, to_insert)
-                    ),
-                )
+            txn.execute_batch(
+                "DELETE FROM current_state_events"
+                " WHERE room_id = ? AND type = ? AND state_key = ?",
+                (
+                    (room_id, etype, state_key)
+                    for etype, state_key in itertools.chain(to_delete, to_insert)
+                ),
+            )
 
-                # We include the membership in the current state table, hence we do
-                # a lookup when we insert. This assumes that all events have already
-                # been inserted into room_memberships.
-                txn.execute_batch(
-                    """INSERT INTO current_state_events
+            # We include the membership in the current state table, hence we do
+            # a lookup when we insert. This assumes that all events have already
+            # been inserted into room_memberships.
+            txn.execute_batch(
+                """INSERT INTO current_state_events
                         (room_id, type, state_key, event_id, membership, event_stream_ordering)
                     VALUES (
                         ?, ?, ?, ?,
@@ -1141,34 +1155,34 @@ class PersistEventsStore:
                         (SELECT stream_ordering FROM events WHERE event_id = ?)
                     )
                     """,
-                    [
-                        (room_id, key[0], key[1], ev_id, ev_id, ev_id)
-                        for key, ev_id in to_insert.items()
-                    ],
-                )
+                [
+                    (room_id, key[0], key[1], ev_id, ev_id, ev_id)
+                    for key, ev_id in to_insert.items()
+                ],
+            )
 
-            # We now update `local_current_membership`. We do this regardless
-            # of whether we're still in the room or not to handle the case where
-            # e.g. we just got banned (where we need to record that fact here).
-
-            # Note: Do we really want to delete rows here (that we do not
-            # subsequently reinsert below)? While technically correct it means
-            # we have no record of the fact the user *was* a member of the
-            # room but got, say, state reset out of it.
-            if to_delete or to_insert:
-                txn.execute_batch(
-                    "DELETE FROM local_current_membership"
-                    " WHERE room_id = ? AND user_id = ?",
-                    (
-                        (room_id, state_key)
-                        for etype, state_key in itertools.chain(to_delete, to_insert)
-                        if etype == EventTypes.Member and self.is_mine_id(state_key)
-                    ),
-                )
+        # We now update `local_current_membership`. We do this regardless
+        # of whether we're still in the room or not to handle the case where
+        # e.g. we just got banned (where we need to record that fact here).
 
-            if to_insert:
-                txn.execute_batch(
-                    """INSERT INTO local_current_membership
+        # Note: Do we really want to delete rows here (that we do not
+        # subsequently reinsert below)? While technically correct it means
+        # we have no record of the fact the user *was* a member of the
+        # room but got, say, state reset out of it.
+        if to_delete or to_insert:
+            txn.execute_batch(
+                "DELETE FROM local_current_membership"
+                " WHERE room_id = ? AND user_id = ?",
+                (
+                    (room_id, state_key)
+                    for etype, state_key in itertools.chain(to_delete, to_insert)
+                    if etype == EventTypes.Member and self.is_mine_id(state_key)
+                ),
+            )
+
+        if to_insert:
+            txn.execute_batch(
+                """INSERT INTO local_current_membership
                         (room_id, user_id, event_id, membership, event_stream_ordering)
                     VALUES (
                         ?, ?, ?,
@@ -1176,29 +1190,27 @@ class PersistEventsStore:
                         (SELECT stream_ordering FROM events WHERE event_id = ?)
                     )
                     """,
-                    [
-                        (room_id, key[1], ev_id, ev_id, ev_id)
-                        for key, ev_id in to_insert.items()
-                        if key[0] == EventTypes.Member and self.is_mine_id(key[1])
-                    ],
-                )
-
-            txn.call_after(
-                self.store._curr_state_delta_stream_cache.entity_has_changed,
-                room_id,
-                stream_id,
+                [
+                    (room_id, key[1], ev_id, ev_id, ev_id)
+                    for key, ev_id in to_insert.items()
+                    if key[0] == EventTypes.Member and self.is_mine_id(key[1])
+                ],
             )
 
-            # Invalidate the various caches
-            self.store._invalidate_state_caches_and_stream(
-                txn, room_id, members_changed
-            )
+        txn.call_after(
+            self.store._curr_state_delta_stream_cache.entity_has_changed,
+            room_id,
+            stream_id,
+        )
 
-            # Check if any of the remote membership changes requires us to
-            # unsubscribe from their device lists.
-            self.store.handle_potentially_left_users_txn(
-                txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
-            )
+        # Invalidate the various caches
+        self.store._invalidate_state_caches_and_stream(txn, room_id, members_changed)
+
+        # Check if any of the remote membership changes requires us to
+        # unsubscribe from their device lists.
+        self.store.handle_potentially_left_users_txn(
+            txn, {m for m in members_changed if not self.hs.is_mine_id(m)}
+        )
 
     def _upsert_room_version_txn(self, txn: LoggingTransaction, room_id: str) -> None:
         """Update the room version in the database based off current state
@@ -1232,23 +1244,19 @@ class PersistEventsStore:
     def _update_forward_extremities_txn(
         self,
         txn: LoggingTransaction,
-        new_forward_extremities: Dict[str, Set[str]],
+        room_id: str,
+        new_forward_extremities: Set[str],
         max_stream_order: int,
     ) -> None:
-        for room_id in new_forward_extremities.keys():
-            self.db_pool.simple_delete_txn(
-                txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
-            )
+        self.db_pool.simple_delete_txn(
+            txn, table="event_forward_extremities", keyvalues={"room_id": room_id}
+        )
 
         self.db_pool.simple_insert_many_txn(
             txn,
             table="event_forward_extremities",
             keys=("event_id", "room_id"),
-            values=[
-                (ev_id, room_id)
-                for room_id, new_extrem in new_forward_extremities.items()
-                for ev_id in new_extrem
-            ],
+            values=[(ev_id, room_id) for ev_id in new_forward_extremities],
         )
         # We now insert into stream_ordering_to_exterm a mapping from room_id,
         # new stream_ordering to new forward extremeties in the room.
@@ -1260,8 +1268,7 @@ class PersistEventsStore:
             keys=("room_id", "event_id", "stream_ordering"),
             values=[
                 (room_id, event_id, max_stream_order)
-                for room_id, new_extrem in new_forward_extremities.items()
-                for event_id in new_extrem
+                for event_id in new_forward_extremities
             ],
         )
 
@@ -1298,36 +1305,45 @@ class PersistEventsStore:
     def _update_room_depths_txn(
         self,
         txn: LoggingTransaction,
+        room_id: str,
         events_and_contexts: List[Tuple[EventBase, EventContext]],
     ) -> None:
         """Update min_depth for each room
 
         Args:
             txn: db connection
+            room_id: The room ID
             events_and_contexts: events we are persisting
         """
-        depth_updates: Dict[str, int] = {}
+        stream_ordering: Optional[int] = None
+        depth_update = 0
         for event, context in events_and_contexts:
-            # Then update the `stream_ordering` position to mark the latest
-            # event as the front of the room. This should not be done for
-            # backfilled events because backfilled events have negative
-            # stream_ordering and happened in the past so we know that we don't
-            # need to update the stream_ordering tip/front for the room.
+            # Don't update the stream ordering for backfilled events because
+            # backfilled events have negative stream_ordering and happened in the
+            # past, so we know that we don't need to update the stream_ordering
+            # tip/front for the room.
             assert event.internal_metadata.stream_ordering is not None
             if event.internal_metadata.stream_ordering >= 0:
-                txn.call_after(
-                    self.store._events_stream_cache.entity_has_changed,
-                    event.room_id,
-                    event.internal_metadata.stream_ordering,
-                )
+                if stream_ordering is None:
+                    stream_ordering = event.internal_metadata.stream_ordering
+                else:
+                    stream_ordering = max(
+                        stream_ordering, event.internal_metadata.stream_ordering
+                    )
 
             if not event.internal_metadata.is_outlier() and not context.rejected:
-                depth_updates[event.room_id] = max(
-                    event.depth, depth_updates.get(event.room_id, event.depth)
-                )
+                depth_update = max(event.depth, depth_update)
+
+        # Then update the `stream_ordering` position to mark the latest event as
+        # the front of the room.
+        if stream_ordering is not None:
+            txn.call_after(
+                self.store._events_stream_cache.entity_has_changed,
+                room_id,
+                stream_ordering,
+            )
 
-        for room_id, depth in depth_updates.items():
-            self._update_min_depth_for_room_txn(txn, room_id, depth)
+        self._update_min_depth_for_room_txn(txn, room_id, depth_update)
 
     def _update_outliers_txn(
         self,