diff --git a/tests/federation/transport/server/test__base.py b/tests/federation/transport/server/test__base.py
index e63885c1c9..d33e86db4c 100644
--- a/tests/federation/transport/server/test__base.py
+++ b/tests/federation/transport/server/test__base.py
@@ -24,7 +24,7 @@ from synapse.types import JsonDict
from synapse.util.ratelimitutils import FederationRateLimiter
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
class CancellableFederationServlet(BaseFederationServlet):
@@ -54,9 +54,7 @@ class CancellableFederationServlet(BaseFederationServlet):
return HTTPStatus.OK, {"result": True}
-class BaseFederationServletCancellationTests(
- unittest.FederatingHomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class BaseFederationServletCancellationTests(unittest.FederatingHomeserverTestCase):
"""Tests for `BaseFederationServlet` cancellation."""
skip = "`BaseFederationServlet` does not support cancellation yet."
@@ -86,7 +84,7 @@ class BaseFederationServletCancellationTests(
# request won't be processed.
self.pump()
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -106,7 +104,7 @@ class BaseFederationServletCancellationTests(
# request won't be processed.
self.pump()
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 01ea7d2a42..b8b465d35b 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -154,7 +154,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
self._record_users()
# delete the device
- self.get_success(self.handler.delete_device(user1, "abc"))
+ self.get_success(self.handler.delete_devices(user1, ["abc"]))
# check the device was deleted
self.get_failure(self.handler.get_device(user1, "abc"), NotFoundError)
@@ -179,7 +179,7 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
# delete the device
- self.get_success(self.handler.delete_device(user1, "abc"))
+ self.get_success(self.handler.delete_devices(user1, ["abc"]))
# check that the device_inbox was deleted
res = self.get_success(
diff --git a/tests/handlers/test_room_summary.py b/tests/handlers/test_room_summary.py
index 0546655690..aa650756e4 100644
--- a/tests/handlers/test_room_summary.py
+++ b/tests/handlers/test_room_summary.py
@@ -178,7 +178,7 @@ class SpaceSummaryTestCase(unittest.HomeserverTestCase):
result_room_ids.append(result_room["room_id"])
result_children_ids.append(
[
- (cs["room_id"], cs["state_key"])
+ (result_room["room_id"], cs["state_key"])
for cs in result_room["children_state"]
]
)
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index b9f1a381aa..994d8880b0 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -12,89 +12,543 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import inspect
+import itertools
+import logging
from http import HTTPStatus
-from typing import Any, Callable, Optional, Union
+from typing import (
+ Any,
+ Callable,
+ ContextManager,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+)
from unittest import mock
+from unittest.mock import Mock
+from twisted.internet.defer import Deferred
from twisted.internet.error import ConnectionDone
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactorClock
+from twisted.web.server import Site
from synapse.http.server import (
HTTP_STATUS_REQUEST_CANCELLED,
respond_with_html_bytes,
respond_with_json,
)
+from synapse.http.site import SynapseRequest
+from synapse.logging.context import LoggingContext, make_deferred_yieldable
from synapse.types import JsonDict
-from tests import unittest
-from tests.server import FakeChannel, ThreadedMemoryReactorClock
+from tests.server import FakeChannel, make_request
+from tests.unittest import logcontext_clean
+logger = logging.getLogger(__name__)
-class EndpointCancellationTestHelperMixin(unittest.TestCase):
- """Provides helper methods for testing cancellation of endpoints."""
- def _test_disconnect(
- self,
- reactor: ThreadedMemoryReactorClock,
- channel: FakeChannel,
- expect_cancellation: bool,
- expected_body: Union[bytes, JsonDict],
- expected_code: Optional[int] = None,
- ) -> None:
- """Disconnects an in-flight request and checks the response.
+T = TypeVar("T")
- Args:
- reactor: The twisted reactor running the request handler.
- channel: The `FakeChannel` for the request.
- expect_cancellation: `True` if request processing is expected to be
- cancelled, `False` if the request should run to completion.
- expected_body: The expected response for the request.
- expected_code: The expected status code for the request. Defaults to `200`
- or `499` depending on `expect_cancellation`.
- """
- # Determine the expected status code.
- if expected_code is None:
- if expect_cancellation:
- expected_code = HTTP_STATUS_REQUEST_CANCELLED
- else:
- expected_code = HTTPStatus.OK
-
- request = channel.request
- self.assertFalse(
- channel.is_finished(),
+
+def test_disconnect(
+ reactor: MemoryReactorClock,
+ channel: FakeChannel,
+ expect_cancellation: bool,
+ expected_body: Union[bytes, JsonDict],
+ expected_code: Optional[int] = None,
+) -> None:
+ """Disconnects an in-flight request and checks the response.
+
+ Args:
+ reactor: The twisted reactor running the request handler.
+ channel: The `FakeChannel` for the request.
+ expect_cancellation: `True` if request processing is expected to be cancelled,
+ `False` if the request should run to completion.
+ expected_body: The expected response for the request.
+ expected_code: The expected status code for the request. Defaults to `200` or
+ `499` depending on `expect_cancellation`.
+ """
+ # Determine the expected status code.
+ if expected_code is None:
+ if expect_cancellation:
+ expected_code = HTTP_STATUS_REQUEST_CANCELLED
+ else:
+ expected_code = HTTPStatus.OK
+
+ request = channel.request
+ if channel.is_finished():
+ raise AssertionError(
"Request finished before we could disconnect - "
- "was `await_result=False` passed to `make_request`?",
+ "ensure `await_result=False` is passed to `make_request`.",
)
- # We're about to disconnect the request. This also disconnects the channel, so
- # we have to rely on mocks to extract the response.
- respond_method: Callable[..., Any]
- if isinstance(expected_body, bytes):
- respond_method = respond_with_html_bytes
+ # We're about to disconnect the request. This also disconnects the channel, so we
+ # have to rely on mocks to extract the response.
+ respond_method: Callable[..., Any]
+ if isinstance(expected_body, bytes):
+ respond_method = respond_with_html_bytes
+ else:
+ respond_method = respond_with_json
+
+ with mock.patch(
+ f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
+ ) as respond_mock:
+ # Disconnect the request.
+ request.connectionLost(reason=ConnectionDone())
+
+ if expect_cancellation:
+ # An immediate cancellation is expected.
+ respond_mock.assert_called_once()
else:
- respond_method = respond_with_json
+ respond_mock.assert_not_called()
- with mock.patch(
- f"synapse.http.server.{respond_method.__name__}", wraps=respond_method
- ) as respond_mock:
- # Disconnect the request.
- request.connectionLost(reason=ConnectionDone())
+ # The handler is expected to run to completion.
+ reactor.advance(1.0)
+ respond_mock.assert_called_once()
- if expect_cancellation:
- # An immediate cancellation is expected.
- respond_mock.assert_called_once()
- args, _kwargs = respond_mock.call_args
- code, body = args[1], args[2]
- self.assertEqual(code, expected_code)
- self.assertEqual(request.code, expected_code)
- self.assertEqual(body, expected_body)
- else:
- respond_mock.assert_not_called()
-
- # The handler is expected to run to completion.
- reactor.pump([1.0])
+ args, _kwargs = respond_mock.call_args
+ code, body = args[1], args[2]
+
+ if code != expected_code:
+ raise AssertionError(
+ f"{code} != {expected_code} : "
+ "Request did not finish with the expected status code."
+ )
+
+ if request.code != expected_code:
+ raise AssertionError(
+ f"{request.code} != {expected_code} : "
+ "Request did not finish with the expected status code."
+ )
+
+ if body != expected_body:
+ raise AssertionError(
+ f"{body!r} != {expected_body!r} : "
+ "Request did not finish with the expected status code."
+ )
+
+
+@logcontext_clean
+def make_request_with_cancellation_test(
+ test_name: str,
+ reactor: MemoryReactorClock,
+ site: Site,
+ method: str,
+ path: str,
+ content: Union[bytes, str, JsonDict] = b"",
+) -> FakeChannel:
+ """Performs a request repeatedly, disconnecting at successive `await`s, until
+ one completes.
+
+ Fails if:
+ * A logging context is lost during cancellation.
+ * A logging context get restarted after it is marked as finished, eg. if
+ a request's logging context is used by some processing started by the
+ request, but the request neglects to cancel that processing or wait for it
+ to complete.
+
+ Note that "Re-starting finished log context" errors get raised within the
+ request handling code and may or may not get caught. These errors will
+ likely manifest as a different logging context error at a later point. When
+ debugging logging context failures, setting a breakpoint in
+ `logcontext_error` can prove useful.
+ * A request gets stuck, possibly due to a previous cancellation.
+ * The request does not return a 499 when the client disconnects.
+ This implies that a `CancelledError` was swallowed somewhere.
+
+ It is up to the caller to verify that the request returns the correct data when
+ it finally runs to completion.
+
+ Note that this function can only cover a single code path and does not guarantee
+ that an endpoint is compatible with cancellation on every code path.
+ To allow inspection of the code path that is being tested, this function will
+ log the stack trace at every `await` that gets cancelled. To view these log
+ lines, `trial` can be run with the `SYNAPSE_TEST_LOG_LEVEL=INFO` environment
+ variable, which will include the log lines in `_trial_temp/test.log`.
+ Alternatively, `_log_for_request` can be modified to write to `sys.stdout`.
+
+ Args:
+ test_name: The name of the test, which will be logged.
+ reactor: The twisted reactor running the request handler.
+ site: The twisted `Site` to use to render the request.
+ method: The HTTP request method ("verb").
+ path: The HTTP path, suitably URL encoded (e.g. escaped UTF-8 & spaces and
+ such).
+ content: The body of the request.
+
+ Returns:
+ The `FakeChannel` object which stores the result of the final request that
+ runs to completion.
+ """
+ # To process a request, a coroutine run is created for the async method handling
+ # the request. That method may then start other coroutine runs, wrapped in
+ # `Deferred`s.
+ #
+ # We would like to trigger a cancellation at the first `await`, re-run the
+ # request and cancel at the second `await`, and so on. By patching
+ # `Deferred.__next__`, we can intercept `await`s, track which ones we have or
+ # have not seen, and force them to block when they wouldn't have.
+
+ # The set of previously seen `await`s.
+ # Each element is a stringified stack trace.
+ seen_awaits: Set[Tuple[str, ...]] = set()
+
+ _log_for_request(
+ 0, f"Running make_request_with_cancellation_test for {test_name}..."
+ )
+
+ for request_number in itertools.count(1):
+ deferred_patch = Deferred__next__Patch(seen_awaits, request_number)
+
+ try:
+ with mock.patch(
+ "synapse.http.server.respond_with_json", wraps=respond_with_json
+ ) as respond_mock:
+ with deferred_patch.patch():
+ # Start the request.
+ channel = make_request(
+ reactor, site, method, path, content, await_result=False
+ )
+ request = channel.request
+
+ # Run the request until we see a new `await` which we have not
+ # yet cancelled at, or it completes.
+ while not respond_mock.called and not deferred_patch.new_await_seen:
+ previous_awaits_seen = deferred_patch.awaits_seen
+
+ reactor.advance(0.0)
+
+ if deferred_patch.awaits_seen == previous_awaits_seen:
+ # We didn't see any progress. Try advancing the clock.
+ reactor.advance(1.0)
+
+ if deferred_patch.awaits_seen == previous_awaits_seen:
+ # We still didn't see any progress. The request might be
+ # stuck.
+ raise AssertionError(
+ "Request appears to be stuck, possibly due to a "
+ "previous cancelled request"
+ )
+
+ if respond_mock.called:
+ # The request ran to completion and we are done with testing it.
+
+ # `respond_with_json` writes the response asynchronously, so we
+ # might have to give the reactor a kick before the channel gets
+ # the response.
+ deferred_patch.unblock_awaits()
+ channel.await_result()
+
+ return channel
+
+ # Disconnect the client and wait for the response.
+ request.connectionLost(reason=ConnectionDone())
+
+ _log_for_request(request_number, "--- disconnected ---")
+
+ # Advance the reactor just enough to get a response.
+ # We don't want to advance the reactor too far, because we can only
+ # detect re-starts of finished logging contexts after we set the
+ # finished flag below.
+ for _ in range(2):
+ # We may need to pump the reactor to allow `delay_cancellation`s to
+ # finish.
+ if not respond_mock.called:
+ reactor.advance(0.0)
+
+ # Try advancing the clock if that didn't work.
+ if not respond_mock.called:
+ reactor.advance(1.0)
+
+ # `delay_cancellation`s may be waiting for processing that we've
+ # forced to block. Try unblocking them, followed by another round of
+ # pumping the reactor.
+ if not respond_mock.called:
+ deferred_patch.unblock_awaits()
+
+ # Mark the request's logging context as finished. If it gets
+ # activated again, an `AssertionError` will be raised and bubble up
+ # through request handling code. This `AssertionError` may or may not be
+ # caught. Eventually some other code will deactivate the logging
+ # context which will raise a different `AssertionError` because
+ # resource usage won't have been correctly tracked.
+ if isinstance(request, SynapseRequest) and request.logcontext:
+ request.logcontext.finished = True
+
+ # Check that the request finished with a 499,
+ # ie. the `CancelledError` wasn't swallowed.
respond_mock.assert_called_once()
- args, _kwargs = respond_mock.call_args
- code, body = args[1], args[2]
- self.assertEqual(code, expected_code)
- self.assertEqual(request.code, expected_code)
- self.assertEqual(body, expected_body)
+
+ if request.code != HTTP_STATUS_REQUEST_CANCELLED:
+ raise AssertionError(
+ f"{request.code} != {HTTP_STATUS_REQUEST_CANCELLED} : "
+ "Cancelled request did not finish with the correct status code."
+ )
+ finally:
+ # Unblock any processing that might be shared between requests, if we
+ # haven't already done so.
+ deferred_patch.unblock_awaits()
+
+ assert False, "unreachable" # noqa: B011
+
+
+class Deferred__next__Patch:
+ """A `Deferred.__next__` patch that will intercept `await`s and force them
+ to block once it sees a new `await`.
+
+ When done with the patch, `unblock_awaits()` must be called to clean up after any
+ `await`s that were forced to block, otherwise processing shared between multiple
+ requests, such as database queries started by `@cached`, will become permanently
+ stuck.
+
+ Usage:
+ seen_awaits = set()
+ deferred_patch = Deferred__next__Patch(seen_awaits, 1)
+ try:
+ with deferred_patch.patch():
+ # do things
+ ...
+ finally:
+ deferred_patch.unblock_awaits()
+ """
+
+ def __init__(self, seen_awaits: Set[Tuple[str, ...]], request_number: int):
+ """
+ Args:
+ seen_awaits: The set of stack traces of `await`s that have been previously
+ seen. When the `Deferred.__next__` patch sees a new `await`, it will add
+ it to the set.
+ request_number: The request number to log against.
+ """
+ self._request_number = request_number
+ self._seen_awaits = seen_awaits
+
+ self._original_Deferred___next__ = Deferred.__next__
+
+ # The number of `await`s on `Deferred`s we have seen so far.
+ self.awaits_seen = 0
+
+ # Whether we have seen a new `await` not in `seen_awaits`.
+ self.new_await_seen = False
+
+ # To force `await`s on resolved `Deferred`s to block, we make up a new
+ # unresolved `Deferred` and return it out of `Deferred.__next__` /
+ # `coroutine.send()`. We have to resolve it later, in case the `await`ing
+ # coroutine is part of some shared processing, such as `@cached`.
+ self._to_unblock: Dict[Deferred, Union[object, Failure]] = {}
+
+ # The last stack we logged.
+ self._previous_stack: List[inspect.FrameInfo] = []
+
+ def patch(self) -> ContextManager[Mock]:
+ """Returns a context manager which patches `Deferred.__next__`."""
+
+ def Deferred___next__(
+ deferred: "Deferred[T]", value: object = None
+ ) -> "Deferred[T]":
+ """Intercepts `await`s on `Deferred`s and rigs them to block once we have
+ seen enough of them.
+
+ `Deferred.__next__` will normally:
+ * return `self` if the `Deferred` is unresolved, in which case
+ `coroutine.send()` will return the `Deferred`, and
+ `_defer.inlineCallbacks` will stop running the coroutine until the
+ `Deferred` is resolved.
+ * raise a `StopIteration(result)`, containing the result of the `await`.
+ * raise another exception, which will come out of the `await`.
+ """
+ self.awaits_seen += 1
+
+ stack = _get_stack(skip_frames=1)
+ stack_hash = _hash_stack(stack)
+
+ if stack_hash not in self._seen_awaits:
+ # Block at the current `await` onwards.
+ self._seen_awaits.add(stack_hash)
+ self.new_await_seen = True
+
+ if not self.new_await_seen:
+ # This `await` isn't interesting. Let it proceed normally.
+
+ # Don't log the stack. It's been seen before in a previous run.
+ self._previous_stack = stack
+
+ return self._original_Deferred___next__(deferred, value)
+
+ # We want to block at the current `await`.
+ if deferred.called and not deferred.paused:
+ # This `Deferred` already has a result.
+ # We return a new, unresolved, `Deferred` for `_inlineCallbacks` to wait
+ # on. This blocks the coroutine that did this `await`.
+ # We queue it up for unblocking later.
+ new_deferred: "Deferred[T]" = Deferred()
+ self._to_unblock[new_deferred] = deferred.result
+
+ _log_await_stack(
+ stack,
+ self._previous_stack,
+ self._request_number,
+ "force-blocked await",
+ )
+ self._previous_stack = stack
+
+ return make_deferred_yieldable(new_deferred)
+
+ # This `Deferred` does not have a result yet.
+ # The `await` will block normally, so we don't have to do anything.
+ _log_await_stack(
+ stack,
+ self._previous_stack,
+ self._request_number,
+ "blocking await",
+ )
+ self._previous_stack = stack
+
+ return self._original_Deferred___next__(deferred, value)
+
+ return mock.patch.object(Deferred, "__next__", new=Deferred___next__)
+
+ def unblock_awaits(self) -> None:
+ """Unblocks any shared processing that we forced to block.
+
+ Must be called when done, otherwise processing shared between multiple requests,
+ such as database queries started by `@cached`, will become permanently stuck.
+ """
+ to_unblock = self._to_unblock
+ self._to_unblock = {}
+ for deferred, result in to_unblock.items():
+ deferred.callback(result)
+
+
+def _log_for_request(request_number: int, message: str) -> None:
+ """Logs a message for an iteration of `make_request_with_cancellation_test`."""
+ # We want consistent alignment when logging stack traces, so ensure the logging
+ # context has a fixed width name.
+ with LoggingContext(name=f"request-{request_number:<2}"):
+ logger.info(message)
+
+
+def _log_await_stack(
+ stack: List[inspect.FrameInfo],
+ previous_stack: List[inspect.FrameInfo],
+ request_number: int,
+ note: str,
+) -> None:
+ """Logs the stack for an `await` in `make_request_with_cancellation_test`.
+
+ Only logs the part of the stack that has changed since the previous call.
+
+ Example output looks like:
+ ```
+ delay_cancellation:750 (synapse/util/async_helpers.py:750)
+ DatabasePool._runInteraction:768 (synapse/storage/database.py:768)
+ > *blocked on await* at DatabasePool.runWithConnection:891 (synapse/storage/database.py:891)
+ ```
+
+ Args:
+ stack: The stack to log, as returned by `_get_stack()`.
+ previous_stack: The previous stack logged, with callers appearing before
+ callees.
+ request_number: The request number to log against.
+ note: A note to attach to the last stack frame, eg. "blocked on await".
+ """
+ for i, frame_info in enumerate(stack[:-1]):
+ # Skip any frames in common with the previous logging.
+ if i < len(previous_stack) and frame_info == previous_stack[i]:
+ continue
+
+ frame = _format_stack_frame(frame_info)
+ message = f"{' ' * i}{frame}"
+ _log_for_request(request_number, message)
+
+ # Always print the final frame with the `await`.
+ # If the frame with the `await` started another coroutine run, we may have already
+ # printed a deeper stack which includes our final frame. We want to log where all
+ # `await`s happen, so we reprint the frame in this case.
+ i = len(stack) - 1
+ frame_info = stack[i]
+ frame = _format_stack_frame(frame_info)
+ message = f"{' ' * i}> *{note}* at {frame}"
+ _log_for_request(request_number, message)
+
+
+def _format_stack_frame(frame_info: inspect.FrameInfo) -> str:
+ """Returns a string representation of a stack frame.
+
+ Used for debug logging.
+
+ Returns:
+ A string, formatted like
+ "JsonResource._async_render:559 (synapse/http/server.py:559)".
+ """
+ method_name = _get_stack_frame_method_name(frame_info)
+
+ return (
+ f"{method_name}:{frame_info.lineno} ({frame_info.filename}:{frame_info.lineno})"
+ )
+
+
+def _get_stack(skip_frames: int) -> List[inspect.FrameInfo]:
+ """Captures the stack for a request.
+
+ Skips any twisted frames and stops at `JsonResource.wrapped_async_request_handler`.
+
+ Used for debug logging.
+
+ Returns:
+ A list of `inspect.FrameInfo`s, with callers appearing before callees.
+ """
+ stack = []
+
+ skip_frames += 1 # Also skip `get_stack` itself.
+
+ for frame_info in inspect.stack()[skip_frames:]:
+ # Skip any twisted `inlineCallbacks` gunk.
+ if "/twisted/" in frame_info.filename:
+ continue
+
+ # Exclude the reactor frame, upwards.
+ method_name = _get_stack_frame_method_name(frame_info)
+ if method_name == "ThreadedMemoryReactorClock.advance":
+ break
+
+ stack.append(frame_info)
+
+ # Stop at `JsonResource`'s `wrapped_async_request_handler`, which is the entry
+ # point for request handling.
+ if frame_info.function == "wrapped_async_request_handler":
+ break
+
+ return stack[::-1]
+
+
+def _get_stack_frame_method_name(frame_info: inspect.FrameInfo) -> str:
+ """Returns the name of a stack frame's method.
+
+ eg. "JsonResource._async_render".
+ """
+ method_name = frame_info.function
+
+ # Prefix the class name for instance methods.
+ frame_self = frame_info.frame.f_locals.get("self")
+ if frame_self:
+ method = getattr(frame_self, method_name, None)
+ if method:
+ method_name = method.__qualname__
+ else:
+ # We couldn't find the method on `self`.
+ # Make something up. It's useful to know which class "contains" a
+ # function anyway.
+ method_name = f"{type(frame_self).__name__} {method_name}"
+
+ return method_name
+
+
+def _hash_stack(stack: List[inspect.FrameInfo]):
+ """Turns a stack into a hashable value that can be put into a set."""
+ return tuple(_format_stack_frame(frame) for frame in stack)
diff --git a/tests/http/test_servlet.py b/tests/http/test_servlet.py
index b3655d7b44..bb966c80c6 100644
--- a/tests/http/test_servlet.py
+++ b/tests/http/test_servlet.py
@@ -30,7 +30,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
def make_request(content):
@@ -108,9 +108,7 @@ class CancellableRestServlet(RestServlet):
return HTTPStatus.OK, {"result": True}
-class TestRestServletCancellation(
- unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class TestRestServletCancellation(unittest.HomeserverTestCase):
"""Tests for `RestServlet` cancellation."""
servlets = [
@@ -120,7 +118,7 @@ class TestRestServletCancellation(
def test_cancellable_disconnect(self) -> None:
"""Test that handlers with the `@cancellable` flag can be cancelled."""
channel = self.make_request("GET", "/sleep", await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -130,7 +128,7 @@ class TestRestServletCancellation(
def test_uncancellable_disconnect(self) -> None:
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
channel = self.make_request("POST", "/sleep", await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index a5ab093a27..822a957c3a 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -25,7 +25,7 @@ from synapse.server import HomeServer
from synapse.types import JsonDict
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
class CancellableReplicationEndpoint(ReplicationEndpoint):
@@ -69,9 +69,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
return HTTPStatus.OK, {"result": True}
-class ReplicationEndpointCancellationTestCase(
- unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""
def create_test_resource(self):
@@ -87,7 +85,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers with the `@cancellable` flag can be cancelled."""
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -98,7 +96,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 62e4db23ef..aa84906548 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -728,6 +728,7 @@ class RelationsTestCase(BaseRelationsTestCase):
class RelationPaginationTestCase(BaseRelationsTestCase):
+ @unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
def test_basic_paginate_relations(self) -> None:
"""Tests that calling pagination API correctly the latest relations."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index f523d89b8f..4be83dfd6d 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -42,6 +42,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
from tests.test_utils import make_awaitable
PATH_PREFIX = b"/_matrix/client/api/v1"
@@ -471,6 +472,49 @@ class RoomPermissionsTestCase(RoomBase):
)
+class RoomStateTestCase(RoomBase):
+ """Tests /rooms/$room_id/state."""
+
+ user_id = "@sid1:red"
+
+ def test_get_state_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/state` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_state_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/state" % room_id,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertCountEqual(
+ [state_event["type"] for state_event in channel.json_body],
+ {
+ "m.room.create",
+ "m.room.power_levels",
+ "m.room.join_rules",
+ "m.room.member",
+ "m.room.history_visibility",
+ },
+ )
+
+ def test_get_state_event_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/state/$event_type` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_state_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/state/m.room.member/%s" % (room_id, self.user_id),
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(channel.json_body, {"membership": "join"})
+
+
class RoomsMemberListTestCase(RoomBase):
"""Tests /rooms/$room_id/members/list REST events."""
@@ -591,6 +635,62 @@ class RoomsMemberListTestCase(RoomBase):
channel = self.make_request("GET", room_path)
self.assertEqual(200, channel.code, msg=channel.result["body"])
+ def test_get_member_list_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/members` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+ channel = make_request_with_cancellation_test(
+ "test_get_member_list_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/members" % room_id,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertLessEqual(
+ {
+ "content": {"membership": "join"},
+ "room_id": room_id,
+ "sender": self.user_id,
+ "state_key": self.user_id,
+ "type": "m.room.member",
+ "user_id": self.user_id,
+ }.items(),
+ channel.json_body["chunk"][0].items(),
+ )
+
+ def test_get_member_list_with_at_token_cancellation(self) -> None:
+ """Test cancellation of a `/rooms/$room_id/members?at=<sync token>` request."""
+ room_id = self.helper.create_room_as(self.user_id)
+
+ # first sync to get an at token
+ channel = self.make_request("GET", "/sync")
+ self.assertEqual(200, channel.code)
+ sync_token = channel.json_body["next_batch"]
+
+ channel = make_request_with_cancellation_test(
+ "test_get_member_list_with_at_token_cancellation",
+ self.reactor,
+ self.site,
+ "GET",
+ "/rooms/%s/members?at=%s" % (room_id, sync_token),
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+ self.assertLessEqual(
+ {
+ "content": {"membership": "join"},
+ "room_id": room_id,
+ "sender": self.user_id,
+ "state_key": self.user_id,
+ "type": "m.room.member",
+ "user_id": self.user_id,
+ }.items(),
+ channel.json_body["chunk"][0].items(),
+ )
+
class RoomsCreateTestCase(RoomBase):
"""Tests /rooms and /rooms/$room_id REST events."""
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 8370a27195..78b83d97b6 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -13,7 +13,17 @@
# limitations under the License.
import itertools
-from typing import List
+from typing import (
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Mapping,
+ Optional,
+ Set,
+ Tuple,
+ TypeVar,
+)
import attr
@@ -22,13 +32,13 @@ from twisted.internet import defer
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
-from synapse.events import make_event_from_dict
+from synapse.events import EventBase, make_event_from_dict
from synapse.state.v2 import (
_get_auth_chain_difference,
lexicographical_topological_sort,
resolve_events_with_store,
)
-from synapse.types import EventID
+from synapse.types import EventID, StateMap
from tests import unittest
@@ -48,7 +58,7 @@ ORIGIN_SERVER_TS = 0
class FakeClock:
- def sleep(self, msec):
+ def sleep(self, msec: float) -> "defer.Deferred[None]":
return defer.succeed(None)
@@ -60,7 +70,14 @@ class FakeEvent:
as domain.
"""
- def __init__(self, id, sender, type, state_key, content):
+ def __init__(
+ self,
+ id: str,
+ sender: str,
+ type: str,
+ state_key: Optional[str],
+ content: Mapping[str, object],
+ ):
self.node_id = id
self.event_id = EventID(id, "example.com").to_string()
self.sender = sender
@@ -69,12 +86,12 @@ class FakeEvent:
self.content = content
self.room_id = ROOM_ID
- def to_event(self, auth_events, prev_events):
+ def to_event(self, auth_events: List[str], prev_events: List[str]) -> EventBase:
"""Given the auth_events and prev_events, convert to a Frozen Event
Args:
- auth_events (list[str]): list of event_ids
- prev_events (list[str]): list of event_ids
+ auth_events: list of event_ids
+ prev_events: list of event_ids
Returns:
FrozenEvent
@@ -164,7 +181,7 @@ INITIAL_EDGES = ["START", "IMZ", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"]
class StateTestCase(unittest.TestCase):
- def test_ban_vs_pl(self):
+ def test_ban_vs_pl(self) -> None:
events = [
FakeEvent(
id="PA",
@@ -202,7 +219,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_join_rule_evasion(self):
+ def test_join_rule_evasion(self) -> None:
events = [
FakeEvent(
id="JR",
@@ -226,7 +243,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_offtopic_pl(self):
+ def test_offtopic_pl(self) -> None:
events = [
FakeEvent(
id="PA",
@@ -257,7 +274,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_topic_basic(self):
+ def test_topic_basic(self) -> None:
events = [
FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -297,7 +314,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_topic_reset(self):
+ def test_topic_reset(self) -> None:
events = [
FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -327,7 +344,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_topic(self):
+ def test_topic(self) -> None:
events = [
FakeEvent(
id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
@@ -380,7 +397,7 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def test_mainline_sort(self):
+ def test_mainline_sort(self) -> None:
"""Tests that the mainline ordering works correctly."""
events = [
@@ -434,22 +451,26 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
- def do_check(self, events, edges, expected_state_ids):
+ def do_check(
+ self,
+ events: List[FakeEvent],
+ edges: List[List[str]],
+ expected_state_ids: List[str],
+ ) -> None:
"""Take a list of events and edges and calculate the state of the
graph at END, and asserts it matches `expected_state_ids`
Args:
- events (list[FakeEvent])
- edges (list[list[str]]): A list of chains of event edges, e.g.
+ events
+ edges: A list of chains of event edges, e.g.
`[[A, B, C]]` are edges A->B and B->C.
- expected_state_ids (list[str]): The expected state at END, (excluding
+ expected_state_ids: The expected state at END, (excluding
the keys that haven't changed since START).
"""
# We want to sort the events into topological order for processing.
- graph = {}
+ graph: Dict[str, Set[str]] = {}
- # node_id -> FakeEvent
- fake_event_map = {}
+ fake_event_map: Dict[str, FakeEvent] = {}
for ev in itertools.chain(INITIAL_EVENTS, events):
graph[ev.node_id] = set()
@@ -462,10 +483,8 @@ class StateTestCase(unittest.TestCase):
for a, b in pairwise(edge_list):
graph[a].add(b)
- # event_id -> FrozenEvent
- event_map = {}
- # node_id -> state
- state_at_event = {}
+ event_map: Dict[str, EventBase] = {}
+ state_at_event: Dict[str, StateMap[str]] = {}
# We copy the map as the sort consumes the graph
graph_copy = {k: set(v) for k, v in graph.items()}
@@ -496,7 +515,16 @@ class StateTestCase(unittest.TestCase):
if fake_event.state_key is not None:
state_after[(fake_event.type, fake_event.state_key)] = event_id
- auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event))
+ # This type ignore is a bit sad. Things we have tried:
+ # 1. Define a `GenericEvent` Protocol satisfied by FakeEvent, EventBase and
+ # EventBuilder. But this is Hard because the relevant attributes are
+ # DictProperty[T] descriptors on EventBase but normal Ts on FakeEvent.
+ # 2. Define a `GenericEvent` Protocol describing `FakeEvent` only, and
+ # change this function to accept Union[Event, EventBase, EventBuilder].
+ # This seems reasonable to me, but mypy isn't happy. I think that's
+ # a mypy bug, see https://github.com/python/mypy/issues/5570
+ # Instead, resort to a type-ignore.
+ auth_types = set(auth_types_for_event(RoomVersions.V6, fake_event)) # type: ignore[arg-type]
auth_events = []
for key in auth_types:
@@ -530,8 +558,14 @@ class StateTestCase(unittest.TestCase):
class LexicographicalTestCase(unittest.TestCase):
- def test_simple(self):
- graph = {"l": {"o"}, "m": {"n", "o"}, "n": {"o"}, "o": set(), "p": {"o"}}
+ def test_simple(self) -> None:
+ graph: Dict[str, Set[str]] = {
+ "l": {"o"},
+ "m": {"n", "o"},
+ "n": {"o"},
+ "o": set(),
+ "p": {"o"},
+ }
res = list(lexicographical_topological_sort(graph, key=lambda x: x))
@@ -539,7 +573,7 @@ class LexicographicalTestCase(unittest.TestCase):
class SimpleParamStateTestCase(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
# We build up a simple DAG.
event_map = {}
@@ -627,7 +661,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
]
}
- def test_event_map_none(self):
+ def test_event_map_none(self) -> None:
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
@@ -649,7 +683,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
events.
"""
- def test_simple(self):
+ def test_simple(self) -> None:
# Test getting the auth difference for a simple chain with a single
# unpersisted event:
#
@@ -695,7 +729,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {c.event_id})
- def test_multiple_unpersisted_chain(self):
+ def test_multiple_unpersisted_chain(self) -> None:
# Test getting the auth difference for a simple chain with multiple
# unpersisted events:
#
@@ -752,7 +786,7 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {d.event_id, c.event_id})
- def test_unpersisted_events_different_sets(self):
+ def test_unpersisted_events_different_sets(self) -> None:
# Test getting the auth difference for with multiple unpersisted events
# in different branches:
#
@@ -820,7 +854,10 @@ class AuthChainDifferenceTestCase(unittest.TestCase):
self.assertEqual(difference, {d.event_id, e.event_id})
-def pairwise(iterable):
+T = TypeVar("T")
+
+
+def pairwise(iterable: Iterable[T]) -> Iterable[Tuple[T, T]]:
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
next(b, None)
@@ -829,24 +866,26 @@ def pairwise(iterable):
@attr.s
class TestStateResolutionStore:
- event_map = attr.ib()
+ event_map: Dict[str, EventBase] = attr.ib()
- def get_events(self, event_ids, allow_rejected=False):
+ def get_events(
+ self, event_ids: Collection[str], allow_rejected: bool = False
+ ) -> "defer.Deferred[Dict[str, EventBase]]":
"""Get events from the database
Args:
- event_ids (list): The event_ids of the events to fetch
- allow_rejected (bool): If True return rejected events.
+ event_ids: The event_ids of the events to fetch
+ allow_rejected: If True return rejected events.
Returns:
- Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+ Dict from event_id to event.
"""
return defer.succeed(
{eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
)
- def _get_auth_chain(self, event_ids: List[str]) -> List[str]:
+ def _get_auth_chain(self, event_ids: Iterable[str]) -> List[str]:
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -880,7 +919,9 @@ class TestStateResolutionStore:
return list(result)
- def get_auth_chain_difference(self, room_id, auth_sets):
+ def get_auth_chain_difference(
+ self, room_id: str, auth_sets: List[Set[str]]
+ ) -> "defer.Deferred[Set[str]]":
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
diff --git a/tests/test_server.py b/tests/test_server.py
index 0f1eb43cbc..847432f791 100644
--- a/tests/test_server.py
+++ b/tests/test_server.py
@@ -34,7 +34,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
from tests.server import (
FakeSite,
ThreadedMemoryReactorClock,
@@ -407,7 +407,7 @@ class CancellableDirectServeHtmlResource(DirectServeHtmlResource):
return HTTPStatus.OK, b"ok"
-class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMixin):
+class DirectServeJsonResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeJsonResource` cancellation."""
def setUp(self):
@@ -421,7 +421,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -433,7 +433,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
@@ -441,7 +441,7 @@ class DirectServeJsonResourceCancellationTests(EndpointCancellationTestHelperMix
)
-class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMixin):
+class DirectServeHtmlResourceCancellationTests(unittest.TestCase):
"""Tests for `DirectServeHtmlResource` cancellation."""
def setUp(self):
@@ -455,7 +455,7 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "GET", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -467,6 +467,6 @@ class DirectServeHtmlResourceCancellationTests(EndpointCancellationTestHelperMix
channel = make_request(
self.reactor, self.site, "POST", "/sleep", await_result=False
)
- self._test_disconnect(
+ test_disconnect(
self.reactor, channel, expect_cancellation=False, expected_body=b"ok"
)
|