summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13680.feature1
-rw-r--r--synapse/api/auth.py5
-rw-r--r--synapse/handlers/device.py3
-rw-r--r--synapse/handlers/e2e_keys.py40
-rw-r--r--synapse/rest/client/keys.py6
-rw-r--r--synapse/storage/controllers/state.py4
-rw-r--r--synapse/storage/databases/main/devices.py4
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py5
-rw-r--r--synapse/storage/databases/main/event_federation.py2
-rw-r--r--synapse/storage/databases/main/events_worker.py4
-rw-r--r--synapse/storage/databases/main/roommember.py2
-rw-r--r--synapse/storage/databases/main/state.py2
-rw-r--r--synapse/storage/databases/main/stream.py2
-rw-r--r--synapse/storage/databases/state/store.py3
-rw-r--r--synapse/storage/util/partial_state_events_tracker.py3
-rw-r--r--synapse/types.py5
-rw-r--r--tests/http/server/_base.py10
-rw-r--r--tests/rest/client/test_keys.py29
18 files changed, 110 insertions, 20 deletions
diff --git a/changelog.d/13680.feature b/changelog.d/13680.feature
new file mode 100644
index 0000000000..4234c7e082
--- /dev/null
+++ b/changelog.d/13680.feature
@@ -0,0 +1 @@
+Cancel the processing of key query requests when they time out.
\ No newline at end of file
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 9a1aea083f..8e54ef84b2 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -38,6 +38,7 @@ from synapse.logging.opentracing import (
     trace,
 )
 from synapse.types import Requester, create_requester
+from synapse.util.cancellation import cancellable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -118,6 +119,7 @@ class Auth:
             errcode=Codes.NOT_JOINED,
         )
 
+    @cancellable
     async def get_user_by_req(
         self,
         request: SynapseRequest,
@@ -166,6 +168,7 @@ class Auth:
                     parent_span.set_tag("appservice_id", requester.app_service.id)
             return requester
 
+    @cancellable
     async def _wrapped_get_user_by_req(
         self,
         request: SynapseRequest,
@@ -281,6 +284,7 @@ class Auth:
                 403, "Application service has not registered this user (%s)" % user_id
             )
 
+    @cancellable
     async def _get_appservice_user(self, request: Request) -> Optional[Requester]:
         """
         Given a request, reads the request parameters to determine:
@@ -523,6 +527,7 @@ class Auth:
         return bool(query_params) or bool(auth_headers)
 
     @staticmethod
+    @cancellable
     def get_access_token_from_request(request: Request) -> str:
         """Extracts the access_token from the request.
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 9c2c3a0e68..c5ac169644 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -52,6 +52,7 @@ from synapse.types import (
 from synapse.util import stringutils
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches.expiringcache import ExpiringCache
+from synapse.util.cancellation import cancellable
 from synapse.util.metrics import measure_func
 from synapse.util.retryutils import NotRetryingDestination
 
@@ -124,6 +125,7 @@ class DeviceWorkerHandler:
 
         return device
 
+    @cancellable
     async def get_device_changes_in_shared_rooms(
         self, user_id: str, room_ids: Collection[str], from_token: StreamToken
     ) -> Collection[str]:
@@ -163,6 +165,7 @@ class DeviceWorkerHandler:
 
     @trace
     @measure_func("device.get_user_ids_changed")
+    @cancellable
     async def get_user_ids_changed(
         self, user_id: str, from_token: StreamToken
     ) -> JsonDict:
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index c938339ddd..ec81639c78 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -37,7 +37,8 @@ from synapse.types import (
     get_verify_key_from_cross_signing_key,
 )
 from synapse.util import json_decoder, unwrapFirstError
-from synapse.util.async_helpers import Linearizer
+from synapse.util.async_helpers import Linearizer, delay_cancellation
+from synapse.util.cancellation import cancellable
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -91,6 +92,7 @@ class E2eKeysHandler:
         )
 
     @trace
+    @cancellable
     async def query_devices(
         self,
         query_body: JsonDict,
@@ -208,22 +210,26 @@ class E2eKeysHandler:
                     r[user_id] = remote_queries[user_id]
 
             # Now fetch any devices that we don't have in our cache
+            # TODO It might make sense to propagate cancellations into the
+            #      deferreds which are querying remote homeservers.
             await make_deferred_yieldable(
-                defer.gatherResults(
-                    [
-                        run_in_background(
-                            self._query_devices_for_destination,
-                            results,
-                            cross_signing_keys,
-                            failures,
-                            destination,
-                            queries,
-                            timeout,
-                        )
-                        for destination, queries in remote_queries_not_in_cache.items()
-                    ],
-                    consumeErrors=True,
-                ).addErrback(unwrapFirstError)
+                delay_cancellation(
+                    defer.gatherResults(
+                        [
+                            run_in_background(
+                                self._query_devices_for_destination,
+                                results,
+                                cross_signing_keys,
+                                failures,
+                                destination,
+                                queries,
+                                timeout,
+                            )
+                            for destination, queries in remote_queries_not_in_cache.items()
+                        ],
+                        consumeErrors=True,
+                    ).addErrback(unwrapFirstError)
+                )
             )
 
             ret = {"device_keys": results, "failures": failures}
@@ -347,6 +353,7 @@ class E2eKeysHandler:
 
         return
 
+    @cancellable
     async def get_cross_signing_keys_from_cache(
         self, query: Iterable[str], from_user_id: Optional[str]
     ) -> Dict[str, Dict[str, dict]]:
@@ -393,6 +400,7 @@ class E2eKeysHandler:
         }
 
     @trace
+    @cancellable
     async def query_local_devices(
         self, query: Mapping[str, Optional[List[str]]]
     ) -> Dict[str, Dict[str, dict]]:
diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py
index a395694fa5..f653d2a3e1 100644
--- a/synapse/rest/client/keys.py
+++ b/synapse/rest/client/keys.py
@@ -27,9 +27,9 @@ from synapse.http.servlet import (
 )
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import log_kv, set_tag
+from synapse.rest.client._base import client_patterns, interactive_auth_handler
 from synapse.types import JsonDict, StreamToken
-
-from ._base import client_patterns, interactive_auth_handler
+from synapse.util.cancellation import cancellable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -156,6 +156,7 @@ class KeyQueryServlet(RestServlet):
         self.auth = hs.get_auth()
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
 
+    @cancellable
     async def on_POST(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         user_id = requester.user.to_string()
@@ -199,6 +200,7 @@ class KeyChangesServlet(RestServlet):
         self.device_handler = hs.get_device_handler()
         self.store = hs.get_datastores().main
 
+    @cancellable
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index ba5380ce3e..bbe568bf05 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -36,6 +36,7 @@ from synapse.storage.util.partial_state_events_tracker import (
     PartialStateEventsTracker,
 )
 from synapse.types import MutableStateMap, StateMap
+from synapse.util.cancellation import cancellable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -229,6 +230,7 @@ class StateStorageController:
 
     @trace
     @tag_args
+    @cancellable
     async def get_state_ids_for_events(
         self,
         event_ids: Collection[str],
@@ -350,6 +352,7 @@ class StateStorageController:
 
     @trace
     @tag_args
+    @cancellable
     async def get_state_group_for_events(
         self,
         event_ids: Collection[str],
@@ -398,6 +401,7 @@ class StateStorageController:
             event_id, room_id, prev_group, delta_ids, current_state_ids
         )
 
+    @cancellable
     async def get_current_state_ids(
         self,
         room_id: str,
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index ca0fe8c4be..5d700ca6c3 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -53,6 +53,7 @@ from synapse.util import json_decoder, json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import LruCache
 from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 from synapse.util.stringutils import shortstr
 
@@ -668,6 +669,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
         ...
 
     @trace
+    @cancellable
     async def get_user_devices_from_cache(
         self, query_list: List[Tuple[str, Optional[str]]]
     ) -> Tuple[Set[str], Dict[str, Dict[str, JsonDict]]]:
@@ -743,6 +745,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
 
         return self._device_list_stream_cache.get_all_entities_changed(from_key)
 
+    @cancellable
     async def get_users_whose_devices_changed(
         self,
         from_key: int,
@@ -1221,6 +1224,7 @@ class DeviceWorkerStore(EndToEndKeyWorkerStore):
             desc="get_min_device_lists_changes_in_room",
         )
 
+    @cancellable
     async def get_device_list_changes_in_rooms(
         self, room_ids: Collection[str], from_id: int
     ) -> Optional[Set[str]]:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 46c0d06157..8e9e1b0b4b 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -50,6 +50,7 @@ from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -135,6 +136,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         return now_stream_id, []
 
     @trace
+    @cancellable
     async def get_e2e_device_keys_for_cs_api(
         self, query_list: List[Tuple[str, Optional[str]]]
     ) -> Dict[str, Dict[str, JsonDict]]:
@@ -197,6 +199,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         ...
 
     @trace
+    @cancellable
     async def get_e2e_device_keys_and_signatures(
         self,
         query_list: Collection[Tuple[str, Optional[str]]],
@@ -887,6 +890,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
 
         return keys
 
+    @cancellable
     async def get_e2e_cross_signing_keys_bulk(
         self, user_ids: List[str], from_user_id: Optional[str] = None
     ) -> Dict[str, Optional[Dict[str, JsonDict]]]:
@@ -902,7 +906,6 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             keys were not found, either their user ID will not be in the dict,
             or their user ID will map to None.
         """
-
         result = await self._get_bare_e2e_cross_signing_keys_bulk(user_ids)
 
         if from_user_id:
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index e687f87eca..ca47a22bf1 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -48,6 +48,7 @@ from synapse.types import JsonDict
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.lrucache import LruCache
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -976,6 +977,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         return int(min_depth) if min_depth is not None else None
 
+    @cancellable
     async def get_forward_extremities_for_room_at_stream_ordering(
         self, room_id: str, stream_ordering: int
     ) -> List[str]:
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 84f17a9945..52914febf9 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -81,6 +81,7 @@ from synapse.util import unwrapFirstError
 from synapse.util.async_helpers import ObservableDeferred, delay_cancellation
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.caches.lrucache import AsyncLruCache
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -339,6 +340,7 @@ class EventsWorkerStore(SQLBaseStore):
     ) -> Optional[EventBase]:
         ...
 
+    @cancellable
     async def get_event(
         self,
         event_id: str,
@@ -433,6 +435,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     @trace
     @tag_args
+    @cancellable
     async def get_events_as_list(
         self,
         event_ids: Collection[str],
@@ -584,6 +587,7 @@ class EventsWorkerStore(SQLBaseStore):
 
         return events
 
+    @cancellable
     async def _get_events_from_cache_or_db(
         self, event_ids: Iterable[str], allow_rejected: bool = False
     ) -> Dict[str, EventCacheEntry]:
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 4f0adb136a..a77e49dc66 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -55,6 +55,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
@@ -770,6 +771,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             _get_users_server_still_shares_room_with_txn,
         )
 
+    @cancellable
     async def get_rooms_for_user(
         self, user_id: str, on_invalidate: Optional[Callable[[], None]] = None
     ) -> FrozenSet[str]:
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 0b10af0e58..e607ccfdc9 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -36,6 +36,7 @@ from synapse.storage.state import StateFilter
 from synapse.types import JsonDict, JsonMapping, StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
+from synapse.util.cancellation import cancellable
 from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
@@ -281,6 +282,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
 
     # FIXME: how should this be cached?
+    @cancellable
     async def get_partial_filtered_current_state_ids(
         self, room_id: str, state_filter: Optional[StateFilter] = None
     ) -> StateMap[str]:
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index a347430aa7..3f9bfaeac5 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -72,6 +72,7 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
 from synapse.types import PersistedEventPosition, RoomStreamToken
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
+from synapse.util.cancellation import cancellable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -597,6 +598,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         return ret, key
 
+    @cancellable
     async def get_membership_changes_for_user(
         self,
         user_id: str,
diff --git a/synapse/storage/databases/state/store.py b/synapse/storage/databases/state/store.py
index bb64543c1f..f8cfcaca83 100644
--- a/synapse/storage/databases/state/store.py
+++ b/synapse/storage/databases/state/store.py
@@ -31,6 +31,7 @@ from synapse.storage.util.sequence import build_sequence_generator
 from synapse.types import MutableStateMap, StateKey, StateMap
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.dictionary_cache import DictionaryCache
+from synapse.util.cancellation import cancellable
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -156,6 +157,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
             "get_state_group_delta", _get_state_group_delta_txn
         )
 
+    @cancellable
     async def _get_state_groups_from_groups(
         self, groups: List[int], state_filter: StateFilter
     ) -> Dict[int, StateMap[str]]:
@@ -235,6 +237,7 @@ class StateGroupDataStore(StateBackgroundUpdateStore, SQLBaseStore):
 
         return state_filter.filter_state(state_dict_ids), not missing_types
 
+    @cancellable
     async def _get_state_for_groups(
         self, groups: Iterable[int], state_filter: Optional[StateFilter] = None
     ) -> Dict[int, MutableStateMap[str]]:
diff --git a/synapse/storage/util/partial_state_events_tracker.py b/synapse/storage/util/partial_state_events_tracker.py
index b4bf49dace..8d8894d1d5 100644
--- a/synapse/storage/util/partial_state_events_tracker.py
+++ b/synapse/storage/util/partial_state_events_tracker.py
@@ -24,6 +24,7 @@ from synapse.logging.opentracing import trace_with_opname
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.room import RoomWorkerStore
 from synapse.util import unwrapFirstError
+from synapse.util.cancellation import cancellable
 
 logger = logging.getLogger(__name__)
 
@@ -60,6 +61,7 @@ class PartialStateEventsTracker:
                 o.callback(None)
 
     @trace_with_opname("PartialStateEventsTracker.await_full_state")
+    @cancellable
     async def await_full_state(self, event_ids: Collection[str]) -> None:
         """Wait for all the given events to have full state.
 
@@ -154,6 +156,7 @@ class PartialCurrentStateTracker:
                 o.callback(None)
 
     @trace_with_opname("PartialCurrentStateTracker.await_full_state")
+    @cancellable
     async def await_full_state(self, room_id: str) -> None:
         # We add the deferred immediately so that the DB call to check for
         # partial state doesn't race when we unpartial the room.
diff --git a/synapse/types.py b/synapse/types.py
index 668d48d646..ec44601f54 100644
--- a/synapse/types.py
+++ b/synapse/types.py
@@ -52,6 +52,7 @@ from twisted.internet.interfaces import (
 )
 
 from synapse.api.errors import Codes, SynapseError
+from synapse.util.cancellation import cancellable
 from synapse.util.stringutils import parse_and_validate_server_name
 
 if TYPE_CHECKING:
@@ -699,7 +700,11 @@ class StreamToken:
     START: ClassVar["StreamToken"]
 
     @classmethod
+    @cancellable
     async def from_string(cls, store: "DataStore", string: str) -> "StreamToken":
+        """
+        Creates a RoomStreamToken from its textual representation.
+        """
         try:
             keys = string.split(cls._SEPARATOR)
             while len(keys) < len(attr.fields(cls)):
diff --git a/tests/http/server/_base.py b/tests/http/server/_base.py
index 5726e60cee..5071f83574 100644
--- a/tests/http/server/_base.py
+++ b/tests/http/server/_base.py
@@ -140,6 +140,8 @@ def make_request_with_cancellation_test(
     method: str,
     path: str,
     content: Union[bytes, str, JsonDict] = b"",
+    *,
+    token: Optional[str] = None,
 ) -> FakeChannel:
     """Performs a request repeatedly, disconnecting at successive `await`s, until
     one completes.
@@ -211,7 +213,13 @@ def make_request_with_cancellation_test(
                 with deferred_patch.patch():
                     # Start the request.
                     channel = make_request(
-                        reactor, site, method, path, content, await_result=False
+                        reactor,
+                        site,
+                        method,
+                        path,
+                        content,
+                        await_result=False,
+                        access_token=token,
                     )
                     request = channel.request
 
diff --git a/tests/rest/client/test_keys.py b/tests/rest/client/test_keys.py
index bbc8e74243..741fecea77 100644
--- a/tests/rest/client/test_keys.py
+++ b/tests/rest/client/test_keys.py
@@ -19,6 +19,7 @@ from synapse.rest import admin
 from synapse.rest.client import keys, login
 
 from tests import unittest
+from tests.http.server._base import make_request_with_cancellation_test
 
 
 class KeyQueryTestCase(unittest.HomeserverTestCase):
@@ -89,3 +90,31 @@ class KeyQueryTestCase(unittest.HomeserverTestCase):
             Codes.BAD_JSON,
             channel.result,
         )
+
+    def test_key_query_cancellation(self) -> None:
+        """
+        Tests that /keys/query is cancellable and does not swallow the
+        CancelledError.
+        """
+        self.register_user("alice", "wonderland")
+        alice_token = self.login("alice", "wonderland")
+
+        bob = self.register_user("bob", "uncle")
+
+        channel = make_request_with_cancellation_test(
+            "test_key_query_cancellation",
+            self.reactor,
+            self.site,
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    # Empty list means we request keys for all bob's devices
+                    bob: [],
+                },
+            },
+            token=alice_token,
+        )
+
+        self.assertEqual(200, channel.code, msg=channel.result["body"])
+        self.assertIn(bob, channel.json_body["device_keys"])