summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2021-09-24 11:56:33 +0100
committerGitHub <noreply@github.com>2021-09-24 11:56:33 +0100
commit85551b7a8555eb4e4456d5cf2db0fecd4a44621c (patch)
tree719de71dbcf4af43cf1719cc26b94b9fe5ba02a6
parentSimplify `_auth_and_persist_fetched_events` (#10901) (diff)
downloadsynapse-85551b7a8555eb4e4456d5cf2db0fecd4a44621c.tar.xz
Factor out common code for persisting fetched auth events (#10896)
* Factor more stuff out of `_get_events_and_persist`

It turns out that the event-sorting algorithm in `_get_events_and_persist` is
also useful in other circumstances. Here we move the current
`_auth_and_persist_fetched_events` to `_auth_and_persist_fetched_events_inner`,
and then factor the sorting part out to `_auth_and_persist_fetched_events`.

* `_get_remote_auth_chain_for_event`: remove redundant `outlier` assignment

`get_event_auth` returns events with the outlier flag already set, so this is
redundant (though we need to update a test where `get_event_auth` is mocked).

* `_get_remote_auth_chain_for_event`: move existing-event tests earlier

Move a couple of tests outside the loop. This is a bit inefficient for now, but
a future commit will make it better. It should be functionally identical.

* `_get_remote_auth_chain_for_event`: use `_auth_and_persist_fetched_events`

We can use the same codepath for persisting the events fetched as part of an
auth chain as for those fetched individually by `_get_events_and_persist` for
building the state at a backwards extremity.

* `_get_remote_auth_chain_for_event`: use a dict for efficiency

`_auth_and_persist_fetched_events` sorts the events itself, so we no longer
need to care about maintaining the ordering from `get_event_auth` (and no
longer need to sort by depth in `get_event_auth`).

That means that we can use a map, making it easier to filter out events we
already have, etc.

* changelog

* `_auth_and_persist_fetched_events`: improve docstring
Diffstat (limited to '')
-rw-r--r--changelog.d/10896.misc1
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/handlers/federation_event.py103
-rw-r--r--tests/handlers/test_federation.py7
4 files changed, 55 insertions, 58 deletions
diff --git a/changelog.d/10896.misc b/changelog.d/10896.misc
new file mode 100644
index 0000000000..41de995842
--- /dev/null
+++ b/changelog.d/10896.misc
@@ -0,0 +1 @@
+ Clean up some of the federation event authentication code for clarity.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 1416abd0fb..584836c04a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -501,8 +501,6 @@ class FederationClient(FederationBase):
             destination, auth_chain, outlier=True, room_version=room_version
         )
 
-        signed_auth.sort(key=lambda e: e.depth)
-
         return signed_auth
 
     def _is_unknown_endpoint(
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index 4eefcc36d8..8fd9e51044 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -1080,7 +1080,7 @@ class FederationEventHandler:
 
         room_version = await self._store.get_room_version(room_id)
 
-        event_map: Dict[str, EventBase] = {}
+        events: List[EventBase] = []
 
         async def get_event(event_id: str) -> None:
             with nested_logging_context(event_id):
@@ -1098,8 +1098,7 @@ class FederationEventHandler:
                             event_id,
                         )
                         return
-
-                    event_map[event.event_id] = event
+                    events.append(event)
 
                 except Exception as e:
                     logger.warning(
@@ -1110,11 +1109,29 @@ class FederationEventHandler:
                     )
 
         await concurrently_execute(get_event, event_ids, 5)
-        logger.info("Fetched %i events of %i requested", len(event_map), len(event_ids))
+        logger.info("Fetched %i events of %i requested", len(events), len(event_ids))
+        await self._auth_and_persist_fetched_events(destination, room_id, events)
+
+    async def _auth_and_persist_fetched_events(
+        self, origin: str, room_id: str, events: Iterable[EventBase]
+    ) -> None:
+        """Persist the events fetched by _get_events_and_persist or _get_remote_auth_chain_for_event
+
+        The events to be persisted must be outliers.
+
+        We first sort the events to make sure that we process each event's auth_events
+        before the event itself, and then auth and persist them.
+
+        Notifies about the events where appropriate.
+
+        Params:
+            origin: where the events came from
+            room_id: the room that the events are meant to be in (though this has
+               not yet been checked)
+            events: the events that have been fetched
+        """
+        event_map = {event.event_id: event for event in events}
 
-        # we now need to auth the events in an order which ensures that each event's
-        # auth_events are authed before the event itself.
-        #
         # XXX: it might be possible to kick this process off in parallel with fetching
         # the events.
         while event_map:
@@ -1141,22 +1158,18 @@ class FederationEventHandler:
                 "Persisting %i of %i remaining events", len(roots), len(event_map)
             )
 
-            await self._auth_and_persist_fetched_events(destination, room_id, roots)
+            await self._auth_and_persist_fetched_events_inner(origin, room_id, roots)
 
             for ev in roots:
                 del event_map[ev.event_id]
 
-    async def _auth_and_persist_fetched_events(
+    async def _auth_and_persist_fetched_events_inner(
         self, origin: str, room_id: str, fetched_events: Collection[EventBase]
     ) -> None:
-        """Persist the events fetched by _get_events_and_persist.
+        """Helper for _auth_and_persist_fetched_events
 
-        The events should not depend on one another, e.g. this should be used to persist
-        a bunch of outliers, but not a chunk of individual events that depend
-        on each other for state calculations.
-
-        We also assume that all of the auth events for all of the events have already
-        been persisted.
+        Persists a batch of events where we have (theoretically) already persisted all
+        of their auth events.
 
         Notifies about the events where appropriate.
 
@@ -1164,7 +1177,7 @@ class FederationEventHandler:
             origin: where the events came from
             room_id: the room that the events are meant to be in (though this has
                not yet been checked)
-            event_id: map from event_id -> event for the fetched events
+            fetched_events: the events to persist
         """
         # get all the auth events for all the events in this batch. By now, they should
         # have been persisted.
@@ -1558,53 +1571,33 @@ class FederationEventHandler:
             event_id: the event for which we are lacking auth events
         """
         try:
-            remote_auth_chain = await self._federation_client.get_event_auth(
-                destination, room_id, event_id
-            )
+            remote_event_map = {
+                e.event_id: e
+                for e in await self._federation_client.get_event_auth(
+                    destination, room_id, event_id
+                )
+            }
         except RequestSendFailed as e1:
             # The other side isn't around or doesn't implement the
             # endpoint, so lets just bail out.
             logger.info("Failed to get event auth from remote: %s", e1)
             return
 
-        seen_remotes = await self._store.have_seen_events(
-            room_id, [e.event_id for e in remote_auth_chain]
-        )
+        logger.info("/event_auth returned %i events", len(remote_event_map))
 
-        for auth_event in remote_auth_chain:
-            if auth_event.event_id in seen_remotes:
-                continue
+        # `event` may be returned, but we should not yet process it.
+        remote_event_map.pop(event_id, None)
 
-            if auth_event.event_id == event_id:
-                continue
+        # nor should we reprocess any events we have already seen.
+        seen_remotes = await self._store.have_seen_events(
+            room_id, remote_event_map.keys()
+        )
+        for s in seen_remotes:
+            remote_event_map.pop(s, None)
 
-            try:
-                auth_ids = auth_event.auth_event_ids()
-                auth = {
-                    (e.type, e.state_key): e
-                    for e in remote_auth_chain
-                    if e.event_id in auth_ids or e.type == EventTypes.Create
-                }
-                auth_event.internal_metadata.outlier = True
-
-                logger.debug(
-                    "_check_event_auth %s missing_auth: %s",
-                    event_id,
-                    auth_event.event_id,
-                )
-                missing_auth_event_context = EventContext.for_outlier()
-                missing_auth_event_context = await self._check_event_auth(
-                    destination,
-                    auth_event,
-                    missing_auth_event_context,
-                    claimed_auth_event_map=auth,
-                )
-                await self.persist_events_and_notify(
-                    room_id,
-                    [(auth_event, missing_auth_event_context)],
-                )
-            except AuthError:
-                pass
+        await self._auth_and_persist_fetched_events(
+            destination, room_id, remote_event_map.values()
+        )
 
     async def _update_context_for_auth_events(
         self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 6c67a16de9..936ebf3dde 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -308,7 +308,12 @@ class FederationTestCase(unittest.HomeserverTestCase):
         async def get_event_auth(
             destination: str, room_id: str, event_id: str
         ) -> List[EventBase]:
-            return auth_events
+            return [
+                event_from_pdu_json(
+                    ae.get_pdu_json(), room_version=room_version, outlier=True
+                )
+                for ae in auth_events
+            ]
 
         self.handler.federation_client.get_event_auth = get_event_auth