diff --git a/changelog.d/11913.misc b/changelog.d/11913.misc
new file mode 100644
index 0000000000..cb70560364
--- /dev/null
+++ b/changelog.d/11913.misc
@@ -0,0 +1 @@
+Tests: replace mocked `Authenticator` with the real thing.
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 7b486aba4a..e40ef95874 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -47,7 +47,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
)
# Get the room complexity
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEquals(200, channel.code)
@@ -59,7 +59,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
store.get_current_state_event_counts = lambda x: make_awaitable(500 * 1.23)
# Get the room complexity again -- make sure it's our artificial value
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/unstable/rooms/%s/complexity" % (room_1,)
)
self.assertEquals(200, channel.code)
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 03e1e11f49..1af284bd2f 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -113,7 +113,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token)
self.inject_room_member(room_1, "@user:other.example.com", "join")
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
self.assertEquals(200, channel.code, channel.result)
@@ -145,7 +145,7 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
room_1 = self.helper.create_room_as(u1, tok=u1_token)
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET", "/_matrix/federation/v1/state/%s" % (room_1,)
)
self.assertEquals(403, channel.code, channel.result)
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index bfa156eebb..686f42ab48 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -245,7 +245,7 @@ class FederationKnockingTestCase(
self.hs, room_id, user_id
)
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET",
"/_matrix/federation/v1/make_knock/%s/%s?ver=%s"
% (
@@ -288,7 +288,7 @@ class FederationKnockingTestCase(
)
# Send the signed knock event into the room
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"PUT",
"/_matrix/federation/v1/send_knock/%s/%s"
% (room_id, signed_knock_event.event_id),
diff --git a/tests/federation/transport/test_server.py b/tests/federation/transport/test_server.py
index 84fa72b9ff..eb62addda8 100644
--- a/tests/federation/transport/test_server.py
+++ b/tests/federation/transport/test_server.py
@@ -22,10 +22,9 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
"""Test that unauthenticated requests to the public rooms directory 403 when
allow_public_rooms_over_federation is False.
"""
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET",
"/_matrix/federation/v1/publicRooms",
- federation_auth_origin=b"example.com",
)
self.assertEquals(403, channel.code)
@@ -34,9 +33,8 @@ class RoomDirectoryFederationTests(unittest.FederatingHomeserverTestCase):
"""Test that unauthenticated requests to the public rooms directory 200 when
allow_public_rooms_over_federation is True.
"""
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
"GET",
"/_matrix/federation/v1/publicRooms",
- federation_auth_origin=b"example.com",
)
self.assertEquals(200, channel.code)
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index 4e71b6ec12..ac6b86ff6b 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -107,6 +107,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
return hs
def prepare(self, reactor, clock, homeserver):
+ super().prepare(reactor, clock, homeserver)
# Create some users and a room to play with during the tests
self.user_id = self.register_user("kermit", "monkey")
self.invitee = self.register_user("invitee", "hackme")
@@ -473,8 +474,6 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def _send_event_over_federation(self) -> None:
"""Send a dummy event over federation and check that the request succeeds."""
body = {
- "origin": self.hs.config.server.server_name,
- "origin_server_ts": self.clock.time_msec(),
"pdus": [
{
"sender": self.user_id,
@@ -492,11 +491,10 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
],
}
- channel = self.make_request(
+ channel = self.make_signed_federation_request(
method="PUT",
path="/_matrix/federation/v1/send/1",
content=body,
- federation_auth_origin=self.hs.config.server.server_name.encode("utf8"),
)
self.assertEqual(channel.code, 200, channel.result)
diff --git a/tests/unittest.py b/tests/unittest.py
index 6fc617601a..a71892cb9d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -17,6 +17,7 @@ import gc
import hashlib
import hmac
import inspect
+import json
import logging
import secrets
import time
@@ -36,9 +37,11 @@ from typing import (
)
from unittest.mock import Mock, patch
-from canonicaljson import json
+import canonicaljson
+import signedjson.key
+import unpaddedbase64
-from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.test.proto_helpers import MemoryReactor
@@ -49,8 +52,7 @@ from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
-from synapse.config.ratelimiting import FederationRateLimitConfig
-from synapse.federation.transport import server as federation_server
+from synapse.federation.transport.server import TransportLayerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import (
@@ -61,10 +63,10 @@ from synapse.logging.context import (
)
from synapse.rest import RegisterServletsFunc
from synapse.server import HomeServer
+from synapse.storage.keys import FetchKeyResult
from synapse.types import JsonDict, UserID, create_requester
from synapse.util import Clock
from synapse.util.httpresourcetree import create_resource_tree
-from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
from tests.test_utils import event_injection, setup_awaitable_errors
@@ -755,42 +757,116 @@ class HomeserverTestCase(TestCase):
class FederatingHomeserverTestCase(HomeserverTestCase):
"""
- A federating homeserver that authenticates incoming requests as `other.example.com`.
+ A federating homeserver, set up to validate incoming federation requests
"""
- def create_resource_dict(self) -> Dict[str, Resource]:
- d = super().create_resource_dict()
- d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
- return d
+ OTHER_SERVER_NAME = "other.example.com"
+ OTHER_SERVER_SIGNATURE_KEY = signedjson.key.generate_signing_key("test")
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+ super().prepare(reactor, clock, hs)
-class TestTransportLayerServer(JsonResource):
- """A test implementation of TransportLayerServer
+ # poke the other server's signing key into the key store, so that we don't
+ # make requests for it
+ verify_key = signedjson.key.get_verify_key(self.OTHER_SERVER_SIGNATURE_KEY)
+ verify_key_id = "%s:%s" % (verify_key.alg, verify_key.version)
- authenticates incoming requests as `other.example.com`.
- """
+ self.get_success(
+ hs.get_datastore().store_server_verify_keys(
+ from_server=self.OTHER_SERVER_NAME,
+ ts_added_ms=clock.time_msec(),
+ verify_keys=[
+ (
+ self.OTHER_SERVER_NAME,
+ verify_key_id,
+ FetchKeyResult(
+ verify_key=verify_key,
+ valid_until_ts=clock.time_msec() + 1000,
+ ),
+ )
+ ],
+ )
+ )
+
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TransportLayerServer(self.hs)
+ return d
- def __init__(self, hs):
- super().__init__(hs)
+ def make_signed_federation_request(
+ self,
+ method: str,
+ path: str,
+ content: Optional[JsonDict] = None,
+ await_result: bool = True,
+ custom_headers: Optional[Iterable[Tuple[AnyStr, AnyStr]]] = None,
+ client_ip: str = "127.0.0.1",
+ ) -> FakeChannel:
+ """Make an inbound signed federation request to this server
- class Authenticator:
- def authenticate_request(self, request, content):
- return succeed("other.example.com")
+ The request is signed as if it came from "other.example.com", which our HS
+ already has the keys for.
+ """
- authenticator = Authenticator()
+ if custom_headers is None:
+ custom_headers = []
+ else:
+ custom_headers = list(custom_headers)
+
+ custom_headers.append(
+ (
+ "Authorization",
+ _auth_header_for_request(
+ origin=self.OTHER_SERVER_NAME,
+ destination=self.hs.hostname,
+ signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+ method=method,
+ path=path,
+ content=content,
+ ),
+ )
+ )
- ratelimiter = FederationRateLimiter(
- hs.get_clock(),
- FederationRateLimitConfig(
- window_size=1,
- sleep_limit=1,
- sleep_delay=1,
- reject_limit=1000,
- concurrent=1000,
- ),
+ return make_request(
+ self.reactor,
+ self.site,
+ method=method,
+ path=path,
+ content=content,
+ shorthand=False,
+ await_result=await_result,
+ custom_headers=custom_headers,
+ client_ip=client_ip,
)
- federation_server.register_servlets(hs, self, authenticator, ratelimiter)
+
+def _auth_header_for_request(
+ origin: str,
+ destination: str,
+ signing_key: signedjson.key.SigningKey,
+ method: str,
+ path: str,
+ content: Optional[JsonDict],
+) -> str:
+ """Build a suitable Authorization header for an outgoing federation request"""
+ request_description: JsonDict = {
+ "method": method,
+ "uri": path,
+ "destination": destination,
+ "origin": origin,
+ }
+ if content is not None:
+ request_description["content"] = content
+ signature_base64 = unpaddedbase64.encode_base64(
+ signing_key.sign(
+ canonicaljson.encode_canonical_json(request_description)
+ ).signature
+ )
+ return (
+ f"X-Matrix origin={origin},"
+ f"key={signing_key.alg}:{signing_key.version},"
+ f"sig={signature_base64}"
+ )
def override_config(extra_config):
|