diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 5726e60cee..5071f83574 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -140,6 +140,8 @@ def make_request_with_cancellation_test(
method: str,
path: str,
content: Union[bytes, str, JsonDict] = b"",
+ *,
+ token: Optional[str] = None,
) -> FakeChannel:
"""Performs a request repeatedly, disconnecting at successive `await`s, until
one completes.
@@ -211,7 +213,13 @@ def make_request_with_cancellation_test(
with deferred_patch.patch():
# Start the request.
channel = make_request(
- reactor, site, method, path, content, await_result=False
+ reactor,
+ site,
+ method,
+ path,
+ content,
+ await_result=False,
+ access_token=token,
)
request = channel.request
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index bbc8e74243..741fecea77 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -19,6 +19,7 @@ from synapse.rest import admin
from synapse.rest.client import keys, login
from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
class KeyQueryTestCase(unittest.HomeserverTestCase):
@@ -89,3 +90,31 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
Codes.BAD_JSON,
channel.result,
)
+
+ def test_key_query_cancellation(self) -> None:
+ """
+ Tests that /keys/query is cancellable and does not swallow the
+ CancelledError.
+ """
+ self.register_user("alice", "wonderland")
+ alice_token = self.login("alice", "wonderland")
+
+ bob = self.register_user("bob", "uncle")
+
+ channel = make_request_with_cancellation_test(
+ "test_key_query_cancellation",
+ self.reactor,
+ self.site,
+ "POST",
+ "/_matrix/client/r0/keys/query",
+ {
+ "device_keys": {
+ # Empty list means we request keys for all bob's devices
+ bob: [],
+ },
+ },
+ token=alice_token,
+ )
+
+ self.assertEqual(200, channel.code, msg=channel.result["body"])
+ self.assertIn(bob, channel.json_body["device_keys"])
|