diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py
new file mode 100644
index 0000000000..3a5f22c022
--- /dev/null
+++ b/tests/http/server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
new file mode 100644
index 0000000000..b9f1a381aa
--- /dev/null
+++ b/tests/http/server/_base.py
@@ -0,0 +1,100 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unles4s required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Any, Callable, Optional, Union
+from unittest import mock
+
+from twisted.internet.error import ConnectionDone
+
+from synapse.http.server import (
+ HTTP_STATUS_REQUEST_CANCELLED,
+ respond_with_html_bytes,
+ respond_with_json,
+)
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel, ThreadedMemoryReactorClock
+
+
+class EndpointCancellationTestHelperMixin(unittest.TestCase):
+ """Provides helper methods for testing cancellation of endpoints."""
+
+ def _test_disconnect(
+ self,
+ reactor: ThreadedMemoryReactorClock,
+ channel: FakeChannel,
+ expect_cancellation: bool,
+ expected_body: Union[bytes, JsonDict],
+ expected_code: Optional[int] = None,
+ ) -> None:
+ """Disconnects an in-flight request and checks the response.
+
+ Args:
+ reactor: The twisted reactor running the request handler.
+ channel: The `FakeChannel` for the request.
+ expect_cancellation: `True` if request processing is expected to be
+ cancelled, `False` if the request should run to completion.
+ expected_body: The expected response for the request.
+ expected_code: The expected status code for the request. Defaults to `200`
+ or `499` depending on `expect_cancellation`.
+ """
+ # Determine the expected status code.
+ if expected_code is None:
+ if expect_cancellation:
+ expected_code = HTTP_STATUS_REQUEST_CANCELLED
+ else:
+ expected_code = HTTPStatus.OK
+
+ request = channel.request
+ self.assertFalse(
+ channel.is_finished(),
+ "Request finished before we could disconnect - "
+ "was `await_result=False` passed to `make_request`?",
+ )
+
+ # We're about to disconnect the request. This also disconnects the channel, so
+ # we have to rely on mocks to extract the response.
+ respond_method: Callable[..., Any]
+ if isinstance(expected_body, bytes):
+ respond_method = respond_with_html_bytes
+ else:
+ respond_method = respond_with_json
+
+ with mock.patch(
+ f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+ ) as respond_mock:
+ # Disconnect the request.
+ request.connectionLost(reason=ConnectionDone())
+
+ if expect_cancellation:
+ # An immediate cancellation is expected.
+ respond_mock.assert_called_once()
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+ self.assertEqual(code, expected_code)
+ self.assertEqual(request.code, expected_code)
+ self.assertEqual(body, expected_body)
+ else:
+ respond_mock.assert_not_called()
+
+ # The handler is expected to run to completion.
+ reactor.pump([1.0])
+ respond_mock.assert_called_once()
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+ self.assertEqual(code, expected_code)
+ self.assertEqual(request.code, expected_code)
+ self.assertEqual(body, expected_body)
diff --git a/tests/server.py b/tests/server.py
index 8f30e250c8..aaefcfc46c 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -109,6 +109,17 @@ class FakeChannel:
_ip: str = "127.0.0.1"
_producer: Optional[Union[IPullProducer, IPushProducer]] = None
resource_usage: Optional[ContextResourceUsage] = None
+ _request: Optional[Request] = None
+
+ @property
+ def request(self) -> Request:
+ assert self._request is not None
+ return self._request
+
+ @request.setter
+ def request(self, request: Request) -> None:
+ assert self._request is None
+ self._request = request
@property
def json_body(self):
@@ -322,6 +333,8 @@ def make_request(
channel = FakeChannel(site, reactor, ip=client_ip)
req = request(channel, site)
+ channel.request = req
+
req.content = BytesIO(content)
# Twisted expects to be at the end of the content when parsing the request.
req.content.seek(0, SEEK_END)
|