summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12852.misc1
-rw-r--r--synapse/handlers/federation_event.py178
-rw-r--r--synapse/handlers/message.py40
-rw-r--r--synapse/state/__init__.py22
-rw-r--r--synapse/storage/databases/main/state.py59
-rw-r--r--tests/handlers/test_federation.py6
-rw-r--r--tests/storage/test_events.py43
-rw-r--r--tests/test_state.py14
8 files changed, 236 insertions, 127 deletions
diff --git a/changelog.d/12852.misc b/changelog.d/12852.misc
new file mode 100644
index 0000000000..afca32471f
--- /dev/null
+++ b/changelog.d/12852.misc
@@ -0,0 +1 @@
+Pull out less state when handling gaps in room DAG.
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 8ce7187bef..a1361af272 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -274,7 +274,7 @@ class FederationEventHandler:
                     affected=pdu.event_id,
                 )
 
-        await self._process_received_pdu(origin, pdu, state=None)
+        await self._process_received_pdu(origin, pdu, state_ids=None)
 
     async def on_send_membership_event(
         self, origin: str, event: EventBase
@@ -463,7 +463,9 @@ class FederationEventHandler:
         with nested_logging_context(suffix=event.event_id):
             context = await self._state_handler.compute_event_context(
                 event,
-                old_state=state,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in state
+                },
                 partial_state=partial_state,
             )
 
@@ -512,12 +514,12 @@ class FederationEventHandler:
             #
             # This is the same operation as we do when we receive a regular event
             # over federation.
-            state = await self._resolve_state_at_missing_prevs(destination, event)
+            state_ids = await self._resolve_state_at_missing_prevs(destination, event)
 
             # build a new state group for it if need be
             context = await self._state_handler.compute_event_context(
                 event,
-                old_state=state,
+                state_ids_before_event=state_ids,
             )
             if context.partial_state:
                 # this can happen if some or all of the event's prev_events still have
@@ -767,11 +769,12 @@ class FederationEventHandler:
             return
 
         try:
-            state = await self._resolve_state_at_missing_prevs(origin, event)
+            state_ids = await self._resolve_state_at_missing_prevs(origin, event)
             # TODO(faster_joins): make sure that _resolve_state_at_missing_prevs does
             #   not return partial state
+
             await self._process_received_pdu(
-                origin, event, state=state, backfilled=backfilled
+                origin, event, state_ids=state_ids, backfilled=backfilled
             )
         except FederationError as e:
             if e.code == 403:
@@ -781,7 +784,7 @@ class FederationEventHandler:
 
     async def _resolve_state_at_missing_prevs(
         self, dest: str, event: EventBase
-    ) -> Optional[Iterable[EventBase]]:
+    ) -> Optional[StateMap[str]]:
         """Calculate the state at an event with missing prev_events.
 
         This is used when we have pulled a batch of events from a remote server, and
@@ -808,8 +811,8 @@ class FederationEventHandler:
             event: an event to check for missing prevs.
 
         Returns:
-            if we already had all the prev events, `None`. Otherwise, returns a list of
-            the events in the state at `event`.
+            if we already had all the prev events, `None`. Otherwise, returns
+            the event ids of the state at `event`.
         """
         room_id = event.room_id
         event_id = event.event_id
@@ -829,7 +832,7 @@ class FederationEventHandler:
         )
         # Calculate the state after each of the previous events, and
         # resolve them to find the correct state at the current event.
-        event_map = {event_id: event}
+
         try:
             # Get the state of the events we know about
             ours = await self._state_storage.get_state_groups_ids(room_id, seen)
@@ -849,40 +852,23 @@ class FederationEventHandler:
                     # note that if any of the missing prevs share missing state or
                     # auth events, the requests to fetch those events are deduped
                     # by the get_pdu_cache in federation_client.
-                    remote_state = await self._get_state_after_missing_prev_event(
-                        dest, room_id, p
+                    remote_state_map = (
+                        await self._get_state_ids_after_missing_prev_event(
+                            dest, room_id, p
+                        )
                     )
 
-                    remote_state_map = {
-                        (x.type, x.state_key): x.event_id for x in remote_state
-                    }
                     state_maps.append(remote_state_map)
 
-                    for x in remote_state:
-                        event_map[x.event_id] = x
-
             room_version = await self._store.get_room_version_id(room_id)
             state_map = await self._state_resolution_handler.resolve_events_with_store(
                 room_id,
                 room_version,
                 state_maps,
-                event_map,
+                event_map={event_id: event},
                 state_res_store=StateResolutionStore(self._store),
             )
 
-            # We need to give _process_received_pdu the actual state events
-            # rather than event ids, so generate that now.
-
-            # First though we need to fetch all the events that are in
-            # state_map, so we can build up the state below.
-            evs = await self._store.get_events(
-                list(state_map.values()),
-                get_prev_content=False,
-                redact_behaviour=EventRedactBehaviour.as_is,
-            )
-            event_map.update(evs)
-
-            state = [event_map[e] for e in state_map.values()]
         except Exception:
             logger.warning(
                 "Error attempting to resolve state at missing prev_events",
@@ -894,14 +880,14 @@ class FederationEventHandler:
                 "We can't get valid state history.",
                 affected=event_id,
             )
-        return state
+        return state_map
 
-    async def _get_state_after_missing_prev_event(
+    async def _get_state_ids_after_missing_prev_event(
         self,
         destination: str,
         room_id: str,
         event_id: str,
-    ) -> List[EventBase]:
+    ) -> StateMap[str]:
         """Requests all of the room state at a given event from a remote homeserver.
 
         Args:
@@ -910,7 +896,7 @@ class FederationEventHandler:
             event_id: The id of the event we want the state at.
 
         Returns:
-            A list of events in the state, including the event itself
+            The event ids of the state *after* the given event.
         """
         (
             state_event_ids,
@@ -925,19 +911,17 @@ class FederationEventHandler:
             len(auth_event_ids),
         )
 
-        # start by just trying to fetch the events from the store
+        # Start by checking events we already have in the DB
         desired_events = set(state_event_ids)
         desired_events.add(event_id)
         logger.debug("Fetching %i events from cache/store", len(desired_events))
-        fetched_events = await self._store.get_events(
-            desired_events, allow_rejected=True
-        )
+        have_events = await self._store.have_seen_events(room_id, desired_events)
 
-        missing_desired_events = desired_events - fetched_events.keys()
+        missing_desired_events = desired_events - have_events
         logger.debug(
             "We are missing %i events (got %i)",
             len(missing_desired_events),
-            len(fetched_events),
+            len(have_events),
         )
 
         # We probably won't need most of the auth events, so let's just check which
@@ -948,7 +932,7 @@ class FederationEventHandler:
         #   already have a bunch of the state events. It would be nice if the
         #   federation api gave us a way of finding out which we actually need.
 
-        missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+        missing_auth_events = set(auth_event_ids) - have_events
         missing_auth_events.difference_update(
             await self._store.have_seen_events(room_id, missing_auth_events)
         )
@@ -974,47 +958,51 @@ class FederationEventHandler:
                 destination=destination, room_id=room_id, event_ids=missing_events
             )
 
-        # we need to make sure we re-load from the database to get the rejected
-        # state correct.
-        fetched_events.update(
-            await self._store.get_events(missing_desired_events, allow_rejected=True)
-        )
+        # We now need to fill out the state map, which involves fetching the
+        # type and state key for each event ID in the state.
+        state_map = {}
 
-        # check for events which were in the wrong room.
-        #
-        # this can happen if a remote server claims that the state or
-        # auth_events at an event in room A are actually events in room B
-
-        bad_events = [
-            (event_id, event.room_id)
-            for event_id, event in fetched_events.items()
-            if event.room_id != room_id
-        ]
+        event_metadata = await self._store.get_metadata_for_events(state_event_ids)
+        for state_event_id, metadata in event_metadata.items():
+            if metadata.room_id != room_id:
+                # This is a bogus situation, but since we may only discover it a long time
+                # after it happened, we try our best to carry on, by just omitting the
+                # bad events from the returned state set.
+                #
+                # This can happen if a remote server claims that the state or
+                # auth_events at an event in room A are actually events in room B
+                logger.warning(
+                    "Remote server %s claims event %s in room %s is an auth/state "
+                    "event in room %s",
+                    destination,
+                    state_event_id,
+                    metadata.room_id,
+                    room_id,
+                )
+                continue
 
-        for bad_event_id, bad_room_id in bad_events:
-            # This is a bogus situation, but since we may only discover it a long time
-            # after it happened, we try our best to carry on, by just omitting the
-            # bad events from the returned state set.
-            logger.warning(
-                "Remote server %s claims event %s in room %s is an auth/state "
-                "event in room %s",
-                destination,
-                bad_event_id,
-                bad_room_id,
-                room_id,
-            )
+            if metadata.state_key is None:
+                logger.warning(
+                    "Remote server gave us non-state event in state: %s", state_event_id
+                )
+                continue
 
-            del fetched_events[bad_event_id]
+            state_map[(metadata.event_type, metadata.state_key)] = state_event_id
 
         # if we couldn't get the prev event in question, that's a problem.
-        remote_event = fetched_events.get(event_id)
+        remote_event = await self._store.get_event(
+            event_id,
+            allow_none=True,
+            allow_rejected=True,
+            redact_behaviour=EventRedactBehaviour.as_is,
+        )
         if not remote_event:
             raise Exception("Unable to get missing prev_event %s" % (event_id,))
 
         # missing state at that event is a warning, not a blocker
         # XXX: this doesn't sound right? it means that we'll end up with incomplete
         #   state.
-        failed_to_fetch = desired_events - fetched_events.keys()
+        failed_to_fetch = desired_events - event_metadata.keys()
         if failed_to_fetch:
             logger.warning(
                 "Failed to fetch missing state events for %s %s",
@@ -1022,14 +1010,12 @@ class FederationEventHandler:
                 failed_to_fetch,
             )
 
-        remote_state = [
-            fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
-        ]
-
         if remote_event.is_state() and remote_event.rejected_reason is None:
-            remote_state.append(remote_event)
+            state_map[
+                (remote_event.type, remote_event.state_key)
+            ] = remote_event.event_id
 
-        return remote_state
+        return state_map
 
     async def _get_state_and_persist(
         self, destination: str, room_id: str, event_id: str
@@ -1056,7 +1042,7 @@ class FederationEventHandler:
         self,
         origin: str,
         event: EventBase,
-        state: Optional[Iterable[EventBase]],
+        state_ids: Optional[StateMap[str]],
         backfilled: bool = False,
     ) -> None:
         """Called when we have a new non-outlier event.
@@ -1078,7 +1064,7 @@ class FederationEventHandler:
 
             event: event to be persisted
 
-            state: Normally None, but if we are handling a gap in the graph
+            state_ids: Normally None, but if we are handling a gap in the graph
                 (ie, we are missing one or more prev_events), the resolved state at the
                 event
 
@@ -1090,7 +1076,8 @@ class FederationEventHandler:
 
         try:
             context = await self._state_handler.compute_event_context(
-                event, old_state=state
+                event,
+                state_ids_before_event=state_ids,
             )
             context = await self._check_event_auth(
                 origin,
@@ -1107,7 +1094,7 @@ class FederationEventHandler:
             # For new (non-backfilled and non-outlier) events we check if the event
             # passes auth based on the current state. If it doesn't then we
             # "soft-fail" the event.
-            await self._check_for_soft_fail(event, state, origin=origin)
+            await self._check_for_soft_fail(event, state_ids, origin=origin)
 
         await self._run_push_actions_and_persist_event(event, context, backfilled)
 
@@ -1589,7 +1576,7 @@ class FederationEventHandler:
     async def _check_for_soft_fail(
         self,
         event: EventBase,
-        state: Optional[Iterable[EventBase]],
+        state_ids: Optional[StateMap[str]],
         origin: str,
     ) -> None:
         """Checks if we should soft fail the event; if so, marks the event as
@@ -1597,7 +1584,7 @@ class FederationEventHandler:
 
         Args:
             event
-            state: The state at the event if we don't have all the event's prev events
+            state_ids: The state at the event if we don't have all the event's prev events
             origin: The host the event originates from.
         """
         extrem_ids_list = await self._store.get_latest_event_ids_in_room(event.room_id)
@@ -1613,7 +1600,7 @@ class FederationEventHandler:
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
         # Calculate the "current state".
-        if state is not None:
+        if state_ids is not None:
             # If we're explicitly given the state then we won't have all the
             # prev events, and so we have a gap in the graph. In this case
             # we want to be a little careful as we might have been down for
@@ -1626,17 +1613,20 @@ class FederationEventHandler:
             # given state at the event. This should correctly handle cases
             # like bans, especially with state res v2.
 
-            state_sets_d = await self._state_storage.get_state_groups(
+            state_sets_d = await self._state_storage.get_state_groups_ids(
                 event.room_id, extrem_ids
             )
-            state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
-            state_sets.append(state)
-            current_states = await self._state_handler.resolve_events(
-                room_version, state_sets, event
+            state_sets: List[StateMap[str]] = list(state_sets_d.values())
+            state_sets.append(state_ids)
+            current_state_ids = (
+                await self._state_resolution_handler.resolve_events_with_store(
+                    event.room_id,
+                    room_version,
+                    state_sets,
+                    event_map=None,
+                    state_res_store=StateResolutionStore(self._store),
+                )
             )
-            current_state_ids: StateMap[str] = {
-                k: e.event_id for k, e in current_states.items()
-            }
         else:
             current_state_ids = await self._state_handler.get_current_state_ids(
                 event.room_id, latest_event_ids=extrem_ids
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9501e7f1b7..7ca126dbd1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -55,7 +55,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.replication.http.send_event import ReplicationSendEventRestServlet
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.state import StateFilter
-from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
+from synapse.types import (
+    MutableStateMap,
+    Requester,
+    RoomAlias,
+    StreamToken,
+    UserID,
+    create_requester,
+)
 from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, gather_results
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -1022,8 +1029,35 @@ class EventCreationHandler:
             #
             # TODO(faster_joins): figure out how this works, and make sure that the
             #   old state is complete.
-            old_state = await self.store.get_events_as_list(state_event_ids)
-            context = await self.state.compute_event_context(event, old_state=old_state)
+            metadata = await self.store.get_metadata_for_events(state_event_ids)
+
+            state_map_for_event: MutableStateMap[str] = {}
+            for state_id in state_event_ids:
+                data = metadata.get(state_id)
+                if data is None:
+                    # We're trying to persist a new historical batch of events
+                    # with the given state, e.g. via
+                    # `RoomBatchSendEventRestServlet`. The state can be inferred
+                    # by Synapse or set directly by the client.
+                    #
+                    # Either way, we should have persisted all the state before
+                    # getting here.
+                    raise Exception(
+                        f"State event {state_id} not found in DB,"
+                        " Synapse should have persisted it before using it."
+                    )
+
+                if data.state_key is None:
+                    raise Exception(
+                        f"Trying to set non-state event {state_id} as state"
+                    )
+
+                state_map_for_event[(data.event_type, data.state_key)] = state_id
+
+            context = await self.state.compute_event_context(
+                event,
+                state_ids_before_event=state_map_for_event,
+            )
         else:
             context = await self.state.compute_event_context(event)
 
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 536564b7ff..9c9d946f38 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -261,7 +261,7 @@ class StateHandler:
     async def compute_event_context(
         self,
         event: EventBase,
-        old_state: Optional[Iterable[EventBase]] = None,
+        state_ids_before_event: Optional[StateMap[str]] = None,
         partial_state: bool = False,
     ) -> EventContext:
         """Build an EventContext structure for a non-outlier event.
@@ -273,12 +273,12 @@ class StateHandler:
 
         Args:
             event:
-            old_state: The state at the event if it can't be
-                calculated from existing events. This is normally only specified
-                when receiving an event from federation where we don't have the
-                prev events for, e.g. when backfilling.
-            partial_state: True if `old_state` is partial and omits non-critical
-                membership events
+            state_ids_before_event: The event ids of the state before the event if
+                it can't be calculated from existing events. This is normally
+                only specified when receiving an event from federation where we
+                don't have the prev events, e.g. when backfilling.
+            partial_state: True if `state_ids_before_event` is partial and omits
+                non-critical membership events
         Returns:
             The event context.
         """
@@ -286,13 +286,11 @@ class StateHandler:
         assert not event.internal_metadata.is_outlier()
 
         #
-        # first of all, figure out the state before the event
+        # first of all, figure out the state before the event, unless we
+        # already have it.
         #
-        if old_state:
+        if state_ids_before_event:
             # if we're given the state before the event, then we use that
-            state_ids_before_event: StateMap[str] = {
-                (s.type, s.state_key): s.event_id for s in old_state
-            }
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 18ae8aee29..ea5cbdac08 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -16,6 +16,8 @@ import collections.abc
 import logging
 from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
 
+import attr
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -26,6 +28,7 @@ from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
+    make_in_list_sql_clause,
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
@@ -33,6 +36,7 @@ from synapse.storage.state import StateFilter
 from synapse.types import JsonDict, JsonMapping, StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -43,6 +47,15 @@ logger = logging.getLogger(__name__)
 MAX_STATE_DELTA_HOPS = 100
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventMetadata:
+    """Returned by `get_metadata_for_events`"""
+
+    room_id: str
+    event_type: str
+    state_key: Optional[str]
+
+
 def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
     v = KNOWN_ROOM_VERSIONS.get(room_version_id)
     if not v:
@@ -133,6 +146,52 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return room_version
 
+    async def get_metadata_for_events(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, EventMetadata]:
+        """Get some metadata (room_id, type, state_key) for the given events.
+
+        This method is a faster alternative than fetching the full events from
+        the DB, and should be used when the full event is not needed.
+
+        Returns metadata for rejected and redacted events. Events that have not
+        been persisted are omitted from the returned dict.
+        """
+
+        def get_metadata_for_events_txn(
+            txn: LoggingTransaction,
+            batch_ids: Collection[str],
+        ) -> Dict[str, EventMetadata]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "e.event_id", batch_ids
+            )
+
+            sql = f"""
+                SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
+                LEFT JOIN state_events USING (event_id)
+                WHERE {clause}
+            """
+
+            txn.execute(sql, args)
+            return {
+                event_id: EventMetadata(
+                    room_id=room_id, event_type=event_type, state_key=state_key
+                )
+                for event_id, room_id, event_type, state_key in txn
+            }
+
+        result_map: Dict[str, EventMetadata] = {}
+        for batch_ids in batch_iter(event_ids, 1000):
+            result_map.update(
+                await self.db_pool.runInteraction(
+                    "get_metadata_for_events",
+                    get_metadata_for_events_txn,
+                    batch_ids=batch_ids,
+                )
+            )
+
+        return result_map
+
     async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
         """Get the predecessor of an upgraded room if it exists.
         Otherwise return None.
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index bef6c2b776..ec00900621 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -276,7 +276,11 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
             # federation handler wanting to backfill the fake event.
             self.get_success(
                 federation_event_handler._process_received_pdu(
-                    self.OTHER_SERVER_NAME, event, state=current_state
+                    self.OTHER_SERVER_NAME,
+                    event,
+                    state_ids={
+                        (e.type, e.state_key): e.event_id for e in current_state
+                    },
                 )
             )
 
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index ef5e25873c..aaa3189b16 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
     def persist_event(self, event, state=None):
         """Persist the event, with optional state"""
         context = self.get_success(
-            self.state.compute_event_context(event, old_state=state)
+            self.state.compute_event_context(event, state_ids_before_event=state)
         )
         self.get_success(self.persistence.persist_event(event, context))
 
@@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id])
@@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
         # setting. The state resolution across the old and new event will then
         # include it, and so the resolved state won't match the new state.
         state_before_gap = dict(
-            self.get_success(self.state.get_current_state(self.room_id))
+            self.get_success(self.state.get_current_state_ids(self.room_id))
         )
         state_before_gap.pop(("m.room.history_visibility", ""))
 
         context = self.get_success(
             self.state.compute_event_context(
-                remote_event_2, old_state=state_before_gap.values()
+                remote_event_2,
+                state_ids_before_event=state_before_gap,
             )
         )
 
@@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id])
@@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id])
@@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.state.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.state.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([local_message_event_id, remote_event_2.event_id])
diff --git a/tests/test_state.py b/tests/test_state.py
index c6baea3d76..84694d368d 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
         ]
 
         context = yield defer.ensureDeferred(
-            self.state.compute_event_context(event, old_state=old_state)
+            self.state.compute_event_context(
+                event,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in old_state
+                },
+            )
         )
 
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
         ]
 
         context = yield defer.ensureDeferred(
-            self.state.compute_event_context(event, old_state=old_state)
+            self.state.compute_event_context(
+                event,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in old_state
+                },
+            )
         )
 
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())