diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index e58301a8f0..a7e721aef2 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -609,13 +609,18 @@ class StateGroupStorage:
return state_group_delta.prev_group, state_group_delta.delta_ids
async def get_state_groups_ids(
- self, _room_id: str, event_ids: Collection[str]
+ self,
+ _room_id: str,
+ event_ids: Collection[str],
+ await_full_state: bool = True,
) -> Dict[int, MutableStateMap[str]]:
"""Get the event IDs of all the state for the state groups for the given events
Args:
_room_id: id of the room for these events
event_ids: ids of the events
+ await_full_state: if true, will block if we do not yet have complete
+ state at these events.
Returns:
dict of state_group_id -> (dict of (type, state_key) -> event id)
@@ -627,7 +632,9 @@ class StateGroupStorage:
if not event_ids:
return {}
- event_to_groups = await self.get_state_group_for_events(event_ids)
+ event_to_groups = await self.get_state_group_for_events(
+ event_ids, await_full_state=await_full_state
+ )
groups = set(event_to_groups.values())
group_to_state = await self.stores.state._get_state_for_groups(groups)
@@ -700,7 +707,10 @@ class StateGroupStorage:
return self.stores.state._get_state_groups_from_groups(groups, state_filter)
async def get_state_for_events(
- self, event_ids: Collection[str], state_filter: Optional[StateFilter] = None
+ self,
+ event_ids: Collection[str],
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> Dict[str, StateMap[EventBase]]:
"""Given a list of event_ids and type tuples, return a list of state
dicts for each event.
@@ -708,6 +718,8 @@ class StateGroupStorage:
Args:
event_ids: The events to fetch the state of.
state_filter: The state filter used to fetch state.
+ await_full_state: if true, will block if the state_filter includes state
+ which is not yet complete.
Returns:
A dict of (event_id) -> (type, state_key) -> [state_events]
@@ -716,8 +728,11 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- await_full_state = True
- if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ if (
+ await_full_state
+ and state_filter
+ and not state_filter.must_await_full_state(self._is_mine_id)
+ ):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
@@ -749,6 +764,7 @@ class StateGroupStorage:
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -757,6 +773,8 @@ class StateGroupStorage:
Args:
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if true, will block if the state_filter includes state
+ which is not yet complete.
Returns:
A dict from event_id -> (type, state_key) -> event_id
@@ -765,8 +783,12 @@ class StateGroupStorage:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- await_full_state = True
- if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+
+ if (
+ await_full_state
+ and state_filter
+ and not state_filter.must_await_full_state(self._is_mine_id)
+ ):
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
@@ -808,7 +830,10 @@ class StateGroupStorage:
return state_map[event_id]
async def get_state_ids_for_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
@@ -816,6 +841,8 @@ class StateGroupStorage:
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if true, will block if the state_filter includes state
+ which is not yet complete.
Returns:
A dict from (type, state_key) -> state_event_id
@@ -825,7 +852,9 @@ class StateGroupStorage:
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
- [event_id], state_filter or StateFilter.all()
+ [event_id],
+ state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
)
return state_map[event_id]
@@ -857,7 +886,7 @@ class StateGroupStorage:
Args:
event_ids: events to get state groups for
await_full_state: if true, will block if we do not yet have complete
- state at these events.
+ state at these event.
"""
if await_full_state:
await self._partial_state_events_tracker.await_full_state(event_ids)
|