diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c355e4f98a..833ffec3de 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -44,7 +44,6 @@ from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.roommember import ProfileInfo
from synapse.storage.state import StateFilter
from synapse.types import StateMap
from synapse.util.async_helpers import Linearizer
@@ -191,6 +190,7 @@ class StateHandler:
room_id: str,
event_ids: Collection[str],
state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
) -> StateMap[str]:
"""Fetch the state after each of the given event IDs. Resolve them and return.
@@ -201,20 +201,25 @@ class StateHandler:
Args:
room_id: the room_id containing the given events.
event_ids: the events whose state should be fetched and resolved.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the given `event_id`s, regardless of whether `state_filter` is
+ satisfied by partial state.
Returns:
the state dict (a mapping from (event_type, state_key) -> event_id) which
holds the resolution of the states after the given event IDs.
"""
logger.debug("calling resolve_state_groups from compute_state_after_events")
- ret = await self.resolve_state_groups_for_events(room_id, event_ids)
+ ret = await self.resolve_state_groups_for_events(
+ room_id, event_ids, await_full_state
+ )
return await ret.get_state(self._state_storage_controller, state_filter)
- async def get_current_users_in_room(
+ async def get_current_user_ids_in_room(
self, room_id: str, latest_event_ids: List[str]
- ) -> Dict[str, ProfileInfo]:
+ ) -> Set[str]:
"""
- Get the users who are currently in a room.
+ Get the users IDs who are currently in a room.
Note: This is much slower than using the equivalent method
`DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
@@ -225,15 +230,15 @@ class StateHandler:
room_id: The ID of the room.
latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
- Dictionary of user IDs to their profileinfo.
+ Set of user IDs in the room.
"""
assert latest_event_ids is not None
- logger.debug("calling resolve_state_groups from get_current_users_in_room")
+ logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
- return await self.store.get_joined_users_from_state(room_id, state, entry)
+ return await self.store.get_joined_user_ids_from_state(room_id, state)
async def get_hosts_in_room_at_events(
self, room_id: str, event_ids: Collection[str]
@@ -421,6 +426,69 @@ class StateHandler:
partial_state=partial_state,
)
+ async def compute_event_context_for_batched(
+ self,
+ event: EventBase,
+ state_ids_before_event: StateMap[str],
+ current_state_group: int,
+ ) -> EventContext:
+ """
+ Generate an event context for an event that has not yet been persisted to the
+ database. Intended for use with events that are created to be persisted in a batch.
+ Args:
+ event: the event the context is being computed for
+ state_ids_before_event: a state map consisting of the state ids of the events
+ created prior to this event.
+ current_state_group: the current state group before the event.
+ """
+ state_group_before_event_prev_group = None
+ deltas_to_state_group_before_event = None
+
+ state_group_before_event = current_state_group
+
+ # if the event is not state, we are set
+ if not event.is_state():
+ return EventContext.with_state(
+ storage=self._storage_controllers,
+ state_group_before_event=state_group_before_event,
+ state_group=state_group_before_event,
+ state_delta_due_to_event={},
+ prev_group=state_group_before_event_prev_group,
+ delta_ids=deltas_to_state_group_before_event,
+ partial_state=False,
+ )
+
+ # otherwise, we'll need to create a new state group for after the event
+ key = (event.type, event.state_key)
+
+ if state_ids_before_event is not None:
+ replaces = state_ids_before_event.get(key)
+
+ if replaces and replaces != event.event_id:
+ event.unsigned["replaces_state"] = replaces
+
+ delta_ids = {key: event.event_id}
+
+ state_group_after_event = (
+ await self._state_storage_controller.store_state_group(
+ event.event_id,
+ event.room_id,
+ prev_group=state_group_before_event,
+ delta_ids=delta_ids,
+ current_state_ids=None,
+ )
+ )
+
+ return EventContext.with_state(
+ storage=self._storage_controllers,
+ state_group=state_group_after_event,
+ state_group_before_event=state_group_before_event,
+ state_delta_due_to_event=delta_ids,
+ prev_group=state_group_before_event,
+ delta_ids=delta_ids,
+ partial_state=False,
+ )
+
@measure_func()
async def resolve_state_groups_for_events(
self, room_id: str, event_ids: Collection[str], await_full_state: bool = True
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 7db032203b..1b9d7d8457 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -271,40 +271,41 @@ async def _get_power_level_for_sender(
async def _get_auth_chain_difference(
room_id: str,
state_sets: Sequence[Mapping[Any, str]],
- event_map: Dict[str, EventBase],
+ unpersisted_events: Dict[str, EventBase],
state_res_store: StateResolutionStore,
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
- that only appear in some but not all of the auth chains.
+ that only appear in some, but not all of the auth chains.
Args:
- state_sets
- event_map
- state_res_store
+ state_sets: The input state sets we are trying to resolve across.
+ unpersisted_events: A map from event ID to EventBase containing all unpersisted
+ events involved in this resolution.
+ state_res_store:
Returns:
- Set of event IDs
+ The auth difference of the given state sets, as a set of event IDs.
"""
# The `StateResolutionStore.get_auth_chain_difference` function assumes that
# all events passed to it (and their auth chains) have been persisted
- # previously. This is not the case for any events in the `event_map`, and so
- # we need to manually handle those events.
+ # previously. We need to manually handle any other events that are yet to be
+ # persisted.
#
- # We do this by:
- # 1. calculating the auth chain difference for the state sets based on the
- # events in `event_map` alone
- # 2. replacing any events in the state_sets that are also in `event_map`
- # with their auth events (recursively), and then calling
- # `store.get_auth_chain_difference` as normal
- # 3. adding the results of 1 and 2 together.
-
- # Map from event ID in `event_map` to their auth event IDs, and their auth
- # event IDs if they appear in the `event_map`. This is the intersection of
- # the event's auth chain with the events in the `event_map` *plus* their
+ # We do this in three steps:
+ # 1. Compute the set of unpersisted events belonging to the auth difference.
+ # 2. Replacing any unpersisted events in the state_sets with their auth events,
+ # recursively, until the state_sets contain only persisted events.
+ # Then we call `store.get_auth_chain_difference` as normal, which computes
+ # the set of persisted events belonging to the auth difference.
+ # 3. Adding the results of 1 and 2 together.
+
+ # Map from event ID in `unpersisted_events` to their auth event IDs, and their auth
+ # event IDs if they appear in the `unpersisted_events`. This is the intersection of
+ # the event's auth chain with the events in `unpersisted_events` *plus* their
# auth event IDs.
events_to_auth_chain: Dict[str, Set[str]] = {}
- for event in event_map.values():
+ for event in unpersisted_events.values():
chain = {event.event_id}
events_to_auth_chain[event.event_id] = chain
@@ -312,16 +313,16 @@ async def _get_auth_chain_difference(
while to_search:
for auth_id in to_search.pop().auth_event_ids():
chain.add(auth_id)
- auth_event = event_map.get(auth_id)
+ auth_event = unpersisted_events.get(auth_id)
if auth_event:
to_search.append(auth_event)
- # We now a) calculate the auth chain difference for the unpersisted events
- # and b) work out the state sets to pass to the store.
+ # We now 1) calculate the auth chain difference for the unpersisted events
+ # and 2) work out the state sets to pass to the store.
#
- # Note: If the `event_map` is empty (which is the common case), we can do a
+ # Note: If there are no `unpersisted_events` (which is the common case), we can do a
# much simpler calculation.
- if event_map:
+ if unpersisted_events:
# The list of state sets to pass to the store, where each state set is a set
# of the event ids making up the state. This is similar to `state_sets`,
# except that (a) we only have event ids, not the complete
@@ -344,14 +345,18 @@ async def _get_auth_chain_difference(
for event_id in state_set.values():
event_chain = events_to_auth_chain.get(event_id)
if event_chain is not None:
- # We have an event in `event_map`. We add all the auth
- # events that it references (that aren't also in `event_map`).
- set_ids.update(e for e in event_chain if e not in event_map)
+ # We have an unpersisted event. We add all the auth
+ # events that it references which are also unpersisted.
+ set_ids.update(
+ e for e in event_chain if e not in unpersisted_events
+ )
# We also add the full chain of unpersisted event IDs
# referenced by this state set, so that we can work out the
# auth chain difference of the unpersisted events.
- unpersisted_ids.update(e for e in event_chain if e in event_map)
+ unpersisted_ids.update(
+ e for e in event_chain if e in unpersisted_events
+ )
else:
set_ids.add(event_id)
@@ -361,15 +366,15 @@ async def _get_auth_chain_difference(
union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
- difference_from_event_map: Collection[str] = union - intersection
+ auth_difference_unpersisted_part: Collection[str] = union - intersection
else:
- difference_from_event_map = ()
+ auth_difference_unpersisted_part = ()
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
difference = await state_res_store.get_auth_chain_difference(
room_id, state_sets_ids
)
- difference.update(difference_from_event_map)
+ difference.update(auth_difference_unpersisted_part)
return difference
@@ -434,7 +439,7 @@ async def _add_event_and_auth_chain_to_graph(
event_id: str,
event_map: Dict[str, EventBase],
state_res_store: StateResolutionStore,
- auth_diff: Set[str],
+ full_conflicted_set: Set[str],
) -> None:
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
@@ -445,7 +450,7 @@ async def _add_event_and_auth_chain_to_graph(
event_id: Event to add to the graph
event_map
state_res_store
- auth_diff: Set of event IDs that are in the auth difference.
+ full_conflicted_set: Set of event IDs that are in the full conflicted set.
"""
state = [event_id]
@@ -455,7 +460,7 @@ async def _add_event_and_auth_chain_to_graph(
event = await _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids():
- if aid in auth_diff:
+ if aid in full_conflicted_set:
if aid not in graph:
state.append(aid)
@@ -468,7 +473,7 @@ async def _reverse_topological_power_sort(
event_ids: Iterable[str],
event_map: Dict[str, EventBase],
state_res_store: StateResolutionStore,
- auth_diff: Set[str],
+ full_conflicted_set: Set[str],
) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
@@ -479,7 +484,7 @@ async def _reverse_topological_power_sort(
event_ids: The events to sort
event_map
state_res_store
- auth_diff: Set of event IDs that are in the auth difference.
+ full_conflicted_set: Set of event IDs that are in the full conflicted set.
Returns:
The sorted list
@@ -488,7 +493,7 @@ async def _reverse_topological_power_sort(
graph: Dict[str, Set[str]] = {}
for idx, event_id in enumerate(event_ids, start=1):
await _add_event_and_auth_chain_to_graph(
- graph, room_id, event_id, event_map, state_res_store, auth_diff
+ graph, room_id, event_id, event_map, state_res_store, full_conflicted_set
)
# We await occasionally when we're working with large data sets to
@@ -572,6 +577,21 @@ async def _iterative_auth_checks(
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
+ if event.rejected_reason is not None:
+ # Do not admit previously rejected events into state.
+ # TODO: This isn't spec compliant. Events that were previously rejected due
+ # to failing auth checks at their state, but pass auth checks during
+ # state resolution should be accepted. Synapse does not handle the
+ # change of rejection status well, so we preserve the previous
+ # rejection status for now.
+ #
+ # Note that events rejected for non-state reasons, such as having the
+ # wrong auth events, should remain rejected.
+ #
+ # https://spec.matrix.org/v1.2/rooms/v9/#rejected-events
+ # https://github.com/matrix-org/synapse/issues/13797
+ continue
+
try:
event_auth.check_state_dependent_auth_rules(
event,
|