summary refs log tree commit diff
path: root/synapse/storage/controllers/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/controllers/state.py')
-rw-r--r--synapse/storage/controllers/state.py24
1 files changed, 20 insertions, 4 deletions
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 1ad002f57b..f9ffd0e29e 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -234,6 +234,7 @@ class StateStorageController:
         self,
         event_ids: Collection[str],
         state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> Dict[str, StateMap[str]]:
         """
         Get the state dicts corresponding to a list of events, containing the event_ids
@@ -242,6 +243,9 @@ class StateStorageController:
         Args:
             event_ids: events whose state should be returned
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at these events and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
 
         Returns:
             A dict from event_id -> (type, state_key) -> event_id
@@ -250,8 +254,12 @@ class StateStorageController:
             RuntimeError if we don't have a state group for one or more of the events
                 (ie they are outliers or unknown)
         """
-        await_full_state = True
-        if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+        if (
+            await_full_state
+            and state_filter
+            and not state_filter.must_await_full_state(self._is_mine_id)
+        ):
+            # Full state is not required if the state filter is restrictive enough.
             await_full_state = False
 
         event_to_groups = await self.get_state_group_for_events(
@@ -294,7 +302,10 @@ class StateStorageController:
 
     @trace
     async def get_state_ids_for_event(
-        self, event_id: str, state_filter: Optional[StateFilter] = None
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
     ) -> StateMap[str]:
         """
         Get the state dict corresponding to a particular event
@@ -302,6 +313,9 @@ class StateStorageController:
         Args:
             event_id: event whose state should be returned
             state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the event and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
 
         Returns:
             A dict from (type, state_key) -> state_event_id
@@ -311,7 +325,9 @@ class StateStorageController:
                 outlier or is unknown)
         """
         state_map = await self.get_state_ids_for_events(
-            [event_id], state_filter or StateFilter.all()
+            [event_id],
+            state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
         )
         return state_map[event_id]