diff --git a/tests/test_server.py b/tests/test_server.py
index f2ffbc895b..0f1eb43cbc 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -13,18 +13,28 @@
# limitations under the License.
import re
+from http import HTTPStatus
+from typing import Tuple
from twisted.internet.defer import Deferred
from twisted.web.resource import Resource
from synapse.api.errors import Codes, RedirectException, SynapseError
from synapse.config.server import parse_listener_def
-from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
-from synapse.http.site import SynapseSite
+from synapse.http.server import (
+ DirectServeHtmlResource,
+ DirectServeJsonResource,
+ JsonResource,
+ OptionsResource,
+ cancellable,
+)
+from synapse.http.site import SynapseRequest, SynapseSite
from synapse.logging.context import make_deferred_yieldable
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
from tests.server import (
FakeSite,
ThreadedMemoryReactorClock,
@@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
self.assertEqual(channel.result["code"], b"200")
self.assertNotIn("body", channel.result)
+
+
+class CancellableDirectServeJsonResource(DirectServeJsonResource):
+ def __init__(self, clock: Clock):
+ super().__init__()
+ self.clock = clock
+
+ @cancellable
+ async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, {"result": True}
+
+
+class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
+ ERROR_TEMPLATE = "{code} {msg}"
+
+ def __init__(self, clock: Clock):
+ super().__init__()
+ self.clock = clock
+
+ @cancellable
+ async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, b"ok"
+
+ async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]:
+ await self.clock.sleep(1.0)
+ return HTTPStatus.OK, b"ok"
+
+
+class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+ """Tests for `DirectServeJsonResource` cancellation."""
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+ self.clock = Clock(self.reactor)
+ self.resource = CancellableDirectServeJsonResource(self.clock)
+ self.site = FakeSite(self.resource, self.reactor)
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "GET", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "POST", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=False,
+ expected_body={"result": True},
+ )
+
+
+class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+ """Tests for `DirectServeHtmlResource` cancellation."""
+
+ def setUp(self):
+ self.reactor = ThreadedMemoryReactorClock()
+ self.clock = Clock(self.reactor)
+ self.resource = CancellableDirectServeHtmlResource(self.clock)
+ self.site = FakeSite(self.resource, self.reactor)
+
+ def test_cancellable_disconnect(self) -> None:
+ """Test that handlers with the `@cancellable` flag can be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "GET", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor,
+ channel,
+ expect_cancellation=True,
+ expected_body=b"499 Request cancelled",
+ )
+
+ def test_uncancellable_disconnect(self) -> None:
+ """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+ channel = make_request(
+ self.reactor, self.site, "POST", "/sleep", await_result=False
+ )
+ self._test_disconnect(
+ self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
+ )
|