summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12630.misc1
-rw-r--r--tests/http/server/__init__.py13
-rw-r--r--tests/http/server/_base.py100
-rw-r--r--tests/server.py13
4 files changed, 127 insertions, 0 deletions
diff --git a/changelog.d/12630.misc b/changelog.d/12630.misc
new file mode 100644
index 0000000000..43e12603e2
--- /dev/null
+++ b/changelog.d/12630.misc
@@ -0,0 +1 @@
+Add a helper class for testing request cancellation.
diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py
new file mode 100644
index 0000000000..3a5f22c022
--- /dev/null
+++ b/tests/http/server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
new file mode 100644
index 0000000000..b9f1a381aa
--- /dev/null
+++ b/tests/http/server/_base.py
@@ -0,0 +1,100 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unles4s required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Any, Callable, Optional, Union
+from unittest import mock
+
+from twisted.internet.error import ConnectionDone
+
+from synapse.http.server import (
+    HTTP_STATUS_REQUEST_CANCELLED,
+    respond_with_html_bytes,
+    respond_with_json,
+)
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel, ThreadedMemoryReactorClock
+
+
+class EndpointCancellationTestHelperMixin(unittest.TestCase):
+    """Provides helper methods for testing cancellation of endpoints."""
+
+    def _test_disconnect(
+        self,
+        reactor: ThreadedMemoryReactorClock,
+        channel: FakeChannel,
+        expect_cancellation: bool,
+        expected_body: Union[bytes, JsonDict],
+        expected_code: Optional[int] = None,
+    ) -> None:
+        """Disconnects an in-flight request and checks the response.
+
+        Args:
+            reactor: The twisted reactor running the request handler.
+            channel: The `FakeChannel` for the request.
+            expect_cancellation: `True` if request processing is expected to be
+                cancelled, `False` if the request should run to completion.
+            expected_body: The expected response for the request.
+            expected_code: The expected status code for the request. Defaults to `200`
+                or `499` depending on `expect_cancellation`.
+        """
+        # Determine the expected status code.
+        if expected_code is None:
+            if expect_cancellation:
+                expected_code = HTTP_STATUS_REQUEST_CANCELLED
+            else:
+                expected_code = HTTPStatus.OK
+
+        request = channel.request
+        self.assertFalse(
+            channel.is_finished(),
+            "Request finished before we could disconnect - "
+            "was `await_result=False` passed to `make_request`?",
+        )
+
+        # We're about to disconnect the request. This also disconnects the channel, so
+        # we have to rely on mocks to extract the response.
+        respond_method: Callable[..., Any]
+        if isinstance(expected_body, bytes):
+            respond_method = respond_with_html_bytes
+        else:
+            respond_method = respond_with_json
+
+        with mock.patch(
+            f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+        ) as respond_mock:
+            # Disconnect the request.
+            request.connectionLost(reason=ConnectionDone())
+
+            if expect_cancellation:
+                # An immediate cancellation is expected.
+                respond_mock.assert_called_once()
+                args, _kwargs = respond_mock.call_args
+                code, body = args[1], args[2]
+                self.assertEqual(code, expected_code)
+                self.assertEqual(request.code, expected_code)
+                self.assertEqual(body, expected_body)
+            else:
+                respond_mock.assert_not_called()
+
+                # The handler is expected to run to completion.
+                reactor.pump([1.0])
+                respond_mock.assert_called_once()
+                args, _kwargs = respond_mock.call_args
+                code, body = args[1], args[2]
+                self.assertEqual(code, expected_code)
+                self.assertEqual(request.code, expected_code)
+                self.assertEqual(body, expected_body)
diff --git a/tests/server.py b/tests/server.py
index 8f30e250c8..aaefcfc46c 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -109,6 +109,17 @@ class FakeChannel:
     _ip: str = "127.0.0.1"
     _producer: Optional[Union[IPullProducer, IPushProducer]] = None
     resource_usage: Optional[ContextResourceUsage] = None
+    _request: Optional[Request] = None
+
+    @property
+    def request(self) -> Request:
+        assert self._request is not None
+        return self._request
+
+    @request.setter
+    def request(self, request: Request) -> None:
+        assert self._request is None
+        self._request = request
 
     @property
     def json_body(self):
@@ -322,6 +333,8 @@ def make_request(
     channel = FakeChannel(site, reactor, ip=client_ip)
 
     req = request(channel, site)
+    channel.request = req
+
     req.content = BytesIO(content)
     # Twisted expects to be at the end of the content when parsing the request.
     req.content.seek(0, SEEK_END)