diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 57b92beb87..994d8880b0 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -46,8 +46,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.types import JsonDict
-from tests import unittest
-from tests.server import FakeChannel, ThreadedMemoryReactorClock, make_request
+from tests.server import FakeChannel, make_request
from tests.unittest import logcontext_clean
logger = logging.getLogger(__name__)
@@ -56,75 +55,82 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
-class EndpointCancellationTestHelperMixin(unittest.TestCase):
- """Provides helper methods for testing cancellation of endpoints."""
+def test_disconnect(
+ reactor: MemoryReactorClock,
+ channel: FakeChannel,
+ expect_cancellation: bool,
+ expected_body: Union[bytes, JsonDict],
+ expected_code: Optional[int] = None,
+) -> None:
+ """Disconnects an in-flight request and checks the response.
- def _test_disconnect(
- self,
- reactor: ThreadedMemoryReactorClock,
- channel: FakeChannel,
- expect_cancellation: bool,
- expected_body: Union[bytes, JsonDict],
- expected_code: Optional[int] = None,
- ) -> None:
- """Disconnects an in-flight request and checks the response.
+ Args:
+ reactor: The twisted reactor running the request handler.
+ channel: The `FakeChannel` for the request.
+ expect_cancellation: `True` if request processing is expected to be cancelled,
+ `False` if the request should run to completion.
+ expected_body: The expected response for the request.
+ expected_code: The expected status code for the request. Defaults to `200` or
+ `499` depending on `expect_cancellation`.
+ """
+ # Determine the expected status code.
+ if expected_code is None:
+ if expect_cancellation:
+ expected_code = HTTP_STATUS_REQUEST_CANCELLED
+ else:
+ expected_code = HTTPStatus.OK
- Args:
- reactor: The twisted reactor running the request handler.
- channel: The `FakeChannel` for the request.
- expect_cancellation: `True` if request processing is expected to be
- cancelled, `False` if the request should run to completion.
- expected_body: The expected response for the request.
- expected_code: The expected status code for the request. Defaults to `200`
- or `499` depending on `expect_cancellation`.
- """
- # Determine the expected status code.
- if expected_code is None:
- if expect_cancellation:
- expected_code = HTTP_STATUS_REQUEST_CANCELLED
- else:
- expected_code = HTTPStatus.OK
-
- request = channel.request
- self.assertFalse(
- channel.is_finished(),
+ request = channel.request
+ if channel.is_finished():
+ raise AssertionError(
"Request finished before we could disconnect - "
- "was `await_result=False` passed to `make_request`?",
+ "ensure `await_result=False` is passed to `make_request`.",
)
- # We're about to disconnect the request. This also disconnects the channel, so
- # we have to rely on mocks to extract the response.
- respond_method: Callable[..., Any]
- if isinstance(expected_body, bytes):
- respond_method = respond_with_html_bytes
+ # We're about to disconnect the request. This also disconnects the channel, so we
+ # have to rely on mocks to extract the response.
+ respond_method: Callable[..., Any]
+ if isinstance(expected_body, bytes):
+ respond_method = respond_with_html_bytes
+ else:
+ respond_method = respond_with_json
+
+ with mock.patch(
+ f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+ ) as respond_mock:
+ # Disconnect the request.
+ request.connectionLost(reason=ConnectionDone())
+
+ if expect_cancellation:
+ # An immediate cancellation is expected.
+ respond_mock.assert_called_once()
else:
- respond_method = respond_with_json
+ respond_mock.assert_not_called()
- with mock.patch(
- f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
- ) as respond_mock:
- # Disconnect the request.
- request.connectionLost(reason=ConnectionDone())
+ # The handler is expected to run to completion.
+ reactor.advance(1.0)
+ respond_mock.assert_called_once()
- if expect_cancellation:
- # An immediate cancellation is expected.
- respond_mock.assert_called_once()
- args, _kwargs = respond_mock.call_args
- code, body = args[1], args[2]
- self.assertEqual(code, expected_code)
- self.assertEqual(request.code, expected_code)
- self.assertEqual(body, expected_body)
- else:
- respond_mock.assert_not_called()
-
- # The handler is expected to run to completion.
- reactor.pump([1.0])
- respond_mock.assert_called_once()
- args, _kwargs = respond_mock.call_args
- code, body = args[1], args[2]
- self.assertEqual(code, expected_code)
- self.assertEqual(request.code, expected_code)
- self.assertEqual(body, expected_body)
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+
+ if code != expected_code:
+ raise AssertionError(
+ f"{code} != {expected_code} : "
+ "Request did not finish with the expected status code."
+ )
+
+ if request.code != expected_code:
+ raise AssertionError(
+ f"{request.code} != {expected_code} : "
+ "Request did not finish with the expected status code."
+ )
+
+ if body != expected_body:
+ raise AssertionError(
+ f"{body!r} != {expected_body!r} : "
+ "Request did not finish with the expected status code."
+ )
@logcontext_clean
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index b3655d7b44..bb966c80c6 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -30,7 +30,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
def make_request(content):
@@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet):
return HTTPStatus.OK, {"result": True}
-class TestRestServletCancellation(
- unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class TestRestServletCancellation(unittest.HomeserverTestCase):
"""Tests for `RestServlet` cancellation."""
servlets = [
@@ -120,7 +118,7 @@ class TestRestServletCancellation(
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -130,7 +128,7 @@ class TestRestServletCancellation(
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
|