summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py145
1 files changed, 111 insertions, 34 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index 1431848367..a71892cb9d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -17,6 +17,7 @@ import gc
 import hashlib
 import hmac
 import inspect
+import json
 import logging
 import secrets
 import time
@@ -36,9 +37,11 @@ from typing import (
 )
 from unittest.mock import Mock, patch
 
-from canonicaljson import json
+import canonicaljson
+import signedjson.key
+import unpaddedbase64
 
-from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.python.failure import Failure
 from twisted.python.threadpool import ThreadPool
 from twisted.test.proto_helpers import MemoryReactor
@@ -49,8 +52,7 @@ from twisted.web.server import Request
 from synapse import events
 from synapse.api.constants import EventTypes, Membership
 from synapse.config.homeserver import HomeServerConfig
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server as federation_server
+from synapse.federation.transport.server import TransportLayerServer
 from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.logging.context import (
@@ -61,10 +63,10 @@ from synapse.logging.context import (
 )
 from synapse.rest import RegisterServletsFunc
 from synapse.server import HomeServer
+from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.ratelimitutils import FederationRateLimiter
 
 from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
 from tests.test_utils import event_injection, setup_awaitable_errors
@@ -620,18 +622,19 @@ class HomeserverTestCase(TestCase):
         self,
         username: str,
         appservice_token: str,
-    ) -> str:
+    ) -> Tuple[str, str]:
         """Register an appservice user as an application service.
         Requires the client-facing registration API be registered.
 
         Args:
             username: the user to be registered by an application service.
-                Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart"
+                Should NOT be a full username, i.e. just "localpart" as opposed to "@localpart:hostname"
             appservice_token: the acccess token for that application service.
 
         Raises: if the request to '/register' does not return 200 OK.
 
-        Returns: the MXID of the new user.
+        Returns:
+            The MXID of the new user, the device ID of the new user's first device.
         """
         channel = self.make_request(
             "POST",
@@ -643,7 +646,7 @@ class HomeserverTestCase(TestCase):
             access_token=appservice_token,
         )
         self.assertEqual(channel.code, 200, channel.json_body)
-        return channel.json_body["user_id"]
+        return channel.json_body["user_id"], channel.json_body["device_id"]
 
     def login(
         self,
@@ -754,42 +757,116 @@ class HomeserverTestCase(TestCase):
 
 class FederatingHomeserverTestCase(HomeserverTestCase):
     """
-    A federating homeserver that authenticates incoming requests as `other.example.com`.
+    A federating homeserver, set up to validate incoming federation requests
     """
 
-    def create_resource_dict(self) -> Dict[str, Resource]:
-        d = super().create_resource_dict()
-        d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
-        return d
+    OTHER_SERVER_NAME = "other.example.com"
+    OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
 
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+        super().prepare(reactor, clock, hs)
 
-class TestTransportLayerServer(JsonResource):
-    """A test implementation of TransportLayerServer
+        # poke the other server's signing key into the key store, so that we don't
+        # make requests for it
+        verify_key = signedjson.key.get_verify_key(self.OTHER_SERVER_SIGNATURE_KEY)
+        verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
 
-    authenticates incoming requests as `other.example.com`.
-    """
+        self.get_success(
+            hs.get_datastore().store_server_verify_keys(
+                from_server=self.OTHER_SERVER_NAME,
+                ts_added_ms=clock.time_msec(),
+                verify_keys=[
+                    (
+                        self.OTHER_SERVER_NAME,
+                        verify_key_id,
+                        FetchKeyResult(
+                            verify_key=verify_key,
+                            valid_until_ts=clock.time_msec() + 1000,
+                        ),
+                    )
+                ],
+            )
+        )
+
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TransportLayerServer(self.hs)
+        return d
 
-    def __init__(self, hs):
-        super().__init__(hs)
+    def make_signed_federation_request(
+        self,
+        method: str,
+        path: str,
+        content: Optional[JsonDict] = None,
+        await_result: bool = True,
+        custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+        client_ip: str = "127.0.0.1",
+    ) -> FakeChannel:
+        """Make an inbound signed federation request to this server
 
-        class Authenticator:
-            def authenticate_request(self, request, content):
-                return succeed("other.example.com")
+        The request is signed as if it came from "other.example.com", which our HS
+        already has the keys for.
+        """
 
-        authenticator = Authenticator()
+        if custom_headers is None:
+            custom_headers = []
+        else:
+            custom_headers = list(custom_headers)
+
+        custom_headers.append(
+            (
+                "Authorization",
+                _auth_header_for_request(
+                    origin=self.OTHER_SERVER_NAME,
+                    destination=self.hs.hostname,
+                    signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+                    method=method,
+                    path=path,
+                    content=content,
+                ),
+            )
+        )
 
-        ratelimiter = FederationRateLimiter(
-            hs.get_clock(),
-            FederationRateLimitConfig(
-                window_size=1,
-                sleep_limit=1,
-                sleep_delay=1,
-                reject_limit=1000,
-                concurrent=1000,
-            ),
+        return make_request(
+            self.reactor,
+            self.site,
+            method=method,
+            path=path,
+            content=content,
+            shorthand=False,
+            await_result=await_result,
+            custom_headers=custom_headers,
+            client_ip=client_ip,
         )
 
-        federation_server.register_servlets(hs, self, authenticator, ratelimiter)
+
+def _auth_header_for_request(
+    origin: str,
+    destination: str,
+    signing_key: signedjson.key.SigningKey,
+    method: str,
+    path: str,
+    content: Optional[JsonDict],
+) -> str:
+    """Build a suitable Authorization header for an outgoing federation request"""
+    request_description: JsonDict = {
+        "method": method,
+        "uri": path,
+        "destination": destination,
+        "origin": origin,
+    }
+    if content is not None:
+        request_description["content"] = content
+    signature_base64 = unpaddedbase64.encode_base64(
+        signing_key.sign(
+            canonicaljson.encode_canonical_json(request_description)
+        ).signature
+    )
+    return (
+        f"X-Matrix origin={origin},"
+        f"key={signing_key.alg}:{signing_key.version},"
+        f"sig={signature_base64}"
+    )
 
 
 def override_config(extra_config):