diff options
Diffstat (limited to 'tests/unittest.py')
-rw-r--r-- | tests/unittest.py | 134 |
1 files changed, 83 insertions, 51 deletions
diff --git a/tests/unittest.py b/tests/unittest.py index a9d59e31f7..767d5d6077 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -20,7 +20,7 @@ import hmac import inspect import logging import time -from typing import Optional, Tuple, Type, TypeVar, Union, overload +from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union from mock import Mock, patch @@ -46,6 +46,7 @@ from synapse.logging.context import ( ) from synapse.server import HomeServer from synapse.types import UserID, create_requester +from synapse.util.httpresourcetree import create_resource_tree from synapse.util.ratelimitutils import FederationRateLimiter from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver @@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase): """ Create a the root resource for the test server. - The default implementation creates a JsonResource and calls each function in - `servlets` to register servletes against it + The default calls `self.create_resource_dict` and builds the resultant dict + into a tree. """ - resource = JsonResource(self.hs) + root_resource = Resource() + create_resource_tree(self.create_resource_dict(), root_resource) + return root_resource - for servlet in self.servlets: - servlet(self.hs, resource) + def create_resource_dict(self) -> Dict[str, Resource]: + """Create a resource tree for the test server - return resource + A resource tree is a mapping from path to twisted.web.resource. + + The default implementation creates a JsonResource and calls each function in + `servlets` to register servlets against it. + """ + servlet_resource = JsonResource(self.hs) + for servlet in self.servlets: + servlet(self.hs, servlet_resource) + return { + "/_matrix/client": servlet_resource, + "/_synapse/admin": servlet_resource, + } def default_config(self): """ @@ -358,24 +372,6 @@ class HomeserverTestCase(TestCase): Function to optionally be overridden in subclasses. """ - # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest - # when the `request` arg isn't given, so we define an explicit override to - # cover that case. - @overload - def make_request( - self, - method: Union[bytes, str], - path: Union[bytes, str], - content: Union[bytes, dict] = b"", - access_token: Optional[str] = None, - shorthand: bool = True, - federation_auth_origin: str = None, - content_is_form: bool = False, - await_result: bool = True, - ) -> Tuple[SynapseRequest, FakeChannel]: - ... - - @overload def make_request( self, method: Union[bytes, str], @@ -387,21 +383,11 @@ class HomeserverTestCase(TestCase): federation_auth_origin: str = None, content_is_form: bool = False, await_result: bool = True, - ) -> Tuple[T, FakeChannel]: - ... - - def make_request( - self, - method: Union[bytes, str], - path: Union[bytes, str], - content: Union[bytes, dict] = b"", - access_token: Optional[str] = None, - request: Type[T] = SynapseRequest, - shorthand: bool = True, - federation_auth_origin: str = None, - content_is_form: bool = False, - await_result: bool = True, - ) -> Tuple[T, FakeChannel]: + custom_headers: Optional[ + Iterable[Tuple[Union[bytes, str], Union[bytes, str]]] + ] = None, + client_ip: str = "127.0.0.1", + ) -> FakeChannel: """ Create a SynapseRequest at the path using the method and containing the given content. @@ -423,8 +409,13 @@ class HomeserverTestCase(TestCase): true (the default), will pump the test reactor until the the renderer tells the channel the request is finished. + custom_headers: (name, value) pairs to add as request headers + + client_ip: The IP to use as the requesting IP. Useful for testing + ratelimiting. + Returns: - Tuple[synapse.http.site.SynapseRequest, channel] + The FakeChannel object which stores the result of the request. """ return make_request( self.reactor, @@ -438,6 +429,8 @@ class HomeserverTestCase(TestCase): federation_auth_origin, content_is_form, await_result, + custom_headers, + client_ip, ) def setup_test_homeserver(self, *args, **kwargs): @@ -554,7 +547,7 @@ class HomeserverTestCase(TestCase): self.hs.config.registration_shared_secret = "shared" # Create the user - request, channel = self.make_request("GET", "/_synapse/admin/v1/register") + channel = self.make_request("GET", "/_synapse/admin/v1/register") self.assertEqual(channel.code, 200, msg=channel.result) nonce = channel.json_body["nonce"] @@ -579,7 +572,7 @@ class HomeserverTestCase(TestCase): "inhibit_login": True, } ) - request, channel = self.make_request( + channel = self.make_request( "POST", "/_synapse/admin/v1/register", body.encode("utf8") ) self.assertEqual(channel.code, 200, channel.json_body) @@ -597,7 +590,7 @@ class HomeserverTestCase(TestCase): if device_id: body["device_id"] = device_id - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.assertEqual(channel.code, 200, channel.result) @@ -665,7 +658,7 @@ class HomeserverTestCase(TestCase): """ body = {"type": "m.login.password", "user": username, "password": password} - request, channel = self.make_request( + channel = self.make_request( "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.assertEqual(channel.code, 403, channel.result) @@ -691,13 +684,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase): A federating homeserver that authenticates incoming requests as `other.example.com`. """ - def prepare(self, reactor, clock, homeserver): + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_matrix/federation"] = TestTransportLayerServer(self.hs) + return d + + +class TestTransportLayerServer(JsonResource): + """A test implementation of TransportLayerServer + + authenticates incoming requests as `other.example.com`. + """ + + def __init__(self, hs): + super().__init__(hs) + class Authenticator: def authenticate_request(self, request, content): return succeed("other.example.com") + authenticator = Authenticator() + ratelimiter = FederationRateLimiter( - clock, + hs.get_clock(), FederationRateLimitConfig( window_size=1, sleep_limit=1, @@ -706,11 +715,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase): concurrent_requests=1000, ), ) - federation_server.register_servlets( - homeserver, self.resource, Authenticator(), ratelimiter - ) - return super().prepare(reactor, clock, homeserver) + federation_server.register_servlets(hs, self, authenticator, ratelimiter) def override_config(extra_config): @@ -735,3 +741,29 @@ def override_config(extra_config): return func return decorator + + +TV = TypeVar("TV") + + +def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]: + """A test decorator which will skip the decorated test unless a condition is set + + For example: + + class MyTestCase(TestCase): + @skip_unless(HAS_FOO, "Cannot test without foo") + def test_foo(self): + ... + + Args: + condition: If true, the test will be skipped + reason: the reason to give for skipping the test + """ + + def decorator(f: TV) -> TV: + if not condition: + f.skip = reason # type: ignore + return f + + return decorator |