diff options
Diffstat (limited to 'tests/unittest.py')
-rw-r--r-- | tests/unittest.py | 97 |
1 files changed, 46 insertions, 51 deletions
diff --git a/tests/unittest.py b/tests/unittest.py index a9d59e31f7..39e5e7b85c 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Optional, Tuple, Type, TypeVar, Union, overload +from typing import Dict, Optional, Type, TypeVar, Union from mock import Mock, patch @@ -46,6 +46,7 @@ from synapse.logging.context import ( ) from synapse.server import HomeServer from synapse.types import UserID, create_requester +from synapse.util.httpresourcetree import create_resource_tree from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver @@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase): """ Create a the root resource for the test server. - The default implementation creates a JsonResource and calls each function in - `servlets` to register servletes against it + The default calls `self.create_resource_dict` and builds the resultant dict + into a tree. """ - resource = JsonResource(self.hs) + root_resource = Resource() + create_resource_tree(self.create_resource_dict(), root_resource) + return root_resource - for servlet in self.servlets: - servlet(self.hs, resource) + def create_resource_dict(self) -> Dict[str, Resource]: + """Create a resource tree for the test server + + A resource tree is a mapping from path to twisted.web.resource. - return resource + The default implementation creates a JsonResource and calls each function in + `servlets` to register servlets against it. + """ + servlet_resource = JsonResource(self.hs) + for servlet in self.servlets: + servlet(self.hs, servlet_resource) + return { + "/_matrix/client": servlet_resource, + "/_synapse/admin": servlet_resource, + } def default_config(self): """ @@ -358,38 +372,6 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ - # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest - # when the `request` arg isn't given, so we define an explicit override to - # cover that case. - @overload - def make_request( - self, - method: Union[bytes, str], - path: Union[bytes, str], - content: Union[bytes, dict] = b"", - access_token: Optional[str] = None, - shorthand: bool = True, - federation_auth_origin: str = None, - content_is_form: bool = False, - await_result: bool = True, - ) -> Tuple[SynapseRequest, FakeChannel]: - ... - - @overload - def make_request( - self, - method: Union[bytes, str], - path: Union[bytes, str], - content: Union[bytes, dict] = b"", - access_token: Optional[str] = None, - request: Type[T] = SynapseRequest, - shorthand: bool = True, - federation_auth_origin: str = None, - content_is_form: bool = False, - await_result: bool = True, - ) -> Tuple[T, FakeChannel]: - ... - def make_request( self, method: Union[bytes, str], @@ -401,7 +383,7 @@ class HomeserverTestCase(TestCase): federation_auth_origin: str = None, content_is_form: bool = False, await_result: bool = True, - ) -> Tuple[T, FakeChannel]: + ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the given content. @@ -424,7 +406,7 @@ class HomeserverTestCase(TestCase): tells the channel the request is finished. Returns: - Tuple[synapse.http.site.SynapseRequest, channel] + The FakeChannel object which stores the result of the request. """ return make_request( self.reactor, @@ -554,7 +536,7 @@ class HomeserverTestCase(TestCase): self.hs.config.registration_shared_secret = "shared" # Create the user - request, channel = self.make_request("GET", "/_synapse/admin/v1/register") + channel = self.make_request("GET", "/_synapse/admin/v1/register") self.assertEqual(channel.code, 200, msg=channel.result) nonce = channel.json_body["nonce"] @@ -579,7 +561,7 @@ class HomeserverTestCase(TestCase): "inhibit_login": True, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "/_synapse/admin/v1/register", body.encode("utf8") ) self.assertEqual(channel.code, 200, channel.json_body) @@ -597,7 +579,7 @@ class HomeserverTestCase(TestCase): if device_id: body["device_id"] = device_id - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.assertEqual(channel.code, 200, channel.result) @@ -665,7 +647,7 @@ class HomeserverTestCase(TestCase): """ body = {"type": "m.login.password", "user": username, "password": password} - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.assertEqual(channel.code, 403, channel.result) @@ -691,13 +673,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase): A federating homeserver that authenticates incoming requests as `other.example.com`. """ - def prepare(self, reactor, clock, homeserver): + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_matrix/federation"] = TestTransportLayerServer(self.hs) + return d + + +class TestTransportLayerServer(JsonResource): + """A test implementation of TransportLayerServer + + authenticates incoming requests as `other.example.com`. + """ + + def __init__(self, hs): + super().__init__(hs) + class Authenticator: def authenticate_request(self, request, content): return succeed("other.example.com") + authenticator = Authenticator() + ratelimiter = FederationRateLimiter( - clock, + hs.get_clock(), FederationRateLimitConfig( window_size=1, sleep_limit=1, @@ -706,11 +704,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase): concurrent_requests=1000, ), ) - federation_server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - return super().prepare(reactor, clock, homeserver) + federation_server.register_servlets(hs, self, authenticator, ratelimiter) def override_config(extra_config): |