summary refs log tree commit diff
path: root/tests/http/server/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http/server/_base.py')
-rw-r--r--tests/http/server/_base.py132
1 files changed, 69 insertions, 63 deletions
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 57b92beb87..994d8880b0 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -46,8 +46,7 @@ from synapse.http.site import SynapseRequest
 from synapse.logging.context import LoggingContext, make_deferred_yieldable
 from synapse.types import JsonDict
 
-from tests import unittest
-from tests.server import FakeChannel, ThreadedMemoryReactorClock, make_request
+from tests.server import FakeChannel, make_request
 from tests.unittest import logcontext_clean
 
 logger = logging.getLogger(__name__)
@@ -56,75 +55,82 @@ logger = logging.getLogger(__name__)
 T = TypeVar("T")
 
 
-class EndpointCancellationTestHelperMixin(unittest.TestCase):
-    """Provides helper methods for testing cancellation of endpoints."""
+def test_disconnect(
+    reactor: MemoryReactorClock,
+    channel: FakeChannel,
+    expect_cancellation: bool,
+    expected_body: Union[bytes, JsonDict],
+    expected_code: Optional[int] = None,
+) -> None:
+    """Disconnects an in-flight request and checks the response.
 
-    def _test_disconnect(
-        self,
-        reactor: ThreadedMemoryReactorClock,
-        channel: FakeChannel,
-        expect_cancellation: bool,
-        expected_body: Union[bytes, JsonDict],
-        expected_code: Optional[int] = None,
-    ) -> None:
-        """Disconnects an in-flight request and checks the response.
+    Args:
+        reactor: The twisted reactor running the request handler.
+        channel: The `FakeChannel` for the request.
+        expect_cancellation: `True` if request processing is expected to be cancelled,
+            `False` if the request should run to completion.
+        expected_body: The expected response for the request.
+        expected_code: The expected status code for the request. Defaults to `200` or
+            `499` depending on `expect_cancellation`.
+    """
+    # Determine the expected status code.
+    if expected_code is None:
+        if expect_cancellation:
+            expected_code = HTTP_STATUS_REQUEST_CANCELLED
+        else:
+            expected_code = HTTPStatus.OK
 
-        Args:
-            reactor: The twisted reactor running the request handler.
-            channel: The `FakeChannel` for the request.
-            expect_cancellation: `True` if request processing is expected to be
-                cancelled, `False` if the request should run to completion.
-            expected_body: The expected response for the request.
-            expected_code: The expected status code for the request. Defaults to `200`
-                or `499` depending on `expect_cancellation`.
-        """
-        # Determine the expected status code.
-        if expected_code is None:
-            if expect_cancellation:
-                expected_code = HTTP_STATUS_REQUEST_CANCELLED
-            else:
-                expected_code = HTTPStatus.OK
-
-        request = channel.request
-        self.assertFalse(
-            channel.is_finished(),
+    request = channel.request
+    if channel.is_finished():
+        raise AssertionError(
             "Request finished before we could disconnect - "
-            "was `await_result=False` passed to `make_request`?",
+            "ensure `await_result=False` is passed to `make_request`.",
         )
 
-        # We're about to disconnect the request. This also disconnects the channel, so
-        # we have to rely on mocks to extract the response.
-        respond_method: Callable[..., Any]
-        if isinstance(expected_body, bytes):
-            respond_method = respond_with_html_bytes
+    # We're about to disconnect the request. This also disconnects the channel, so we
+    # have to rely on mocks to extract the response.
+    respond_method: Callable[..., Any]
+    if isinstance(expected_body, bytes):
+        respond_method = respond_with_html_bytes
+    else:
+        respond_method = respond_with_json
+
+    with mock.patch(
+        f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+    ) as respond_mock:
+        # Disconnect the request.
+        request.connectionLost(reason=ConnectionDone())
+
+        if expect_cancellation:
+            # An immediate cancellation is expected.
+            respond_mock.assert_called_once()
         else:
-            respond_method = respond_with_json
+            respond_mock.assert_not_called()
 
-        with mock.patch(
-            f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
-        ) as respond_mock:
-            # Disconnect the request.
-            request.connectionLost(reason=ConnectionDone())
+            # The handler is expected to run to completion.
+            reactor.advance(1.0)
+            respond_mock.assert_called_once()
 
-            if expect_cancellation:
-                # An immediate cancellation is expected.
-                respond_mock.assert_called_once()
-                args, _kwargs = respond_mock.call_args
-                code, body = args[1], args[2]
-                self.assertEqual(code, expected_code)
-                self.assertEqual(request.code, expected_code)
-                self.assertEqual(body, expected_body)
-            else:
-                respond_mock.assert_not_called()
-
-                # The handler is expected to run to completion.
-                reactor.pump([1.0])
-                respond_mock.assert_called_once()
-                args, _kwargs = respond_mock.call_args
-                code, body = args[1], args[2]
-                self.assertEqual(code, expected_code)
-                self.assertEqual(request.code, expected_code)
-                self.assertEqual(body, expected_body)
+        args, _kwargs = respond_mock.call_args
+        code, body = args[1], args[2]
+
+        if code != expected_code:
+            raise AssertionError(
+                f"{code} != {expected_code} : "
+                "Request did not finish with the expected status code."
+            )
+
+        if request.code != expected_code:
+            raise AssertionError(
+                f"{request.code} != {expected_code} : "
+                "Request did not finish with the expected status code."
+            )
+
+        if body != expected_body:
+            raise AssertionError(
+                f"{body!r} != {expected_body!r} : "
+                "Request did not finish with the expected status code."
+            )
 
 
 @logcontext_clean