summary refs log tree commit diff
path: root/synapse/state/__init__.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/__init__.py')
-rw-r--r--synapse/state/__init__.py176
1 files changed, 71 insertions, 105 deletions
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index fdfb46ab82..e877e6f1a1 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -39,7 +39,11 @@ from prometheus_client import Counter, Histogram
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import (
+    EventContext,
+    UnpersistedEventContext,
+    UnpersistedEventContextBase,
+)
 from synapse.logging.context import ContextResourceUsage
 from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
@@ -262,31 +266,31 @@ class StateHandler:
         state = await entry.get_state(self._state_storage_controller, StateFilter.all())
         return await self.store.get_joined_hosts(room_id, state, entry)
 
-    async def compute_event_context(
+    async def calculate_context_info(
         self,
         event: EventBase,
         state_ids_before_event: Optional[StateMap[str]] = None,
         partial_state: Optional[bool] = None,
-    ) -> EventContext:
-        """Build an EventContext structure for a non-outlier event.
-
-        (for an outlier, call EventContext.for_outlier directly)
-
-        This works out what the current state should be for the event, and
-        generates a new state group if necessary.
-
-        Args:
-            event:
-            state_ids_before_event: The event ids of the state before the event if
-                it can't be calculated from existing events. This is normally
-                only specified when receiving an event from federation where we
-                don't have the prev events, e.g. when backfilling.
-            partial_state:
-                `True` if `state_ids_before_event` is partial and omits non-critical
-                membership events.
-                `False` if `state_ids_before_event` is the full state.
-                `None` when `state_ids_before_event` is not provided. In this case, the
-                flag will be calculated based on `event`'s prev events.
+        state_group_before_event: Optional[int] = None,
+    ) -> UnpersistedEventContextBase:
+        """
+        Calulates the contents of an unpersisted event context, other than the current
+        state group (which is either provided or calculated when the event context is persisted)
+
+        state_ids_before_event:
+            The event ids of the full state before the event if
+            it can't be calculated from existing events. This is normally
+            only specified when receiving an event from federation where we
+            don't have the prev events, e.g. when backfilling or when the event
+            is being created for batch persisting.
+        partial_state:
+            `True` if `state_ids_before_event` is partial and omits non-critical
+            membership events.
+            `False` if `state_ids_before_event` is the full state.
+            `None` when `state_ids_before_event` is not provided. In this case, the
+            flag will be calculated based on `event`'s prev events.
+        state_group_before_event:
+            the current state group at the time of event, if known
         Returns:
             The event context.
 
@@ -294,7 +298,6 @@ class StateHandler:
             RuntimeError if `state_ids_before_event` is not provided and one or more
                 prev events are missing or outliers.
         """
-
         assert not event.internal_metadata.is_outlier()
 
         #
@@ -306,17 +309,6 @@ class StateHandler:
             state_group_before_event_prev_group = None
             deltas_to_state_group_before_event = None
 
-            # .. though we need to get a state group for it.
-            state_group_before_event = (
-                await self._state_storage_controller.store_state_group(
-                    event.event_id,
-                    event.room_id,
-                    prev_group=None,
-                    delta_ids=None,
-                    current_state_ids=state_ids_before_event,
-                )
-            )
-
             # the partial_state flag must be provided
             assert partial_state is not None
         else:
@@ -345,6 +337,7 @@ class StateHandler:
             logger.debug("calling resolve_state_groups from compute_event_context")
             # we've already taken into account partial state, so no need to wait for
             # complete state here.
+
             entry = await self.resolve_state_groups_for_events(
                 event.room_id,
                 event.prev_event_ids(),
@@ -383,18 +376,19 @@ class StateHandler:
         #
 
         if not event.is_state():
-            return EventContext.with_state(
+            return UnpersistedEventContext(
                 storage=self._storage_controllers,
                 state_group_before_event=state_group_before_event,
-                state_group=state_group_before_event,
+                state_group_after_event=state_group_before_event,
                 state_delta_due_to_event={},
-                prev_group=state_group_before_event_prev_group,
-                delta_ids=deltas_to_state_group_before_event,
+                prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+                delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
                 partial_state=partial_state,
+                state_map_before_event=state_ids_before_event,
             )
 
         #
-        # otherwise, we'll need to create a new state group for after the event
+        # otherwise, we'll need to set up creating a new state group for after the event
         #
 
         key = (event.type, event.state_key)
@@ -412,88 +406,60 @@ class StateHandler:
 
         delta_ids = {key: event.event_id}
 
-        state_group_after_event = (
-            await self._state_storage_controller.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=state_group_before_event,
-                delta_ids=delta_ids,
-                current_state_ids=None,
-            )
-        )
-
-        return EventContext.with_state(
+        return UnpersistedEventContext(
             storage=self._storage_controllers,
-            state_group=state_group_after_event,
             state_group_before_event=state_group_before_event,
+            state_group_after_event=None,
             state_delta_due_to_event=delta_ids,
-            prev_group=state_group_before_event,
-            delta_ids=delta_ids,
+            prev_group_for_state_group_before_event=state_group_before_event_prev_group,
+            delta_ids_to_state_group_before_event=deltas_to_state_group_before_event,
             partial_state=partial_state,
+            state_map_before_event=state_ids_before_event,
         )
 
-    async def compute_event_context_for_batched(
+    async def compute_event_context(
         self,
         event: EventBase,
-        state_ids_before_event: StateMap[str],
-        current_state_group: int,
+        state_ids_before_event: Optional[StateMap[str]] = None,
+        partial_state: Optional[bool] = None,
     ) -> EventContext:
-        """
-        Generate an event context for an event that has not yet been persisted to the
-        database. Intended for use with events that are created to be persisted in a batch.
-        Args:
-            event: the event the context is being computed for
-            state_ids_before_event: a state map consisting of the state ids of the events
-            created prior to this event.
-            current_state_group: the current state group before the event.
-        """
-        state_group_before_event_prev_group = None
-        deltas_to_state_group_before_event = None
-
-        state_group_before_event = current_state_group
-
-        # if the event is not state, we are set
-        if not event.is_state():
-            return EventContext.with_state(
-                storage=self._storage_controllers,
-                state_group_before_event=state_group_before_event,
-                state_group=state_group_before_event,
-                state_delta_due_to_event={},
-                prev_group=state_group_before_event_prev_group,
-                delta_ids=deltas_to_state_group_before_event,
-                partial_state=False,
-            )
+        """Build an EventContext structure for a non-outlier event.
 
-        # otherwise, we'll need to create a new state group for after the event
-        key = (event.type, event.state_key)
+        (for an outlier, call EventContext.for_outlier directly)
 
-        if state_ids_before_event is not None:
-            replaces = state_ids_before_event.get(key)
+        This works out what the current state should be for the event, and
+        generates a new state group if necessary.
 
-        if replaces and replaces != event.event_id:
-            event.unsigned["replaces_state"] = replaces
+        Args:
+            event:
+            state_ids_before_event: The event ids of the state before the event if
+                it can't be calculated from existing events. This is normally
+                only specified when receiving an event from federation where we
+                don't have the prev events, e.g. when backfilling.
+            partial_state:
+                `True` if `state_ids_before_event` is partial and omits non-critical
+                membership events.
+                `False` if `state_ids_before_event` is the full state.
+                `None` when `state_ids_before_event` is not provided. In this case, the
+                flag will be calculated based on `event`'s prev events.
+            entry:
+                A state cache entry for the resolved state across the prev events. We may
+                have already calculated this, so if it's available pass it in
+        Returns:
+            The event context.
 
-        delta_ids = {key: event.event_id}
+        Raises:
+            RuntimeError if `state_ids_before_event` is not provided and one or more
+                prev events are missing or outliers.
+        """
 
-        state_group_after_event = (
-            await self._state_storage_controller.store_state_group(
-                event.event_id,
-                event.room_id,
-                prev_group=state_group_before_event,
-                delta_ids=delta_ids,
-                current_state_ids=None,
-            )
+        unpersisted_context = await self.calculate_context_info(
+            event=event,
+            state_ids_before_event=state_ids_before_event,
+            partial_state=partial_state,
         )
 
-        return EventContext.with_state(
-            storage=self._storage_controllers,
-            state_group=state_group_after_event,
-            state_group_before_event=state_group_before_event,
-            state_delta_due_to_event=delta_ids,
-            prev_group=state_group_before_event,
-            delta_ids=delta_ids,
-            partial_state=False,
-        )
+        return await unpersisted_context.persist(event)
 
     @measure_func()
     async def resolve_state_groups_for_events(