summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py41
1 files changed, 36 insertions, 5 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index 7a6f5954d0..ae393ee53e 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import inspect
 import logging
 import secrets
 import time
-from typing import Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
+from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, TypeVar, Union
 from unittest.mock import Mock, patch
 
 from canonicaljson import json
@@ -28,6 +28,7 @@ from canonicaljson import json
 from twisted.internet.defer import Deferred, ensureDeferred, succeed
 from twisted.python.failure import Failure
 from twisted.python.threadpool import ThreadPool
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.trial import unittest
 from twisted.web.resource import Resource
 
@@ -46,6 +47,7 @@ from synapse.logging.context import (
 )
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
+from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.ratelimitutils import FederationRateLimiter
 
@@ -232,7 +234,7 @@ class HomeserverTestCase(TestCase):
         # Honour the `use_frozen_dicts` config option. We have to do this
         # manually because this is taken care of in the app `start` code, which
         # we don't run. Plus we want to reset it on tearDown.
-        events.USE_FROZEN_DICTS = self.hs.config.use_frozen_dicts
+        events.USE_FROZEN_DICTS = self.hs.config.server.use_frozen_dicts
 
         if self.hs is None:
             raise Exception("No homeserver returned from make_homeserver.")
@@ -371,7 +373,7 @@ class HomeserverTestCase(TestCase):
 
         return config
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
         """
         Prepare for the test.  This involves things like mocking out parts of
         the homeserver, or building test data common across the whole test
@@ -447,7 +449,7 @@ class HomeserverTestCase(TestCase):
             client_ip,
         )
 
-    def setup_test_homeserver(self, *args, **kwargs):
+    def setup_test_homeserver(self, *args: Any, **kwargs: Any) -> HomeServer:
         """
         Set up the test homeserver, meant to be called by the overridable
         make_homeserver. It automatically passes through the test class's
@@ -558,7 +560,7 @@ class HomeserverTestCase(TestCase):
         Returns:
             The MXID of the new user.
         """
-        self.hs.config.registration_shared_secret = "shared"
+        self.hs.config.registration.registration_shared_secret = "shared"
 
         # Create the user
         channel = self.make_request("GET", "/_synapse/admin/v1/register")
@@ -594,6 +596,35 @@ class HomeserverTestCase(TestCase):
         user_id = channel.json_body["user_id"]
         return user_id
 
+    def register_appservice_user(
+        self,
+        username: str,
+        appservice_token: str,
+    ) -> str:
+        """Register an appservice user as an application service.
+        Requires the client-facing registration API be registered.
+
+        Args:
+            username: the user to be registered by an application service.
+                Should be a full username, i.e. ""@localpart:hostname" as opposed to just "localpart"
+            appservice_token: the acccess token for that application service.
+
+        Raises: if the request to '/register' does not return 200 OK.
+
+        Returns: the MXID of the new user.
+        """
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/register",
+            {
+                "username": username,
+                "type": "m.login.application_service",
+            },
+            access_token=appservice_token,
+        )
+        self.assertEqual(channel.code, 200, channel.json_body)
+        return channel.json_body["user_id"]
+
     def login(
         self,
         username,