summary refs log tree commit diff
path: root/tests/test_server.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/test_server.py111
1 files changed, 109 insertions, 2 deletions
diff --git a/tests/test_server.py b/tests/test_server.py
index f2ffbc895b..0f1eb43cbc 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -13,18 +13,28 @@
 # limitations under the License.
 
 import re
+from http import HTTPStatus
+from typing import Tuple
 
 from twisted.internet.defer import Deferred
 from twisted.web.resource import Resource
 
 from synapse.api.errors import Codes, RedirectException, SynapseError
 from synapse.config.server import parse_listener_def
-from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource
-from synapse.http.site import SynapseSite
+from synapse.http.server import (
+    DirectServeHtmlResource,
+    DirectServeJsonResource,
+    JsonResource,
+    OptionsResource,
+    cancellable,
+)
+from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.logging.context import make_deferred_yieldable
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
 from tests.server import (
     FakeSite,
     ThreadedMemoryReactorClock,
@@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase):
 
         self.assertEqual(channel.result["code"], b"200")
         self.assertNotIn("body", channel.result)
+
+
+class CancellableDirectServeJsonResource(DirectServeJsonResource):
+    def __init__(self, clock: Clock):
+        super().__init__()
+        self.clock = clock
+
+    @cancellable
+    async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, {"result": True}
+
+    async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, {"result": True}
+
+
+class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
+    ERROR_TEMPLATE = "{code} {msg}"
+
+    def __init__(self, clock: Clock):
+        super().__init__()
+        self.clock = clock
+
+    @cancellable
+    async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, b"ok"
+
+    async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, b"ok"
+
+
+class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+    """Tests for `DirectServeJsonResource` cancellation."""
+
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+        self.clock = Clock(self.reactor)
+        self.resource = CancellableDirectServeJsonResource(self.clock)
+        self.site = FakeSite(self.resource, self.reactor)
+
+    def test_cancellable_disconnect(self) -> None:
+        """Test that handlers with the `@cancellable` flag can be cancelled."""
+        channel = make_request(
+            self.reactor, self.site, "GET", "/sleep", await_result=False
+        )
+        self._test_disconnect(
+            self.reactor,
+            channel,
+            expect_cancellation=True,
+            expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+        )
+
+    def test_uncancellable_disconnect(self) -> None:
+        """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+        channel = make_request(
+            self.reactor, self.site, "POST", "/sleep", await_result=False
+        )
+        self._test_disconnect(
+            self.reactor,
+            channel,
+            expect_cancellation=False,
+            expected_body={"result": True},
+        )
+
+
+class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+    """Tests for `DirectServeHtmlResource` cancellation."""
+
+    def setUp(self):
+        self.reactor = ThreadedMemoryReactorClock()
+        self.clock = Clock(self.reactor)
+        self.resource = CancellableDirectServeHtmlResource(self.clock)
+        self.site = FakeSite(self.resource, self.reactor)
+
+    def test_cancellable_disconnect(self) -> None:
+        """Test that handlers with the `@cancellable` flag can be cancelled."""
+        channel = make_request(
+            self.reactor, self.site, "GET", "/sleep", await_result=False
+        )
+        self._test_disconnect(
+            self.reactor,
+            channel,
+            expect_cancellation=True,
+            expected_body=b"499 Request cancelled",
+        )
+
+    def test_uncancellable_disconnect(self) -> None:
+        """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+        channel = make_request(
+            self.reactor, self.site, "POST", "/sleep", await_result=False
+        )
+        self._test_disconnect(
+            self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
+        )