summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/federation.py247
-rw-r--r--synapse/util/async_helpers.py4
2 files changed, 82 insertions, 169 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 1d39a9a4f5..3f480f2056 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -65,8 +65,7 @@ from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRes
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
 from synapse.types import UserID, get_domain_from_id
-from synapse.util import batch_iter, unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
 from synapse.visibility import filter_events_for_server
@@ -238,7 +237,6 @@ class FederationHandler(BaseHandler):
             return None
 
         state = None
-        auth_chain = []
 
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
@@ -348,7 +346,6 @@ class FederationHandler(BaseHandler):
 
                 # Calculate the state after each of the previous events, and
                 # resolve them to find the correct state at the current event.
-                auth_chains = set()
                 event_map = {event_id: pdu}
                 try:
                     # Get the state of the events we know about
@@ -369,24 +366,14 @@ class FederationHandler(BaseHandler):
                             "Requesting state at missing prev_event %s", event_id,
                         )
 
-                        room_version = await self.store.get_room_version(room_id)
-
                         with nested_logging_context(p):
                             # note that if any of the missing prevs share missing state or
                             # auth events, the requests to fetch those events are deduped
                             # by the get_pdu_cache in federation_client.
-                            (
-                                remote_state,
-                                got_auth_chain,
-                            ) = await self._get_state_for_room(
+                            (remote_state, _,) = await self._get_state_for_room(
                                 origin, room_id, p, include_event_in_state=True
                             )
 
-                            # XXX hrm I'm not convinced that duplicate events will compare
-                            # for equality, so I'm not sure this does what the author
-                            # hoped.
-                            auth_chains.update(got_auth_chain)
-
                             remote_state_map = {
                                 (x.type, x.state_key): x.event_id for x in remote_state
                             }
@@ -395,6 +382,7 @@ class FederationHandler(BaseHandler):
                             for x in remote_state:
                                 event_map[x.event_id] = x
 
+                    room_version = await self.store.get_room_version(room_id)
                     state_map = await resolve_events_with_store(
                         room_id,
                         room_version,
@@ -416,7 +404,6 @@ class FederationHandler(BaseHandler):
                     event_map.update(evs)
 
                     state = [event_map[e] for e in six.itervalues(state_map)]
-                    auth_chain = list(auth_chains)
                 except Exception:
                     logger.warning(
                         "[%s %s] Error attempting to resolve state at missing "
@@ -432,9 +419,7 @@ class FederationHandler(BaseHandler):
                         affected=event_id,
                     )
 
-        await self._process_received_pdu(
-            origin, pdu, state=state, auth_chain=auth_chain
-        )
+        await self._process_received_pdu(origin, pdu, state=state)
 
     async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
         """
@@ -633,10 +618,7 @@ class FederationHandler(BaseHandler):
     ) -> Dict[str, EventBase]:
         """Fetch events from a remote destination, checking if we already have them.
 
-        Args:
-            destination
-            room_id
-            event_ids
+        Persists any events we don't already have as outliers.
 
         If we fail to fetch any of the events, a warning will be logged, and the event
         will be omitted from the result. Likewise, any events which turn out not to
@@ -656,27 +638,15 @@ class FederationHandler(BaseHandler):
                 room_id,
             )
 
-            room_version = await self.store.get_room_version(room_id)
-
-            # XXX 20 requests at once? really?
-            for batch in batch_iter(missing_events, 20):
-                deferreds = [
-                    run_in_background(
-                        self.federation_client.get_pdu,
-                        destinations=[destination],
-                        event_id=e_id,
-                        room_version=room_version,
-                    )
-                    for e_id in batch
-                ]
-
-                res = await make_deferred_yieldable(
-                    defer.DeferredList(deferreds, consumeErrors=True)
-                )
+            await self._get_events_and_persist(
+                destination=destination, room_id=room_id, events=missing_events
+            )
 
-                for success, result in res:
-                    if success and result:
-                        fetched_events[result.event_id] = result
+            # we need to make sure we re-load from the database to get the rejected
+            # state correct.
+            fetched_events.update(
+                (await self.store.get_events(missing_events, allow_rejected=True))
+            )
 
         # check for events which were in the wrong room.
         #
@@ -705,50 +675,26 @@ class FederationHandler(BaseHandler):
 
         return fetched_events
 
-    async def _process_received_pdu(self, origin, event, state, auth_chain):
+    async def _process_received_pdu(
+        self, origin: str, event: EventBase, state: Optional[Iterable[EventBase]],
+    ):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
+
+        Args:
+            origin: server sending the event
+
+            event: event to be persisted
+
+            state: Normally None, but if we are handling a gap in the graph
+                (ie, we are missing one or more prev_events), the resolved state at the
+                event
         """
         room_id = event.room_id
         event_id = event.event_id
 
         logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
 
-        event_ids = set()
-        if state:
-            event_ids |= {e.event_id for e in state}
-        if auth_chain:
-            event_ids |= {e.event_id for e in auth_chain}
-
-        seen_ids = await self.store.have_seen_events(event_ids)
-
-        if state and auth_chain is not None:
-            # If we have any state or auth_chain given to us by the replication
-            # layer, then we should handle them (if we haven't before.)
-
-            event_infos = []
-
-            for e in itertools.chain(auth_chain, state):
-                if e.event_id in seen_ids:
-                    continue
-                e.internal_metadata.outlier = True
-                auth_ids = e.auth_event_ids()
-                auth = {
-                    (e.type, e.state_key): e
-                    for e in auth_chain
-                    if e.event_id in auth_ids or e.type == EventTypes.Create
-                }
-                event_infos.append(_NewEventInfo(event=e, auth_events=auth))
-                seen_ids.add(e.event_id)
-
-            logger.info(
-                "[%s %s] persisting newly-received auth/state events %s",
-                room_id,
-                event_id,
-                [e.event.event_id for e in event_infos],
-            )
-            await self._handle_new_events(origin, event_infos)
-
         try:
             context = await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
@@ -803,8 +749,6 @@ class FederationHandler(BaseHandler):
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        room_version = await self.store.get_room_version(room_id)
-
         events = await self.federation_client.backfill(
             dest, room_id, limit=limit, extremities=extremities
         )
@@ -833,6 +777,9 @@ class FederationHandler(BaseHandler):
 
         event_ids = set(e.event_id for e in events)
 
+        # build a list of events whose prev_events weren't in the batch.
+        # (XXX: this will include events whose prev_events we already have; that doesn't
+        # sound right?)
         edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
 
         logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
@@ -861,95 +808,11 @@ class FederationHandler(BaseHandler):
         auth_events.update(
             {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
         )
-        missing_auth = required_auth - set(auth_events)
-        failed_to_fetch = set()
-
-        # Try and fetch any missing auth events from both DB and remote servers.
-        # We repeatedly do this until we stop finding new auth events.
-        while missing_auth - failed_to_fetch:
-            logger.info("Missing auth for backfill: %r", missing_auth)
-            ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
-            auth_events.update(ret_events)
-
-            required_auth.update(
-                a_id for event in ret_events.values() for a_id in event.auth_event_ids()
-            )
-            missing_auth = required_auth - set(auth_events)
-
-            if missing_auth - failed_to_fetch:
-                logger.info(
-                    "Fetching missing auth for backfill: %r",
-                    missing_auth - failed_to_fetch,
-                )
-
-                results = await make_deferred_yieldable(
-                    defer.gatherResults(
-                        [
-                            run_in_background(
-                                self.federation_client.get_pdu,
-                                [dest],
-                                event_id,
-                                room_version=room_version,
-                                outlier=True,
-                                timeout=10000,
-                            )
-                            for event_id in missing_auth - failed_to_fetch
-                        ],
-                        consumeErrors=True,
-                    )
-                ).addErrback(unwrapFirstError)
-                auth_events.update({a.event_id: a for a in results if a})
-                required_auth.update(
-                    a_id
-                    for event in results
-                    if event
-                    for a_id in event.auth_event_ids()
-                )
-                missing_auth = required_auth - set(auth_events)
-
-                failed_to_fetch = missing_auth - set(auth_events)
-
-        seen_events = await self.store.have_seen_events(
-            set(auth_events.keys()) | set(state_events.keys())
-        )
-
-        # We now have a chunk of events plus associated state and auth chain to
-        # persist. We do the persistence in two steps:
-        #   1. Auth events and state get persisted as outliers, plus the
-        #      backward extremities get persisted (as non-outliers).
-        #   2. The rest of the events in the chunk get persisted one by one, as
-        #      each one depends on the previous event for its state.
-        #
-        # The important thing is that events in the chunk get persisted as
-        # non-outliers, including when those events are also in the state or
-        # auth chain. Caution must therefore be taken to ensure that they are
-        # not accidentally marked as outliers.
 
-        # Step 1a: persist auth events that *don't* appear in the chunk
         ev_infos = []
-        for a in auth_events.values():
-            # We only want to persist auth events as outliers that we haven't
-            # seen and aren't about to persist as part of the backfilled chunk.
-            if a.event_id in seen_events or a.event_id in event_map:
-                continue
 
-            a.internal_metadata.outlier = True
-            ev_infos.append(
-                _NewEventInfo(
-                    event=a,
-                    auth_events={
-                        (
-                            auth_events[a_id].type,
-                            auth_events[a_id].state_key,
-                        ): auth_events[a_id]
-                        for a_id in a.auth_event_ids()
-                        if a_id in auth_events
-                    },
-                )
-            )
-
-        # Step 1b: persist the events in the chunk we fetched state for (i.e.
-        # the backwards extremities) as non-outliers.
+        # Step 1: persist the events in the chunk we fetched state for (i.e.
+        # the backwards extremities), with custom auth events and state
         for e_id in events_to_state:
             # For paranoia we ensure that these events are marked as
             # non-outliers
@@ -1191,6 +1054,56 @@ class FederationHandler(BaseHandler):
 
         return False
 
+    async def _get_events_and_persist(
+        self, destination: str, room_id: str, events: Iterable[str]
+    ):
+        """Fetch the given events from a server, and persist them as outliers.
+
+        Logs a warning if we can't find the given event.
+        """
+
+        room_version = await self.store.get_room_version(room_id)
+
+        event_infos = []
+
+        async def get_event(event_id: str):
+            with nested_logging_context(event_id):
+                try:
+                    event = await self.federation_client.get_pdu(
+                        [destination], event_id, room_version, outlier=True,
+                    )
+                    if event is None:
+                        logger.warning(
+                            "Server %s didn't return event %s", destination, event_id,
+                        )
+                        return
+
+                    # recursively fetch the auth events for this event
+                    auth_events = await self._get_events_from_store_or_dest(
+                        destination, room_id, event.auth_event_ids()
+                    )
+                    auth = {}
+                    for auth_event_id in event.auth_event_ids():
+                        ae = auth_events.get(auth_event_id)
+                        if ae:
+                            auth[(ae.type, ae.state_key)] = ae
+
+                    event_infos.append(_NewEventInfo(event, None, auth))
+
+                except Exception as e:
+                    logger.warning(
+                        "Error fetching missing state/auth event %s: %s %s",
+                        event_id,
+                        type(e),
+                        e,
+                    )
+
+        await concurrently_execute(get_event, events, 5)
+
+        await self._handle_new_events(
+            destination, event_infos,
+        )
+
     def _sanity_check_event(self, ev):
         """
         Do some early sanity checks of a received event
diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py
index 5c4de2e69f..04b6abdc24 100644
--- a/synapse/util/async_helpers.py
+++ b/synapse/util/async_helpers.py
@@ -140,8 +140,8 @@ def concurrently_execute(func, args, limit):
 
     Args:
         func (func): Function to execute, should return a deferred or coroutine.
-        args (list): List of arguments to pass to func, each invocation of func
-            gets a signle argument.
+        args (Iterable): List of arguments to pass to func, each invocation of func
+            gets a single argument.
         limit (int): Maximum number of conccurent executions.
 
     Returns: