diff options
Diffstat (limited to 'tests/test_server.py')
-rw-r--r-- | tests/test_server.py | 111 |
1 files changed, 109 insertions, 2 deletions
diff --git a/tests/test_server.py b/tests/test_server.py index f2ffbc895b..0f1eb43cbc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -13,18 +13,28 @@ # limitations under the License. import re +from http import HTTPStatus +from typing import Tuple from twisted.internet.defer import Deferred from twisted.web.resource import Resource from synapse.api.errors import Codes, RedirectException, SynapseError from synapse.config.server import parse_listener_def -from synapse.http.server import DirectServeHtmlResource, JsonResource, OptionsResource -from synapse.http.site import SynapseSite +from synapse.http.server import ( + DirectServeHtmlResource, + DirectServeJsonResource, + JsonResource, + OptionsResource, + cancellable, +) +from synapse.http.site import SynapseRequest, SynapseSite from synapse.logging.context import make_deferred_yieldable +from synapse.types import JsonDict from synapse.util import Clock from tests import unittest +from tests.http.server._base import EndpointCancellationTestHelperMixin from tests.server import ( FakeSite, ThreadedMemoryReactorClock, @@ -363,3 +373,100 @@ class WrapHtmlRequestHandlerTests(unittest.TestCase): self.assertEqual(channel.result["code"], b"200") self.assertNotIn("body", channel.result) + + +class CancellableDirectServeJsonResource(DirectServeJsonResource): + def __init__(self, clock: Clock): + super().__init__() + self.clock = clock + + @cancellable + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, {"result": True} + + +class CancellableDirectServeHtmlResource(DirectServeHtmlResource): + ERROR_TEMPLATE = "{code} {msg}" + + def __init__(self, clock: Clock): + super().__init__() + self.clock = clock + + @cancellable + async def _async_render_GET(self, request: SynapseRequest) -> Tuple[int, bytes]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, b"ok" + + async def _async_render_POST(self, request: SynapseRequest) -> Tuple[int, bytes]: + await self.clock.sleep(1.0) + return HTTPStatus.OK, b"ok" + + +class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin): + """Tests for `DirectServeJsonResource` cancellation.""" + + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + self.clock = Clock(self.reactor) + self.resource = CancellableDirectServeJsonResource(self.clock) + self.site = FakeSite(self.resource, self.reactor) + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = make_request( + self.reactor, self.site, "GET", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body={"error": "Request cancelled", "errcode": Codes.UNKNOWN}, + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = make_request( + self.reactor, self.site, "POST", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=False, + expected_body={"result": True}, + ) + + +class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin): + """Tests for `DirectServeHtmlResource` cancellation.""" + + def setUp(self): + self.reactor = ThreadedMemoryReactorClock() + self.clock = Clock(self.reactor) + self.resource = CancellableDirectServeHtmlResource(self.clock) + self.site = FakeSite(self.resource, self.reactor) + + def test_cancellable_disconnect(self) -> None: + """Test that handlers with the `@cancellable` flag can be cancelled.""" + channel = make_request( + self.reactor, self.site, "GET", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, + channel, + expect_cancellation=True, + expected_body=b"499 Request cancelled", + ) + + def test_uncancellable_disconnect(self) -> None: + """Test that handlers without the `@cancellable` flag cannot be cancelled.""" + channel = make_request( + self.reactor, self.site, "POST", "/sleep", await_result=False + ) + self._test_disconnect( + self.reactor, channel, expect_cancellation=False, expected_body=b"ok" + ) |