diff --git a/tests/unittest.py b/tests/unittest.py
index 7dbb64af59..8816a4d152 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,28 +14,42 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import gc
import hashlib
import hmac
+import inspect
import logging
+import time
+from typing import Optional, Tuple, Type, TypeVar, Union
from mock import Mock
from canonicaljson import json
-import twisted
-import twisted.logger
-from twisted.internet.defer import Deferred
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource
-from synapse.http.site import SynapseRequest
+from synapse.http.site import SynapseRequest, SynapseSite
+from synapse.logging.context import LoggingContext
from synapse.server import HomeServer
-from synapse.types import UserID, create_requester
-from synapse.util.logcontext import LoggingContext
-
-from tests.server import get_clock, make_request, render, setup_test_homeserver
+from synapse.types import Requester, UserID, create_requester
+from synapse.util.ratelimitutils import FederationRateLimiter
+
+from tests.server import (
+ FakeChannel,
+ get_clock,
+ make_request,
+ render,
+ setup_test_homeserver,
+)
from tests.test_utils.logging_setup import setup_logging
from tests.utils import default_config, setupdb
@@ -63,6 +78,9 @@ def around(target):
return _around
+T = TypeVar("T")
+
+
class TestCase(unittest.TestCase):
"""A subclass of twisted.trial's TestCase which looks for 'loglevel'
attributes on both itself and its individual test methods, to override the
@@ -77,10 +95,6 @@ class TestCase(unittest.TestCase):
@around(self)
def setUp(orig):
- # enable debugging of delayed calls - this means that we get a
- # traceback when a unit test exits leaving things on the reactor.
- twisted.internet.base.DelayedCall.debug = True
-
# if we're not starting in the sentinel logcontext, then to be honest
# all future bets are off.
if LoggingContext.current_context() is not LoggingContext.sentinel:
@@ -154,6 +168,21 @@ class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
+ Defines a setUp method which creates a mock reactor, and instantiates a homeserver
+ running on that reactor.
+
+ There are various hooks for modifying the way that the homeserver is instantiated:
+
+ * override make_homeserver, for example by making it pass different parameters into
+ setup_test_homeserver.
+
+ * override default_config, to return a modified configuration dictionary for use
+ by setup_test_homeserver.
+
+ * On a per-test basis, you can use the @override_config decorator to give a
+ dictionary containing additional configuration settings to be added to the basic
+ config dict.
+
Attributes:
servlets (list[function]): List of servlet registration function.
user_id (str): The user ID to assume if auth is hijacked.
@@ -163,6 +192,14 @@ class HomeserverTestCase(TestCase):
servlets = []
hijack_auth = True
+ needs_threadpool = False
+
+ def __init__(self, methodName, *args, **kwargs):
+ super().__init__(methodName, *args, **kwargs)
+
+ # see if we have any additional config for this test
+ method = getattr(self, methodName)
+ self._extra_config = getattr(method, "_extra_config", None)
def setUp(self):
"""
@@ -183,6 +220,15 @@ class HomeserverTestCase(TestCase):
# Register the resources
self.resource = self.create_test_json_resource()
+ # create a site to wrap the resource.
+ self.site = SynapseSite(
+ logger_name="synapse.access.http.fake",
+ site_tag="test",
+ config={},
+ resource=self.resource,
+ server_version_string="1",
+ )
+
from tests.rest.client.v1.utils import RestHelper
self.helper = RestHelper(self.hs, self.resource, getattr(self, "user_id", None))
@@ -191,15 +237,19 @@ class HomeserverTestCase(TestCase):
if self.hijack_auth:
def get_user_by_access_token(token=None, allow_guest=False):
- return {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
+ return succeed(
+ {
+ "user": UserID.from_string(self.helper.auth_user_id),
+ "token_id": 1,
+ "is_guest": False,
+ }
+ )
def get_user_by_req(request, allow_guest=False, rights="access"):
- return create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
+ return succeed(
+ create_requester(
+ UserID.from_string(self.helper.auth_user_id), 1, False, None
+ )
)
self.hs.get_auth().get_user_by_req = get_user_by_req
@@ -208,9 +258,26 @@ class HomeserverTestCase(TestCase):
return_value="1234"
)
+ if self.needs_threadpool:
+ self.reactor.threadpool = ThreadPool()
+ self.addCleanup(self.reactor.threadpool.stop)
+ self.reactor.threadpool.start()
+
if hasattr(self, "prepare"):
self.prepare(self.reactor, self.clock, self.hs)
+ def wait_on_thread(self, deferred, timeout=10):
+ """
+ Wait until a Deferred is done, where it's waiting on a real thread.
+ """
+ start_time = time.time()
+
+ while not deferred.called:
+ if start_time + timeout < time.time():
+ raise ValueError("Timed out waiting for threadpool")
+ self.reactor.advance(0.01)
+ time.sleep(0.01)
+
def make_homeserver(self, reactor, clock):
"""
Make and return a homeserver.
@@ -251,7 +318,14 @@ class HomeserverTestCase(TestCase):
Args:
name (str): The homeserver name/domain.
"""
- return default_config(name)
+ config = default_config(name)
+
+ # apply any additional config which was specified via the override_config
+ # decorator.
+ if self._extra_config is not None:
+ config.update(self._extra_config)
+
+ return config
def prepare(self, reactor, clock, homeserver):
"""
@@ -270,14 +344,14 @@ class HomeserverTestCase(TestCase):
def make_request(
self,
- method,
- path,
- content=b"",
- access_token=None,
- request=SynapseRequest,
- shorthand=True,
- federation_auth_origin=None,
- ):
+ method: Union[bytes, str],
+ path: Union[bytes, str],
+ content: Union[bytes, dict] = b"",
+ access_token: Optional[str] = None,
+ request: Type[T] = SynapseRequest,
+ shorthand: bool = True,
+ federation_auth_origin: str = None,
+ ) -> Tuple[T, FakeChannel]:
"""
Create a SynapseRequest at the path using the method and containing the
given content.
@@ -297,7 +371,7 @@ class HomeserverTestCase(TestCase):
Tuple[synapse.http.site.SynapseRequest, channel]
"""
if isinstance(content, dict):
- content = json.dumps(content).encode('utf8')
+ content = json.dumps(content).encode("utf8")
return make_request(
self.reactor,
@@ -341,16 +415,18 @@ class HomeserverTestCase(TestCase):
# Parse the config from a config dict into a HomeServerConfig
config_obj = HomeServerConfig()
- config_obj.parse_config_dict(config)
+ config_obj.parse_config_dict(config, "", "")
kwargs["config"] = config_obj
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
- # Run the database background updates.
- if hasattr(stor, "do_next_background_update"):
- while not self.get_success(stor.has_completed_background_updates()):
- self.get_success(stor.do_next_background_update(1))
+ # Run the database background updates, when running against "master".
+ if hs.__class__.__name__ == "TestHomeServer":
+ while not self.get_success(
+ stor.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(stor.db.updates.do_next_background_update(1))
return hs
@@ -361,6 +437,8 @@ class HomeserverTestCase(TestCase):
self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0):
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump(by=by)
@@ -370,6 +448,8 @@ class HomeserverTestCase(TestCase):
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump()
@@ -388,21 +468,22 @@ class HomeserverTestCase(TestCase):
Returns:
The MXID of the new user (unicode).
"""
- self.hs.config.registration_shared_secret = u"shared"
+ self.hs.config.registration_shared_secret = "shared"
# Create the user
request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register")
self.render(request)
+ self.assertEqual(channel.code, 200, msg=channel.result)
nonce = channel.json_body["nonce"]
want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1)
- nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')])
+ nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")])
if admin:
nonce_str += b"\x00admin"
else:
nonce_str += b"\x00notadmin"
- want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str)
+ want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str)
want_mac = want_mac.hexdigest()
body = json.dumps(
@@ -415,10 +496,10 @@ class HomeserverTestCase(TestCase):
}
)
request, channel = self.make_request(
- "POST", "/_matrix/client/r0/admin/register", body.encode('utf8')
+ "POST", "/_matrix/client/r0/admin/register", body.encode("utf8")
)
self.render(request)
- self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.code, 200, channel.json_body)
user_id = channel.json_body["user_id"]
return user_id
@@ -434,7 +515,7 @@ class HomeserverTestCase(TestCase):
body["device_id"] = device_id
request, channel = self.make_request(
- "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8')
+ "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
self.render(request)
self.assertEqual(channel.code, 200, channel.result)
@@ -442,6 +523,58 @@ class HomeserverTestCase(TestCase):
access_token = channel.json_body["access_token"]
return access_token
+ def create_and_send_event(
+ self, room_id, user, soft_failed=False, prev_event_ids=None
+ ):
+ """
+ Create and send an event.
+
+ Args:
+ soft_failed (bool): Whether to create a soft failed event or not
+ prev_event_ids (list[str]|None): Explicitly set the prev events,
+ or if None just use the default
+
+ Returns:
+ str: The new event's ID.
+ """
+ event_creator = self.hs.get_event_creation_handler()
+ secrets = self.hs.get_secrets()
+ requester = Requester(user, None, False, None, None)
+
+ event, context = self.get_success(
+ event_creator.create_event(
+ requester,
+ {
+ "type": EventTypes.Message,
+ "room_id": room_id,
+ "sender": user.to_string(),
+ "content": {"body": secrets.token_hex(), "msgtype": "m.text"},
+ },
+ prev_event_ids=prev_event_ids,
+ )
+ )
+
+ if soft_failed:
+ event.internal_metadata.soft_failed = True
+
+ self.get_success(event_creator.send_nonmember_event(requester, event, context))
+
+ return event.event_id
+
+ def add_extremity(self, room_id, event_id):
+ """
+ Add the given event as an extremity to the room.
+ """
+ self.get_success(
+ self.hs.get_datastore().db.simple_insert(
+ table="event_forward_extremities",
+ values={"room_id": room_id, "event_id": event_id},
+ desc="test_add_extremity",
+ )
+ )
+
+ self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,))
+
def attempt_wrong_password_login(self, username, password):
"""Attempts to login as the user with the given password, asserting
that the attempt *fails*.
@@ -449,7 +582,93 @@ class HomeserverTestCase(TestCase):
body = {"type": "m.login.password", "user": username, "password": password}
request, channel = self.make_request(
- "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8')
+ "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8")
)
self.render(request)
self.assertEqual(channel.code, 403, channel.result)
+
+ def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+ """
+ Inject a membership event into a room.
+
+ Args:
+ room: Room ID to inject the event into.
+ user: MXID of the user to inject the membership for.
+ membership: The membership type.
+ """
+ event_builder_factory = self.hs.get_event_builder_factory()
+ event_creation_handler = self.hs.get_event_creation_handler()
+
+ room_version = self.get_success(
+ self.hs.get_datastore().get_room_version_id(room)
+ )
+
+ builder = event_builder_factory.for_room_version(
+ KNOWN_ROOM_VERSIONS[room_version],
+ {
+ "type": EventTypes.Member,
+ "sender": user,
+ "state_key": user,
+ "room_id": room,
+ "content": {"membership": membership},
+ },
+ )
+
+ event, context = self.get_success(
+ event_creation_handler.create_new_client_event(builder)
+ )
+
+ self.get_success(
+ self.hs.get_storage().persistence.persist_event(event, context)
+ )
+
+
+class FederatingHomeserverTestCase(HomeserverTestCase):
+ """
+ A federating homeserver that authenticates incoming requests as `other.example.com`.
+ """
+
+ def prepare(self, reactor, clock, homeserver):
+ class Authenticator(object):
+ def authenticate_request(self, request, content):
+ return succeed("other.example.com")
+
+ ratelimiter = FederationRateLimiter(
+ clock,
+ FederationRateLimitConfig(
+ window_size=1,
+ sleep_limit=1,
+ sleep_msec=1,
+ reject_limit=1000,
+ concurrent_requests=1000,
+ ),
+ )
+ federation_server.register_servlets(
+ homeserver, self.resource, Authenticator(), ratelimiter
+ )
+
+ return super().prepare(reactor, clock, homeserver)
+
+
+def override_config(extra_config):
+ """A decorator which can be applied to test functions to give additional HS config
+
+ For use
+
+ For example:
+
+ class MyTestCase(HomeserverTestCase):
+ @override_config({"enable_registration": False, ...})
+ def test_foo(self):
+ ...
+
+ Args:
+ extra_config(dict): Additional config settings to be merged into the default
+ config dict before instantiating the test homeserver.
+ """
+
+ def decorator(func):
+ func._extra_config = extra_config
+ return func
+
+ return decorator
|