diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index f9eced23bf..cc9b162ae4 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -45,7 +45,7 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialStateEventsTracker,
)
from synapse.synapse_rust.acl import ServerAclEvaluator
-from synapse.types import MutableStateMap, StateMap, get_domain_from_id
+from synapse.types import MutableStateMap, StateMap, StreamToken, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
@@ -372,6 +372,91 @@ class StateStorageController:
)
return state_map[event_id]
+ async def get_state_after_event(
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
+ ) -> StateMap[str]:
+ """
+ Get the room state after the given event
+
+ Args:
+ event_id: event of interest
+ state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the event and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
+ """
+ state_ids = await self.get_state_ids_for_event(
+ event_id,
+ state_filter=state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
+ )
+
+ # using get_metadata_for_events here (instead of get_event) sidesteps an issue
+ # with redactions: if `event_id` is a redaction event, and we don't have the
+ # original (possibly because it got purged), get_event will refuse to return
+ # the redaction event, which isn't terribly helpful here.
+ #
+ # (To be fair, in that case we could assume it's *not* a state event, and
+ # therefore we don't need to worry about it. But still, it seems cleaner just
+ # to pull the metadata.)
+ m = (await self.stores.main.get_metadata_for_events([event_id]))[event_id]
+ if m.state_key is not None and m.rejection_reason is None:
+ state_ids = dict(state_ids)
+ state_ids[(m.event_type, m.state_key)] = event_id
+
+ return state_ids
+
+ async def get_state_at(
+ self,
+ room_id: str,
+ stream_position: StreamToken,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
+ ) -> StateMap[str]:
+ """Get the room state at a particular stream position
+
+ Args:
+ room_id: room for which to get state
+ stream_position: point at which to get state
+ state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the last event in the room before `stream_position` and
+ `state_filter` is not satisfied by partial state. Defaults to `True`.
+ """
+ # FIXME: This gets the state at the latest event before the stream ordering,
+ # which might not be the same as the "current state" of the room at the time
+ # of the stream token if there were multiple forward extremities at the time.
+ last_event_id = (
+ await self.stores.main.get_last_event_id_in_room_before_stream_ordering(
+ room_id,
+ end_token=stream_position.room_key,
+ )
+ )
+
+ if last_event_id:
+ state = await self.get_state_after_event(
+ last_event_id,
+ state_filter=state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
+ )
+
+ else:
+ # no events in this room - so presumably no state
+ state = {}
+
+ # (erikj) This should be rarely hit, but we've had some reports that
+ # we get more state down gappy syncs than we should, so let's add
+ # some logging.
+ logger.info(
+ "Failed to find any events in room %s at %s",
+ room_id,
+ stream_position.room_key,
+ )
+ return state
+
@trace
@tag_args
async def get_state_for_groups(
|