diff --git a/tests/unittest.py b/tests/unittest.py
index e654c0442d..a9d59e31f7 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -20,7 +20,7 @@ import hmac
import inspect
import logging
import time
-from typing import Optional, Tuple, Type, TypeVar, Union
+from typing import Optional, Tuple, Type, TypeVar, Union, overload
from mock import Mock, patch
@@ -30,6 +30,7 @@ from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
+from twisted.web.resource import Resource
from synapse.api.constants import EventTypes, Membership
from synapse.config.homeserver import HomeServerConfig
@@ -44,17 +45,11 @@ from synapse.logging.context import (
set_current_context,
)
from synapse.server import HomeServer
-from synapse.types import Requester, UserID, create_requester
+from synapse.types import UserID, create_requester
from synapse.util.ratelimitutils import FederationRateLimiter
-from tests.server import (
- FakeChannel,
- get_clock,
- make_request,
- render,
- setup_test_homeserver,
-)
-from tests.test_utils import event_injection
+from tests.server import FakeChannel, get_clock, make_request, setup_test_homeserver
+from tests.test_utils import event_injection, setup_awaitable_errors
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -119,6 +114,10 @@ class TestCase(unittest.TestCase):
logging.getLogger().setLevel(level)
+ # Trial messes with the warnings configuration, thus this has to be
+ # done in the context of an individual TestCase.
+ self.addCleanup(setup_awaitable_errors())
+
return orig()
@around(self)
@@ -235,13 +234,11 @@ class HomeserverTestCase(TestCase):
if not isinstance(self.hs, HomeServer):
raise Exception("A homeserver wasn't returned, but %r" % (self.hs,))
- # Register the resources
- self.resource = self.create_test_json_resource()
-
- # create a site to wrap the resource.
+ # create the root resource, and a site to wrap it.
+ self.resource = self.create_test_resource()
self.site = SynapseSite(
logger_name="synapse.access.http.fake",
- site_tag="test",
+ site_tag=self.hs.config.server.server_name,
config=self.hs.config.server.listeners[0],
resource=self.resource,
server_version_string="1",
@@ -249,22 +246,29 @@ class HomeserverTestCase(TestCase):
from tests.rest.client.v1.utils import RestHelper
- self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
+ self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
if hasattr(self, "user_id"):
if self.hijack_auth:
+ # We need a valid token ID to satisfy foreign key constraints.
+ token_id = self.get_success(
+ self.hs.get_datastore().add_access_token_to_user(
+ self.helper.auth_user_id, "some_fake_token", None, None,
+ )
+ )
+
async def get_user_by_access_token(token=None, allow_guest=False):
return {
"user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
+ "token_id": token_id,
"is_guest": False,
}
async def get_user_by_req(request, allow_guest=False, rights="access"):
return create_requester(
UserID.from_string(self.helper.auth_user_id),
- 1,
+ token_id,
False,
False,
None,
@@ -312,15 +316,12 @@ class HomeserverTestCase(TestCase):
hs = self.setup_test_homeserver()
return hs
- def create_test_json_resource(self):
+ def create_test_resource(self) -> Resource:
"""
- Create a test JsonResource, with the relevant servlets registerd to it
-
- The default implementation calls each function in `servlets` to do the
- registration.
+ Create a the root resource for the test server.
- Returns:
- JsonResource:
+ The default implementation creates a JsonResource and calls each function in
+ `servlets` to register servletes against it
"""
resource = JsonResource(self.hs)
@@ -357,6 +358,38 @@ class HomeserverTestCase(TestCase):
Function to optionally be overridden in subclasses.
"""
+ # Annoyingly mypy doesn't seem to pick up the fact that T is SynapseRequest
+ # when the `request` arg isn't given, so we define an explicit override to
+ # cover that case.
+ @overload
+ def make_request(
+ self,
+ method: Union[bytes, str],
+ path: Union[bytes, str],
+ content: Union[bytes, dict] = b"",
+ access_token: Optional[str] = None,
+ shorthand: bool = True,
+ federation_auth_origin: str = None,
+ content_is_form: bool = False,
+ await_result: bool = True,
+ ) -> Tuple[SynapseRequest, FakeChannel]:
+ ...
+
+ @overload
+ def make_request(
+ self,
+ method: Union[bytes, str],
+ path: Union[bytes, str],
+ content: Union[bytes, dict] = b"",
+ access_token: Optional[str] = None,
+ request: Type[T] = SynapseRequest,
+ shorthand: bool = True,
+ federation_auth_origin: str = None,
+ content_is_form: bool = False,
+ await_result: bool = True,
+ ) -> Tuple[T, FakeChannel]:
+ ...
+
def make_request(
self,
method: Union[bytes, str],
@@ -367,6 +400,7 @@ class HomeserverTestCase(TestCase):
shorthand: bool = True,
federation_auth_origin: str = None,
content_is_form: bool = False,
+ await_result: bool = True,
) -> Tuple[T, FakeChannel]:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -385,14 +419,16 @@ class HomeserverTestCase(TestCase):
content_is_form: Whether the content is URL encoded form data. Adds the
'Content-Type': 'application/x-www-form-urlencoded' header.
+ await_result: whether to wait for the request to complete rendering. If
+ true (the default), will pump the test reactor until the the renderer
+ tells the channel the request is finished.
+
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
"""
- if isinstance(content, dict):
- content = json.dumps(content).encode("utf8")
-
return make_request(
self.reactor,
+ self.site,
method,
path,
content,
@@ -401,18 +437,9 @@ class HomeserverTestCase(TestCase):
shorthand,
federation_auth_origin,
content_is_form,
+ await_result,
)
- def render(self, request):
- """
- Render a request against the resources registered by the test class's
- servlets.
-
- Args:
- request (synapse.http.site.SynapseRequest): The request to render.
- """
- render(request, self.resource, self.reactor)
-
def setup_test_homeserver(self, *args, **kwargs):
"""
Set up the test homeserver, meant to be called by the overridable
@@ -505,24 +532,29 @@ class HomeserverTestCase(TestCase):
return result
- def register_user(self, username, password, admin=False):
+ def register_user(
+ self,
+ username: str,
+ password: str,
+ admin: Optional[bool] = False,
+ displayname: Optional[str] = None,
+ ) -> str:
"""
Register a user. Requires the Admin API be registered.
Args:
- username (bytes/unicode): The user part of the new user.
- password (bytes/unicode): The password of the new user.
- admin (bool): Whether the user should be created as an admin
- or not.
+ username: The user part of the new user.
+ password: The password of the new user.
+ admin: Whether the user should be created as an admin or not.
+ displayname: The displayname of the new user.
Returns:
- The MXID of the new user (unicode).
+ The MXID of the new user.
"""
self.hs.config.registration_shared_secret = "shared"
# Create the user
- request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
- self.render(request)
+ request, channel = self.make_request("GET", "/_synapse/admin/v1/register")
self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]
@@ -540,6 +572,7 @@ class HomeserverTestCase(TestCase):
{
"nonce": nonce,
"username": username,
+ "displayname": displayname,
"password": password,
"admin": admin,
"mac": want_mac,
@@ -547,9 +580,8 @@ class HomeserverTestCase(TestCase):
}
)
request, channel = self.make_request(
- "POST", "/_matrix/client/r0/admin/register", body.encode("utf8")
+ "POST", "/_synapse/admin/v1/register", body.encode("utf8")
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.json_body)
user_id = channel.json_body["user_id"]
@@ -568,7 +600,6 @@ class HomeserverTestCase(TestCase):
request, channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
- self.render(request)
self.assertEqual(channel.code, 200, channel.result)
access_token = channel.json_body["access_token"]
@@ -590,7 +621,7 @@ class HomeserverTestCase(TestCase):
"""
event_creator = self.hs.get_event_creation_handler()
secrets = self.hs.get_secrets()
- requester = Requester(user, None, False, False, None, None)
+ requester = create_requester(user)
event, context = self.get_success(
event_creator.create_event(
@@ -608,7 +639,9 @@ class HomeserverTestCase(TestCase):
if soft_failed:
event.internal_metadata.soft_failed = True
- self.get_success(event_creator.send_nonmember_event(requester, event, context))
+ self.get_success(
+ event_creator.handle_new_client_event(requester, event, context)
+ )
return event.event_id
@@ -635,7 +668,6 @@ class HomeserverTestCase(TestCase):
request, channel = self.make_request(
"POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
- self.render(request)
self.assertEqual(channel.code, 403, channel.result)
def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
|