summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2022-08-01 13:53:56 +0100
committerGitHub <noreply@github.com>2022-08-01 13:53:56 +0100
commit224d792dd7827fb53b9ed3b73a99f4acefb456ba (patch)
tree1fa31a187862b516245c46e0d55cb77e2e110b26 /synapse
parentEnable Complement CI tests in the 'latest deps' test run. (#13213) (diff)
downloadsynapse-224d792dd7827fb53b9ed3b73a99f4acefb456ba.tar.xz
Refactor `_resolve_state_at_missing_prevs` to return an `EventContext` (#13404)
Previously, `_resolve_state_at_missing_prevs` returned the resolved
state before an event and a partial state flag. These were unwieldy to
carry around would only ever be used to build an event context. Build
the event context directly instead.

Signed-off-by: Sean Quah <seanq@matrix.org>
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation_event.py126
-rw-r--r--synapse/state/__init__.py8
-rw-r--r--synapse/storage/controllers/state.py4
3 files changed, 56 insertions, 82 deletions
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index bcc755a376..612e5aaa5b 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -23,7 +23,6 @@ from typing import (
     Dict,
     Iterable,
     List,
-    Optional,
     Sequence,
     Set,
     Tuple,
@@ -278,9 +277,8 @@ class FederationEventHandler:
                 )
 
         try:
-            await self._process_received_pdu(
-                origin, pdu, state_ids=None, partial_state=None
-            )
+            context = await self._state_handler.compute_event_context(pdu)
+            await self._process_received_pdu(origin, pdu, context)
         except PartialStateConflictError:
             # The room was un-partial stated while we were processing the PDU.
             # Try once more, with full state this time.
@@ -288,9 +286,8 @@ class FederationEventHandler:
                 "Room %s was un-partial stated while processing the PDU, trying again.",
                 room_id,
             )
-            await self._process_received_pdu(
-                origin, pdu, state_ids=None, partial_state=None
-            )
+            context = await self._state_handler.compute_event_context(pdu)
+            await self._process_received_pdu(origin, pdu, context)
 
     async def on_send_membership_event(
         self, origin: str, event: EventBase
@@ -320,6 +317,7 @@ class FederationEventHandler:
             The event and context of the event after inserting it into the room graph.
 
         Raises:
+            RuntimeError if any prev_events are missing
             SynapseError if the event is not accepted into the room
             PartialStateConflictError if the room was un-partial stated in between
                 computing the state at the event and persisting it. The caller should
@@ -380,7 +378,7 @@ class FederationEventHandler:
         # need to.
         await self._event_creation_handler.cache_joined_hosts_for_event(event, context)
 
-        await self._check_for_soft_fail(event, None, origin=origin)
+        await self._check_for_soft_fail(event, context=context, origin=origin)
         await self._run_push_actions_and_persist_event(event, context)
         return event, context
 
@@ -538,36 +536,10 @@ class FederationEventHandler:
             #
             # This is the same operation as we do when we receive a regular event
             # over federation.
-            state_ids, partial_state = await self._resolve_state_at_missing_prevs(
+            context = await self._compute_event_context_with_maybe_missing_prevs(
                 destination, event
             )
-
-            # There are three possible cases for (state_ids, partial_state):
-            #   * `state_ids` and `partial_state` are both `None` if we had all the
-            #     prev_events. The prev_events may or may not have partial state and
-            #     we won't know until we compute the event context.
-            #   * `state_ids` is not `None` and `partial_state` is `False` if we were
-            #     missing some prev_events (but we have full state for any we did
-            #     have). We calculated the full state after the prev_events.
-            #   * `state_ids` is not `None` and `partial_state` is `True` if we were
-            #     missing some, but not all, prev_events. At least one of the
-            #     prev_events we did have had partial state, so we calculated a partial
-            #     state after the prev_events.
-
-            context = None
-            if state_ids is not None and partial_state:
-                # the state after the prev events is still partial. We can't de-partial
-                # state the event, so don't bother building the event context.
-                pass
-            else:
-                # build a new state group for it if need be
-                context = await self._state_handler.compute_event_context(
-                    event,
-                    state_ids_before_event=state_ids,
-                    partial_state=partial_state,
-                )
-
-            if context is None or context.partial_state:
+            if context.partial_state:
                 # this can happen if some or all of the event's prev_events still have
                 # partial state. We were careful to only pick events from the db without
                 # partial-state prev events, so that implies that a prev event has
@@ -840,26 +812,25 @@ class FederationEventHandler:
 
         try:
             try:
-                state_ids, partial_state = await self._resolve_state_at_missing_prevs(
+                context = await self._compute_event_context_with_maybe_missing_prevs(
                     origin, event
                 )
                 await self._process_received_pdu(
                     origin,
                     event,
-                    state_ids=state_ids,
-                    partial_state=partial_state,
+                    context,
                     backfilled=backfilled,
                 )
             except PartialStateConflictError:
                 # The room was un-partial stated while we were processing the event.
                 # Try once more, with full state this time.
-                state_ids, partial_state = await self._resolve_state_at_missing_prevs(
+                context = await self._compute_event_context_with_maybe_missing_prevs(
                     origin, event
                 )
 
                 # We ought to have full state now, barring some unlikely race where we left and
                 # rejoned the room in the background.
-                if state_ids is not None and partial_state:
+                if context.partial_state:
                     raise AssertionError(
                         f"Event {event.event_id} still has a partial resolved state "
                         f"after room {event.room_id} was un-partial stated"
@@ -868,8 +839,7 @@ class FederationEventHandler:
                 await self._process_received_pdu(
                     origin,
                     event,
-                    state_ids=state_ids,
-                    partial_state=partial_state,
+                    context,
                     backfilled=backfilled,
                 )
         except FederationError as e:
@@ -878,15 +848,18 @@ class FederationEventHandler:
             else:
                 raise
 
-    async def _resolve_state_at_missing_prevs(
+    async def _compute_event_context_with_maybe_missing_prevs(
         self, dest: str, event: EventBase
-    ) -> Tuple[Optional[StateMap[str]], Optional[bool]]:
-        """Calculate the state at an event with missing prev_events.
+    ) -> EventContext:
+        """Build an EventContext structure for a non-outlier event whose prev_events may
+        be missing.
 
-        This is used when we have pulled a batch of events from a remote server, and
-        still don't have all the prev_events.
+        This is used when we have pulled a batch of events from a remote server, and may
+        not have all the prev_events.
 
-        If we already have all the prev_events for `event`, this method does nothing.
+        To build an EventContext, we need to calculate the state before the event. If we
+        already have all the prev_events for `event`, we can simply use the state after
+        the prev_events to calculate the state before `event`.
 
         Otherwise, the missing prevs become new backwards extremities, and we fall back
         to asking the remote server for the state after each missing `prev_event`,
@@ -907,10 +880,7 @@ class FederationEventHandler:
             event: an event to check for missing prevs.
 
         Returns:
-            if we already had all the prev events, `None, None`. Otherwise, returns a
-            tuple containing:
-             * the event ids of the state at `event`.
-             * a boolean indicating whether the state may be partial.
+            The event context.
 
         Raises:
             FederationError if we fail to get the state from the remote server after any
@@ -924,7 +894,7 @@ class FederationEventHandler:
         missing_prevs = prevs - seen
 
         if not missing_prevs:
-            return None, None
+            return await self._state_handler.compute_event_context(event)
 
         logger.info(
             "Event %s is missing prev_events %s: calculating state for a "
@@ -990,7 +960,9 @@ class FederationEventHandler:
                 "We can't get valid state history.",
                 affected=event_id,
             )
-        return state_map, partial_state
+        return await self._state_handler.compute_event_context(
+            event, state_ids_before_event=state_map, partial_state=partial_state
+        )
 
     async def _get_state_ids_after_missing_prev_event(
         self,
@@ -1159,8 +1131,7 @@ class FederationEventHandler:
         self,
         origin: str,
         event: EventBase,
-        state_ids: Optional[StateMap[str]],
-        partial_state: Optional[bool],
+        context: EventContext,
         backfilled: bool = False,
     ) -> None:
         """Called when we have a new non-outlier event.
@@ -1182,32 +1153,18 @@ class FederationEventHandler:
 
             event: event to be persisted
 
-            state_ids: Normally None, but if we are handling a gap in the graph
-                (ie, we are missing one or more prev_events), the resolved state at the
-                event
-
-            partial_state:
-                `True` if `state_ids` is partial and omits non-critical membership
-                events.
-                `False` if `state_ids` is the full state.
-                `None` if `state_ids` is not provided. In this case, the flag will be
-                calculated based on `event`'s prev events.
+            context: The `EventContext` to persist the event with.
 
             backfilled: True if this is part of a historical batch of events (inhibits
                 notification to clients, and validation of device keys.)
 
         PartialStateConflictError: if the room was un-partial stated in between
-            computing the state at the event and persisting it. The caller should retry
-            exactly once in this case.
+            computing the state at the event and persisting it. The caller should
+            recompute `context` and retry exactly once when this happens.
         """
         logger.debug("Processing event: %s", event)
         assert not event.internal_metadata.outlier
 
-        context = await self._state_handler.compute_event_context(
-            event,
-            state_ids_before_event=state_ids,
-            partial_state=partial_state,
-        )
         try:
             await self._check_event_auth(origin, event, context)
         except AuthError as e:
@@ -1219,7 +1176,7 @@ class FederationEventHandler:
             # For new (non-backfilled and non-outlier) events we check if the event
             # passes auth based on the current state. If it doesn't then we
             # "soft-fail" the event.
-            await self._check_for_soft_fail(event, state_ids, origin=origin)
+            await self._check_for_soft_fail(event, context=context, origin=origin)
 
         await self._run_push_actions_and_persist_event(event, context, backfilled)
 
@@ -1782,7 +1739,7 @@ class FederationEventHandler:
     async def _check_for_soft_fail(
         self,
         event: EventBase,
-        state_ids: Optional[StateMap[str]],
+        context: EventContext,
         origin: str,
     ) -> None:
         """Checks if we should soft fail the event; if so, marks the event as
@@ -1793,7 +1750,7 @@ class FederationEventHandler:
 
         Args:
             event
-            state_ids: The state at the event if we don't have all the event's prev events
+            context: The `EventContext` which we are about to persist the event with.
             origin: The host the event originates from.
         """
         if await self._store.is_partial_state_room(event.room_id):
@@ -1819,11 +1776,15 @@ class FederationEventHandler:
         auth_types = auth_types_for_event(room_version_obj, event)
 
         # Calculate the "current state".
-        if state_ids is not None:
-            # If we're explicitly given the state then we won't have all the
-            # prev events, and so we have a gap in the graph. In this case
-            # we want to be a little careful as we might have been down for
-            # a while and have an incorrect view of the current state,
+        seen_event_ids = await self._store.have_events_in_timeline(prev_event_ids)
+        has_missing_prevs = bool(prev_event_ids - seen_event_ids)
+        if has_missing_prevs:
+            # We don't have all the prev_events of this event, which means we have a
+            # gap in the graph, and the new event is going to become a new backwards
+            # extremity.
+            #
+            # In this case we want to be a little careful as we might have been
+            # down for a while and have an incorrect view of the current state,
             # however we still want to do checks as gaps are easy to
             # maliciously manufacture.
             #
@@ -1836,6 +1797,7 @@ class FederationEventHandler:
                 event.room_id, extrem_ids
             )
             state_sets: List[StateMap[str]] = list(state_sets_d.values())
+            state_ids = await context.get_prev_state_ids()
             state_sets.append(state_ids)
             current_state_ids = (
                 await self._state_resolution_handler.resolve_events_with_store(
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 69834de0de..c355e4f98a 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -278,6 +278,10 @@ class StateHandler:
                 flag will be calculated based on `event`'s prev events.
         Returns:
             The event context.
+
+        Raises:
+            RuntimeError if `state_ids_before_event` is not provided and one or more
+                prev events are missing or outliers.
         """
 
         assert not event.internal_metadata.is_outlier()
@@ -432,6 +436,10 @@ class StateHandler:
 
         Returns:
             The resolved state
+
+        Raises:
+            RuntimeError if we don't have a state group for one or more of the events
+               (ie. they are outliers or unknown)
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index 20805c94fa..1e35046e07 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -338,6 +338,10 @@ class StateStorageController:
             event_ids: events to get state groups for
             await_full_state: if true, will block if we do not yet have complete
                state at these events.
+
+        Raises:
+            RuntimeError if we don't have a state group for one or more of the events
+               (ie. they are outliers or unknown)
         """
         if await_full_state:
             await self._partial_state_events_tracker.await_full_state(event_ids)