summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13815.feature1
-rw-r--r--synapse/federation/federation_base.py25
-rw-r--r--synapse/federation/federation_client.py50
-rw-r--r--tests/federation/test_federation_client.py75
-rw-r--r--tests/test_federation.py4
5 files changed, 140 insertions, 15 deletions
diff --git a/changelog.d/13815.feature b/changelog.d/13815.feature
new file mode 100644
index 0000000000..ba411f5067
--- /dev/null
+++ b/changelog.d/13815.feature
@@ -0,0 +1 @@
+Keep track when an event pulled over federation fails its signature check so we can intelligently back-off in the future.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index abe2c1971a..6bd4742140 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Awaitable, Callable, Optional
 
 from synapse.api.constants import MAX_DEPTH, EventContentFields, EventTypes, Membership
 from synapse.api.errors import Codes, SynapseError
@@ -58,7 +58,12 @@ class FederationBase:
 
     @trace
     async def _check_sigs_and_hash(
-        self, room_version: RoomVersion, pdu: EventBase
+        self,
+        room_version: RoomVersion,
+        pdu: EventBase,
+        record_failure_callback: Optional[
+            Callable[[EventBase, str], Awaitable[None]]
+        ] = None,
     ) -> EventBase:
         """Checks that event is correctly signed by the sending server.
 
@@ -70,6 +75,11 @@ class FederationBase:
         Args:
             room_version: The room version of the PDU
             pdu: the event to be checked
+            record_failure_callback: A callback to run whenever the given event
+                fails signature or hash checks. This includes exceptions
+                that would be normally be thrown/raised but also things like
+                checking for event tampering where we just return the redacted
+                event.
 
         Returns:
               * the original event if the checks pass
@@ -80,7 +90,12 @@ class FederationBase:
           InvalidEventSignatureError if the signature check failed. Nothing
              will be logged in this case.
         """
-        await _check_sigs_on_pdu(self.keyring, room_version, pdu)
+        try:
+            await _check_sigs_on_pdu(self.keyring, room_version, pdu)
+        except InvalidEventSignatureError as exc:
+            if record_failure_callback:
+                await record_failure_callback(pdu, str(exc))
+            raise exc
 
         if not check_event_content_hash(pdu):
             # let's try to distinguish between failures because the event was
@@ -116,6 +131,10 @@ class FederationBase:
                         "event_id": pdu.event_id,
                     }
                 )
+                if record_failure_callback:
+                    await record_failure_callback(
+                        pdu, "Event content has been tampered with"
+                    )
             return redacted_event
 
         spam_check = await self.spam_checker.check_event_for_spam(pdu)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 464672a3da..4dca711cd2 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -278,7 +278,7 @@ class FederationClient(FederationBase):
         pdus = [event_from_pdu_json(p, room_version) for p in transaction_data_pdus]
 
         # Check signatures and hash of pdus, removing any from the list that fail checks
-        pdus[:] = await self._check_sigs_and_hash_and_fetch(
+        pdus[:] = await self._check_sigs_and_hash_for_pulled_events_and_fetch(
             dest, pdus, room_version=room_version
         )
 
@@ -328,7 +328,17 @@ class FederationClient(FederationBase):
 
             # Check signatures are correct.
             try:
-                signed_pdu = await self._check_sigs_and_hash(room_version, pdu)
+
+                async def _record_failure_callback(
+                    event: EventBase, cause: str
+                ) -> None:
+                    await self.store.record_event_failed_pull_attempt(
+                        event.room_id, event.event_id, cause
+                    )
+
+                signed_pdu = await self._check_sigs_and_hash(
+                    room_version, pdu, _record_failure_callback
+                )
             except InvalidEventSignatureError as e:
                 errmsg = f"event id {pdu.event_id}: {e}"
                 logger.warning("%s", errmsg)
@@ -547,24 +557,28 @@ class FederationClient(FederationBase):
             len(auth_event_map),
         )
 
-        valid_auth_events = await self._check_sigs_and_hash_and_fetch(
+        valid_auth_events = await self._check_sigs_and_hash_for_pulled_events_and_fetch(
             destination, auth_event_map.values(), room_version
         )
 
-        valid_state_events = await self._check_sigs_and_hash_and_fetch(
-            destination, state_event_map.values(), room_version
+        valid_state_events = (
+            await self._check_sigs_and_hash_for_pulled_events_and_fetch(
+                destination, state_event_map.values(), room_version
+            )
         )
 
         return valid_state_events, valid_auth_events
 
     @trace
-    async def _check_sigs_and_hash_and_fetch(
+    async def _check_sigs_and_hash_for_pulled_events_and_fetch(
         self,
         origin: str,
         pdus: Collection[EventBase],
         room_version: RoomVersion,
     ) -> List[EventBase]:
-        """Checks the signatures and hashes of a list of events.
+        """
+        Checks the signatures and hashes of a list of pulled events we got from
+        federation and records any signature failures as failed pull attempts.
 
         If a PDU fails its signature check then we check if we have it in
         the database, and if not then request it from the sender's server (if that
@@ -597,11 +611,17 @@ class FederationClient(FederationBase):
 
         valid_pdus: List[EventBase] = []
 
+        async def _record_failure_callback(event: EventBase, cause: str) -> None:
+            await self.store.record_event_failed_pull_attempt(
+                event.room_id, event.event_id, cause
+            )
+
         async def _execute(pdu: EventBase) -> None:
             valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
                 pdu=pdu,
                 origin=origin,
                 room_version=room_version,
+                record_failure_callback=_record_failure_callback,
             )
 
             if valid_pdu:
@@ -618,6 +638,9 @@ class FederationClient(FederationBase):
         pdu: EventBase,
         origin: str,
         room_version: RoomVersion,
+        record_failure_callback: Optional[
+            Callable[[EventBase, str], Awaitable[None]]
+        ] = None,
     ) -> Optional[EventBase]:
         """Takes a PDU and checks its signatures and hashes.
 
@@ -634,6 +657,11 @@ class FederationClient(FederationBase):
             origin
             pdu
             room_version
+            record_failure_callback: A callback to run whenever the given event
+                fails signature or hash checks. This includes exceptions
+                that would be normally be thrown/raised but also things like
+                checking for event tampering where we just return the redacted
+                event.
 
         Returns:
             The PDU (possibly redacted) if it has valid signatures and hashes.
@@ -641,7 +669,9 @@ class FederationClient(FederationBase):
         """
 
         try:
-            return await self._check_sigs_and_hash(room_version, pdu)
+            return await self._check_sigs_and_hash(
+                room_version, pdu, record_failure_callback
+            )
         except InvalidEventSignatureError as e:
             logger.warning(
                 "Signature on retrieved event %s was invalid (%s). "
@@ -694,7 +724,7 @@ class FederationClient(FederationBase):
 
         auth_chain = [event_from_pdu_json(p, room_version) for p in res["auth_chain"]]
 
-        signed_auth = await self._check_sigs_and_hash_and_fetch(
+        signed_auth = await self._check_sigs_and_hash_for_pulled_events_and_fetch(
             destination, auth_chain, room_version=room_version
         )
 
@@ -1401,7 +1431,7 @@ class FederationClient(FederationBase):
                 event_from_pdu_json(e, room_version) for e in content.get("events", [])
             ]
 
-            signed_events = await self._check_sigs_and_hash_and_fetch(
+            signed_events = await self._check_sigs_and_hash_for_pulled_events_and_fetch(
                 destination, events, room_version=room_version
             )
         except HttpResponseException as e:
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 50e376f695..a538215931 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -23,14 +23,23 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
 
+from tests.test_utils import event_injection
 from tests.unittest import FederatingHomeserverTestCase
 
 
 class FederationClientTest(FederatingHomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
     def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
         super().prepare(reactor, clock, homeserver)
 
@@ -231,6 +240,72 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         return remote_pdu
 
+    def test_backfill_invalid_signature_records_failed_pull_attempts(
+        self,
+    ) -> None:
+        """
+        Test to make sure that events from /backfill with invalid signatures get
+        recorded as failed pull attempts.
+        """
+        OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+        main_store = self.hs.get_datastores().main
+
+        # Create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+        # We purposely don't run `add_hashes_and_signatures_from_other_server`
+        # over this because we want the signature check to fail.
+        pulled_event, _ = self.get_success(
+            event_injection.create_event(
+                self.hs,
+                room_id=room_id,
+                sender=OTHER_USER,
+                type="test_event_type",
+                content={"body": "garply"},
+            )
+        )
+
+        # We expect an outbound request to /backfill, so stub that out
+        self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
+            _mock_response(
+                {
+                    "origin": "yet.another.server",
+                    "origin_server_ts": 900,
+                    # Mimic the other server returning our new `pulled_event`
+                    "pdus": [pulled_event.get_pdu_json()],
+                }
+            )
+        )
+
+        self.get_success(
+            self.hs.get_federation_client().backfill(
+                # We use "yet.another.server" instead of
+                # `self.OTHER_SERVER_NAME` because we want to see the behavior
+                # from `_check_sigs_and_hash_and_fetch_one` where it tries to
+                # fetch the PDU again from the origin server if the signature
+                # fails. Just want to make sure that the failure is counted from
+                # both code paths.
+                dest="yet.another.server",
+                room_id=room_id,
+                limit=1,
+                extremities=[pulled_event.event_id],
+            ),
+        )
+
+        # Make sure our failed pull attempt was recorded
+        backfill_num_attempts = self.get_success(
+            main_store.db_pool.simple_select_one_onecol(
+                table="event_failed_pull_attempts",
+                keyvalues={"event_id": pulled_event.event_id},
+                retcol="num_attempts",
+            )
+        )
+        # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the
+        # other from "yet.another.server"
+        self.assertEqual(backfill_num_attempts, 2)
+
 
 def _mock_response(resp: JsonDict):
     body = json.dumps(resp).encode("utf-8")
diff --git a/tests/test_federation.py b/tests/test_federation.py
index 779fad1f63..80e5c590d8 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -86,8 +86,8 @@ class MessageAcceptTests(unittest.HomeserverTestCase):
 
         federation_event_handler._check_event_auth = _check_event_auth
         self.client = self.homeserver.get_federation_client()
-        self.client._check_sigs_and_hash_and_fetch = lambda dest, pdus, **k: succeed(
-            pdus
+        self.client._check_sigs_and_hash_for_pulled_events_and_fetch = (
+            lambda dest, pdus, **k: succeed(pdus)
         )
 
         # Send the join, it should return None (which is not an error)