diff options
Diffstat (limited to 'tests/http')
-rw-r--r-- | tests/http/server/_base.py | 132 | ||||
-rw-r--r-- | tests/http/test_servlet.py | 10 |
2 files changed, 73 insertions, 69 deletions
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py index 57b92beb87..994d8880b0 100644 --- a/tests/http/server/_base.py +++ b/tests/http/server/_base.py @@ -46,8 +46,7 @@ from synapse.http.site import SynapseRequest from synapse.logging.context import LoggingContext, make_deferred_yieldable from synapse.types import JsonDict -from tests import unittest -from tests.server import FakeChannel, ThreadedMemoryReactorClock, make_request +from tests.server import FakeChannel, make_request from tests.unittest import logcontext_clean logger = logging.getLogger(__name__) @@ -56,75 +55,82 @@ logger = logging.getLogger(__name__) T = TypeVar("T") -class EndpointCancellationTestHelperMixin(unittest.TestCase): - """Provides helper methods for testing cancellation of endpoints.""" +def test_disconnect( + reactor: MemoryReactorClock, + channel: FakeChannel, + expect_cancellation: bool, + expected_body: Union[bytes, JsonDict], + expected_code: Optional[int] = None, +) -> None: + """Disconnects an in-flight request and checks the response. - def _test_disconnect( - self, - reactor: ThreadedMemoryReactorClock, - channel: FakeChannel, - expect_cancellation: bool, - expected_body: Union[bytes, JsonDict], - expected_code: Optional[int] = None, - ) -> None: - """Disconnects an in-flight request and checks the response. + Args: + reactor: The twisted reactor running the request handler. + channel: The `FakeChannel` for the request. + expect_cancellation: `True` if request processing is expected to be cancelled, + `False` if the request should run to completion. + expected_body: The expected response for the request. + expected_code: The expected status code for the request. Defaults to `200` or + `499` depending on `expect_cancellation`. + """ + # Determine the expected status code. + if expected_code is None: + if expect_cancellation: + expected_code = HTTP_STATUS_REQUEST_CANCELLED + else: + expected_code = HTTPStatus.OK - Args: - reactor: The twisted reactor running the request handler. - channel: The `FakeChannel` for the request. - expect_cancellation: `True` if request processing is expected to be - cancelled, `False` if the request should run to completion. - expected_body: The expected response for the request. - expected_code: The expected status code for the request. Defaults to `200` - or `499` depending on `expect_cancellation`. - """ - # Determine the expected status code. - if expected_code is None: - if expect_cancellation: - expected_code = HTTP_STATUS_REQUEST_CANCELLED - else: - expected_code = HTTPStatus.OK - - request = channel.request - self.assertFalse( - channel.is_finished(), + request = channel.request + if channel.is_finished(): + raise AssertionError( "Request finished before we could disconnect - " - "was `await_result=False` passed to `make_request`?", + "ensure `await_result=False` is passed to `make_request`.", ) - # We're about to disconnect the request. This also disconnects the channel, so - # we have to rely on mocks to extract the response. - respond_method: Callable[..., Any] - if isinstance(expected_body, bytes): - respond_method = respond_with_html_bytes + # We're about to disconnect the request. This also disconnects the channel, so we + # have to rely on mocks to extract the response. + respond_method: Callable[..., Any] + if isinstance(expected_body, bytes): + respond_method = respond_with_html_bytes + else: + respond_method = respond_with_json + + with mock.patch( + f"synapse.http.server.{respond_method.__name__}", wraps=respond_method + ) as respond_mock: + # Disconnect the request. + request.connectionLost(reason=ConnectionDone()) + + if expect_cancellation: + # An immediate cancellation is expected. + respond_mock.assert_called_once() else: - respond_method = respond_with_json + respond_mock.assert_not_called() - with mock.patch( - f"synapse.http.server.{respond_method.__name__}", wraps=respond_method - ) as respond_mock: - # Disconnect the request. - request.connectionLost(reason=ConnectionDone()) + # The handler is expected to run to completion. + reactor.advance(1.0) + respond_mock.assert_called_once() - if expect_cancellation: - # An immediate cancellation is expected. - respond_mock.assert_called_once() - args, _kwargs = respond_mock.call_args - code, body = args[1], args[2] - self.assertEqual(code, expected_code) - self.assertEqual(request.code, expected_code) - self.assertEqual(body, expected_body) - else: - respond_mock.assert_not_called() - - # The handler is expected to run to completion. - reactor.pump([1.0]) - respond_mock.assert_called_once() - args, _kwargs = respond_mock.call_args - code, body = args[1], args[2] - self.assertEqual(code, expected_code) - self.assertEqual(request.code, expected_code) - self.assertEqual(body, expected_body) + args, _kwargs = respond_mock.call_args + code, body = args[1], args[2] + + if code != expected_code: + raise AssertionError( + f"{code} != {expected_code} : " + "Request did not finish with the expected status code." + ) + + if request.code != expected_code: + raise AssertionError( + f"{request.code} != {expected_code} : " + "Request did not finish with the expected status code." + ) + + if body != expected_body: + raise AssertionError( + f"{body!r} != {expected_body!r} : " + "Request did not finish with the expected status code." + ) @logcontext_clean diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py index b3655d7b44..bb966c80c6 100644 --- a/tests/http/test_servlet.py +++ b/tests/http/test_servlet.py @@ -30,7 +30,7 @@ from synapse.server import HomeServer from synapse.types import JsonDict from tests import unittest -from tests.http.server._base import EndpointCancellationTestHelperMixin +from tests.http.server._base import test_disconnect def make_request(content): @@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet): return HTTPStatus.OK, {"result": True} -class TestRestServletCancellation( - unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin -): +class TestRestServletCancellation(unittest.HomeserverTestCase): """Tests for `RestServlet` cancellation.""" servlets = [ @@ -120,7 +118,7 @@ class TestRestServletCancellation( def test_cancellable_disconnect(self) -> None: """Test that handlers with the `@cancellable` flag can be cancelled.""" channel = self.make_request("GET", "/sleep", await_result=False) - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=True, @@ -130,7 +128,7 @@ class TestRestServletCancellation( def test_uncancellable_disconnect(self) -> None: """Test that handlers without the `@cancellable` flag cannot be cancelled.""" channel = self.make_request("POST", "/sleep", await_result=False) - self._test_disconnect( + test_disconnect( self.reactor, channel, expect_cancellation=False, |