diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index b53e7a20ec..434718ddfc 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -1462,7 +1462,7 @@ def register_servlets(hs, resource, authenticator, ratelimiter, servlet_groups=N
Args:
hs (synapse.server.HomeServer): homeserver
- resource (TransportLayerServer): resource class to register to
+ resource (JsonResource): resource class to register to
authenticator (Authenticator): authenticator to use
ratelimiter (util.ratelimitutils.FederationRateLimiter): ratelimiter to use
servlet_groups (list[str], optional): List of servlet groups to register.
diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py
index abbdf2d524..9a085ccaf6 100644
--- a/tests/handlers/test_typing.py
+++ b/tests/handlers/test_typing.py
@@ -15,18 +15,20 @@
import json
+from typing import Dict
from mock import ANY, Mock, call
from twisted.internet import defer
+from twisted.web.resource import Resource
from synapse.api.errors import AuthError
+from synapse.federation.transport.server import TransportLayerServer
from synapse.types import UserID, create_requester
from tests import unittest
from tests.test_utils import make_awaitable
from tests.unittest import override_config
-from tests.utils import register_federation_servlets
# Some local users to test with
U_APPLE = UserID.from_string("@apple:test")
@@ -53,8 +55,6 @@ def _make_edu_transaction_json(edu_type, content):
class TypingNotificationsTestCase(unittest.HomeserverTestCase):
- servlets = [register_federation_servlets]
-
def make_homeserver(self, reactor, clock):
# we mock out the keyring so as to skip the authentication check on the
# federation API call.
@@ -77,6 +77,11 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase):
return hs
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TransportLayerServer(self.hs)
+ return d
+
def prepare(self, reactor, clock, hs):
mock_notifier = hs.get_notifier()
self.on_new_event = mock_notifier.on_new_event
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 295c5d58a6..79738ab46f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- servlets = [
- streams.register_servlets,
- ]
-
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+ return d
+
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
diff --git a/tests/server.py b/tests/server.py
index a51ad0c14e..eee970c43c 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -216,8 +216,9 @@ def make_request(
and not path.startswith(b"/_matrix")
and not path.startswith(b"/_synapse")
):
+ if path.startswith(b"/"):
+ path = path[1:]
path = b"/_matrix/client/r0/" + path
- path = path.replace(b"//", b"/")
if not path.startswith(b"/"):
path = b"/" + path
diff --git a/tests/unittest.py b/tests/unittest.py
index 425b39b1d1..102b0a1f34 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -705,13 +705,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
A federating homeserver that authenticates incoming requests as `other.example.com`.
"""
- def prepare(self, reactor, clock, homeserver):
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+ return d
+
+
+class TestTransportLayerServer(JsonResource):
+ """A test implementation of TransportLayerServer
+
+ authenticates incoming requests as `other.example.com`.
+ """
+
+ def __init__(self, hs):
+ super().__init__(hs)
+
class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
+ authenticator = Authenticator()
+
ratelimiter = FederationRateLimiter(
- clock,
+ hs.get_clock(),
FederationRateLimitConfig(
window_size=1,
sleep_limit=1,
@@ -720,11 +736,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
concurrent_requests=1000,
),
)
- federation_server.register_servlets(
- homeserver, self.resource, Authenticator(), ratelimiter
- )
- return super().prepare(reactor, clock, homeserver)
+ federation_server.register_servlets(hs, self, authenticator, ratelimiter)
def override_config(extra_config):
|