summary refs log tree commit diff
path: root/tests/http
diff options
context:
space:
mode:
authorSean Quah <8349537+squahtx@users.noreply.github.com>2022-05-11 12:25:13 +0100
committerGitHub <noreply@github.com>2022-05-11 12:25:13 +0100
commit9d8e380d2e8267129de921b9b926257c36417cd2 (patch)
tree97bb00ba5fb4ba829992529ebbfaffb1cf173a33 /tests/http
parentRespect the `@cancellable` flag for `DirectServe{Html,Json}Resource`s (#12698) (diff)
downloadsynapse-9d8e380d2e8267129de921b9b926257c36417cd2.tar.xz
Respect the `@cancellable` flag for `RestServlet`s and `BaseFederationServlet`s (#12699)
Both `RestServlet`s and `BaseFederationServlet`s register their handlers
with `HttpServer.register_paths` / `JsonResource.register_paths`. Update
`JsonResource` to respect the `@cancellable` flag on handlers registered
in this way.

Although `ReplicationEndpoint` also registers itself using
`register_paths`, it does not pass the handler method that would have the
`@cancellable` flag directly, and so needs separate handling.

Signed-off-by: Sean Quah <seanq@element.io>
Diffstat (limited to 'tests/http')
-rw-r--r--tests/http/test_servlet.py60
1 files changed, 59 insertions, 1 deletions
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index a80bfb9f4e..ad521525cf 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -12,16 +12,25 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import json
+from http import HTTPStatus
 from io import BytesIO
+from typing import Tuple
 from unittest.mock import Mock
 
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
+from synapse.http.server import cancellable
 from synapse.http.servlet import (
+    RestServlet,
     parse_json_object_from_request,
     parse_json_value_from_request,
 )
+from synapse.http.site import SynapseRequest
+from synapse.rest.client._base import client_patterns
+from synapse.server import HomeServer
+from synapse.types import JsonDict
 
 from tests import unittest
+from tests.http.server._base import EndpointCancellationTestHelperMixin
 
 
 def make_request(content):
@@ -76,3 +85,52 @@ class TestServletUtils(unittest.TestCase):
         # Test not an object
         with self.assertRaises(SynapseError):
             parse_json_object_from_request(make_request(b'["foo"]'))
+
+
+class CancellableRestServlet(RestServlet):
+    """A `RestServlet` with a mix of cancellable and uncancellable handlers."""
+
+    PATTERNS = client_patterns("/sleep$")
+
+    def __init__(self, hs: HomeServer):
+        super().__init__()
+        self.clock = hs.get_clock()
+
+    @cancellable
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, {"result": True}
+
+    async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        await self.clock.sleep(1.0)
+        return HTTPStatus.OK, {"result": True}
+
+
+class TestRestServletCancellation(
+    unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
+):
+    """Tests for `RestServlet` cancellation."""
+
+    servlets = [
+        lambda hs, http_server: CancellableRestServlet(hs).register(http_server)
+    ]
+
+    def test_cancellable_disconnect(self) -> None:
+        """Test that handlers with the `@cancellable` flag can be cancelled."""
+        channel = self.make_request("GET", "/sleep", await_result=False)
+        self._test_disconnect(
+            self.reactor,
+            channel,
+            expect_cancellation=True,
+            expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN},
+        )
+
+    def test_uncancellable_disconnect(self) -> None:
+        """Test that handlers without the `@cancellable` flag cannot be cancelled."""
+        channel = self.make_request("POST", "/sleep", await_result=False)
+        self._test_disconnect(
+            self.reactor,
+            channel,
+            expect_cancellation=False,
+            expected_body={"result": True},
+        )