summary refs log tree commit diff
path: root/tests/http
diff options
context:
space:
mode:
Diffstat (limited to 'tests/http')
-rw-r--r--tests/http/test_servlet.py60
1 files changed, 59 insertions, 1 deletions
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index a80bfb9f4e..ad521525cf 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -12,16 +12,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+from http import HTTPStatus
 from io import BytesIO
+from typing import Tuple
 from unittest.mock import Mock
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import cancellable
 from synapse.http.servlet import (
+    RestServlet,
     parse_json_object_from_request,
     parse_json_value_from_request,
 )
+from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns
+from synapse.server import HomeServer
+from synapse.types import JsonDict
 
 from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
 
 
 def make_request(content):
@@ -76,3 +85,52 @@ class TestServletUtils(unittest.TestCase):
         # Test not an object
         with self.assertRaises(SynapseError):
             parse_json_object_from_request(make_request(b'["foo"]'))
+
+
+class CancellableRestServlet(RestServlet):
+    """A `RestServlet` with a mix of cancellable and uncancellable handlers."""
+
+    PATTERNS = client_patterns("/sleep$")
+
+    def __init__(self, hs: HomeServer):
+        super().__init__()
+        self.clock = hs.get_clock()
+
+    @cancellable
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, {"result": True}
+
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, {"result": True}
+
+
+class TestRestServletCancellation(
+    unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+    """Tests for `RestServlet` cancellation."""
+
+    servlets = [
+        lambda hs, http_server: CancellableRestServlet(hs).register(http_server)
+    ]
+
+    def test_cancellable_disconnect(self) -> None:
+        """Test that handlers with the `@cancellable` flag can be cancelled."""
+        channel = self.make_request("GET", "/sleep", await_result=False)
+        self._test_disconnect(
+            self.reactor,
+            channel,
+            expect_cancellation=True,
+            expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+        )
+
+    def test_uncancellable_disconnect(self) -> None:
+        """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+        channel = self.make_request("POST", "/sleep", await_result=False)
+        self._test_disconnect(
+            self.reactor,
+            channel,
+            expect_cancellation=False,
+            expected_body={"result": True},
+        )