diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c355e4f98a..833ffec3de 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -44,7 +44,6 @@ from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
@@ -191,6 +190,7 @@ class StateHandler:
room_id: str,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""Fetch the state after each of the given event IDs. Resolve them and return.
@@ -201,20 +201,25 @@ class StateHandler:
Args:
room_id: the room_id containing the given events.
event_ids: the events whose state should be fetched and resolved.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the given `event_id`s, regardless of whether `state_filter` is
+ satisfied by partial state.
Returns:
the state dict (a mapping from (event_type, state_key) -> event_id) which
holds the resolution of the states after the given event IDs.
"""
logger.debug("calling resolve_state_groups from compute_state_after_events")
- ret = await self.resolve_state_groups_for_events(room_id, event_ids)
+ ret = await self.resolve_state_groups_for_events(
+ room_id, event_ids, await_full_state
+ )
return await ret.get_state(self._state_storage_controller, state_filter)
- async def get_current_users_in_room(
+ async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str]
- ) -> Dict[str, ProfileInfo]:
+ ) -> Set[str]:
"""
- Get the users who are currently in a room.
+ Get the users IDs who are currently in a room.
Note: This is much slower than using the equivalent method
`DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
@@ -225,15 +230,15 @@ class StateHandler:
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
- Dictionary of user IDs to their profileinfo.
+ Set of user IDs in the room.
"""
assert latest_event_ids is not None
- logger.debug("calling resolve_state_groups from get_current_users_in_room")
+ logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
- return await self.store.get_joined_users_from_state(room_id, state, entry)
+ return await self.store.get_joined_user_ids_from_state(room_id, state)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
@@ -421,6 +426,69 @@ class StateHandler:
partial_state=partial_state,
)
+ async def compute_event_context_for_batched(
+ self,
+ event: EventBase,
+ state_ids_before_event: StateMap[str],
+ current_state_group: int,
+ ) -> EventContext:
+ """
+ Generate an event context for an event that has not yet been persisted to the
+ database. Intended for use with events that are created to be persisted in a batch.
+ Args:
+ event: the event the context is being computed for
+ state_ids_before_event: a state map consisting of the state ids of the events
+ created prior to this event.
+ current_state_group: the current state group before the event.
+ """
+ state_group_before_event_prev_group = None
+ deltas_to_state_group_before_event = None
+
+ state_group_before_event = current_state_group
+
+ # if the event is not state, we are set
+ if not event.is_state():
+ return EventContext.with_state(
+ storage=self._storage_controllers,
+ state_group_before_event=state_group_before_event,
+ state_group=state_group_before_event,
+ state_delta_due_to_event={},
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ partial_state=False,
+ )
+
+ # otherwise, we'll need to create a new state group for after the event
+ key = (event.type, event.state_key)
+
+ if state_ids_before_event is not None:
+ replaces = state_ids_before_event.get(key)
+
+ if replaces and replaces != event.event_id:
+ event.unsigned["replaces_state"] = replaces
+
+ delta_ids = {key: event.event_id}
+
+ state_group_after_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event,
+ delta_ids=delta_ids,
+ current_state_ids=None,
+ )
+ )
+
+ return EventContext.with_state(
+ storage=self._storage_controllers,
+ state_group=state_group_after_event,
+ state_group_before_event=state_group_before_event,
+ state_delta_due_to_event=delta_ids,
+ prev_group=state_group_before_event,
+ delta_ids=delta_ids,
+ partial_state=False,
+ )
+
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
|