summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12698.misc1
-rw-r--r--synapse/http/server.py2
-rw-r--r--tests/test_server.py111
3 files changed, 112 insertions, 2 deletions
diff --git a/changelog.d/12698.misc b/changelog.d/12698.misc
new file mode 100644
index 0000000000..5d626352f9
--- /dev/null
+++ b/changelog.d/12698.misc
@@ -0,0 +1 @@
+Respect the `@cancellable` flag for `DirectServe{Html,Json}Resource`s.
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 4b4debc5cd..f6d4d8db86 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -382,6 +382,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
         method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
         if method_handler:
+            request.is_render_cancellable = is_method_cancellable(method_handler)
+
             raw_callback_return = method_handler(request)
 
             # Is it synchronous? We'll allow this for now.
diff --git a/tests/test_server.py b/tests/test_server.py
index f2ffbc895b..0f1eb43cbc 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -13,18 +13,28 @@
 # limitations under the License.
 
 import re
+from http import HTTPStatus
+from typing import Tuple
 
 from twisted.internet.defer import Deferred
 from twisted.web.resource import Resource
 
 from synapse.api.errors import Codes, RedirectException, SynapseError
 from synapse.config.server import parse_listener_def
-from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
-from synapse.http.site import SynapseSite
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    DirectServeJsonResource,
+    JsonResource,
+    OptionsResource,
+    cancellable,
+)
+from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.logging.context import make_deferred_yieldable
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
 from tests.server import (
     FakeSite,
     ThreadedMemoryReactorClock,
@@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertNotIn("body", channel.result)
+
+
+class CancellableDirectServeJsonResource(DirectServeJsonResource):
+    def __init__(self, clock: Clock):
+        super().__init__()
+        self.clock = clock
+
+    @cancellable
+    async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, {"result": True}
+
+    async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, {"result": True}
+
+
+class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
+    ERROR_TEMPLATE = "{code} {msg}"
+
+    def __init__(self, clock: Clock):
+        super().__init__()
+        self.clock = clock
+
+    @cancellable
+    async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, b"ok"
+
+    async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, b"ok"
+
+
+class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+    """Tests for `DirectServeJsonResource` cancellation."""
+
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+        self.clock = Clock(self.reactor)
+        self.resource = CancellableDirectServeJsonResource(self.clock)
+        self.site = FakeSite(self.resource, self.reactor)
+
+    def test_cancellable_disconnect(self) -> None:
+        """Test that handlers with the `@cancellable` flag can be cancelled."""
+        channel = make_request(
+            self.reactor, self.site, "GET", "/sleep", await_result=False
+        )
+        self._test_disconnect(
+            self.reactor,
+            channel,
+            expect_cancellation=True,
+            expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+        )
+
+    def test_uncancellable_disconnect(self) -> None:
+        """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+        channel = make_request(
+            self.reactor, self.site, "POST", "/sleep", await_result=False
+        )
+        self._test_disconnect(
+            self.reactor,
+            channel,
+            expect_cancellation=False,
+            expected_body={"result": True},
+        )
+
+
+class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+    """Tests for `DirectServeHtmlResource` cancellation."""
+
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+        self.clock = Clock(self.reactor)
+        self.resource = CancellableDirectServeHtmlResource(self.clock)
+        self.site = FakeSite(self.resource, self.reactor)
+
+    def test_cancellable_disconnect(self) -> None:
+        """Test that handlers with the `@cancellable` flag can be cancelled."""
+        channel = make_request(
+            self.reactor, self.site, "GET", "/sleep", await_result=False
+        )
+        self._test_disconnect(
+            self.reactor,
+            channel,
+            expect_cancellation=True,
+            expected_body=b"499 Request cancelled",
+        )
+
+    def test_uncancellable_disconnect(self) -> None:
+        """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+        channel = make_request(
+            self.reactor, self.site, "POST", "/sleep", await_result=False
+        )
+        self._test_disconnect(
+            self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
+        )