diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdfb46ab82..e877e6f1a1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import (
+ EventContext,
+ UnpersistedEventContext,
+ UnpersistedEventContextBase,
+)
from synapse.logging.context import ContextResourceUsage
from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
from synapse.state import v1, v2
@@ -262,31 +266,31 @@ class StateHandler:
state = await entry.get_state(self._state_storage_controller, StateFilter.all())
return await self.store.get_joined_hosts(room_id, state, entry)
- async def compute_event_context(
+ async def calculate_context_info(
self,
event: EventBase,
state_ids_before_event: Optional[StateMap[str]] = None,
partial_state: Optional[bool] = None,
- ) -> EventContext:
- """Build an EventContext structure for a non-outlier event.
-
- (for an outlier, call EventContext.for_outlier directly)
-
- This works out what the current state should be for the event, and
- generates a new state group if necessary.
-
- Args:
- event:
- state_ids_before_event: The event ids of the state before the event if
- it can't be calculated from existing events. This is normally
- only specified when receiving an event from federation where we
- don't have the prev events, e.g. when backfilling.
- partial_state:
- `True` if `state_ids_before_event` is partial and omits non-critical
- membership events.
- `False` if `state_ids_before_event` is the full state.
- `None` when `state_ids_before_event` is not provided. In this case, the
- flag will be calculated based on `event`'s prev events.
+ state_group_before_event: Optional[int] = None,
+ ) -> UnpersistedEventContextBase:
+ """
+ Calulates the contents of an unpersisted event context, other than the current
+ state group (which is either provided or calculated when the event context is persisted)
+
+ state_ids_before_event:
+ The event ids of the full state before the event if
+ it can't be calculated from existing events. This is normally
+ only specified when receiving an event from federation where we
+ don't have the prev events, e.g. when backfilling or when the event
+ is being created for batch persisting.
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
+ state_group_before_event:
+ the current state group at the time of event, if known
Returns:
The event context.
@@ -294,7 +298,6 @@ class StateHandler:
RuntimeError if `state_ids_before_event` is not provided and one or more
prev events are missing or outliers.
"""
-
assert not event.internal_metadata.is_outlier()
#
@@ -306,17 +309,6 @@ class StateHandler:
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
- # .. though we need to get a state group for it.
- state_group_before_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=None,
- delta_ids=None,
- current_state_ids=state_ids_before_event,
- )
- )
-
# the partial_state flag must be provided
assert partial_state is not None
else:
@@ -345,6 +337,7 @@ class StateHandler:
logger.debug("calling resolve_state_groups from compute_event_context")
# we've already taken into account partial state, so no need to wait for
# complete state here.
+
entry = await self.resolve_state_groups_for_events(
event.room_id,
event.prev_event_ids(),
@@ -383,18 +376,19 @@ class StateHandler:
#
if not event.is_state():
- return EventContext.with_state(
+ return UnpersistedEventContext(
storage=self._storage_controllers,
state_group_before_event=state_group_before_event,
- state_group=state_group_before_event,
+ state_group_after_event=state_group_before_event,
state_delta_due_to_event={},
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
+ prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+ delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
partial_state=partial_state,
+ state_map_before_event=state_ids_before_event,
)
#
- # otherwise, we'll need to create a new state group for after the event
+ # otherwise, we'll need to set up creating a new state group for after the event
#
key = (event.type, event.state_key)
@@ -412,88 +406,60 @@ class StateHandler:
delta_ids = {key: event.event_id}
- state_group_after_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=None,
- )
- )
-
- return EventContext.with_state(
+ return UnpersistedEventContext(
storage=self._storage_controllers,
- state_group=state_group_after_event,
state_group_before_event=state_group_before_event,
+ state_group_after_event=None,
state_delta_due_to_event=delta_ids,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
+ prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+ delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
partial_state=partial_state,
+ state_map_before_event=state_ids_before_event,
)
- async def compute_event_context_for_batched(
+ async def compute_event_context(
self,
event: EventBase,
- state_ids_before_event: StateMap[str],
- current_state_group: int,
+ state_ids_before_event: Optional[StateMap[str]] = None,
+ partial_state: Optional[bool] = None,
) -> EventContext:
- """
- Generate an event context for an event that has not yet been persisted to the
- database. Intended for use with events that are created to be persisted in a batch.
- Args:
- event: the event the context is being computed for
- state_ids_before_event: a state map consisting of the state ids of the events
- created prior to this event.
- current_state_group: the current state group before the event.
- """
- state_group_before_event_prev_group = None
- deltas_to_state_group_before_event = None
-
- state_group_before_event = current_state_group
-
- # if the event is not state, we are set
- if not event.is_state():
- return EventContext.with_state(
- storage=self._storage_controllers,
- state_group_before_event=state_group_before_event,
- state_group=state_group_before_event,
- state_delta_due_to_event={},
- prev_group=state_group_before_event_prev_group,
- delta_ids=deltas_to_state_group_before_event,
- partial_state=False,
- )
+ """Build an EventContext structure for a non-outlier event.
- # otherwise, we'll need to create a new state group for after the event
- key = (event.type, event.state_key)
+ (for an outlier, call EventContext.for_outlier directly)
- if state_ids_before_event is not None:
- replaces = state_ids_before_event.get(key)
+ This works out what the current state should be for the event, and
+ generates a new state group if necessary.
- if replaces and replaces != event.event_id:
- event.unsigned["replaces_state"] = replaces
+ Args:
+ event:
+ state_ids_before_event: The event ids of the state before the event if
+ it can't be calculated from existing events. This is normally
+ only specified when receiving an event from federation where we
+ don't have the prev events, e.g. when backfilling.
+ partial_state:
+ `True` if `state_ids_before_event` is partial and omits non-critical
+ membership events.
+ `False` if `state_ids_before_event` is the full state.
+ `None` when `state_ids_before_event` is not provided. In this case, the
+ flag will be calculated based on `event`'s prev events.
+ entry:
+ A state cache entry for the resolved state across the prev events. We may
+ have already calculated this, so if it's available pass it in
+ Returns:
+ The event context.
- delta_ids = {key: event.event_id}
+ Raises:
+ RuntimeError if `state_ids_before_event` is not provided and one or more
+ prev events are missing or outliers.
+ """
- state_group_after_event = (
- await self._state_storage_controller.store_state_group(
- event.event_id,
- event.room_id,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- current_state_ids=None,
- )
+ unpersisted_context = await self.calculate_context_info(
+ event=event,
+ state_ids_before_event=state_ids_before_event,
+ partial_state=partial_state,
)
- return EventContext.with_state(
- storage=self._storage_controllers,
- state_group=state_group_after_event,
- state_group_before_event=state_group_before_event,
- state_delta_due_to_event=delta_ids,
- prev_group=state_group_before_event,
- delta_ids=delta_ids,
- partial_state=False,
- )
+ return await unpersisted_context.persist(event)
@measure_func()
async def resolve_state_groups_for_events(
|