diff --git a/tests/unittest.py b/tests/unittest.py
index 561cebc223..b30b7d1718 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2016 OpenMarket Ltd
# Copyright 2018 New Vector
+# Copyright 2019 Matrix.org Federation C.I.C
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -13,9 +14,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
import gc
import hashlib
import hmac
+import inspect
import logging
import time
@@ -23,17 +26,21 @@ from mock import Mock
from canonicaljson import json
-from twisted.internet.defer import Deferred, succeed
+from twisted.internet.defer import Deferred, ensureDeferred, succeed
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
-from synapse.api.constants import EventTypes
+from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.ratelimiting import FederationRateLimitConfig
+from synapse.federation.transport import server as federation_server
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
from synapse.logging.context import LoggingContext
from synapse.server import HomeServer
from synapse.types import Requester, UserID, create_requester
+from synapse.util.ratelimitutils import FederationRateLimiter
from tests.server import get_clock, make_request, render, setup_test_homeserver
from tests.test_utils.logging_setup import setup_logging
@@ -395,10 +402,12 @@ class HomeserverTestCase(TestCase):
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
- # Run the database background updates.
- if hasattr(stor, "do_next_background_update"):
- while not self.get_success(stor.has_completed_background_updates()):
- self.get_success(stor.do_next_background_update(1))
+ # Run the database background updates, when running against "master".
+ if hs.__class__.__name__ == "TestHomeServer":
+ while not self.get_success(
+ stor.db.updates.has_completed_background_updates()
+ ):
+ self.get_success(stor.db.updates.do_next_background_update(1))
return hs
@@ -409,6 +418,8 @@ class HomeserverTestCase(TestCase):
self.reactor.pump([by] * 100)
def get_success(self, d, by=0.0):
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump(by=by)
@@ -418,6 +429,8 @@ class HomeserverTestCase(TestCase):
"""
Run a Deferred and get a Failure from it. The failure must be of the type `exc`.
"""
+ if inspect.isawaitable(d):
+ d = ensureDeferred(d)
if not isinstance(d, Deferred):
return d
self.pump()
@@ -538,7 +551,7 @@ class HomeserverTestCase(TestCase):
Add the given event as an extremity to the room.
"""
self.get_success(
- self.hs.get_datastore()._simple_insert(
+ self.hs.get_datastore().db.simple_insert(
table="event_forward_extremities",
values={"room_id": room_id, "event_id": event_id},
desc="test_add_extremity",
@@ -559,6 +572,66 @@ class HomeserverTestCase(TestCase):
self.render(request)
self.assertEqual(channel.code, 403, channel.result)
+ def inject_room_member(self, room: str, user: str, membership: Membership) -> None:
+ """
+ Inject a membership event into a room.
+
+ Args:
+ room: Room ID to inject the event into.
+ user: MXID of the user to inject the membership for.
+ membership: The membership type.
+ """
+ event_builder_factory = self.hs.get_event_builder_factory()
+ event_creation_handler = self.hs.get_event_creation_handler()
+
+ room_version = self.get_success(self.hs.get_datastore().get_room_version(room))
+
+ builder = event_builder_factory.for_room_version(
+ KNOWN_ROOM_VERSIONS[room_version],
+ {
+ "type": EventTypes.Member,
+ "sender": user,
+ "state_key": user,
+ "room_id": room,
+ "content": {"membership": membership},
+ },
+ )
+
+ event, context = self.get_success(
+ event_creation_handler.create_new_client_event(builder)
+ )
+
+ self.get_success(
+ self.hs.get_storage().persistence.persist_event(event, context)
+ )
+
+
+class FederatingHomeserverTestCase(HomeserverTestCase):
+ """
+ A federating homeserver that authenticates incoming requests as `other.example.com`.
+ """
+
+ def prepare(self, reactor, clock, homeserver):
+ class Authenticator(object):
+ def authenticate_request(self, request, content):
+ return succeed("other.example.com")
+
+ ratelimiter = FederationRateLimiter(
+ clock,
+ FederationRateLimitConfig(
+ window_size=1,
+ sleep_limit=1,
+ sleep_msec=1,
+ reject_limit=1000,
+ concurrent_requests=1000,
+ ),
+ )
+ federation_server.register_servlets(
+ homeserver, self.resource, Authenticator(), ratelimiter
+ )
+
+ return super().prepare(reactor, clock, homeserver)
+
def override_config(extra_config):
"""A decorator which can be applied to test functions to give additional HS config
|