summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/12929.misc1
-rw-r--r--tests/federation/transport/server/test__base.py10
-rw-r--r--tests/http/server/_base.py132
-rw-r--r--tests/http/test_servlet.py10
-rw-r--r--tests/replication/http/test__base.py10
-rw-r--r--tests/test_server.py14
6 files changed, 89 insertions, 88 deletions
diff --git a/changelog.d/12929.misc b/changelog.d/12929.misc
new file mode 100644
index 0000000000..20718d258d
--- /dev/null
+++ b/changelog.d/12929.misc
@@ -0,0 +1 @@
+Clean up the test code for client disconnection.
diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index e63885c1c9..d33e86db4c 100644
--- a/tests/federation/transport/server/test__base.py
+++ b/tests/federation/transport/server/test__base.py
@@ -24,7 +24,7 @@ from synapse.types import JsonDict
 from synapse.util.ratelimitutils import FederationRateLimiter
 
 from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
 
 
 class CancellableFederationServlet(BaseFederationServlet):
@@ -54,9 +54,7 @@ class CancellableFederationServlet(BaseFederationServlet):
         return HTTPStatus.OK, {"result": True}
 
 
-class BaseFederationServletCancellationTests(
-    unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCase):
     """Tests for `BaseFederationServlet` cancellation."""
 
     skip = "`BaseFederationServlet` does not support cancellation yet."
@@ -86,7 +84,7 @@ class BaseFederationServletCancellationTests(
         # request won't be processed.
         self.pump()
 
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -106,7 +104,7 @@ class BaseFederationServletCancellationTests(
         # request won't be processed.
         self.pump()
 
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=False,
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 57b92beb87..994d8880b0 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -46,8 +46,7 @@ from synapse.http.site import SynapseRequest
 from synapse.logging.context import LoggingContext, make_deferred_yieldable
 from synapse.types import JsonDict
 
-from tests import unittest
-from tests.server import FakeChannel, ThreadedMemoryReactorClock, make_request
+from tests.server import FakeChannel, make_request
 from tests.unittest import logcontext_clean
 
 logger = logging.getLogger(__name__)
@@ -56,75 +55,82 @@ logger = logging.getLogger(__name__)
 T = TypeVar("T")
 
 
-class EndpointCancellationTestHelperMixin(unittest.TestCase):
-    """Provides helper methods for testing cancellation of endpoints."""
+def test_disconnect(
+    reactor: MemoryReactorClock,
+    channel: FakeChannel,
+    expect_cancellation: bool,
+    expected_body: Union[bytes, JsonDict],
+    expected_code: Optional[int] = None,
+) -> None:
+    """Disconnects an in-flight request and checks the response.
 
-    def _test_disconnect(
-        self,
-        reactor: ThreadedMemoryReactorClock,
-        channel: FakeChannel,
-        expect_cancellation: bool,
-        expected_body: Union[bytes, JsonDict],
-        expected_code: Optional[int] = None,
-    ) -> None:
-        """Disconnects an in-flight request and checks the response.
+    Args:
+        reactor: The twisted reactor running the request handler.
+        channel: The `FakeChannel` for the request.
+        expect_cancellation: `True` if request processing is expected to be cancelled,
+            `False` if the request should run to completion.
+        expected_body: The expected response for the request.
+        expected_code: The expected status code for the request. Defaults to `200` or
+            `499` depending on `expect_cancellation`.
+    """
+    # Determine the expected status code.
+    if expected_code is None:
+        if expect_cancellation:
+            expected_code = HTTP_STATUS_REQUEST_CANCELLED
+        else:
+            expected_code = HTTPStatus.OK
 
-        Args:
-            reactor: The twisted reactor running the request handler.
-            channel: The `FakeChannel` for the request.
-            expect_cancellation: `True` if request processing is expected to be
-                cancelled, `False` if the request should run to completion.
-            expected_body: The expected response for the request.
-            expected_code: The expected status code for the request. Defaults to `200`
-                or `499` depending on `expect_cancellation`.
-        """
-        # Determine the expected status code.
-        if expected_code is None:
-            if expect_cancellation:
-                expected_code = HTTP_STATUS_REQUEST_CANCELLED
-            else:
-                expected_code = HTTPStatus.OK
-
-        request = channel.request
-        self.assertFalse(
-            channel.is_finished(),
+    request = channel.request
+    if channel.is_finished():
+        raise AssertionError(
             "Request finished before we could disconnect - "
-            "was `await_result=False` passed to `make_request`?",
+            "ensure `await_result=False` is passed to `make_request`.",
         )
 
-        # We're about to disconnect the request. This also disconnects the channel, so
-        # we have to rely on mocks to extract the response.
-        respond_method: Callable[..., Any]
-        if isinstance(expected_body, bytes):
-            respond_method = respond_with_html_bytes
+    # We're about to disconnect the request. This also disconnects the channel, so we
+    # have to rely on mocks to extract the response.
+    respond_method: Callable[..., Any]
+    if isinstance(expected_body, bytes):
+        respond_method = respond_with_html_bytes
+    else:
+        respond_method = respond_with_json
+
+    with mock.patch(
+        f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+    ) as respond_mock:
+        # Disconnect the request.
+        request.connectionLost(reason=ConnectionDone())
+
+        if expect_cancellation:
+            # An immediate cancellation is expected.
+            respond_mock.assert_called_once()
         else:
-            respond_method = respond_with_json
+            respond_mock.assert_not_called()
 
-        with mock.patch(
-            f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
-        ) as respond_mock:
-            # Disconnect the request.
-            request.connectionLost(reason=ConnectionDone())
+            # The handler is expected to run to completion.
+            reactor.advance(1.0)
+            respond_mock.assert_called_once()
 
-            if expect_cancellation:
-                # An immediate cancellation is expected.
-                respond_mock.assert_called_once()
-                args, _kwargs = respond_mock.call_args
-                code, body = args[1], args[2]
-                self.assertEqual(code, expected_code)
-                self.assertEqual(request.code, expected_code)
-                self.assertEqual(body, expected_body)
-            else:
-                respond_mock.assert_not_called()
-
-                # The handler is expected to run to completion.
-                reactor.pump([1.0])
-                respond_mock.assert_called_once()
-                args, _kwargs = respond_mock.call_args
-                code, body = args[1], args[2]
-                self.assertEqual(code, expected_code)
-                self.assertEqual(request.code, expected_code)
-                self.assertEqual(body, expected_body)
+        args, _kwargs = respond_mock.call_args
+        code, body = args[1], args[2]
+
+        if code != expected_code:
+            raise AssertionError(
+                f"{code} != {expected_code} : "
+                "Request did not finish with the expected status code."
+            )
+
+        if request.code != expected_code:
+            raise AssertionError(
+                f"{request.code} != {expected_code} : "
+                "Request did not finish with the expected status code."
+            )
+
+        if body != expected_body:
+            raise AssertionError(
+                f"{body!r} != {expected_body!r} : "
+                "Request did not finish with the expected status code."
+            )
 
 
 @logcontext_clean
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index b3655d7b44..bb966c80c6 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -30,7 +30,7 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 
 from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
 
 
 def make_request(content):
@@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet):
         return HTTPStatus.OK, {"result": True}
 
 
-class TestRestServletCancellation(
-    unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class TestRestServletCancellation(unittest.HomeserverTestCase):
     """Tests for `RestServlet` cancellation."""
 
     servlets = [
@@ -120,7 +118,7 @@ class TestRestServletCancellation(
     def test_cancellable_disconnect(self) -> None:
         """Test that handlers with the `@cancellable` flag can be cancelled."""
         channel = self.make_request("GET", "/sleep", await_result=False)
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -130,7 +128,7 @@ class TestRestServletCancellation(
     def test_uncancellable_disconnect(self) -> None:
         """Test that handlers without the `@cancellable` flag cannot be cancelled."""
         channel = self.make_request("POST", "/sleep", await_result=False)
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=False,
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index a5ab093a27..822a957c3a 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -25,7 +25,7 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict
 
 from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
 
 
 class CancellableReplicationEndpoint(ReplicationEndpoint):
@@ -69,9 +69,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
         return HTTPStatus.OK, {"result": True}
 
 
-class ReplicationEndpointCancellationTestCase(
-    unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
     """Tests for `ReplicationEndpoint` cancellation."""
 
     def create_test_resource(self):
@@ -87,7 +85,7 @@ class ReplicationEndpointCancellationTestCase(
         """Test that handlers with the `@cancellable` flag can be cancelled."""
         path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
         channel = self.make_request("POST", path, await_result=False)
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -98,7 +96,7 @@ class ReplicationEndpointCancellationTestCase(
         """Test that handlers without the `@cancellable` flag cannot be cancelled."""
         path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
         channel = self.make_request("POST", path, await_result=False)
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=False,
diff --git a/tests/test_server.py b/tests/test_server.py
index 0f1eb43cbc..847432f791 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -34,7 +34,7 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
 from tests.server import (
     FakeSite,
     ThreadedMemoryReactorClock,
@@ -407,7 +407,7 @@ class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
         return HTTPStatus.OK, b"ok"
 
 
-class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+class DirectServeJsonResourceCancellationTests(unittest.TestCase):
     """Tests for `DirectServeJsonResource` cancellation."""
 
     def setUp(self):
@@ -421,7 +421,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
         channel = make_request(
             self.reactor, self.site, "GET", "/sleep", await_result=False
         )
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -433,7 +433,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
         channel = make_request(
             self.reactor, self.site, "POST", "/sleep", await_result=False
         )
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=False,
@@ -441,7 +441,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
         )
 
 
-class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
     """Tests for `DirectServeHtmlResource` cancellation."""
 
     def setUp(self):
@@ -455,7 +455,7 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
         channel = make_request(
             self.reactor, self.site, "GET", "/sleep", await_result=False
         )
-        self._test_disconnect(
+        test_disconnect(
             self.reactor,
             channel,
             expect_cancellation=True,
@@ -467,6 +467,6 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
         channel = make_request(
             self.reactor, self.site, "POST", "/sleep", await_result=False
         )
-        self._test_disconnect(
+        test_disconnect(
             self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
         )