summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-05-21 14:13:10 +0100
committerErik Johnston <erik@matrix.org>2022-05-21 14:13:10 +0100
commitda10dfc311ba61f4e2437c385089b83eb84b78c8 (patch)
treebe6807cb55fb7922477e721fe8c042fc02194c30
parentIgnore display name push (diff)
parentFix tests (diff)
downloadsynapse-da10dfc311ba61f4e2437c385089b83eb84b78c8.tar.xz
Merge branch 'erikj/less_state_on_missing' into erikj/push_hack
-rw-r--r--changelog.d/12672.feature1
-rw-r--r--changelog.d/12672.misc1
-rw-r--r--changelog.d/12703.misc1
-rw-r--r--changelog.d/12809.feature1
-rw-r--r--changelog.d/12828.misc1
-rw-r--r--synapse/api/errors.py11
-rw-r--r--synapse/handlers/federation_event.py140
-rw-r--r--synapse/handlers/message.py17
-rw-r--r--synapse/replication/tcp/commands.py12
-rw-r--r--synapse/replication/tcp/redis.py6
-rw-r--r--synapse/state/__init__.py14
-rw-r--r--synapse/storage/databases/main/state.py42
-rw-r--r--tests/handlers/test_federation.py4
-rw-r--r--tests/storage/test_events.py43
-rw-r--r--tests/test_state.py14
15 files changed, 198 insertions, 110 deletions
diff --git a/changelog.d/12672.feature b/changelog.d/12672.feature
new file mode 100644
index 0000000000..b989e0d208
--- /dev/null
+++ b/changelog.d/12672.feature
@@ -0,0 +1 @@
+Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands.
\ No newline at end of file
diff --git a/changelog.d/12672.misc b/changelog.d/12672.misc
deleted file mode 100644
index 265e0a801f..0000000000
--- a/changelog.d/12672.misc
+++ /dev/null
@@ -1 +0,0 @@
-Lay some foundation work to allow workers to only subscribe to some kinds of messages, reducing replication traffic.
\ No newline at end of file
diff --git a/changelog.d/12703.misc b/changelog.d/12703.misc
new file mode 100644
index 0000000000..9aaa1bbaa3
--- /dev/null
+++ b/changelog.d/12703.misc
@@ -0,0 +1 @@
+Convert namespace class `Codes` into a string enum.
\ No newline at end of file
diff --git a/changelog.d/12809.feature b/changelog.d/12809.feature
new file mode 100644
index 0000000000..b989e0d208
--- /dev/null
+++ b/changelog.d/12809.feature
@@ -0,0 +1 @@
+Send `USER_IP` commands on a different Redis channel, in order to reduce traffic to workers that do not process these commands.
\ No newline at end of file
diff --git a/changelog.d/12828.misc b/changelog.d/12828.misc
new file mode 100644
index 0000000000..afca32471f
--- /dev/null
+++ b/changelog.d/12828.misc
@@ -0,0 +1 @@
+Pull out less state when handling gaps in room DAG.
diff --git a/synapse/api/errors.py b/synapse/api/errors.py
index cb3b7323d5..9614be6b4e 100644
--- a/synapse/api/errors.py
+++ b/synapse/api/errors.py
@@ -17,6 +17,7 @@
 
 import logging
 import typing
+from enum import Enum
 from http import HTTPStatus
 from typing import Any, Dict, List, Optional, Union
 
@@ -30,7 +31,11 @@ if typing.TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
-class Codes:
+class Codes(str, Enum):
+    """
+    All known error codes, as an enum of strings.
+    """
+
     UNRECOGNIZED = "M_UNRECOGNIZED"
     UNAUTHORIZED = "M_UNAUTHORIZED"
     FORBIDDEN = "M_FORBIDDEN"
@@ -265,7 +270,9 @@ class UnrecognizedRequestError(SynapseError):
     """An error indicating we don't understand the request you're trying to make"""
 
     def __init__(
-        self, msg: str = "Unrecognized request", errcode: str = Codes.UNRECOGNIZED
+        self,
+        msg: str = "Unrecognized request",
+        errcode: str = Codes.UNRECOGNIZED,
     ):
         super().__init__(400, msg, errcode)
 
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 3c7ae2ea39..b65bbaff69 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -463,7 +463,9 @@ class FederationEventHandler:
         with nested_logging_context(suffix=event.event_id):
             context = await self._state_handler.compute_event_context(
                 event,
-                old_state=state,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in state
+                },
                 partial_state=partial_state,
             )
 
@@ -501,7 +503,7 @@ class FederationEventHandler:
             # build a new state group for it if need be
             context = await self._state_handler.compute_event_context(
                 event,
-                old_state=state,
+                state_ids_before_event=state,
             )
             if context.partial_state:
                 # this can happen if some or all of the event's prev_events still have
@@ -765,7 +767,7 @@ class FederationEventHandler:
 
     async def _resolve_state_at_missing_prevs(
         self, dest: str, event: EventBase
-    ) -> Optional[Iterable[EventBase]]:
+    ) -> Optional[StateMap[str]]:
         """Calculate the state at an event with missing prev_events.
 
         This is used when we have pulled a batch of events from a remote server, and
@@ -792,8 +794,8 @@ class FederationEventHandler:
             event: an event to check for missing prevs.
 
         Returns:
-            if we already had all the prev events, `None`. Otherwise, returns a list of
-            the events in the state at `event`.
+            if we already had all the prev events, `None`. Otherwise, returns
+            the state at `event`.
         """
         room_id = event.room_id
         event_id = event.event_id
@@ -837,13 +839,7 @@ class FederationEventHandler:
                         dest, room_id, p
                     )
 
-                    remote_state_map = {
-                        (x.type, x.state_key): x.event_id for x in remote_state
-                    }
-                    state_maps.append(remote_state_map)
-
-                    for x in remote_state:
-                        event_map[x.event_id] = x
+                    state_maps.append(remote_state)
 
             room_version = await self._store.get_room_version_id(room_id)
             state_map = await self._state_resolution_handler.resolve_events_with_store(
@@ -854,19 +850,6 @@ class FederationEventHandler:
                 state_res_store=StateResolutionStore(self._store),
             )
 
-            # We need to give _process_received_pdu the actual state events
-            # rather than event ids, so generate that now.
-
-            # First though we need to fetch all the events that are in
-            # state_map, so we can build up the state below.
-            evs = await self._store.get_events(
-                list(state_map.values()),
-                get_prev_content=False,
-                redact_behaviour=EventRedactBehaviour.as_is,
-            )
-            event_map.update(evs)
-
-            state = [event_map[e] for e in state_map.values()]
         except Exception:
             logger.warning(
                 "Error attempting to resolve state at missing prev_events",
@@ -878,14 +861,14 @@ class FederationEventHandler:
                 "We can't get valid state history.",
                 affected=event_id,
             )
-        return state
+        return state_map
 
     async def _get_state_after_missing_prev_event(
         self,
         destination: str,
         room_id: str,
         event_id: str,
-    ) -> List[EventBase]:
+    ) -> StateMap[str]:
         """Requests all of the room state at a given event from a remote homeserver.
 
         Args:
@@ -894,7 +877,7 @@ class FederationEventHandler:
             event_id: The id of the event we want the state at.
 
         Returns:
-            A list of events in the state, including the event itself
+            The state *after* the given event.
         """
         (
             state_event_ids,
@@ -913,15 +896,13 @@ class FederationEventHandler:
         desired_events = set(state_event_ids)
         desired_events.add(event_id)
         logger.debug("Fetching %i events from cache/store", len(desired_events))
-        fetched_events = await self._store.get_events(
-            desired_events, allow_rejected=True
-        )
+        have_events = await self._store.have_seen_events(room_id, desired_events)
 
-        missing_desired_events = desired_events - fetched_events.keys()
+        missing_desired_events = desired_events - have_events
         logger.debug(
             "We are missing %i events (got %i)",
             len(missing_desired_events),
-            len(fetched_events),
+            len(have_events),
         )
 
         # We probably won't need most of the auth events, so let's just check which
@@ -932,7 +913,7 @@ class FederationEventHandler:
         #   already have a bunch of the state events. It would be nice if the
         #   federation api gave us a way of finding out which we actually need.
 
-        missing_auth_events = set(auth_event_ids) - fetched_events.keys()
+        missing_auth_events = set(auth_event_ids) - have_events
         missing_auth_events.difference_update(
             await self._store.have_seen_events(room_id, missing_auth_events)
         )
@@ -958,47 +939,54 @@ class FederationEventHandler:
                 destination=destination, room_id=room_id, event_ids=missing_events
             )
 
-        # we need to make sure we re-load from the database to get the rejected
-        # state correct.
-        fetched_events.update(
-            await self._store.get_events(missing_desired_events, allow_rejected=True)
-        )
+        event_metadata = await self._store.get_metadata_for_events(state_event_ids)
 
         # check for events which were in the wrong room.
         #
         # this can happen if a remote server claims that the state or
         # auth_events at an event in room A are actually events in room B
 
-        bad_events = [
-            (event_id, event.room_id)
-            for event_id, event in fetched_events.items()
-            if event.room_id != room_id
-        ]
+        event_metadata = await self._store.get_metadata_for_events(state_event_ids)
 
-        for bad_event_id, bad_room_id in bad_events:
-            # This is a bogus situation, but since we may only discover it a long time
-            # after it happened, we try our best to carry on, by just omitting the
-            # bad events from the returned state set.
-            logger.warning(
-                "Remote server %s claims event %s in room %s is an auth/state "
-                "event in room %s",
-                destination,
-                bad_event_id,
-                bad_room_id,
-                room_id,
-            )
+        state_map = {}
+
+        for state_event_id, metadata in event_metadata.items():
+            if metadata.room_id != room_id:
+                # This is a bogus situation, but since we may only discover it a long time
+                # after it happened, we try our best to carry on, by just omitting the
+                # bad events from the returned state set.
+                logger.warning(
+                    "Remote server %s claims event %s in room %s is an auth/state "
+                    "event in room %s",
+                    destination,
+                    state_event_id,
+                    metadata.room_id,
+                    room_id,
+                )
+                continue
+
+            if metadata.state_key is None:
+                logger.warning(
+                    "Remote server gave us non-state event in state: %s", state_event_id
+                )
+                continue
 
-            del fetched_events[bad_event_id]
+            state_map[(metadata.event_type, metadata.state_key)] = state_event_id
 
         # if we couldn't get the prev event in question, that's a problem.
-        remote_event = fetched_events.get(event_id)
+        remote_event = await self._store.get_event(
+            event_id,
+            allow_none=True,
+            allow_rejected=True,
+            redact_behaviour=EventRedactBehaviour.as_is,
+        )
         if not remote_event:
             raise Exception("Unable to get missing prev_event %s" % (event_id,))
 
         # missing state at that event is a warning, not a blocker
         # XXX: this doesn't sound right? it means that we'll end up with incomplete
         #   state.
-        failed_to_fetch = desired_events - fetched_events.keys()
+        failed_to_fetch = desired_events - event_metadata.keys()
         if failed_to_fetch:
             logger.warning(
                 "Failed to fetch missing state events for %s %s",
@@ -1006,14 +994,12 @@ class FederationEventHandler:
                 failed_to_fetch,
             )
 
-        remote_state = [
-            fetched_events[e_id] for e_id in state_event_ids if e_id in fetched_events
-        ]
-
         if remote_event.is_state() and remote_event.rejected_reason is None:
-            remote_state.append(remote_event)
+            state_map[
+                (remote_event.type, remote_event.state_key)
+            ] = remote_event.event_id
 
-        return remote_state
+        return state_map
 
     async def _get_state_and_persist(
         self, destination: str, room_id: str, event_id: str
@@ -1040,7 +1026,7 @@ class FederationEventHandler:
         self,
         origin: str,
         event: EventBase,
-        state: Optional[Iterable[EventBase]],
+        state: Optional[StateMap[str]],
         backfilled: bool = False,
     ) -> None:
         """Called when we have a new non-outlier event.
@@ -1074,7 +1060,7 @@ class FederationEventHandler:
 
         try:
             context = await self._state_handler.compute_event_context(
-                event, old_state=state
+                event, state_ids_before_event=state
             )
             context = await self._check_event_auth(
                 origin,
@@ -1565,7 +1551,7 @@ class FederationEventHandler:
     async def _check_for_soft_fail(
         self,
         event: EventBase,
-        state: Optional[Iterable[EventBase]],
+        state: Optional[StateMap[str]],
         origin: str,
     ) -> None:
         """Checks if we should soft fail the event; if so, marks the event as
@@ -1605,17 +1591,21 @@ class FederationEventHandler:
             # given state at the event. This should correctly handle cases
             # like bans, especially with state res v2.
 
-            state_sets_d = await self._state_store.get_state_groups(
+            state_sets_d = await self._state_store.get_state_groups_ids(
                 event.room_id, extrem_ids
             )
-            state_sets: List[Iterable[EventBase]] = list(state_sets_d.values())
+            state_sets: List[StateMap[str]] = list(state_sets_d.values())
             state_sets.append(state)
-            current_states = await self._state_handler.resolve_events(
-                room_version, state_sets, event
+
+            current_state_ids = (
+                await self._state_resolution_handler.resolve_events_with_store(
+                    event.room_id,
+                    room_version,
+                    state_sets,
+                    event_map={},
+                    state_res_store=StateResolutionStore(self._store),
+                )
             )
-            current_state_ids: StateMap[str] = {
-                k: e.event_id for k, e in current_states.items()
-            }
         else:
             current_state_ids = await self._store.get_filtered_current_state_ids(
                 event.room_id, StateFilter.from_types(auth_types)
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index f712e8cf75..7656e6c2f4 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1023,8 +1023,21 @@ class EventCreationHandler:
             #
             # TODO(faster_joins): figure out how this works, and make sure that the
             #   old state is complete.
-            old_state = await self.store.get_events_as_list(state_event_ids)
-            context = await self.state.compute_event_context(event, old_state=old_state)
+            metadata = await self.store.get_metadata_for_events(state_event_ids)
+
+            state_map = {}
+            for event_id, data in metadata.items():
+                if data.state_key is None:
+                    raise Exception(
+                        "Trying to set non-state event as state: %s", event_id
+                    )
+
+                state_map[(data.event_type, data.state_key)] = event_id
+
+            context = await self.state.compute_event_context(
+                event,
+                state_ids_before_event=state_map,
+            )
         else:
             context = await self.state.compute_event_context(event)
 
diff --git a/synapse/replication/tcp/commands.py b/synapse/replication/tcp/commands.py
index fe34948168..32f52e54d8 100644
--- a/synapse/replication/tcp/commands.py
+++ b/synapse/replication/tcp/commands.py
@@ -58,6 +58,15 @@ class Command(metaclass=abc.ABCMeta):
         # by default, we just use the command name.
         return self.NAME
 
+    def redis_channel_name(self, prefix: str) -> str:
+        """
+        Returns the Redis channel name upon which to publish this command.
+
+        Args:
+            prefix: The prefix for the channel.
+        """
+        return prefix
+
 
 SC = TypeVar("SC", bound="_SimpleCommand")
 
@@ -395,6 +404,9 @@ class UserIpCommand(Command):
             f"{self.user_agent!r}, {self.device_id!r}, {self.last_seen})"
         )
 
+    def redis_channel_name(self, prefix: str) -> str:
+        return f"{prefix}/USER_IP"
+
 
 class RemoteServerUpCommand(_SimpleCommand):
     """Sent when a worker has detected that a remote server is no longer
diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 73294654ef..fd1c0ec6af 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -221,10 +221,10 @@ class RedisSubscriber(txredisapi.SubscriberProtocol):
         # remote instances.
         tcp_outbound_commands_counter.labels(cmd.NAME, "redis").inc()
 
+        channel_name = cmd.redis_channel_name(self.synapse_stream_prefix)
+
         await make_deferred_yieldable(
-            self.synapse_outbound_redis_connection.publish(
-                self.synapse_stream_prefix, encoded_string
-            )
+            self.synapse_outbound_redis_connection.publish(channel_name, encoded_string)
         )
 
 
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 347fa96ce7..96b5e48571 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -194,7 +194,7 @@ class StateHandler:
     async def compute_event_context(
         self,
         event: EventBase,
-        old_state: Optional[Iterable[EventBase]] = None,
+        state_ids_before_event: Optional[StateMap[str]] = None,
         partial_state: bool = False,
     ) -> EventContext:
         """Build an EventContext structure for a non-outlier event.
@@ -206,12 +206,12 @@ class StateHandler:
 
         Args:
             event:
-            old_state: The state at the event if it can't be
+            state_ids_before_event: The state at the event if it can't be
                 calculated from existing events. This is normally only specified
                 when receiving an event from federation where we don't have the
                 prev events for, e.g. when backfilling.
-            partial_state: True if `old_state` is partial and omits non-critical
-                membership events
+            partial_state: True if `state_ids_before_event` is partial and omits
+                non-critical membership events
         Returns:
             The event context.
         """
@@ -221,11 +221,7 @@ class StateHandler:
         #
         # first of all, figure out the state before the event
         #
-        if old_state:
-            # if we're given the state before the event, then we use that
-            state_ids_before_event: StateMap[str] = {
-                (s.type, s.state_key): s.event_id for s in old_state
-            }
+        if state_ids_before_event:
             state_group_before_event = None
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index b8a277ae31..7b237adf8b 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -16,6 +16,8 @@ import collections.abc
 import logging
 from typing import TYPE_CHECKING, Collection, Dict, Iterable, Optional, Set, Tuple
 
+import attr
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
@@ -26,6 +28,7 @@ from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
     LoggingTransaction,
+    make_in_list_sql_clause,
 )
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
@@ -43,6 +46,15 @@ logger = logging.getLogger(__name__)
 MAX_STATE_DELTA_HOPS = 100
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventMetadata:
+    """Returned by `get_metadata_for_events`"""
+
+    room_id: str
+    event_type: str
+    state_key: Optional[str]
+
+
 def _retrieve_and_check_room_version(room_id: str, room_version_id: str) -> RoomVersion:
     v = KNOWN_ROOM_VERSIONS.get(room_version_id)
     if not v:
@@ -133,6 +145,36 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return room_version
 
+    async def get_metadata_for_events(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, EventMetadata]:
+        """Get some metadata (room_id, type, state_key) for the given events."""
+
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "e.event_id", event_ids
+        )
+
+        sql = f"""
+            SELECT e.event_id, e.room_id, e.type, e.state_key FROM events AS e
+            LEFT JOIN state_events USING (event_id)
+            WHERE {clause}
+        """
+
+        def get_metadata_for_events_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[str, EventMetadata]:
+            txn.execute(sql, args)
+            return {
+                event_id: EventMetadata(
+                    room_id=room_id, event_type=event_type, state_key=state_key
+                )
+                for event_id, room_id, event_type, state_key in txn
+            }
+
+        return await self.db_pool.runInteraction(
+            "get_metadata_for_events", get_metadata_for_events_txn
+        )
+
     async def get_room_predecessor(self, room_id: str) -> Optional[JsonMapping]:
         """Get the predecessor of an upgraded room if it exists.
         Otherwise return None.
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index e95dfdce20..183a6635fa 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -276,7 +276,9 @@ class FederationTestCase(unittest.FederatingHomeserverTestCase):
             # federation handler wanting to backfill the fake event.
             self.get_success(
                 federation_event_handler._process_received_pdu(
-                    self.OTHER_SERVER_NAME, event, state=current_state
+                    self.OTHER_SERVER_NAME,
+                    event,
+                    state={(e.type, e.state_key): e.event_id for e in current_state},
                 )
             )
 
diff --git a/tests/storage/test_events.py b/tests/storage/test_events.py
index 18a5907dab..2c1bf875ee 100644
--- a/tests/storage/test_events.py
+++ b/tests/storage/test_events.py
@@ -69,7 +69,7 @@ class ExtremPruneTestCase(HomeserverTestCase):
     def persist_event(self, event, state=None):
         """Persist the event, with optional state"""
         context = self.get_success(
-            self.state.compute_event_context(event, old_state=state)
+            self.state.compute_event_context(event, state_ids_before_event=state)
         )
         self.get_success(self.persistence.persist_event(event, context))
 
@@ -103,9 +103,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.store.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id])
@@ -135,13 +137,14 @@ class ExtremPruneTestCase(HomeserverTestCase):
         # setting. The state resolution across the old and new event will then
         # include it, and so the resolved state won't match the new state.
         state_before_gap = dict(
-            self.get_success(self.store.get_current_state(self.room_id))
+            self.get_success(self.store.get_current_state_ids(self.room_id))
         )
         state_before_gap.pop(("m.room.history_visibility", ""))
 
         context = self.get_success(
             self.state.compute_event_context(
-                remote_event_2, old_state=state_before_gap.values()
+                remote_event_2,
+                state_ids_before_event=state_before_gap,
             )
         )
 
@@ -177,9 +180,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.store.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id])
@@ -207,9 +212,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.store.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([self.remote_event_1.event_id, remote_event_2.event_id])
@@ -247,9 +254,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.store.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id])
@@ -289,9 +298,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.store.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([remote_event_2.event_id, local_message_event_id])
@@ -323,9 +334,11 @@ class ExtremPruneTestCase(HomeserverTestCase):
             RoomVersions.V6,
         )
 
-        state_before_gap = self.get_success(self.store.get_current_state(self.room_id))
+        state_before_gap = self.get_success(
+            self.store.get_current_state_ids(self.room_id)
+        )
 
-        self.persist_event(remote_event_2, state=state_before_gap.values())
+        self.persist_event(remote_event_2, state=state_before_gap)
 
         # Check the new extremity is just the new remote event.
         self.assert_extremities([local_message_event_id, remote_event_2.event_id])
diff --git a/tests/test_state.py b/tests/test_state.py
index c6baea3d76..84694d368d 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -442,7 +442,12 @@ class StateTestCase(unittest.TestCase):
         ]
 
         context = yield defer.ensureDeferred(
-            self.state.compute_event_context(event, old_state=old_state)
+            self.state.compute_event_context(
+                event,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in old_state
+                },
+            )
         )
 
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())
@@ -467,7 +472,12 @@ class StateTestCase(unittest.TestCase):
         ]
 
         context = yield defer.ensureDeferred(
-            self.state.compute_event_context(event, old_state=old_state)
+            self.state.compute_event_context(
+                event,
+                state_ids_before_event={
+                    (e.type, e.state_key): e.event_id for e in old_state
+                },
+            )
         )
 
         prev_state_ids = yield defer.ensureDeferred(context.get_prev_state_ids())