summary refs log tree commit diff
path: root/synapse/state/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/__init__.py')
-rw-r--r--synapse/state/__init__.py35
1 files changed, 29 insertions, 6 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 4b4ed42cff..098b5f32ff 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -48,6 +48,7 @@ from synapse.logging.context import ContextResourceUsage
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
 from synapse.storage.roommember import ProfileInfo
+from synapse.storage.state import StateFilter
 from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
@@ -177,7 +178,16 @@ class StateHandler:
         assert latest_event_ids is not None
 
         logger.debug("calling resolve_state_groups from get_current_state")
-        ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
+
+        filter = StateFilter.all()
+        if event_type:
+            filter = StateFilter.from_types(((event_type, state_key),))
+
+        ret = await self.resolve_state_groups_for_events(
+            room_id,
+            latest_event_ids,
+            await_full_state=filter.must_await_full_state(self.hs.is_mine_id),
+        )
         state = ret.state
 
         if event_type:
@@ -195,7 +205,10 @@ class StateHandler:
         }
 
     async def get_current_state_ids(
-        self, room_id: str, latest_event_ids: Optional[Collection[str]] = None
+        self,
+        room_id: str,
+        latest_event_ids: Optional[Collection[str]] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """Get the current state, or the state at a set of events, for a room
 
@@ -203,6 +216,8 @@ class StateHandler:
             room_id:
             latest_event_ids: if given, the forward extremities to resolve. If
                 None, we look them up from the database (via a cache).
+            await_full_state: if true, will block if we do not yet have complete
+               state at the latest events.
 
         Returns:
             the state dict, mapping from (event_type, state_key) -> event_id
@@ -212,7 +227,9 @@ class StateHandler:
         assert latest_event_ids is not None
 
         logger.debug("calling resolve_state_groups from get_current_state_ids")
-        ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
+        ret = await self.resolve_state_groups_for_events(
+            room_id, latest_event_ids, await_full_state=await_full_state
+        )
         return ret.state
 
     async def get_current_users_in_room(
@@ -323,7 +340,9 @@ class StateHandler:
 
             logger.debug("calling resolve_state_groups from compute_event_context")
             entry = await self.resolve_state_groups_for_events(
-                event.room_id, event.prev_event_ids()
+                event.room_id,
+                event.prev_event_ids(),
+                await_full_state=False,
             )
 
             state_ids_before_event = entry.state
@@ -404,7 +423,7 @@ class StateHandler:
 
     @measure_func()
     async def resolve_state_groups_for_events(
-        self, room_id: str, event_ids: Collection[str]
+        self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
     ) -> _StateCacheEntry:
         """Given a list of event_ids this method fetches the state at each
         event, resolves conflicts between them and returns them.
@@ -412,13 +431,17 @@ class StateHandler:
         Args:
             room_id
             event_ids
+            await_full_state: if true, will block if we do not yet have complete
+               state at these events.
 
         Returns:
             The resolved state
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
-        state_groups = await self.state_store.get_state_group_for_events(event_ids)
+        state_groups = await self.state_store.get_state_group_for_events(
+            event_ids, await_full_state=await_full_state
+        )
 
         state_group_ids = state_groups.values()