summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r--synapse/handlers/e2e_keys.py40
1 files changed, 24 insertions, 16 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index c938339ddd..ec81639c78 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -37,7 +37,8 @@ from synapse.types import (
     get_verify_key_from_cross_signing_key,
 )
 from synapse.util import json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, delay_cancellation
+from synapse.util.cancellation import cancellable
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -91,6 +92,7 @@ class E2eKeysHandler:
         )
 
     @trace
+    @cancellable
     async def query_devices(
         self,
         query_body: JsonDict,
@@ -208,22 +210,26 @@ class E2eKeysHandler:
                     r[user_id] = remote_queries[user_id]
 
             # Now fetch any devices that we don't have in our cache
+            # TODO It might make sense to propagate cancellations into the
+            #      deferreds which are querying remote homeservers.
             await make_deferred_yieldable(
-                defer.gatherResults(
-                    [
-                        run_in_background(
-                            self._query_devices_for_destination,
-                            results,
-                            cross_signing_keys,
-                            failures,
-                            destination,
-                            queries,
-                            timeout,
-                        )
-                        for destination, queries in remote_queries_not_in_cache.items()
-                    ],
-                    consumeErrors=True,
-                ).addErrback(unwrapFirstError)
+                delay_cancellation(
+                    defer.gatherResults(
+                        [
+                            run_in_background(
+                                self._query_devices_for_destination,
+                                results,
+                                cross_signing_keys,
+                                failures,
+                                destination,
+                                queries,
+                                timeout,
+                            )
+                            for destination, queries in remote_queries_not_in_cache.items()
+                        ],
+                        consumeErrors=True,
+                    ).addErrback(unwrapFirstError)
+                )
             )
 
             ret = {"device_keys": results, "failures": failures}
@@ -347,6 +353,7 @@ class E2eKeysHandler:
 
         return
 
+    @cancellable
     async def get_cross_signing_keys_from_cache(
         self, query: Iterable[str], from_user_id: Optional[str]
     ) -> Dict[str, Dict[str, dict]]:
@@ -393,6 +400,7 @@ class E2eKeysHandler:
         }
 
     @trace
+    @cancellable
     async def query_local_devices(
         self, query: Mapping[str, Optional[List[str]]]
     ) -> Dict[str, Dict[str, dict]]: