summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py139
1 files changed, 106 insertions, 33 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index 561cebc223..6b6f224e9c 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
 # Copyright 2018 New Vector
+# Copyright 2019 Matrix.org Federation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -13,29 +14,47 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 import gc
 import hashlib
 import hmac
+import inspect
 import logging
 import time
+from typing import Optional, Tuple, Type, TypeVar, Union
 
 from mock import Mock
 
 from canonicaljson import json
 
-from twisted.internet.defer import Deferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
 from twisted.python.threadpool import ThreadPool
 from twisted.trial import unittest
 
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
 from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server as federation_server
 from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
-from synapse.logging.context import LoggingContext
+from synapse.http.site import SynapseRequest, SynapseSite
+from synapse.logging.context import (
+    SENTINEL_CONTEXT,
+    LoggingContext,
+    current_context,
+    set_current_context,
+)
 from synapse.server import HomeServer
 from synapse.types import Requester, UserID, create_requester
-
-from tests.server import get_clock, make_request, render, setup_test_homeserver
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests.server import (
+    FakeChannel,
+    get_clock,
+    make_request,
+    render,
+    setup_test_homeserver,
+)
+from tests.test_utils import event_injection
 from tests.test_utils.logging_setup import setup_logging
 from tests.utils import default_config, setupdb
 
@@ -64,6 +83,9 @@ def around(target):
     return _around
 
 
+T = TypeVar("T")
+
+
 class TestCase(unittest.TestCase):
     """A subclass of twisted.trial's TestCase which looks for 'loglevel'
     attributes on both itself and its individual test methods, to override the
@@ -80,10 +102,10 @@ class TestCase(unittest.TestCase):
         def setUp(orig):
             # if we're not starting in the sentinel logcontext, then to be honest
             # all future bets are off.
-            if LoggingContext.current_context() is not LoggingContext.sentinel:
+            if current_context():
                 self.fail(
                     "Test starting with non-sentinel logging context %s"
-                    % (LoggingContext.current_context(),)
+                    % (current_context(),)
                 )
 
             old_level = logging.getLogger().level
@@ -105,7 +127,7 @@ class TestCase(unittest.TestCase):
             # force a GC to workaround problems with deferreds leaking logcontexts when
             # they are GCed (see the logcontext docs)
             gc.collect()
-            LoggingContext.set_current_context(LoggingContext.sentinel)
+            set_current_context(SENTINEL_CONTEXT)
 
             return ret
 
@@ -203,6 +225,15 @@ class HomeserverTestCase(TestCase):
         # Register the resources
         self.resource = self.create_test_json_resource()
 
+        # create a site to wrap the resource.
+        self.site = SynapseSite(
+            logger_name="synapse.access.http.fake",
+            site_tag="test",
+            config={},
+            resource=self.resource,
+            server_version_string="1",
+        )
+
         from tests.rest.client.v1.utils import RestHelper
 
         self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
@@ -285,14 +316,11 @@ class HomeserverTestCase(TestCase):
 
         return resource
 
-    def default_config(self, name="test"):
+    def default_config(self):
         """
         Get a default HomeServer config dict.
-
-        Args:
-            name (str): The homeserver name/domain.
         """
-        config = default_config(name)
+        config = default_config("test")
 
         # apply any additional config which was specified via the override_config
         # decorator.
@@ -318,14 +346,14 @@ class HomeserverTestCase(TestCase):
 
     def make_request(
         self,
-        method,
-        path,
-        content=b"",
-        access_token=None,
-        request=SynapseRequest,
-        shorthand=True,
-        federation_auth_origin=None,
-    ):
+        method: Union[bytes, str],
+        path: Union[bytes, str],
+        content: Union[bytes, dict] = b"",
+        access_token: Optional[str] = None,
+        request: Type[T] = SynapseRequest,
+        shorthand: bool = True,
+        federation_auth_origin: str = None,
+    ) -> Tuple[T, FakeChannel]:
         """
         Create a SynapseRequest at the path using the method and containing the
         given content.
@@ -392,13 +420,17 @@ class HomeserverTestCase(TestCase):
         config_obj.parse_config_dict(config, "", "")
         kwargs["config"] = config_obj
 
+        async def run_bg_updates():
+            with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
+                while not await stor.db.updates.has_completed_background_updates():
+                    await stor.db.updates.do_next_background_update(1)
+
         hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
         stor = hs.get_datastore()
 
-        # Run the database background updates.
-        if hasattr(stor, "do_next_background_update"):
-            while not self.get_success(stor.has_completed_background_updates()):
-                self.get_success(stor.do_next_background_update(1))
+        # Run the database background updates, when running against "master".
+        if hs.__class__.__name__ == "TestHomeServer":
+            self.get_success(run_bg_updates())
 
         return hs
 
@@ -409,6 +441,8 @@ class HomeserverTestCase(TestCase):
         self.reactor.pump([by] * 100)
 
     def get_success(self, d, by=0.0):
+        if inspect.isawaitable(d):
+            d = ensureDeferred(d)
         if not isinstance(d, Deferred):
             return d
         self.pump(by=by)
@@ -418,6 +452,8 @@ class HomeserverTestCase(TestCase):
         """
         Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
         """
+        if inspect.isawaitable(d):
+            d = ensureDeferred(d)
         if not isinstance(d, Deferred):
             return d
         self.pump()
@@ -441,7 +477,7 @@ class HomeserverTestCase(TestCase):
         # Create the user
         request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
         self.render(request)
-        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.code, 200, msg=channel.result)
         nonce = channel.json_body["nonce"]
 
         want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
@@ -461,6 +497,7 @@ class HomeserverTestCase(TestCase):
                 "password": password,
                 "admin": admin,
                 "mac": want_mac,
+                "inhibit_login": True,
             }
         )
         request, channel = self.make_request(
@@ -509,10 +546,6 @@ class HomeserverTestCase(TestCase):
         secrets = self.hs.get_secrets()
         requester = Requester(user, None, False, None, None)
 
-        prev_events_and_hashes = None
-        if prev_event_ids:
-            prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids]
-
         event, context = self.get_success(
             event_creator.create_event(
                 requester,
@@ -522,7 +555,7 @@ class HomeserverTestCase(TestCase):
                     "sender": user.to_string(),
                     "content": {"body": secrets.token_hex(), "msgtype": "m.text"},
                 },
-                prev_events_and_hashes=prev_events_and_hashes,
+                prev_event_ids=prev_event_ids,
             )
         )
 
@@ -538,7 +571,7 @@ class HomeserverTestCase(TestCase):
         Add the given event as an extremity to the room.
         """
         self.get_success(
-            self.hs.get_datastore()._simple_insert(
+            self.hs.get_datastore().db.simple_insert(
                 table="event_forward_extremities",
                 values={"room_id": room_id, "event_id": event_id},
                 desc="test_add_extremity",
@@ -559,6 +592,46 @@ class HomeserverTestCase(TestCase):
         self.render(request)
         self.assertEqual(channel.code, 403, channel.result)
 
+    def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+        """
+        Inject a membership event into a room.
+
+        Deprecated: use event_injection.inject_room_member directly
+
+        Args:
+            room: Room ID to inject the event into.
+            user: MXID of the user to inject the membership for.
+            membership: The membership type.
+        """
+        event_injection.inject_member_event(self.hs, room, user, membership)
+
+
+class FederatingHomeserverTestCase(HomeserverTestCase):
+    """
+    A federating homeserver that authenticates incoming requests as `other.example.com`.
+    """
+
+    def prepare(self, reactor, clock, homeserver):
+        class Authenticator(object):
+            def authenticate_request(self, request, content):
+                return succeed("other.example.com")
+
+        ratelimiter = FederationRateLimiter(
+            clock,
+            FederationRateLimitConfig(
+                window_size=1,
+                sleep_limit=1,
+                sleep_msec=1,
+                reject_limit=1000,
+                concurrent_requests=1000,
+            ),
+        )
+        federation_server.register_servlets(
+            homeserver, self.resource, Authenticator(), ratelimiter
+        )
+
+        return super().prepare(reactor, clock, homeserver)
+
 
 def override_config(extra_config):
     """A decorator which can be applied to test functions to give additional HS config