diff options
Diffstat (limited to 'tests/unittest.py')
-rw-r--r-- | tests/unittest.py | 42 |
1 files changed, 33 insertions, 9 deletions
diff --git a/tests/unittest.py b/tests/unittest.py index 36df43c137..a09e76c7c2 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -17,6 +17,7 @@ import gc import hashlib import hmac import logging +import time from mock import Mock @@ -24,16 +25,17 @@ from canonicaljson import json import twisted import twisted.logger -from twisted.internet.defer import Deferred +from twisted.internet.defer import Deferred, succeed +from twisted.python.threadpool import ThreadPool from twisted.trial import unittest from synapse.api.constants import EventTypes from synapse.config.homeserver import HomeServerConfig from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest +from synapse.logging.context import LoggingContext from synapse.server import HomeServer from synapse.types import Requester, UserID, create_requester -from synapse.util.logcontext import LoggingContext from tests.server import get_clock, make_request, render, setup_test_homeserver from tests.test_utils.logging_setup import setup_logging @@ -164,6 +166,7 @@ class HomeserverTestCase(TestCase): servlets = [] hijack_auth = True + needs_threadpool = False def setUp(self): """ @@ -192,15 +195,19 @@ class HomeserverTestCase(TestCase): if self.hijack_auth: def get_user_by_access_token(token=None, allow_guest=False): - return { - "user": UserID.from_string(self.helper.auth_user_id), - "token_id": 1, - "is_guest": False, - } + return succeed( + { + "user": UserID.from_string(self.helper.auth_user_id), + "token_id": 1, + "is_guest": False, + } + ) def get_user_by_req(request, allow_guest=False, rights="access"): - return create_requester( - UserID.from_string(self.helper.auth_user_id), 1, False, None + return succeed( + create_requester( + UserID.from_string(self.helper.auth_user_id), 1, False, None + ) ) self.hs.get_auth().get_user_by_req = get_user_by_req @@ -209,9 +216,26 @@ class HomeserverTestCase(TestCase): return_value="1234" ) + if self.needs_threadpool: + self.reactor.threadpool = ThreadPool() + self.addCleanup(self.reactor.threadpool.stop) + self.reactor.threadpool.start() + if hasattr(self, "prepare"): self.prepare(self.reactor, self.clock, self.hs) + def wait_on_thread(self, deferred, timeout=10): + """ + Wait until a Deferred is done, where it's waiting on a real thread. + """ + start_time = time.time() + + while not deferred.called: + if start_time + timeout < time.time(): + raise ValueError("Timed out waiting for threadpool") + self.reactor.advance(0.01) + time.sleep(0.01) + def make_homeserver(self, reactor, clock): """ Make and return a homeserver. |