diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 7b2084664c..5964835ea3 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -29,7 +29,8 @@ from typing import (
from synapse.api.constants import EventTypes
from synapse.events import EventBase
-from synapse.logging.tracing import trace
+from synapse.logging.tracing import tag_args, trace
+from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.storage.util.partial_state_events_tracker import (
PartialCurrentStateTracker,
@@ -228,10 +229,12 @@ class StateStorageController:
return {event: event_to_state[event] for event in event_ids}
@trace
+ @tag_args
async def get_state_ids_for_events(
self,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> Dict[str, StateMap[str]]:
"""
Get the state dicts corresponding to a list of events, containing the event_ids
@@ -240,6 +243,9 @@ class StateStorageController:
Args:
event_ids: events whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at these events and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
A dict from event_id -> (type, state_key) -> event_id
@@ -248,8 +254,12 @@ class StateStorageController:
RuntimeError if we don't have a state group for one or more of the events
(ie they are outliers or unknown)
"""
- await_full_state = True
- if state_filter and not state_filter.must_await_full_state(self._is_mine_id):
+ if (
+ await_full_state
+ and state_filter
+ and not state_filter.must_await_full_state(self._is_mine_id)
+ ):
+ # Full state is not required if the state filter is restrictive enough.
await_full_state = False
event_to_groups = await self.get_state_group_for_events(
@@ -292,7 +302,10 @@ class StateStorageController:
@trace
async def get_state_ids_for_event(
- self, event_id: str, state_filter: Optional[StateFilter] = None
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""
Get the state dict corresponding to a particular event
@@ -300,6 +313,9 @@ class StateStorageController:
Args:
event_id: event whose state should be returned
state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the event and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
Returns:
A dict from (type, state_key) -> state_event_id
@@ -309,7 +325,9 @@ class StateStorageController:
outlier or is unknown)
"""
state_map = await self.get_state_ids_for_events(
- [event_id], state_filter or StateFilter.all()
+ [event_id],
+ state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
)
return state_map[event_id]
@@ -332,6 +350,7 @@ class StateStorageController:
)
@trace
+ @tag_args
async def get_state_group_for_events(
self,
event_ids: Collection[str],
@@ -473,6 +492,7 @@ class StateStorageController:
prev_stream_id, max_stream_id
)
+ @trace
async def get_current_state(
self, room_id: str, state_filter: Optional[StateFilter] = None
) -> StateMap[EventBase]:
@@ -506,3 +526,15 @@ class StateStorageController:
await self._partial_state_room_tracker.await_full_state(room_id)
return await self.stores.main.get_current_hosts_in_room(room_id)
+
+ async def get_users_in_room_with_profiles(
+ self, room_id: str
+ ) -> Dict[str, ProfileInfo]:
+ """
+ Get the current users in the room with their profiles.
+ If the room is currently partial-stated, this will block until the room has
+ full state.
+ """
+ await self._partial_state_room_tracker.await_full_state(room_id)
+
+ return await self.stores.main.get_users_in_room_with_profiles(room_id)
|