summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation_event.py29
-rw-r--r--synapse/storage/databases/main/state.py8
2 files changed, 31 insertions, 6 deletions
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 2ba2b1527e..bcc755a376 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -581,6 +581,13 @@ class FederationEventHandler:
                     event.event_id,
                 )
                 return
+
+            # since the state at this event has changed, we should now re-evaluate
+            # whether it should have been rejected. We must already have all of the
+            # auth events (from last time we went round this path), so there is no
+            # need to pass the origin.
+            await self._check_event_auth(None, event, context)
+
             await self._store.update_state_for_partial_state_event(event, context)
             self._state_storage_controller.notify_event_un_partial_stated(
                 event.event_id
@@ -1624,13 +1631,15 @@ class FederationEventHandler:
         )
 
     async def _check_event_auth(
-        self, origin: str, event: EventBase, context: EventContext
+        self, origin: Optional[str], event: EventBase, context: EventContext
     ) -> None:
         """
         Checks whether an event should be rejected (for failing auth checks).
 
         Args:
-            origin: The host the event originates from.
+            origin: The host the event originates from. This is used to fetch
+               any missing auth events. It can be set to None, but only if we are
+               sure that we already have all the auth events.
             event: The event itself.
             context:
                 The event context.
@@ -1876,7 +1885,7 @@ class FederationEventHandler:
             event.internal_metadata.soft_failed = True
 
     async def _load_or_fetch_auth_events_for_event(
-        self, destination: str, event: EventBase
+        self, destination: Optional[str], event: EventBase
     ) -> Collection[EventBase]:
         """Fetch this event's auth_events, from database or remote
 
@@ -1892,12 +1901,19 @@ class FederationEventHandler:
         Args:
             destination: where to send the /event_auth request. Typically the server
                that sent us `event` in the first place.
+
+               If this is None, no attempt is made to load any missing auth events:
+               rather, an AssertionError is raised if there are any missing events.
+
             event: the event whose auth_events we want
 
         Returns:
             all of the events listed in `event.auth_events_ids`, after deduplication
 
         Raises:
+            AssertionError if some auth events were missing and no `destination` was
+            supplied.
+
             AuthError if we were unable to fetch the auth_events for any reason.
         """
         event_auth_event_ids = set(event.auth_event_ids())
@@ -1909,6 +1925,13 @@ class FederationEventHandler:
         )
         if not missing_auth_event_ids:
             return event_auth_events.values()
+        if destination is None:
+            # this shouldn't happen: destination must be set unless we know we have already
+            # persisted the auth events.
+            raise AssertionError(
+                "_load_or_fetch_auth_events_for_event() called with no destination for "
+                "an event with missing auth_events"
+            )
 
         logger.info(
             "Event %s refers to unknown auth events %s: fetching auth chain",
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 9674c4a757..f70705a0af 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -419,13 +419,15 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         # anything that was rejected should have the same state as its
         # predecessor.
         if context.rejected:
-            assert context.state_group == context.state_group_before_event
+            state_group = context.state_group_before_event
+        else:
+            state_group = context.state_group
 
         self.db_pool.simple_update_txn(
             txn,
             table="event_to_state_groups",
             keyvalues={"event_id": event.event_id},
-            updatevalues={"state_group": context.state_group},
+            updatevalues={"state_group": state_group},
         )
 
         self.db_pool.simple_delete_one_txn(
@@ -440,7 +442,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         txn.call_after(
             self._get_state_group_for_event.prefill,
             (event.event_id,),
-            context.state_group,
+            state_group,
         )