summary refs log tree commit diff
path: root/tests/federation/test_federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/federation/test_federation_client.py')
-rw-r--r--tests/federation/test_federation_client.py75
1 files changed, 75 insertions, 0 deletions
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index 50e376f695..a538215931 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -23,14 +23,23 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client import login, room
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
 
+from tests.test_utils import event_injection
 from tests.unittest import FederatingHomeserverTestCase
 
 
 class FederationClientTest(FederatingHomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
     def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
         super().prepare(reactor, clock, homeserver)
 
@@ -231,6 +240,72 @@ class FederationClientTest(FederatingHomeserverTestCase):
 
         return remote_pdu
 
+    def test_backfill_invalid_signature_records_failed_pull_attempts(
+        self,
+    ) -> None:
+        """
+        Test to make sure that events from /backfill with invalid signatures get
+        recorded as failed pull attempts.
+        """
+        OTHER_USER = f"@user:{self.OTHER_SERVER_NAME}"
+        main_store = self.hs.get_datastores().main
+
+        # Create the room
+        user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+        room_id = self.helper.create_room_as(room_creator=user_id, tok=tok)
+
+        # We purposely don't run `add_hashes_and_signatures_from_other_server`
+        # over this because we want the signature check to fail.
+        pulled_event, _ = self.get_success(
+            event_injection.create_event(
+                self.hs,
+                room_id=room_id,
+                sender=OTHER_USER,
+                type="test_event_type",
+                content={"body": "garply"},
+            )
+        )
+
+        # We expect an outbound request to /backfill, so stub that out
+        self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
+            _mock_response(
+                {
+                    "origin": "yet.another.server",
+                    "origin_server_ts": 900,
+                    # Mimic the other server returning our new `pulled_event`
+                    "pdus": [pulled_event.get_pdu_json()],
+                }
+            )
+        )
+
+        self.get_success(
+            self.hs.get_federation_client().backfill(
+                # We use "yet.another.server" instead of
+                # `self.OTHER_SERVER_NAME` because we want to see the behavior
+                # from `_check_sigs_and_hash_and_fetch_one` where it tries to
+                # fetch the PDU again from the origin server if the signature
+                # fails. Just want to make sure that the failure is counted from
+                # both code paths.
+                dest="yet.another.server",
+                room_id=room_id,
+                limit=1,
+                extremities=[pulled_event.event_id],
+            ),
+        )
+
+        # Make sure our failed pull attempt was recorded
+        backfill_num_attempts = self.get_success(
+            main_store.db_pool.simple_select_one_onecol(
+                table="event_failed_pull_attempts",
+                keyvalues={"event_id": pulled_event.event_id},
+                retcol="num_attempts",
+            )
+        )
+        # This is 2 because it failed once from `self.OTHER_SERVER_NAME` and the
+        # other from "yet.another.server"
+        self.assertEqual(backfill_num_attempts, 2)
+
 
 def _mock_response(resp: JsonDict):
     body = json.dumps(resp).encode("utf-8")