summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/federation/test_complexity.py4
-rw-r--r--tests/federation/test_federation_server.py4
-rw-r--r--tests/federation/transport/test_knocking.py4
-rw-r--r--tests/federation/transport/test_server.py6
-rw-r--r--tests/rest/client/test_third_party_rules.py6
-rw-r--r--tests/unittest.py136
6 files changed, 116 insertions, 44 deletions
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 7b486aba4a..e40ef95874 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -47,7 +47,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         )
 
         # Get the room complexity
-        channel = self.make_request(
+        channel = self.make_signed_federation_request(
             "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
         )
         self.assertEquals(200, channel.code)
@@ -59,7 +59,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
 
         # Get the room complexity again -- make sure it's our artificial value
-        channel = self.make_request(
+        channel = self.make_signed_federation_request(
             "GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
         )
         self.assertEquals(200, channel.code)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 03e1e11f49..1af284bd2f 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -113,7 +113,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
         room_1 = self.helper.create_room_as(u1, tok=u1_token)
         self.inject_room_member(room_1, "@user:other.example.com", "join")
 
-        channel = self.make_request(
+        channel = self.make_signed_federation_request(
             "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
         )
         self.assertEquals(200, channel.code, channel.result)
@@ -145,7 +145,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
 
         room_1 = self.helper.create_room_as(u1, tok=u1_token)
 
-        channel = self.make_request(
+        channel = self.make_signed_federation_request(
             "GET", "/_matrix/federation/v1/state/%s" % (room_1,)
         )
         self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index bfa156eebb..686f42ab48 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -245,7 +245,7 @@ class FederationKnockingTestCase(
             self.hs, room_id, user_id
         )
 
-        channel = self.make_request(
+        channel = self.make_signed_federation_request(
             "GET",
             "/_matrix/federation/v1/make_knock/%s/%s?ver=%s"
             % (
@@ -288,7 +288,7 @@ class FederationKnockingTestCase(
         )
 
         # Send the signed knock event into the room
-        channel = self.make_request(
+        channel = self.make_signed_federation_request(
             "PUT",
             "/_matrix/federation/v1/send_knock/%s/%s"
             % (room_id, signed_knock_event.event_id),
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 84fa72b9ff..eb62addda8 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -22,10 +22,9 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
         """Test that unauthenticated requests to the public rooms directory 403 when
         allow_public_rooms_over_federation is False.
         """
-        channel = self.make_request(
+        channel = self.make_signed_federation_request(
             "GET",
             "/_matrix/federation/v1/publicRooms",
-            federation_auth_origin=b"example.com",
         )
         self.assertEquals(403, channel.code)
 
@@ -34,9 +33,8 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
         """Test that unauthenticated requests to the public rooms directory 200 when
         allow_public_rooms_over_federation is True.
         """
-        channel = self.make_request(
+        channel = self.make_signed_federation_request(
             "GET",
             "/_matrix/federation/v1/publicRooms",
-            federation_auth_origin=b"example.com",
         )
         self.assertEquals(200, channel.code)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 4e71b6ec12..ac6b86ff6b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -107,6 +107,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         return hs
 
     def prepare(self, reactor, clock, homeserver):
+        super().prepare(reactor, clock, homeserver)
         # Create some users and a room to play with during the tests
         self.user_id = self.register_user("kermit", "monkey")
         self.invitee = self.register_user("invitee", "hackme")
@@ -473,8 +474,6 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
     def _send_event_over_federation(self) -> None:
         """Send a dummy event over federation and check that the request succeeds."""
         body = {
-            "origin": self.hs.config.server.server_name,
-            "origin_server_ts": self.clock.time_msec(),
             "pdus": [
                 {
                     "sender": self.user_id,
@@ -492,11 +491,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
             ],
         }
 
-        channel = self.make_request(
+        channel = self.make_signed_federation_request(
             method="PUT",
             path="/_matrix/federation/v1/send/1",
             content=body,
-            federation_auth_origin=self.hs.config.server.server_name.encode("utf8"),
         )
 
         self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/unittest.py b/tests/unittest.py
index 6fc617601a..a71892cb9d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -17,6 +17,7 @@ import gc
 import hashlib
 import hmac
 import inspect
+import json
 import logging
 import secrets
 import time
@@ -36,9 +37,11 @@ from typing import (
 )
 from unittest.mock import Mock, patch
 
-from canonicaljson import json
+import canonicaljson
+import signedjson.key
+import unpaddedbase64
 
-from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.python.failure import Failure
 from twisted.python.threadpool import ThreadPool
 from twisted.test.proto_helpers import MemoryReactor
@@ -49,8 +52,7 @@ from twisted.web.server import Request
 from synapse import events
 from synapse.api.constants import EventTypes, Membership
 from synapse.config.homeserver import HomeServerConfig
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server as federation_server
+from synapse.federation.transport.server import TransportLayerServer
 from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.logging.context import (
@@ -61,10 +63,10 @@ from synapse.logging.context import (
 )
 from synapse.rest import RegisterServletsFunc
 from synapse.server import HomeServer
+from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.ratelimitutils import FederationRateLimiter
 
 from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
 from tests.test_utils import event_injection, setup_awaitable_errors
@@ -755,42 +757,116 @@ class HomeserverTestCase(TestCase):
 
 class FederatingHomeserverTestCase(HomeserverTestCase):
     """
-    A federating homeserver that authenticates incoming requests as `other.example.com`.
+    A federating homeserver, set up to validate incoming federation requests
     """
 
-    def create_resource_dict(self) -> Dict[str, Resource]:
-        d = super().create_resource_dict()
-        d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
-        return d
+    OTHER_SERVER_NAME = "other.example.com"
+    OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
 
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+        super().prepare(reactor, clock, hs)
 
-class TestTransportLayerServer(JsonResource):
-    """A test implementation of TransportLayerServer
+        # poke the other server's signing key into the key store, so that we don't
+        # make requests for it
+        verify_key = signedjson.key.get_verify_key(self.OTHER_SERVER_SIGNATURE_KEY)
+        verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
 
-    authenticates incoming requests as `other.example.com`.
-    """
+        self.get_success(
+            hs.get_datastore().store_server_verify_keys(
+                from_server=self.OTHER_SERVER_NAME,
+                ts_added_ms=clock.time_msec(),
+                verify_keys=[
+                    (
+                        self.OTHER_SERVER_NAME,
+                        verify_key_id,
+                        FetchKeyResult(
+                            verify_key=verify_key,
+                            valid_until_ts=clock.time_msec() + 1000,
+                        ),
+                    )
+                ],
+            )
+        )
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TransportLayerServer(self.hs)
+        return d
 
-    def __init__(self, hs):
-        super().__init__(hs)
+    def make_signed_federation_request(
+        self,
+        method: str,
+        path: str,
+        content: Optional[JsonDict] = None,
+        await_result: bool = True,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+        client_ip: str = "127.0.0.1",
+    ) -> FakeChannel:
+        """Make an inbound signed federation request to this server
 
-        class Authenticator:
-            def authenticate_request(self, request, content):
-                return succeed("other.example.com")
+        The request is signed as if it came from "other.example.com", which our HS
+        already has the keys for.
+        """
 
-        authenticator = Authenticator()
+        if custom_headers is None:
+            custom_headers = []
+        else:
+            custom_headers = list(custom_headers)
+
+        custom_headers.append(
+            (
+                "Authorization",
+                _auth_header_for_request(
+                    origin=self.OTHER_SERVER_NAME,
+                    destination=self.hs.hostname,
+                    signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+                    method=method,
+                    path=path,
+                    content=content,
+                ),
+            )
+        )
 
-        ratelimiter = FederationRateLimiter(
-            hs.get_clock(),
-            FederationRateLimitConfig(
-                window_size=1,
-                sleep_limit=1,
-                sleep_delay=1,
-                reject_limit=1000,
-                concurrent=1000,
-            ),
+        return make_request(
+            self.reactor,
+            self.site,
+            method=method,
+            path=path,
+            content=content,
+            shorthand=False,
+            await_result=await_result,
+            custom_headers=custom_headers,
+            client_ip=client_ip,
         )
 
-        federation_server.register_servlets(hs, self, authenticator, ratelimiter)
+
+def _auth_header_for_request(
+    origin: str,
+    destination: str,
+    signing_key: signedjson.key.SigningKey,
+    method: str,
+    path: str,
+    content: Optional[JsonDict],
+) -> str:
+    """Build a suitable Authorization header for an outgoing federation request"""
+    request_description: JsonDict = {
+        "method": method,
+        "uri": path,
+        "destination": destination,
+        "origin": origin,
+    }
+    if content is not None:
+        request_description["content"] = content
+    signature_base64 = unpaddedbase64.encode_base64(
+        signing_key.sign(
+            canonicaljson.encode_canonical_json(request_description)
+        ).signature
+    )
+    return (
+        f"X-Matrix origin={origin},"
+        f"key={signing_key.alg}:{signing_key.version},"
+        f"sig={signature_base64}"
+    )
 
 
 def override_config(extra_config):