summary refs log tree commit diff
path: root/synapse/federation/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r--synapse/federation/federation_client.py182
1 files changed, 150 insertions, 32 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 9e7527291b..0615ce0870 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -80,6 +80,18 @@ PDU_RETRY_TIME_MS = 1 * 60 * 1000
 T = TypeVar("T")
 
 
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class PulledPduInfo:
+    """
+    A result object that stores the PDU and info about it like which homeserver we
+    pulled it from (`pull_origin`)
+    """
+
+    pdu: EventBase
+    # Which homeserver we pulled the PDU from
+    pull_origin: str
+
+
 class InvalidResponseError(RuntimeError):
     """Helper for _try_destination_list: indicates that the server returned a response
     we couldn't parse
@@ -114,7 +126,9 @@ class FederationClient(FederationBase):
         self.hostname = hs.hostname
         self.signing_key = hs.signing_key
 
-        self._get_pdu_cache: ExpiringCache[str, EventBase] = ExpiringCache(
+        # Cache mapping `event_id` to a tuple of the event itself and the `pull_origin`
+        # (which server we pulled the event from)
+        self._get_pdu_cache: ExpiringCache[str, Tuple[EventBase, str]] = ExpiringCache(
             cache_name="get_pdu_cache",
             clock=self._clock,
             max_len=1000,
@@ -278,7 +292,7 @@ class FederationClient(FederationBase):
         pdus = [event_from_pdu_json(p, room_version) for p in transaction_data_pdus]
 
         # Check signatures and hash of pdus, removing any from the list that fail checks
-        pdus[:] = await self._check_sigs_and_hash_and_fetch(
+        pdus[:] = await self._check_sigs_and_hash_for_pulled_events_and_fetch(
             dest, pdus, room_version=room_version
         )
 
@@ -328,7 +342,17 @@ class FederationClient(FederationBase):
 
             # Check signatures are correct.
             try:
-                signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
+
+                async def _record_failure_callback(
+                    event: EventBase, cause: str
+                ) -> None:
+                    await self.store.record_event_failed_pull_attempt(
+                        event.room_id, event.event_id, cause
+                    )
+
+                signed_pdu = await self._check_sigs_and_hash(
+                    room_version, pdu, _record_failure_callback
+                )
             except InvalidEventSignatureError as e:
                 errmsg = f"event id {pdu.event_id}: {e}"
                 logger.warning("%s", errmsg)
@@ -342,11 +366,11 @@ class FederationClient(FederationBase):
     @tag_args
     async def get_pdu(
         self,
-        destinations: Iterable[str],
+        destinations: Collection[str],
         event_id: str,
         room_version: RoomVersion,
         timeout: Optional[int] = None,
-    ) -> Optional[EventBase]:
+    ) -> Optional[PulledPduInfo]:
         """Requests the PDU with given origin and ID from the remote home
         servers.
 
@@ -361,11 +385,11 @@ class FederationClient(FederationBase):
                 moving to the next destination. None indicates no timeout.
 
         Returns:
-            The requested PDU, or None if we were unable to find it.
+            The requested PDU wrapped in `PulledPduInfo`, or None if we were unable to find it.
         """
 
         logger.debug(
-            "get_pdu: event_id=%s from destinations=%s", event_id, destinations
+            "get_pdu(event_id=%s): from destinations=%s", event_id, destinations
         )
 
         # TODO: Rate limit the number of times we try and get the same event.
@@ -374,19 +398,25 @@ class FederationClient(FederationBase):
         # it gets persisted to the database), so we cache the results of the lookup.
         # Note that this is separate to the regular get_event cache which caches
         # events once they have been persisted.
-        event = self._get_pdu_cache.get(event_id)
+        get_pdu_cache_entry = self._get_pdu_cache.get(event_id)
 
+        event = None
+        pull_origin = None
+        if get_pdu_cache_entry:
+            event, pull_origin = get_pdu_cache_entry
         # If we don't see the event in the cache, go try to fetch it from the
         # provided remote federated destinations
-        if not event:
+        else:
             pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
 
+            # TODO: We can probably refactor this to use `_try_destination_list`
             for destination in destinations:
                 now = self._clock.time_msec()
                 last_attempt = pdu_attempts.get(destination, 0)
                 if last_attempt + PDU_RETRY_TIME_MS > now:
                     logger.debug(
-                        "get_pdu: skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)",
+                        "get_pdu(event_id=%s): skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)",
+                        event_id,
                         destination,
                         last_attempt,
                         PDU_RETRY_TIME_MS,
@@ -401,43 +431,48 @@ class FederationClient(FederationBase):
                         room_version=room_version,
                         timeout=timeout,
                     )
+                    pull_origin = destination
 
                     pdu_attempts[destination] = now
 
                     if event:
                         # Prime the cache
-                        self._get_pdu_cache[event.event_id] = event
+                        self._get_pdu_cache[event.event_id] = (event, pull_origin)
 
                         # Now that we have an event, we can break out of this
                         # loop and stop asking other destinations.
                         break
 
+                except NotRetryingDestination as e:
+                    logger.info("get_pdu(event_id=%s): %s", event_id, e)
+                    continue
+                except FederationDeniedError:
+                    logger.info(
+                        "get_pdu(event_id=%s): Not attempting to fetch PDU from %s because the homeserver is not on our federation whitelist",
+                        event_id,
+                        destination,
+                    )
+                    continue
                 except SynapseError as e:
                     logger.info(
-                        "Failed to get PDU %s from %s because %s",
+                        "get_pdu(event_id=%s): Failed to get PDU from %s because %s",
                         event_id,
                         destination,
                         e,
                     )
                     continue
-                except NotRetryingDestination as e:
-                    logger.info(str(e))
-                    continue
-                except FederationDeniedError as e:
-                    logger.info(str(e))
-                    continue
                 except Exception as e:
                     pdu_attempts[destination] = now
 
                     logger.info(
-                        "Failed to get PDU %s from %s because %s",
+                        "get_pdu(event_id=%s): Failed to get PDU from %s because %s",
                         event_id,
                         destination,
                         e,
                     )
                     continue
 
-        if not event:
+        if not event or not pull_origin:
             return None
 
         # `event` now refers to an object stored in `get_pdu_cache`. Our
@@ -449,7 +484,7 @@ class FederationClient(FederationBase):
             event.room_version,
         )
 
-        return event_copy
+        return PulledPduInfo(event_copy, pull_origin)
 
     @trace
     @tag_args
@@ -547,24 +582,28 @@ class FederationClient(FederationBase):
             len(auth_event_map),
         )
 
-        valid_auth_events = await self._check_sigs_and_hash_and_fetch(
+        valid_auth_events = await self._check_sigs_and_hash_for_pulled_events_and_fetch(
             destination, auth_event_map.values(), room_version
         )
 
-        valid_state_events = await self._check_sigs_and_hash_and_fetch(
-            destination, state_event_map.values(), room_version
+        valid_state_events = (
+            await self._check_sigs_and_hash_for_pulled_events_and_fetch(
+                destination, state_event_map.values(), room_version
+            )
         )
 
         return valid_state_events, valid_auth_events
 
     @trace
-    async def _check_sigs_and_hash_and_fetch(
+    async def _check_sigs_and_hash_for_pulled_events_and_fetch(
         self,
         origin: str,
         pdus: Collection[EventBase],
         room_version: RoomVersion,
     ) -> List[EventBase]:
-        """Checks the signatures and hashes of a list of events.
+        """
+        Checks the signatures and hashes of a list of pulled events we got from
+        federation and records any signature failures as failed pull attempts.
 
         If a PDU fails its signature check then we check if we have it in
         the database, and if not then request it from the sender's server (if that
@@ -597,11 +636,17 @@ class FederationClient(FederationBase):
 
         valid_pdus: List[EventBase] = []
 
+        async def _record_failure_callback(event: EventBase, cause: str) -> None:
+            await self.store.record_event_failed_pull_attempt(
+                event.room_id, event.event_id, cause
+            )
+
         async def _execute(pdu: EventBase) -> None:
             valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
                 pdu=pdu,
                 origin=origin,
                 room_version=room_version,
+                record_failure_callback=_record_failure_callback,
             )
 
             if valid_pdu:
@@ -618,6 +663,9 @@ class FederationClient(FederationBase):
         pdu: EventBase,
         origin: str,
         room_version: RoomVersion,
+        record_failure_callback: Optional[
+            Callable[[EventBase, str], Awaitable[None]]
+        ] = None,
     ) -> Optional[EventBase]:
         """Takes a PDU and checks its signatures and hashes.
 
@@ -634,6 +682,11 @@ class FederationClient(FederationBase):
             origin
             pdu
             room_version
+            record_failure_callback: A callback to run whenever the given event
+                fails signature or hash checks. This includes exceptions
+                that would be normally be thrown/raised but also things like
+                checking for event tampering where we just return the redacted
+                event.
 
         Returns:
             The PDU (possibly redacted) if it has valid signatures and hashes.
@@ -641,7 +694,9 @@ class FederationClient(FederationBase):
         """
 
         try:
-            return await self._check_sigs_and_hash(room_version, pdu)
+            return await self._check_sigs_and_hash(
+                room_version, pdu, record_failure_callback
+            )
         except InvalidEventSignatureError as e:
             logger.warning(
                 "Signature on retrieved event %s was invalid (%s). "
@@ -669,12 +724,14 @@ class FederationClient(FederationBase):
         pdu_origin = get_domain_from_id(pdu.sender)
         if not res and pdu_origin != origin:
             try:
-                res = await self.get_pdu(
+                pulled_pdu_info = await self.get_pdu(
                     destinations=[pdu_origin],
                     event_id=pdu.event_id,
                     room_version=room_version,
                     timeout=10000,
                 )
+                if pulled_pdu_info is not None:
+                    res = pulled_pdu_info.pdu
             except SynapseError:
                 pass
 
@@ -694,7 +751,7 @@ class FederationClient(FederationBase):
 
         auth_chain = [event_from_pdu_json(p, room_version) for p in res["auth_chain"]]
 
-        signed_auth = await self._check_sigs_and_hash_and_fetch(
+        signed_auth = await self._check_sigs_and_hash_for_pulled_events_and_fetch(
             destination, auth_chain, room_version=room_version
         )
 
@@ -776,6 +833,7 @@ class FederationClient(FederationBase):
             )
 
         for destination in destinations:
+            # We don't want to ask our own server for information we don't have
             if destination == self.server_name:
                 continue
 
@@ -784,9 +842,21 @@ class FederationClient(FederationBase):
             except (
                 RequestSendFailed,
                 InvalidResponseError,
-                NotRetryingDestination,
             ) as e:
                 logger.warning("Failed to %s via %s: %s", description, destination, e)
+                # Skip to the next homeserver in the list to try.
+                continue
+            except NotRetryingDestination as e:
+                logger.info("%s: %s", description, e)
+                continue
+            except FederationDeniedError:
+                logger.info(
+                    "%s: Not attempting to %s from %s because the homeserver is not on our federation whitelist",
+                    description,
+                    description,
+                    destination,
+                )
+                continue
             except UnsupportedRoomVersionError:
                 raise
             except HttpResponseException as e:
@@ -1264,7 +1334,7 @@ class FederationClient(FederationBase):
         return resp[1]
 
     async def send_knock(self, destinations: List[str], pdu: EventBase) -> JsonDict:
-        """Attempts to send a knock event to given a list of servers. Iterates
+        """Attempts to send a knock event to a given list of servers. Iterates
         through the list until one attempt succeeds.
 
         Doing so will cause the remote server to add the event to the graph,
@@ -1401,7 +1471,7 @@ class FederationClient(FederationBase):
                 event_from_pdu_json(e, room_version) for e in content.get("events", [])
             ]
 
-            signed_events = await self._check_sigs_and_hash_and_fetch(
+            signed_events = await self._check_sigs_and_hash_for_pulled_events_and_fetch(
                 destination, events, room_version=room_version
             )
         except HttpResponseException as e:
@@ -1579,6 +1649,54 @@ class FederationClient(FederationBase):
         return result
 
     async def timestamp_to_event(
+        self, *, destinations: List[str], room_id: str, timestamp: int, direction: str
+    ) -> Optional["TimestampToEventResponse"]:
+        """
+        Calls each remote federating server from `destinations` asking for their closest
+        event to the given timestamp in the given direction until we get a response.
+        Also validates the response to always return the expected keys or raises an
+        error.
+
+        Args:
+            destinations: The domains of homeservers to try fetching from
+            room_id: Room to fetch the event from
+            timestamp: The point in time (inclusive) we should navigate from in
+                the given direction to find the closest event.
+            direction: ["f"|"b"] to indicate whether we should navigate forward
+                or backward from the given timestamp to find the closest event.
+
+        Returns:
+            A parsed TimestampToEventResponse including the closest event_id
+            and origin_server_ts or None if no destination has a response.
+        """
+
+        async def _timestamp_to_event_from_destination(
+            destination: str,
+        ) -> TimestampToEventResponse:
+            return await self._timestamp_to_event_from_destination(
+                destination, room_id, timestamp, direction
+            )
+
+        try:
+            # Loop through each homeserver candidate until we get a succesful response
+            timestamp_to_event_response = await self._try_destination_list(
+                "timestamp_to_event",
+                destinations,
+                # TODO: The requested timestamp may lie in a part of the
+                #   event graph that the remote server *also* didn't have,
+                #   in which case they will have returned another event
+                #   which may be nowhere near the requested timestamp. In
+                #   the future, we may need to reconcile that gap and ask
+                #   other homeservers, and/or extend `/timestamp_to_event`
+                #   to return events on *both* sides of the timestamp to
+                #   help reconcile the gap faster.
+                _timestamp_to_event_from_destination,
+            )
+            return timestamp_to_event_response
+        except SynapseError:
+            return None
+
+    async def _timestamp_to_event_from_destination(
         self, destination: str, room_id: str, timestamp: int, direction: str
     ) -> "TimestampToEventResponse":
         """