diff --git a/changelog.d/12929.misc b/changelog.d/12929.misc
new file mode 100644
index 0000000000..20718d258d
--- /dev/null
+++ b/changelog.d/12929.misc
@@ -0,0 +1 @@
+Clean up the test code for client disconnection.
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index e63885c1c9..d33e86db4c 100644
--- a/tests/federation/transport/server/test__base.py
+++ b/tests/federation/transport/server/test__base.py
@@ -24,7 +24,7 @@ from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
class CancellableFederationServlet(BaseFederationServlet):
@@ -54,9 +54,7 @@ class CancellableFederationServlet(BaseFederationServlet):
return HTTPStatus.OK, {"result": True}
-class BaseFederationServletCancellationTests(
- unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCase):
"""Tests for `BaseFederationServlet` cancellation."""
skip = "`BaseFederationServlet` does not support cancellation yet."
@@ -86,7 +84,7 @@ class BaseFederationServletCancellationTests(
# request won't be processed.
self.pump()
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -106,7 +104,7 @@ class BaseFederationServletCancellationTests(
# request won't be processed.
self.pump()
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 57b92beb87..994d8880b0 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -46,8 +46,7 @@ from synapse.http.site import SynapseRequest
from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.types import JsonDict
-from tests import unittest
-from tests.server import FakeChannel, ThreadedMemoryReactorClock, make_request
+from tests.server import FakeChannel, make_request
from tests.unittest import logcontext_clean
logger = logging.getLogger(__name__)
@@ -56,75 +55,82 @@ logger = logging.getLogger(__name__)
T = TypeVar("T")
-class EndpointCancellationTestHelperMixin(unittest.TestCase):
- """Provides helper methods for testing cancellation of endpoints."""
+def test_disconnect(
+ reactor: MemoryReactorClock,
+ channel: FakeChannel,
+ expect_cancellation: bool,
+ expected_body: Union[bytes, JsonDict],
+ expected_code: Optional[int] = None,
+) -> None:
+ """Disconnects an in-flight request and checks the response.
- def _test_disconnect(
- self,
- reactor: ThreadedMemoryReactorClock,
- channel: FakeChannel,
- expect_cancellation: bool,
- expected_body: Union[bytes, JsonDict],
- expected_code: Optional[int] = None,
- ) -> None:
- """Disconnects an in-flight request and checks the response.
+ Args:
+ reactor: The twisted reactor running the request handler.
+ channel: The `FakeChannel` for the request.
+ expect_cancellation: `True` if request processing is expected to be cancelled,
+ `False` if the request should run to completion.
+ expected_body: The expected response for the request.
+ expected_code: The expected status code for the request. Defaults to `200` or
+ `499` depending on `expect_cancellation`.
+ """
+ # Determine the expected status code.
+ if expected_code is None:
+ if expect_cancellation:
+ expected_code = HTTP_STATUS_REQUEST_CANCELLED
+ else:
+ expected_code = HTTPStatus.OK
- Args:
- reactor: The twisted reactor running the request handler.
- channel: The `FakeChannel` for the request.
- expect_cancellation: `True` if request processing is expected to be
- cancelled, `False` if the request should run to completion.
- expected_body: The expected response for the request.
- expected_code: The expected status code for the request. Defaults to `200`
- or `499` depending on `expect_cancellation`.
- """
- # Determine the expected status code.
- if expected_code is None:
- if expect_cancellation:
- expected_code = HTTP_STATUS_REQUEST_CANCELLED
- else:
- expected_code = HTTPStatus.OK
-
- request = channel.request
- self.assertFalse(
- channel.is_finished(),
+ request = channel.request
+ if channel.is_finished():
+ raise AssertionError(
"Request finished before we could disconnect - "
- "was `await_result=False` passed to `make_request`?",
+ "ensure `await_result=False` is passed to `make_request`.",
)
- # We're about to disconnect the request. This also disconnects the channel, so
- # we have to rely on mocks to extract the response.
- respond_method: Callable[..., Any]
- if isinstance(expected_body, bytes):
- respond_method = respond_with_html_bytes
+ # We're about to disconnect the request. This also disconnects the channel, so we
+ # have to rely on mocks to extract the response.
+ respond_method: Callable[..., Any]
+ if isinstance(expected_body, bytes):
+ respond_method = respond_with_html_bytes
+ else:
+ respond_method = respond_with_json
+
+ with mock.patch(
+ f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+ ) as respond_mock:
+ # Disconnect the request.
+ request.connectionLost(reason=ConnectionDone())
+
+ if expect_cancellation:
+ # An immediate cancellation is expected.
+ respond_mock.assert_called_once()
else:
- respond_method = respond_with_json
+ respond_mock.assert_not_called()
- with mock.patch(
- f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
- ) as respond_mock:
- # Disconnect the request.
- request.connectionLost(reason=ConnectionDone())
+ # The handler is expected to run to completion.
+ reactor.advance(1.0)
+ respond_mock.assert_called_once()
- if expect_cancellation:
- # An immediate cancellation is expected.
- respond_mock.assert_called_once()
- args, _kwargs = respond_mock.call_args
- code, body = args[1], args[2]
- self.assertEqual(code, expected_code)
- self.assertEqual(request.code, expected_code)
- self.assertEqual(body, expected_body)
- else:
- respond_mock.assert_not_called()
-
- # The handler is expected to run to completion.
- reactor.pump([1.0])
- respond_mock.assert_called_once()
- args, _kwargs = respond_mock.call_args
- code, body = args[1], args[2]
- self.assertEqual(code, expected_code)
- self.assertEqual(request.code, expected_code)
- self.assertEqual(body, expected_body)
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+
+ if code != expected_code:
+ raise AssertionError(
+ f"{code} != {expected_code} : "
+ "Request did not finish with the expected status code."
+ )
+
+ if request.code != expected_code:
+ raise AssertionError(
+ f"{request.code} != {expected_code} : "
+ "Request did not finish with the expected status code."
+ )
+
+ if body != expected_body:
+ raise AssertionError(
+ f"{body!r} != {expected_body!r} : "
+ "Request did not finish with the expected status code."
+ )
@logcontext_clean
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index b3655d7b44..bb966c80c6 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -30,7 +30,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
def make_request(content):
@@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet):
return HTTPStatus.OK, {"result": True}
-class TestRestServletCancellation(
- unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class TestRestServletCancellation(unittest.HomeserverTestCase):
"""Tests for `RestServlet` cancellation."""
servlets = [
@@ -120,7 +118,7 @@ class TestRestServletCancellation(
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -130,7 +128,7 @@ class TestRestServletCancellation(
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index a5ab093a27..822a957c3a 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -25,7 +25,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
class CancellableReplicationEndpoint(ReplicationEndpoint):
@@ -69,9 +69,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
return HTTPStatus.OK, {"result": True}
-class ReplicationEndpointCancellationTestCase(
- unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""
def create_test_resource(self):
@@ -87,7 +85,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers with the `@cancellable` flag can be cancelled."""
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -98,7 +96,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
diff --git a/tests/test_server.py b/tests/test_server.py
index 0f1eb43cbc..847432f791 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -34,7 +34,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
from tests.server import (
FakeSite,
ThreadedMemoryReactorClock,
@@ -407,7 +407,7 @@ class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
return HTTPStatus.OK, b"ok"
-class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+class DirectServeJsonResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeJsonResource` cancellation."""
def setUp(self):
@@ -421,7 +421,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -433,7 +433,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
@@ -441,7 +441,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
)
-class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeHtmlResource` cancellation."""
def setUp(self):
@@ -455,7 +455,7 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -467,6 +467,6 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
)
|