diff --git a/tests/http/server/__init__.py b/tests/http/server/__init__.py
new file mode 100644
index 0000000000..3a5f22c022
--- /dev/null
+++ b/tests/http/server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
new file mode 100644
index 0000000000..b9f1a381aa
--- /dev/null
+++ b/tests/http/server/_base.py
@@ -0,0 +1,100 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unles4s required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from http import HTTPStatus
+from typing import Any, Callable, Optional, Union
+from unittest import mock
+
+from twisted.internet.error import ConnectionDone
+
+from synapse.http.server import (
+ HTTP_STATUS_REQUEST_CANCELLED,
+ respond_with_html_bytes,
+ respond_with_json,
+)
+from synapse.types import JsonDict
+
+from tests import unittest
+from tests.server import FakeChannel, ThreadedMemoryReactorClock
+
+
+class EndpointCancellationTestHelperMixin(unittest.TestCase):
+ """Provides helper methods for testing cancellation of endpoints."""
+
+ def _test_disconnect(
+ self,
+ reactor: ThreadedMemoryReactorClock,
+ channel: FakeChannel,
+ expect_cancellation: bool,
+ expected_body: Union[bytes, JsonDict],
+ expected_code: Optional[int] = None,
+ ) -> None:
+ """Disconnects an in-flight request and checks the response.
+
+ Args:
+ reactor: The twisted reactor running the request handler.
+ channel: The `FakeChannel` for the request.
+ expect_cancellation: `True` if request processing is expected to be
+ cancelled, `False` if the request should run to completion.
+ expected_body: The expected response for the request.
+ expected_code: The expected status code for the request. Defaults to `200`
+ or `499` depending on `expect_cancellation`.
+ """
+ # Determine the expected status code.
+ if expected_code is None:
+ if expect_cancellation:
+ expected_code = HTTP_STATUS_REQUEST_CANCELLED
+ else:
+ expected_code = HTTPStatus.OK
+
+ request = channel.request
+ self.assertFalse(
+ channel.is_finished(),
+ "Request finished before we could disconnect - "
+ "was `await_result=False` passed to `make_request`?",
+ )
+
+ # We're about to disconnect the request. This also disconnects the channel, so
+ # we have to rely on mocks to extract the response.
+ respond_method: Callable[..., Any]
+ if isinstance(expected_body, bytes):
+ respond_method = respond_with_html_bytes
+ else:
+ respond_method = respond_with_json
+
+ with mock.patch(
+ f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+ ) as respond_mock:
+ # Disconnect the request.
+ request.connectionLost(reason=ConnectionDone())
+
+ if expect_cancellation:
+ # An immediate cancellation is expected.
+ respond_mock.assert_called_once()
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+ self.assertEqual(code, expected_code)
+ self.assertEqual(request.code, expected_code)
+ self.assertEqual(body, expected_body)
+ else:
+ respond_mock.assert_not_called()
+
+ # The handler is expected to run to completion.
+ reactor.pump([1.0])
+ respond_mock.assert_called_once()
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+ self.assertEqual(code, expected_code)
+ self.assertEqual(request.code, expected_code)
+ self.assertEqual(body, expected_body)
|