summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py115
1 files changed, 98 insertions, 17 deletions
diff --git a/tests/unittest.py b/tests/unittest.py

index 7dbb64af59..d26804b5b5 100644 --- a/tests/unittest.py +++ b/tests/unittest.py
@@ -17,6 +17,7 @@ import gc import hashlib import hmac import logging +import time from mock import Mock @@ -24,14 +25,16 @@ from canonicaljson import json import twisted import twisted.logger -from twisted.internet.defer import Deferred +from twisted.internet.defer import Deferred, succeed +from twisted.python.threadpool import ThreadPool from twisted.trial import unittest +from synapse.api.constants import EventTypes from synapse.config.homeserver import HomeServerConfig from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest from synapse.server import HomeServer -from synapse.types import UserID, create_requester +from synapse.types import Requester, UserID, create_requester from synapse.util.logcontext import LoggingContext from tests.server import get_clock, make_request, render, setup_test_homeserver @@ -163,6 +166,7 @@ class HomeserverTestCase(TestCase): servlets = [] hijack_auth = True + needs_threadpool = False def setUp(self): """ @@ -191,15 +195,19 @@ class HomeserverTestCase(TestCase): if self.hijack_auth: def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } + return succeed( + { + "user": UserID.from_string(self.helper.auth_user_id), + "token_id": 1, + "is_guest": False, + } + ) def get_user_by_req(request, allow_guest=False, rights="access"): - return create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None + return succeed( + create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None + ) ) self.hs.get_auth().get_user_by_req = get_user_by_req @@ -208,9 +216,26 @@ class HomeserverTestCase(TestCase): return_value="1234" ) + if self.needs_threadpool: + self.reactor.threadpool = ThreadPool() + self.addCleanup(self.reactor.threadpool.stop) + self.reactor.threadpool.start() + if hasattr(self, "prepare"): self.prepare(self.reactor, self.clock, self.hs) + def wait_on_thread(self, deferred, timeout=10): + """ + Wait until a Deferred is done, where it's waiting on a real thread. + """ + start_time = time.time() + + while not deferred.called: + if start_time + timeout < time.time(): + raise ValueError("Timed out waiting for threadpool") + self.reactor.advance(0.01) + time.sleep(0.01) + def make_homeserver(self, reactor, clock): """ Make and return a homeserver. @@ -297,7 +322,7 @@ class HomeserverTestCase(TestCase): Tuple[synapse.http.site.SynapseRequest, channel] """ if isinstance(content, dict): - content = json.dumps(content).encode('utf8') + content = json.dumps(content).encode("utf8") return make_request( self.reactor, @@ -341,7 +366,7 @@ class HomeserverTestCase(TestCase): # Parse the config from a config dict into a HomeServerConfig config_obj = HomeServerConfig() - config_obj.parse_config_dict(config) + config_obj.parse_config_dict(config, "", "") kwargs["config"] = config_obj hs = setup_test_homeserver(self.addCleanup, *args, **kwargs) @@ -388,7 +413,7 @@ class HomeserverTestCase(TestCase): Returns: The MXID of the new user (unicode). """ - self.hs.config.registration_shared_secret = u"shared" + self.hs.config.registration_shared_secret = "shared" # Create the user request, channel = self.make_request("GET", "/_matrix/client/r0/admin/register") @@ -396,13 +421,13 @@ class HomeserverTestCase(TestCase): nonce = channel.json_body["nonce"] want_mac = hmac.new(key=b"shared", digestmod=hashlib.sha1) - nonce_str = b"\x00".join([username.encode('utf8'), password.encode('utf8')]) + nonce_str = b"\x00".join([username.encode("utf8"), password.encode("utf8")]) if admin: nonce_str += b"\x00admin" else: nonce_str += b"\x00notadmin" - want_mac.update(nonce.encode('ascii') + b"\x00" + nonce_str) + want_mac.update(nonce.encode("ascii") + b"\x00" + nonce_str) want_mac = want_mac.hexdigest() body = json.dumps( @@ -415,7 +440,7 @@ class HomeserverTestCase(TestCase): } ) request, channel = self.make_request( - "POST", "/_matrix/client/r0/admin/register", body.encode('utf8') + "POST", "/_matrix/client/r0/admin/register", body.encode("utf8") ) self.render(request) self.assertEqual(channel.code, 200) @@ -434,7 +459,7 @@ class HomeserverTestCase(TestCase): body["device_id"] = device_id request, channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.render(request) self.assertEqual(channel.code, 200, channel.result) @@ -442,6 +467,62 @@ class HomeserverTestCase(TestCase): access_token = channel.json_body["access_token"] return access_token + def create_and_send_event( + self, room_id, user, soft_failed=False, prev_event_ids=None + ): + """ + Create and send an event. + + Args: + soft_failed (bool): Whether to create a soft failed event or not + prev_event_ids (list[str]|None): Explicitly set the prev events, + or if None just use the default + + Returns: + str: The new event's ID. + """ + event_creator = self.hs.get_event_creation_handler() + secrets = self.hs.get_secrets() + requester = Requester(user, None, False, None, None) + + prev_events_and_hashes = None + if prev_event_ids: + prev_events_and_hashes = [[p, {}, 0] for p in prev_event_ids] + + event, context = self.get_success( + event_creator.create_event( + requester, + { + "type": EventTypes.Message, + "room_id": room_id, + "sender": user.to_string(), + "content": {"body": secrets.token_hex(), "msgtype": "m.text"}, + }, + prev_events_and_hashes=prev_events_and_hashes, + ) + ) + + if soft_failed: + event.internal_metadata.soft_failed = True + + self.get_success(event_creator.send_nonmember_event(requester, event, context)) + + return event.event_id + + def add_extremity(self, room_id, event_id): + """ + Add the given event as an extremity to the room. + """ + self.get_success( + self.hs.get_datastore()._simple_insert( + table="event_forward_extremities", + values={"room_id": room_id, "event_id": event_id}, + desc="test_add_extremity", + ) + ) + + self.hs.get_datastore().get_latest_event_ids_in_room.invalidate((room_id,)) + def attempt_wrong_password_login(self, username, password): """Attempts to login as the user with the given password, asserting that the attempt *fails*. @@ -449,7 +530,7 @@ class HomeserverTestCase(TestCase): body = {"type": "m.login.password", "user": username, "password": password} request, channel = self.make_request( - "POST", "/_matrix/client/r0/login", json.dumps(body).encode('utf8') + "POST", "/_matrix/client/r0/login", json.dumps(body).encode("utf8") ) self.render(request) self.assertEqual(channel.code, 403, channel.result)