diff --git a/tests/unittest.py b/tests/unittest.py
index 68d2586efd..e654c0442d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,7 +14,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import gc
import hashlib
import hmac
@@ -23,11 +22,12 @@ import logging
import time
from typing import Optional, Tuple, Type, TypeVar, Union
-from mock import Mock
+from mock import Mock, patch
from canonicaljson import json
from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.python.failure import Failure
from twisted.python.threadpool import ThreadPool
from twisted.trial import unittest
@@ -92,7 +92,7 @@ class TestCase(unittest.TestCase):
root logger's logging level while that test (case|method) runs."""
def __init__(self, methodName, *args, **kwargs):
- super(TestCase, self).__init__(methodName, *args, **kwargs)
+ super().__init__(methodName, *args, **kwargs)
method = getattr(self, methodName)
@@ -169,6 +169,19 @@ def INFO(target):
return target
+def logcontext_clean(target):
+ """A decorator which marks the TestCase or method as 'logcontext_clean'
+
+ ... ie, any logcontext errors should cause a test failure
+ """
+
+ def logcontext_error(msg):
+ raise AssertionError("logcontext error: %s" % (msg))
+
+ patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
+ return patcher(target)
+
+
class HomeserverTestCase(TestCase):
"""
A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -241,20 +254,20 @@ class HomeserverTestCase(TestCase):
if hasattr(self, "user_id"):
if self.hijack_auth:
- def get_user_by_access_token(token=None, allow_guest=False):
- return succeed(
- {
- "user": UserID.from_string(self.helper.auth_user_id),
- "token_id": 1,
- "is_guest": False,
- }
- )
-
- def get_user_by_req(request, allow_guest=False, rights="access"):
- return succeed(
- create_requester(
- UserID.from_string(self.helper.auth_user_id), 1, False, None
- )
+ async def get_user_by_access_token(token=None, allow_guest=False):
+ return {
+ "user": UserID.from_string(self.helper.auth_user_id),
+ "token_id": 1,
+ "is_guest": False,
+ }
+
+ async def get_user_by_req(request, allow_guest=False, rights="access"):
+ return create_requester(
+ UserID.from_string(self.helper.auth_user_id),
+ 1,
+ False,
+ False,
+ None,
)
self.hs.get_auth().get_user_by_req = get_user_by_req
@@ -353,6 +366,7 @@ class HomeserverTestCase(TestCase):
request: Type[T] = SynapseRequest,
shorthand: bool = True,
federation_auth_origin: str = None,
+ content_is_form: bool = False,
) -> Tuple[T, FakeChannel]:
"""
Create a SynapseRequest at the path using the method and containing the
@@ -368,6 +382,8 @@ class HomeserverTestCase(TestCase):
with the usual REST API path, if it doesn't contain it.
federation_auth_origin (bytes|None): if set to not-None, we will add a fake
Authorization header pretenting to be the given server name.
+ content_is_form: Whether the content is URL encoded form data. Adds the
+ 'Content-Type': 'application/x-www-form-urlencoded' header.
Returns:
Tuple[synapse.http.site.SynapseRequest, channel]
@@ -384,6 +400,7 @@ class HomeserverTestCase(TestCase):
request,
shorthand,
federation_auth_origin,
+ content_is_form,
)
def render(self, request):
@@ -422,8 +439,8 @@ class HomeserverTestCase(TestCase):
async def run_bg_updates():
with LoggingContext("run_bg_updates", request="run_bg_updates-1"):
- while not await stor.db.updates.has_completed_background_updates():
- await stor.db.updates.do_next_background_update(1)
+ while not await stor.db_pool.updates.has_completed_background_updates():
+ await stor.db_pool.updates.do_next_background_update(1)
hs = setup_test_homeserver(self.addCleanup, *args, **kwargs)
stor = hs.get_datastore()
@@ -459,6 +476,35 @@ class HomeserverTestCase(TestCase):
self.pump()
return self.failureResultOf(d, exc)
+ def get_success_or_raise(self, d, by=0.0):
+ """Drive deferred to completion and return result or raise exception
+ on failure.
+ """
+
+ if inspect.isawaitable(d):
+ deferred = ensureDeferred(d)
+ if not isinstance(deferred, Deferred):
+ return d
+
+ results = [] # type: list
+ deferred.addBoth(results.append)
+
+ self.pump(by=by)
+
+ if not results:
+ self.fail(
+ "Success result expected on {!r}, found no result instead".format(
+ deferred
+ )
+ )
+
+ result = results[0]
+
+ if isinstance(result, Failure):
+ result.raiseException()
+
+ return result
+
def register_user(self, username, password, admin=False):
"""
Register a user. Requires the Admin API be registered.
@@ -544,7 +590,7 @@ class HomeserverTestCase(TestCase):
"""
event_creator = self.hs.get_event_creation_handler()
secrets = self.hs.get_secrets()
- requester = Requester(user, None, False, None, None)
+ requester = Requester(user, None, False, False, None, None)
event, context = self.get_success(
event_creator.create_event(
@@ -571,7 +617,7 @@ class HomeserverTestCase(TestCase):
Add the given event as an extremity to the room.
"""
self.get_success(
- self.hs.get_datastore().db.simple_insert(
+ self.hs.get_datastore().db_pool.simple_insert(
table="event_forward_extremities",
values={"room_id": room_id, "event_id": event_id},
desc="test_add_extremity",
@@ -614,7 +660,7 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
"""
def prepare(self, reactor, clock, homeserver):
- class Authenticator(object):
+ class Authenticator:
def authenticate_request(self, request, content):
return succeed("other.example.com")
|