summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py134
1 files changed, 83 insertions, 51 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index a9d59e31f7..767d5d6077 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
 import inspect
 import logging
 import time
-from typing import Optional, Tuple, Type, TypeVar, Union, overload
+from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
 
 from mock import Mock, patch
 
@@ -46,6 +46,7 @@ from synapse.logging.context import (
 )
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
+from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
@@ -320,15 +321,28 @@ class HomeserverTestCase(TestCase):
         """
         Create a the root resource for the test server.
 
-        The default implementation creates a JsonResource and calls each function in
-        `servlets` to register servletes against it
+        The default calls `self.create_resource_dict` and builds the resultant dict
+        into a tree.
         """
-        resource = JsonResource(self.hs)
+        root_resource = Resource()
+        create_resource_tree(self.create_resource_dict(), root_resource)
+        return root_resource
 
-        for servlet in self.servlets:
-            servlet(self.hs, resource)
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        """Create a resource tree for the test server
 
-        return resource
+        A resource tree is a mapping from path to twisted.web.resource.
+
+        The default implementation creates a JsonResource and calls each function in
+        `servlets` to register servlets against it.
+        """
+        servlet_resource = JsonResource(self.hs)
+        for servlet in self.servlets:
+            servlet(self.hs, servlet_resource)
+        return {
+            "/_matrix/client": servlet_resource,
+            "/_synapse/admin": servlet_resource,
+        }
 
     def default_config(self):
         """
@@ -358,24 +372,6 @@ class HomeserverTestCase(TestCase):
         Function to optionally be overridden in subclasses.
         """
 
-    # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
-    # when the `request` arg isn't given, so we define an explicit override to
-    # cover that case.
-    @overload
-    def make_request(
-        self,
-        method: Union[bytes, str],
-        path: Union[bytes, str],
-        content: Union[bytes, dict] = b"",
-        access_token: Optional[str] = None,
-        shorthand: bool = True,
-        federation_auth_origin: str = None,
-        content_is_form: bool = False,
-        await_result: bool = True,
-    ) -> Tuple[SynapseRequest, FakeChannel]:
-        ...
-
-    @overload
     def make_request(
         self,
         method: Union[bytes, str],
@@ -387,21 +383,11 @@ class HomeserverTestCase(TestCase):
         federation_auth_origin: str = None,
         content_is_form: bool = False,
         await_result: bool = True,
-    ) -> Tuple[T, FakeChannel]:
-        ...
-
-    def make_request(
-        self,
-        method: Union[bytes, str],
-        path: Union[bytes, str],
-        content: Union[bytes, dict] = b"",
-        access_token: Optional[str] = None,
-        request: Type[T] = SynapseRequest,
-        shorthand: bool = True,
-        federation_auth_origin: str = None,
-        content_is_form: bool = False,
-        await_result: bool = True,
-    ) -> Tuple[T, FakeChannel]:
+        custom_headers: Optional[
+            Iterable[Tuple[Union[bytes, str], Union[bytes, str]]]
+        ] = None,
+        client_ip: str = "127.0.0.1",
+    ) -> FakeChannel:
         """
         Create a SynapseRequest at the path using the method and containing the
         given content.
@@ -423,8 +409,13 @@ class HomeserverTestCase(TestCase):
                  true (the default), will pump the test reactor until the the renderer
                  tells the channel the request is finished.
 
+            custom_headers: (name, value) pairs to add as request headers
+
+            client_ip: The IP to use as the requesting IP. Useful for testing
+                ratelimiting.
+
         Returns:
-            Tuple[synapse.http.site.SynapseRequest, channel]
+            The FakeChannel object which stores the result of the request.
         """
         return make_request(
             self.reactor,
@@ -438,6 +429,8 @@ class HomeserverTestCase(TestCase):
             federation_auth_origin,
             content_is_form,
             await_result,
+            custom_headers,
+            client_ip,
         )
 
     def setup_test_homeserver(self, *args, **kwargs):
@@ -554,7 +547,7 @@ class HomeserverTestCase(TestCase):
         self.hs.config.registration_shared_secret = "shared"
 
         # Create the user
-        request, channel = self.make_request("GET", "/_synapse/admin/v1/register")
+        channel = self.make_request("GET", "/_synapse/admin/v1/register")
         self.assertEqual(channel.code, 200, msg=channel.result)
         nonce = channel.json_body["nonce"]
 
@@ -579,7 +572,7 @@ class HomeserverTestCase(TestCase):
                 "inhibit_login": True,
             }
         )
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/_synapse/admin/v1/register", body.encode("utf8")
         )
         self.assertEqual(channel.code, 200, channel.json_body)
@@ -597,7 +590,7 @@ class HomeserverTestCase(TestCase):
         if device_id:
             body["device_id"] = device_id
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
         )
         self.assertEqual(channel.code, 200, channel.result)
@@ -665,7 +658,7 @@ class HomeserverTestCase(TestCase):
         """
         body = {"type": "m.login.password", "user": username, "password": password}
 
-        request, channel = self.make_request(
+        channel = self.make_request(
             "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
         )
         self.assertEqual(channel.code, 403, channel.result)
@@ -691,13 +684,29 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
     A federating homeserver that authenticates incoming requests as `other.example.com`.
     """
 
-    def prepare(self, reactor, clock, homeserver):
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_matrix/federation"] = TestTransportLayerServer(self.hs)
+        return d
+
+
+class TestTransportLayerServer(JsonResource):
+    """A test implementation of TransportLayerServer
+
+    authenticates incoming requests as `other.example.com`.
+    """
+
+    def __init__(self, hs):
+        super().__init__(hs)
+
         class Authenticator:
             def authenticate_request(self, request, content):
                 return succeed("other.example.com")
 
+        authenticator = Authenticator()
+
         ratelimiter = FederationRateLimiter(
-            clock,
+            hs.get_clock(),
             FederationRateLimitConfig(
                 window_size=1,
                 sleep_limit=1,
@@ -706,11 +715,8 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
                 concurrent_requests=1000,
             ),
         )
-        federation_server.register_servlets(
-            homeserver, self.resource, Authenticator(), ratelimiter
-        )
 
-        return super().prepare(reactor, clock, homeserver)
+        federation_server.register_servlets(hs, self, authenticator, ratelimiter)
 
 
 def override_config(extra_config):
@@ -735,3 +741,29 @@ def override_config(extra_config):
         return func
 
     return decorator
+
+
+TV = TypeVar("TV")
+
+
+def skip_unless(condition: bool, reason: str) -> Callable[[TV], TV]:
+    """A test decorator which will skip the decorated test unless a condition is set
+
+    For example:
+
+    class MyTestCase(TestCase):
+        @skip_unless(HAS_FOO, "Cannot test without foo")
+        def test_foo(self):
+            ...
+
+    Args:
+        condition: If true, the test will be skipped
+        reason: the reason to give for skipping the test
+    """
+
+    def decorator(f: TV) -> TV:
+        if not condition:
+            f.skip = reason  # type: ignore
+        return f
+
+    return decorator