diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 9501e7f1b7..7ca126dbd1 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -55,7 +55,14 @@ from synapse.metrics.background_process_metrics import run_as_background_process
from synapse.replication.http.send_event import ReplicationSendEventRestServlet
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
from synapse.storage.state import StateFilter
-from synapse.types import Requester, RoomAlias, StreamToken, UserID, create_requester
+from synapse.types import (
+ MutableStateMap,
+ Requester,
+ RoomAlias,
+ StreamToken,
+ UserID,
+ create_requester,
+)
from synapse.util import json_decoder, json_encoder, log_failure, unwrapFirstError
from synapse.util.async_helpers import Linearizer, gather_results
from synapse.util.caches.expiringcache import ExpiringCache
@@ -1022,8 +1029,35 @@ class EventCreationHandler:
#
# TODO(faster_joins): figure out how this works, and make sure that the
# old state is complete.
- old_state = await self.store.get_events_as_list(state_event_ids)
- context = await self.state.compute_event_context(event, old_state=old_state)
+ metadata = await self.store.get_metadata_for_events(state_event_ids)
+
+ state_map_for_event: MutableStateMap[str] = {}
+ for state_id in state_event_ids:
+ data = metadata.get(state_id)
+ if data is None:
+ # We're trying to persist a new historical batch of events
+ # with the given state, e.g. via
+ # `RoomBatchSendEventRestServlet`. The state can be inferred
+ # by Synapse or set directly by the client.
+ #
+ # Either way, we should have persisted all the state before
+ # getting here.
+ raise Exception(
+ f"State event {state_id} not found in DB,"
+ " Synapse should have persisted it before using it."
+ )
+
+ if data.state_key is None:
+ raise Exception(
+ f"Trying to set non-state event {state_id} as state"
+ )
+
+ state_map_for_event[(data.event_type, data.state_key)] = state_id
+
+ context = await self.state.compute_event_context(
+ event,
+ state_ids_before_event=state_map_for_event,
+ )
else:
context = await self.state.compute_event_context(event)
|