diff options
Diffstat (limited to 'tests/unittest.py')
-rw-r--r-- | tests/unittest.py | 52 |
1 files changed, 49 insertions, 3 deletions
diff --git a/tests/unittest.py b/tests/unittest.py index 3cb55a7e96..e654c0442d 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -14,7 +14,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import gc import hashlib import hmac @@ -23,11 +22,12 @@ import logging import time from typing import Optional, Tuple, Type, TypeVar, Union -from mock import Mock +from mock import Mock, patch from canonicaljson import json from twisted.internet.defer import Deferred, ensureDeferred, succeed +from twisted.python.failure import Failure from twisted.python.threadpool import ThreadPool from twisted.trial import unittest @@ -92,7 +92,7 @@ class TestCase(unittest.TestCase): root logger's logging level while that test (case|method) runs.""" def __init__(self, methodName, *args, **kwargs): - super(TestCase, self).__init__(methodName, *args, **kwargs) + super().__init__(methodName, *args, **kwargs) method = getattr(self, methodName) @@ -169,6 +169,19 @@ def INFO(target): return target +def logcontext_clean(target): + """A decorator which marks the TestCase or method as 'logcontext_clean' + + ... ie, any logcontext errors should cause a test failure + """ + + def logcontext_error(msg): + raise AssertionError("logcontext error: %s" % (msg)) + + patcher = patch("synapse.logging.context.logcontext_error", new=logcontext_error) + return patcher(target) + + class HomeserverTestCase(TestCase): """ A base TestCase that reduces boilerplate for HomeServer-using test cases. @@ -353,6 +366,7 @@ class HomeserverTestCase(TestCase): request: Type[T] = SynapseRequest, shorthand: bool = True, federation_auth_origin: str = None, + content_is_form: bool = False, ) -> Tuple[T, FakeChannel]: """ Create a SynapseRequest at the path using the method and containing the @@ -368,6 +382,8 @@ class HomeserverTestCase(TestCase): with the usual REST API path, if it doesn't contain it. federation_auth_origin (bytes|None): if set to not-None, we will add a fake Authorization header pretenting to be the given server name. + content_is_form: Whether the content is URL encoded form data. Adds the + 'Content-Type': 'application/x-www-form-urlencoded' header. Returns: Tuple[synapse.http.site.SynapseRequest, channel] @@ -384,6 +400,7 @@ class HomeserverTestCase(TestCase): request, shorthand, federation_auth_origin, + content_is_form, ) def render(self, request): @@ -459,6 +476,35 @@ class HomeserverTestCase(TestCase): self.pump() return self.failureResultOf(d, exc) + def get_success_or_raise(self, d, by=0.0): + """Drive deferred to completion and return result or raise exception + on failure. + """ + + if inspect.isawaitable(d): + deferred = ensureDeferred(d) + if not isinstance(deferred, Deferred): + return d + + results = [] # type: list + deferred.addBoth(results.append) + + self.pump(by=by) + + if not results: + self.fail( + "Success result expected on {!r}, found no result instead".format( + deferred + ) + ) + + result = results[0] + + if isinstance(result, Failure): + result.raiseException() + + return result + def register_user(self, username, password, admin=False): """ Register a user. Requires the Admin API be registered. |