summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_typing.py11
-rw-r--r--tests/replication/_base.py14
-rw-r--r--tests/server.py3
-rw-r--r--tests/unittest.py25
4 files changed, 37 insertions, 16 deletions
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index abbdf2d524..9a085ccaf6 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
 
 
 import json
+from typing import Dict
 
 from mock import ANY, Mock, call
 
 from twisted.internet import defer
+from twisted.web.resource import Resource
 
 from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
 from synapse.types import UserID, create_requester
 
 from tests import unittest
 from tests.test_utils import make_awaitable
 from tests.unittest import override_config
-from tests.utils import register_federation_servlets
 
 # Some local users to test with
 U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
 
 
 class TypingNotificationsTestCase(unittest.HomeserverTestCase):
-    servlets = [register_federation_servlets]
-
     def make_homeserver(self, reactor, clock):
         # we mock out the keyring so as to skip the authentication check on the
         # federation API call.
@@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
 
         return hs
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TransportLayerServer(self.hs)
+        return d
+
     def prepare(self, reactor, clock, hs):
         mock_notifier = hs.get_notifier()
         self.on_new_event = mock_notifier.on_new_event
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 295c5d58a6..79738ab46f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
 import attr
 
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
 )
 from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
     if not hiredis:
         skip = "Requires hiredis"
 
-    servlets = [
-        streams.register_servlets,
-    ]
-
     def prepare(self, reactor, clock, hs):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._client_transport = None
         self._server_transport = None
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+        return d
+
     def _get_worker_hs_config(self) -> dict:
         config = self.default_config()
         config["worker_app"] = "synapse.app.generic_worker"
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..eee970c43c 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -216,8 +216,9 @@ def make_request(
         and not path.startswith(b"/_matrix")
         and not path.startswith(b"/_synapse")
     ):
+        if path.startswith(b"/"):
+            path = path[1:]
         path = b"/_matrix/client/r0/" + path
-        path = path.replace(b"//", b"/")
 
     if not path.startswith(b"/"):
         path = b"/" + path
diff --git a/tests/unittest.py b/tests/unittest.py
index 425b39b1d1..102b0a1f34 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -705,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
     A federating homeserver that authenticates incoming requests as `other.example.com`.
     """
 
-    def prepare(self, reactor, clock, homeserver):
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+        return d
+
+
+class TestTransportLayerServer(JsonResource):
+    """A test implementation of TransportLayerServer
+
+    authenticates incoming requests as `other.example.com`.
+    """
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
         class Authenticator:
             def authenticate_request(self, request, content):
                 return succeed("other.example.com")
 
+        authenticator = Authenticator()
+
         ratelimiter = FederationRateLimiter(
-            clock,
+            hs.get_clock(),
             FederationRateLimitConfig(
                 window_size=1,
                 sleep_limit=1,
@@ -720,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
                 concurrent_requests=1000,
             ),
         )
-        federation_server.register_servlets(
-            homeserver, self.resource, Authenticator(), ratelimiter
-        )
 
-        return super().prepare(reactor, clock, homeserver)
+        federation_server.register_servlets(hs, self, authenticator, ratelimiter)
 
 
 def override_config(extra_config):