summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-06-09 07:05:32 -0400
committerGitHub <noreply@github.com>2021-06-09 07:05:32 -0400
commit11846dff8c667cbe6861ddc821ca7c53e3e2d890 (patch)
tree5671b8fd94e1cc09d6f0008ff237dd07ea9fd2f3
parentClean up the interface for injecting opentracing over HTTP (#10143) (diff)
downloadsynapse-11846dff8c667cbe6861ddc821ca7c53e3e2d890.tar.xz
Limit the number of in-flight /keys/query requests from a single device. (#10144)
-rw-r--r--changelog.d/10144.misc1
-rw-r--r--synapse/handlers/e2e_keys.py350
-rw-r--r--synapse/rest/client/v2_alpha/keys.py5
-rw-r--r--tests/handlers/test_e2e_keys.py13
4 files changed, 196 insertions, 173 deletions
diff --git a/changelog.d/10144.misc b/changelog.d/10144.misc
new file mode 100644
index 0000000000..fe96d645d7
--- /dev/null
+++ b/changelog.d/10144.misc
@@ -0,0 +1 @@
+Limit the number of in-flight `/keys/query` requests from a single device.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 974487800d..3972849d4d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -79,9 +79,15 @@ class E2eKeysHandler:
             "client_keys", self.on_federation_query_client_keys
         )
 
+        # Limit the number of in-flight requests from a single device.
+        self._query_devices_linearizer = Linearizer(
+            name="query_devices",
+            max_count=10,
+        )
+
     @trace
     async def query_devices(
-        self, query_body: JsonDict, timeout: int, from_user_id: str
+        self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
     ) -> JsonDict:
         """Handle a device key query from a client
 
@@ -105,191 +111,197 @@ class E2eKeysHandler:
             from_user_id: the user making the query.  This is used when
                 adding cross-signing signatures to limit what signatures users
                 can see.
+            from_device_id: the device making the query. This is used to limit
+                the number of in-flight queries at a time.
         """
-
-        device_keys_query = query_body.get(
-            "device_keys", {}
-        )  # type: Dict[str, Iterable[str]]
-
-        # separate users by domain.
-        # make a map from domain to user_id to device_ids
-        local_query = {}
-        remote_queries = {}
-
-        for user_id, device_ids in device_keys_query.items():
-            # we use UserID.from_string to catch invalid user ids
-            if self.is_mine(UserID.from_string(user_id)):
-                local_query[user_id] = device_ids
-            else:
-                remote_queries[user_id] = device_ids
-
-        set_tag("local_key_query", local_query)
-        set_tag("remote_key_query", remote_queries)
-
-        # First get local devices.
-        # A map of destination -> failure response.
-        failures = {}  # type: Dict[str, JsonDict]
-        results = {}
-        if local_query:
-            local_result = await self.query_local_devices(local_query)
-            for user_id, keys in local_result.items():
-                if user_id in local_query:
-                    results[user_id] = keys
-
-        # Get cached cross-signing keys
-        cross_signing_keys = await self.get_cross_signing_keys_from_cache(
-            device_keys_query, from_user_id
-        )
-
-        # Now attempt to get any remote devices from our local cache.
-        # A map of destination -> user ID -> device IDs.
-        remote_queries_not_in_cache = {}  # type: Dict[str, Dict[str, Iterable[str]]]
-        if remote_queries:
-            query_list = []  # type: List[Tuple[str, Optional[str]]]
-            for user_id, device_ids in remote_queries.items():
-                if device_ids:
-                    query_list.extend((user_id, device_id) for device_id in device_ids)
+        with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
+            device_keys_query = query_body.get(
+                "device_keys", {}
+            )  # type: Dict[str, Iterable[str]]
+
+            # separate users by domain.
+            # make a map from domain to user_id to device_ids
+            local_query = {}
+            remote_queries = {}
+
+            for user_id, device_ids in device_keys_query.items():
+                # we use UserID.from_string to catch invalid user ids
+                if self.is_mine(UserID.from_string(user_id)):
+                    local_query[user_id] = device_ids
                 else:
-                    query_list.append((user_id, None))
-
-            (
-                user_ids_not_in_cache,
-                remote_results,
-            ) = await self.store.get_user_devices_from_cache(query_list)
-            for user_id, devices in remote_results.items():
-                user_devices = results.setdefault(user_id, {})
-                for device_id, device in devices.items():
-                    keys = device.get("keys", None)
-                    device_display_name = device.get("device_display_name", None)
-                    if keys:
-                        result = dict(keys)
-                        unsigned = result.setdefault("unsigned", {})
-                        if device_display_name:
-                            unsigned["device_display_name"] = device_display_name
-                        user_devices[device_id] = result
-
-            # check for missing cross-signing keys.
-            for user_id in remote_queries.keys():
-                cached_cross_master = user_id in cross_signing_keys["master_keys"]
-                cached_cross_selfsigning = (
-                    user_id in cross_signing_keys["self_signing_keys"]
-                )
-
-                # check if we are missing only one of cross-signing master or
-                # self-signing key, but the other one is cached.
-                # as we need both, this will issue a federation request.
-                # if we don't have any of the keys, either the user doesn't have
-                # cross-signing set up, or the cached device list
-                # is not (yet) updated.
-                if cached_cross_master ^ cached_cross_selfsigning:
-                    user_ids_not_in_cache.add(user_id)
-
-            # add those users to the list to fetch over federation.
-            for user_id in user_ids_not_in_cache:
-                domain = get_domain_from_id(user_id)
-                r = remote_queries_not_in_cache.setdefault(domain, {})
-                r[user_id] = remote_queries[user_id]
-
-        # Now fetch any devices that we don't have in our cache
-        @trace
-        async def do_remote_query(destination):
-            """This is called when we are querying the device list of a user on
-            a remote homeserver and their device list is not in the device list
-            cache. If we share a room with this user and we're not querying for
-            specific user we will update the cache with their device list.
-            """
-
-            destination_query = remote_queries_not_in_cache[destination]
-
-            # We first consider whether we wish to update the device list cache with
-            # the users device list. We want to track a user's devices when the
-            # authenticated user shares a room with the queried user and the query
-            # has not specified a particular device.
-            # If we update the cache for the queried user we remove them from further
-            # queries. We use the more efficient batched query_client_keys for all
-            # remaining users
-            user_ids_updated = []
-            for (user_id, device_list) in destination_query.items():
-                if user_id in user_ids_updated:
-                    continue
-
-                if device_list:
-                    continue
+                    remote_queries[user_id] = device_ids
+
+            set_tag("local_key_query", local_query)
+            set_tag("remote_key_query", remote_queries)
+
+            # First get local devices.
+            # A map of destination -> failure response.
+            failures = {}  # type: Dict[str, JsonDict]
+            results = {}
+            if local_query:
+                local_result = await self.query_local_devices(local_query)
+                for user_id, keys in local_result.items():
+                    if user_id in local_query:
+                        results[user_id] = keys
 
-                room_ids = await self.store.get_rooms_for_user(user_id)
-                if not room_ids:
-                    continue
+            # Get cached cross-signing keys
+            cross_signing_keys = await self.get_cross_signing_keys_from_cache(
+                device_keys_query, from_user_id
+            )
 
-                # We've decided we're sharing a room with this user and should
-                # probably be tracking their device lists. However, we haven't
-                # done an initial sync on the device list so we do it now.
-                try:
-                    if self._is_master:
-                        user_devices = await self.device_handler.device_list_updater.user_device_resync(
-                            user_id
+            # Now attempt to get any remote devices from our local cache.
+            # A map of destination -> user ID -> device IDs.
+            remote_queries_not_in_cache = (
+                {}
+            )  # type: Dict[str, Dict[str, Iterable[str]]]
+            if remote_queries:
+                query_list = []  # type: List[Tuple[str, Optional[str]]]
+                for user_id, device_ids in remote_queries.items():
+                    if device_ids:
+                        query_list.extend(
+                            (user_id, device_id) for device_id in device_ids
                         )
                     else:
-                        user_devices = await self._user_device_resync_client(
-                            user_id=user_id
-                        )
-
-                    user_devices = user_devices["devices"]
-                    user_results = results.setdefault(user_id, {})
-                    for device in user_devices:
-                        user_results[device["device_id"]] = device["keys"]
-                    user_ids_updated.append(user_id)
-                except Exception as e:
-                    failures[destination] = _exception_to_failure(e)
-
-            if len(destination_query) == len(user_ids_updated):
-                # We've updated all the users in the query and we do not need to
-                # make any further remote calls.
-                return
+                        query_list.append((user_id, None))
 
-            # Remove all the users from the query which we have updated
-            for user_id in user_ids_updated:
-                destination_query.pop(user_id)
+                (
+                    user_ids_not_in_cache,
+                    remote_results,
+                ) = await self.store.get_user_devices_from_cache(query_list)
+                for user_id, devices in remote_results.items():
+                    user_devices = results.setdefault(user_id, {})
+                    for device_id, device in devices.items():
+                        keys = device.get("keys", None)
+                        device_display_name = device.get("device_display_name", None)
+                        if keys:
+                            result = dict(keys)
+                            unsigned = result.setdefault("unsigned", {})
+                            if device_display_name:
+                                unsigned["device_display_name"] = device_display_name
+                            user_devices[device_id] = result
+
+                # check for missing cross-signing keys.
+                for user_id in remote_queries.keys():
+                    cached_cross_master = user_id in cross_signing_keys["master_keys"]
+                    cached_cross_selfsigning = (
+                        user_id in cross_signing_keys["self_signing_keys"]
+                    )
 
-            try:
-                remote_result = await self.federation.query_client_keys(
-                    destination, {"device_keys": destination_query}, timeout=timeout
-                )
+                    # check if we are missing only one of cross-signing master or
+                    # self-signing key, but the other one is cached.
+                    # as we need both, this will issue a federation request.
+                    # if we don't have any of the keys, either the user doesn't have
+                    # cross-signing set up, or the cached device list
+                    # is not (yet) updated.
+                    if cached_cross_master ^ cached_cross_selfsigning:
+                        user_ids_not_in_cache.add(user_id)
+
+                # add those users to the list to fetch over federation.
+                for user_id in user_ids_not_in_cache:
+                    domain = get_domain_from_id(user_id)
+                    r = remote_queries_not_in_cache.setdefault(domain, {})
+                    r[user_id] = remote_queries[user_id]
+
+            # Now fetch any devices that we don't have in our cache
+            @trace
+            async def do_remote_query(destination):
+                """This is called when we are querying the device list of a user on
+                a remote homeserver and their device list is not in the device list
+                cache. If we share a room with this user and we're not querying for
+                specific user we will update the cache with their device list.
+                """
+
+                destination_query = remote_queries_not_in_cache[destination]
+
+                # We first consider whether we wish to update the device list cache with
+                # the users device list. We want to track a user's devices when the
+                # authenticated user shares a room with the queried user and the query
+                # has not specified a particular device.
+                # If we update the cache for the queried user we remove them from further
+                # queries. We use the more efficient batched query_client_keys for all
+                # remaining users
+                user_ids_updated = []
+                for (user_id, device_list) in destination_query.items():
+                    if user_id in user_ids_updated:
+                        continue
+
+                    if device_list:
+                        continue
+
+                    room_ids = await self.store.get_rooms_for_user(user_id)
+                    if not room_ids:
+                        continue
+
+                    # We've decided we're sharing a room with this user and should
+                    # probably be tracking their device lists. However, we haven't
+                    # done an initial sync on the device list so we do it now.
+                    try:
+                        if self._is_master:
+                            user_devices = await self.device_handler.device_list_updater.user_device_resync(
+                                user_id
+                            )
+                        else:
+                            user_devices = await self._user_device_resync_client(
+                                user_id=user_id
+                            )
+
+                        user_devices = user_devices["devices"]
+                        user_results = results.setdefault(user_id, {})
+                        for device in user_devices:
+                            user_results[device["device_id"]] = device["keys"]
+                        user_ids_updated.append(user_id)
+                    except Exception as e:
+                        failures[destination] = _exception_to_failure(e)
+
+                if len(destination_query) == len(user_ids_updated):
+                    # We've updated all the users in the query and we do not need to
+                    # make any further remote calls.
+                    return
+
+                # Remove all the users from the query which we have updated
+                for user_id in user_ids_updated:
+                    destination_query.pop(user_id)
 
-                for user_id, keys in remote_result["device_keys"].items():
-                    if user_id in destination_query:
-                        results[user_id] = keys
+                try:
+                    remote_result = await self.federation.query_client_keys(
+                        destination, {"device_keys": destination_query}, timeout=timeout
+                    )
 
-                if "master_keys" in remote_result:
-                    for user_id, key in remote_result["master_keys"].items():
+                    for user_id, keys in remote_result["device_keys"].items():
                         if user_id in destination_query:
-                            cross_signing_keys["master_keys"][user_id] = key
+                            results[user_id] = keys
 
-                if "self_signing_keys" in remote_result:
-                    for user_id, key in remote_result["self_signing_keys"].items():
-                        if user_id in destination_query:
-                            cross_signing_keys["self_signing_keys"][user_id] = key
+                    if "master_keys" in remote_result:
+                        for user_id, key in remote_result["master_keys"].items():
+                            if user_id in destination_query:
+                                cross_signing_keys["master_keys"][user_id] = key
 
-            except Exception as e:
-                failure = _exception_to_failure(e)
-                failures[destination] = failure
-                set_tag("error", True)
-                set_tag("reason", failure)
+                    if "self_signing_keys" in remote_result:
+                        for user_id, key in remote_result["self_signing_keys"].items():
+                            if user_id in destination_query:
+                                cross_signing_keys["self_signing_keys"][user_id] = key
 
-        await make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(do_remote_query, destination)
-                    for destination in remote_queries_not_in_cache
-                ],
-                consumeErrors=True,
-            ).addErrback(unwrapFirstError)
-        )
+                except Exception as e:
+                    failure = _exception_to_failure(e)
+                    failures[destination] = failure
+                    set_tag("error", True)
+                    set_tag("reason", failure)
+
+            await make_deferred_yieldable(
+                defer.gatherResults(
+                    [
+                        run_in_background(do_remote_query, destination)
+                        for destination in remote_queries_not_in_cache
+                    ],
+                    consumeErrors=True,
+                ).addErrback(unwrapFirstError)
+            )
 
-        ret = {"device_keys": results, "failures": failures}
+            ret = {"device_keys": results, "failures": failures}
 
-        ret.update(cross_signing_keys)
+            ret.update(cross_signing_keys)
 
-        return ret
+            return ret
 
     async def get_cross_signing_keys_from_cache(
         self, query: Iterable[str], from_user_id: Optional[str]
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index a57ccbb5e5..4a28f2c072 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -160,9 +160,12 @@ class KeyQueryServlet(RestServlet):
     async def on_POST(self, request):
         requester = await self.auth.get_user_by_req(request, allow_guest=True)
         user_id = requester.user.to_string()
+        device_id = requester.device_id
         timeout = parse_integer(request, "timeout", 10 * 1000)
         body = parse_json_object_from_request(request)
-        result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
+        result = await self.e2e_keys_handler.query_devices(
+            body, timeout, user_id, device_id
+        )
         return 200, result
 
 
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 61a00130b8..e0a24824cc 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -257,7 +257,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
 
         devices = self.get_success(
-            self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
+            self.handler.query_devices(
+                {"device_keys": {local_user: []}}, 0, local_user, "device123"
+            )
         )
         self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
 
@@ -357,7 +359,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
         device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
         devices = self.get_success(
-            self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
+            self.handler.query_devices(
+                {"device_keys": {local_user: []}}, 0, local_user, "device123"
+            )
         )
         del devices["device_keys"][local_user]["abc"]["unsigned"]
         del devices["device_keys"][local_user]["def"]["unsigned"]
@@ -591,7 +595,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
         # fetch the signed keys/devices and make sure that the signatures are there
         ret = self.get_success(
             self.handler.query_devices(
-                {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+                {"device_keys": {local_user: [], other_user: []}},
+                0,
+                local_user,
+                "device123",
             )
         )