diff options
Diffstat (limited to 'synapse/state/__init__.py')
-rw-r--r-- | synapse/state/__init__.py | 35 |
1 files changed, 29 insertions, 6 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py index 4b4ed42cff..098b5f32ff 100644 --- a/synapse/state/__init__.py +++ b/synapse/state/__init__.py @@ -48,6 +48,7 @@ from synapse.logging.context import ContextResourceUsage from synapse.state import v1, v2 from synapse.storage.databases.main.events_worker import EventRedactBehaviour from synapse.storage.roommember import ProfileInfo +from synapse.storage.state import StateFilter from synapse.types import StateMap from synapse.util.async_helpers import Linearizer from synapse.util.caches.expiringcache import ExpiringCache @@ -177,7 +178,16 @@ class StateHandler: assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state") - ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + + filter = StateFilter.all() + if event_type: + filter = StateFilter.from_types(((event_type, state_key),)) + + ret = await self.resolve_state_groups_for_events( + room_id, + latest_event_ids, + await_full_state=filter.must_await_full_state(self.hs.is_mine_id), + ) state = ret.state if event_type: @@ -195,7 +205,10 @@ class StateHandler: } async def get_current_state_ids( - self, room_id: str, latest_event_ids: Optional[Collection[str]] = None + self, + room_id: str, + latest_event_ids: Optional[Collection[str]] = None, + await_full_state: bool = True, ) -> StateMap[str]: """Get the current state, or the state at a set of events, for a room @@ -203,6 +216,8 @@ class StateHandler: room_id: latest_event_ids: if given, the forward extremities to resolve. If None, we look them up from the database (via a cache). + await_full_state: if true, will block if we do not yet have complete + state at the latest events. Returns: the state dict, mapping from (event_type, state_key) -> event_id @@ -212,7 +227,9 @@ class StateHandler: assert latest_event_ids is not None logger.debug("calling resolve_state_groups from get_current_state_ids") - ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids) + ret = await self.resolve_state_groups_for_events( + room_id, latest_event_ids, await_full_state=await_full_state + ) return ret.state async def get_current_users_in_room( @@ -323,7 +340,9 @@ class StateHandler: logger.debug("calling resolve_state_groups from compute_event_context") entry = await self.resolve_state_groups_for_events( - event.room_id, event.prev_event_ids() + event.room_id, + event.prev_event_ids(), + await_full_state=False, ) state_ids_before_event = entry.state @@ -404,7 +423,7 @@ class StateHandler: @measure_func() async def resolve_state_groups_for_events( - self, room_id: str, event_ids: Collection[str] + self, room_id: str, event_ids: Collection[str], await_full_state: bool = True ) -> _StateCacheEntry: """Given a list of event_ids this method fetches the state at each event, resolves conflicts between them and returns them. @@ -412,13 +431,17 @@ class StateHandler: Args: room_id event_ids + await_full_state: if true, will block if we do not yet have complete + state at these events. Returns: The resolved state """ logger.debug("resolve_state_groups event_ids %s", event_ids) - state_groups = await self.state_store.get_state_group_for_events(event_ids) + state_groups = await self.state_store.get_state_group_for_events( + event_ids, await_full_state=await_full_state + ) state_group_ids = state_groups.values() |