diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index c938339ddd..ec81639c78 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -37,7 +37,8 @@ from synapse.types import (
get_verify_key_from_cross_signing_key,
)
from synapse.util import json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, delay_cancellation
+from synapse.util.cancellation import cancellable
from synapse.util.retryutils import NotRetryingDestination
if TYPE_CHECKING:
@@ -91,6 +92,7 @@ class E2eKeysHandler:
)
@trace
+ @cancellable
async def query_devices(
self,
query_body: JsonDict,
@@ -208,22 +210,26 @@ class E2eKeysHandler:
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
+ # TODO It might make sense to propagate cancellations into the
+ # deferreds which are querying remote homeservers.
await make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self._query_devices_for_destination,
- results,
- cross_signing_keys,
- failures,
- destination,
- queries,
- timeout,
- )
- for destination, queries in remote_queries_not_in_cache.items()
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ delay_cancellation(
+ defer.gatherResults(
+ [
+ run_in_background(
+ self._query_devices_for_destination,
+ results,
+ cross_signing_keys,
+ failures,
+ destination,
+ queries,
+ timeout,
+ )
+ for destination, queries in remote_queries_not_in_cache.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
)
ret = {"device_keys": results, "failures": failures}
@@ -347,6 +353,7 @@ class E2eKeysHandler:
return
+ @cancellable
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
@@ -393,6 +400,7 @@ class E2eKeysHandler:
}
@trace
+ @cancellable
async def query_local_devices(
self, query: Mapping[str, Optional[List[str]]]
) -> Dict[str, Dict[str, dict]]:
|