summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py52
1 files changed, 49 insertions, 3 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index 3cb55a7e96..e654c0442d 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -14,7 +14,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import gc
 import hashlib
 import hmac
@@ -23,11 +22,12 @@ import logging
 import time
 from typing import Optional, Tuple, Type, TypeVar, Union
 
-from mock import Mock
+from mock import Mock, patch
 
 from canonicaljson import json
 
 from twisted.internet.defer import Deferred, ensureDeferred, succeed
+from twisted.python.failure import Failure
 from twisted.python.threadpool import ThreadPool
 from twisted.trial import unittest
 
@@ -92,7 +92,7 @@ class TestCase(unittest.TestCase):
     root logger's logging level while that test (case|method) runs."""
 
     def __init__(self, methodName, *args, **kwargs):
-        super(TestCase, self).__init__(methodName, *args, **kwargs)
+        super().__init__(methodName, *args, **kwargs)
 
         method = getattr(self, methodName)
 
@@ -169,6 +169,19 @@ def INFO(target):
     return target
 
 
+def logcontext_clean(target):
+    """A decorator which marks the TestCase or method as 'logcontext_clean'
+
+    ... ie, any logcontext errors should cause a test failure
+    """
+
+    def logcontext_error(msg):
+        raise AssertionError("logcontext error: %s" % (msg))
+
+    patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error)
+    return patcher(target)
+
+
 class HomeserverTestCase(TestCase):
     """
     A base TestCase that reduces boilerplate for HomeServer-using test cases.
@@ -353,6 +366,7 @@ class HomeserverTestCase(TestCase):
         request: Type[T] = SynapseRequest,
         shorthand: bool = True,
         federation_auth_origin: str = None,
+        content_is_form: bool = False,
     ) -> Tuple[T, FakeChannel]:
         """
         Create a SynapseRequest at the path using the method and containing the
@@ -368,6 +382,8 @@ class HomeserverTestCase(TestCase):
             with the usual REST API path, if it doesn't contain it.
             federation_auth_origin (bytes|None): if set to not-None, we will add a fake
                 Authorization header pretenting to be the given server name.
+            content_is_form: Whether the content is URL encoded form data. Adds the
+                'Content-Type': 'application/x-www-form-urlencoded' header.
 
         Returns:
             Tuple[synapse.http.site.SynapseRequest, channel]
@@ -384,6 +400,7 @@ class HomeserverTestCase(TestCase):
             request,
             shorthand,
             federation_auth_origin,
+            content_is_form,
         )
 
     def render(self, request):
@@ -459,6 +476,35 @@ class HomeserverTestCase(TestCase):
         self.pump()
         return self.failureResultOf(d, exc)
 
+    def get_success_or_raise(self, d, by=0.0):
+        """Drive deferred to completion and return result or raise exception
+        on failure.
+        """
+
+        if inspect.isawaitable(d):
+            deferred = ensureDeferred(d)
+        if not isinstance(deferred, Deferred):
+            return d
+
+        results = []  # type: list
+        deferred.addBoth(results.append)
+
+        self.pump(by=by)
+
+        if not results:
+            self.fail(
+                "Success result expected on {!r}, found no result instead".format(
+                    deferred
+                )
+            )
+
+        result = results[0]
+
+        if isinstance(result, Failure):
+            result.raiseException()
+
+        return result
+
     def register_user(self, username, password, admin=False):
         """
         Register a user. Requires the Admin API be registered.