summary refs log tree commit diff
path: root/tests/http
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http')
-rw-r--r--tests/http/server/_base.py132
-rw-r--r--tests/http/test_servlet.py10
2 files changed, 73 insertions, 69 deletions
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 57b92beb87..994d8880b0 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -46,8 +46,7 @@ from synapse.http.site import SynapseRequest
 from synapse.logging.context import LoggingContext, make_deferred_yieldable
 from synapse.types import JsonDict
 
-from tests import unittest
-from tests.server import FakeChannel, ThreadedMemoryReactorClock, make_request
+from tests.server import FakeChannel, make_request
 from tests.unittest import logcontext_clean
 
 logger = logging.getLogger(__name__)
@@ -56,75 +55,82 @@ logger = logging.getLogger(__name__)
 T = TypeVar("T")
 
 
-class EndpointCancellationTestHelperMixin(unittest.TestCase):
-    """Provides helper methods for testing cancellation of endpoints."""
+def test_disconnect(
+    reactor: MemoryReactorClock,
+    channel: FakeChannel,
+    expect_cancellation: bool,
+    expected_body: Union[bytes, JsonDict],
+    expected_code: Optional[int] = None,
+) -> None:
+    """Disconnects an in-flight request and checks the response.
 
-    def _test_disconnect(
-        self,
-        reactor: ThreadedMemoryReactorClock,
-        channel: FakeChannel,
-        expect_cancellation: bool,
-        expected_body: Union[bytes, JsonDict],
-        expected_code: Optional[int] = None,
-    ) -> None:
-        """Disconnects an in-flight request and checks the response.
+    Args:
+        reactor: The twisted reactor running the request handler.
+        channel: The `FakeChannel` for the request.
+        expect_cancellation: `True` if request processing is expected to be cancelled,
+            `False` if the request should run to completion.
+        expected_body: The expected response for the request.
+        expected_code: The expected status code for the request. Defaults to `200` or
+            `499` depending on `expect_cancellation`.
+    """
+    # Determine the expected status code.
+    if expected_code is None:
+        if expect_cancellation:
+            expected_code = HTTP_STATUS_REQUEST_CANCELLED
+        else:
+            expected_code = HTTPStatus.OK
 
-        Args:
-            reactor: The twisted reactor running the request handler.
-            channel: The `FakeChannel` for the request.
-            expect_cancellation: `True` if request processing is expected to be
-                cancelled, `False` if the request should run to completion.
-            expected_body: The expected response for the request.
-            expected_code: The expected status code for the request. Defaults to `200`
-                or `499` depending on `expect_cancellation`.
-        """
-        # Determine the expected status code.
-        if expected_code is None:
-            if expect_cancellation:
-                expected_code = HTTP_STATUS_REQUEST_CANCELLED
-            else:
-                expected_code = HTTPStatus.OK
-
-        request = channel.request
-        self.assertFalse(
-            channel.is_finished(),
+    request = channel.request
+    if channel.is_finished():
+        raise AssertionError(
             "Request finished before we could disconnect - "
-            "was `await_result=False` passed to `make_request`?",
+            "ensure `await_result=False` is passed to `make_request`.",
         )
 
-        # We're about to disconnect the request. This also disconnects the channel, so
-        # we have to rely on mocks to extract the response.
-        respond_method: Callable[..., Any]
-        if isinstance(expected_body, bytes):
-            respond_method = respond_with_html_bytes
+    # We're about to disconnect the request. This also disconnects the channel, so we
+    # have to rely on mocks to extract the response.
+    respond_method: Callable[..., Any]
+    if isinstance(expected_body, bytes):
+        respond_method = respond_with_html_bytes
+    else:
+        respond_method = respond_with_json
+
+    with mock.patch(
+        f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+    ) as respond_mock:
+        # Disconnect the request.
+        request.connectionLost(reason=ConnectionDone())
+
+        if expect_cancellation:
+            # An immediate cancellation is expected.
+            respond_mock.assert_called_once()
         else:
-            respond_method = respond_with_json
+            respond_mock.assert_not_called()
 
-        with mock.patch(
-            f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
-        ) as respond_mock:
-            # Disconnect the request.
-            request.connectionLost(reason=ConnectionDone())
+            # The handler is expected to run to completion.
+            reactor.advance(1.0)
+            respond_mock.assert_called_once()
 
-            if expect_cancellation:
-                # An immediate cancellation is expected.
-                respond_mock.assert_called_once()
-                args, _kwargs = respond_mock.call_args
-                code, body = args[1], args[2]
-                self.assertEqual(code, expected_code)
-                self.assertEqual(request.code, expected_code)
-                self.assertEqual(body, expected_body)
-            else:
-                respond_mock.assert_not_called()
-
-                # The handler is expected to run to completion.
-                reactor.pump([1.0])
-                respond_mock.assert_called_once()
-                args, _kwargs = respond_mock.call_args
-                code, body = args[1], args[2]
-                self.assertEqual(code, expected_code)
-                self.assertEqual(request.code, expected_code)
-                self.assertEqual(body, expected_body)
+        args, _kwargs = respond_mock.call_args
+        code, body = args[1], args[2]
+
+        if code != expected_code:
+            raise AssertionError(
+                f"{code} != {expected_code} : "
+                "Request did not finish with the expected status code."
+            )
+
+        if request.code != expected_code:
+            raise AssertionError(
+                f"{request.code} != {expected_code} : "
+                "Request did not finish with the expected status code."
+            )
+
+        if body != expected_body:
+            raise AssertionError(
+                f"{body!r} != {expected_body!r} : "
+                "Request did not finish with the expected status code."
+            )
 
 
 @logcontext_clean
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index b3655d7b44..bb966c80c6 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -30,7 +30,7 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 
 from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
 
 
 def make_request(content):
@@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet):
         return HTTPStatus.OK, {"result": True}
 
 
-class TestRestServletCancellation(
-    unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class TestRestServletCancellation(unittest.HomeserverTestCase):
     """Tests for `RestServlet` cancellation."""
 
     servlets = [
@@ -120,7 +118,7 @@ class TestRestServletCancellation(
     def test_cancellable_disconnect(self) -> None:
         """Test that handlers with the `@cancellable` flag can be cancelled."""
         channel = self.make_request("GET", "/sleep", await_result=False)
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -130,7 +128,7 @@ class TestRestServletCancellation(
     def test_uncancellable_disconnect(self) -> None:
         """Test that handlers without the `@cancellable` flag cannot be cancelled."""
         channel = self.make_request("POST", "/sleep", await_result=False)
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=False,