diff --git a/changelog.d/10144.misc b/changelog.d/10144.misc
new file mode 100644
index 0000000000..fe96d645d7
--- /dev/null
+++ b/changelog.d/10144.misc
@@ -0,0 +1 @@
+Limit the number of in-flight `/keys/query` requests from a single device.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 974487800d..3972849d4d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -79,9 +79,15 @@ class E2eKeysHandler:
"client_keys", self.on_federation_query_client_keys
)
+ # Limit the number of in-flight requests from a single device.
+ self._query_devices_linearizer = Linearizer(
+ name="query_devices",
+ max_count=10,
+ )
+
@trace
async def query_devices(
- self, query_body: JsonDict, timeout: int, from_user_id: str
+ self, query_body: JsonDict, timeout: int, from_user_id: str, from_device_id: str
) -> JsonDict:
"""Handle a device key query from a client
@@ -105,191 +111,197 @@ class E2eKeysHandler:
from_user_id: the user making the query. This is used when
adding cross-signing signatures to limit what signatures users
can see.
+ from_device_id: the device making the query. This is used to limit
+ the number of in-flight queries at a time.
"""
-
- device_keys_query = query_body.get(
- "device_keys", {}
- ) # type: Dict[str, Iterable[str]]
-
- # separate users by domain.
- # make a map from domain to user_id to device_ids
- local_query = {}
- remote_queries = {}
-
- for user_id, device_ids in device_keys_query.items():
- # we use UserID.from_string to catch invalid user ids
- if self.is_mine(UserID.from_string(user_id)):
- local_query[user_id] = device_ids
- else:
- remote_queries[user_id] = device_ids
-
- set_tag("local_key_query", local_query)
- set_tag("remote_key_query", remote_queries)
-
- # First get local devices.
- # A map of destination -> failure response.
- failures = {} # type: Dict[str, JsonDict]
- results = {}
- if local_query:
- local_result = await self.query_local_devices(local_query)
- for user_id, keys in local_result.items():
- if user_id in local_query:
- results[user_id] = keys
-
- # Get cached cross-signing keys
- cross_signing_keys = await self.get_cross_signing_keys_from_cache(
- device_keys_query, from_user_id
- )
-
- # Now attempt to get any remote devices from our local cache.
- # A map of destination -> user ID -> device IDs.
- remote_queries_not_in_cache = {} # type: Dict[str, Dict[str, Iterable[str]]]
- if remote_queries:
- query_list = [] # type: List[Tuple[str, Optional[str]]]
- for user_id, device_ids in remote_queries.items():
- if device_ids:
- query_list.extend((user_id, device_id) for device_id in device_ids)
+ with await self._query_devices_linearizer.queue((from_user_id, from_device_id)):
+ device_keys_query = query_body.get(
+ "device_keys", {}
+ ) # type: Dict[str, Iterable[str]]
+
+ # separate users by domain.
+ # make a map from domain to user_id to device_ids
+ local_query = {}
+ remote_queries = {}
+
+ for user_id, device_ids in device_keys_query.items():
+ # we use UserID.from_string to catch invalid user ids
+ if self.is_mine(UserID.from_string(user_id)):
+ local_query[user_id] = device_ids
else:
- query_list.append((user_id, None))
-
- (
- user_ids_not_in_cache,
- remote_results,
- ) = await self.store.get_user_devices_from_cache(query_list)
- for user_id, devices in remote_results.items():
- user_devices = results.setdefault(user_id, {})
- for device_id, device in devices.items():
- keys = device.get("keys", None)
- device_display_name = device.get("device_display_name", None)
- if keys:
- result = dict(keys)
- unsigned = result.setdefault("unsigned", {})
- if device_display_name:
- unsigned["device_display_name"] = device_display_name
- user_devices[device_id] = result
-
- # check for missing cross-signing keys.
- for user_id in remote_queries.keys():
- cached_cross_master = user_id in cross_signing_keys["master_keys"]
- cached_cross_selfsigning = (
- user_id in cross_signing_keys["self_signing_keys"]
- )
-
- # check if we are missing only one of cross-signing master or
- # self-signing key, but the other one is cached.
- # as we need both, this will issue a federation request.
- # if we don't have any of the keys, either the user doesn't have
- # cross-signing set up, or the cached device list
- # is not (yet) updated.
- if cached_cross_master ^ cached_cross_selfsigning:
- user_ids_not_in_cache.add(user_id)
-
- # add those users to the list to fetch over federation.
- for user_id in user_ids_not_in_cache:
- domain = get_domain_from_id(user_id)
- r = remote_queries_not_in_cache.setdefault(domain, {})
- r[user_id] = remote_queries[user_id]
-
- # Now fetch any devices that we don't have in our cache
- @trace
- async def do_remote_query(destination):
- """This is called when we are querying the device list of a user on
- a remote homeserver and their device list is not in the device list
- cache. If we share a room with this user and we're not querying for
- specific user we will update the cache with their device list.
- """
-
- destination_query = remote_queries_not_in_cache[destination]
-
- # We first consider whether we wish to update the device list cache with
- # the users device list. We want to track a user's devices when the
- # authenticated user shares a room with the queried user and the query
- # has not specified a particular device.
- # If we update the cache for the queried user we remove them from further
- # queries. We use the more efficient batched query_client_keys for all
- # remaining users
- user_ids_updated = []
- for (user_id, device_list) in destination_query.items():
- if user_id in user_ids_updated:
- continue
-
- if device_list:
- continue
+ remote_queries[user_id] = device_ids
+
+ set_tag("local_key_query", local_query)
+ set_tag("remote_key_query", remote_queries)
+
+ # First get local devices.
+ # A map of destination -> failure response.
+ failures = {} # type: Dict[str, JsonDict]
+ results = {}
+ if local_query:
+ local_result = await self.query_local_devices(local_query)
+ for user_id, keys in local_result.items():
+ if user_id in local_query:
+ results[user_id] = keys
- room_ids = await self.store.get_rooms_for_user(user_id)
- if not room_ids:
- continue
+ # Get cached cross-signing keys
+ cross_signing_keys = await self.get_cross_signing_keys_from_cache(
+ device_keys_query, from_user_id
+ )
- # We've decided we're sharing a room with this user and should
- # probably be tracking their device lists. However, we haven't
- # done an initial sync on the device list so we do it now.
- try:
- if self._is_master:
- user_devices = await self.device_handler.device_list_updater.user_device_resync(
- user_id
+ # Now attempt to get any remote devices from our local cache.
+ # A map of destination -> user ID -> device IDs.
+ remote_queries_not_in_cache = (
+ {}
+ ) # type: Dict[str, Dict[str, Iterable[str]]]
+ if remote_queries:
+ query_list = [] # type: List[Tuple[str, Optional[str]]]
+ for user_id, device_ids in remote_queries.items():
+ if device_ids:
+ query_list.extend(
+ (user_id, device_id) for device_id in device_ids
)
else:
- user_devices = await self._user_device_resync_client(
- user_id=user_id
- )
-
- user_devices = user_devices["devices"]
- user_results = results.setdefault(user_id, {})
- for device in user_devices:
- user_results[device["device_id"]] = device["keys"]
- user_ids_updated.append(user_id)
- except Exception as e:
- failures[destination] = _exception_to_failure(e)
-
- if len(destination_query) == len(user_ids_updated):
- # We've updated all the users in the query and we do not need to
- # make any further remote calls.
- return
+ query_list.append((user_id, None))
- # Remove all the users from the query which we have updated
- for user_id in user_ids_updated:
- destination_query.pop(user_id)
+ (
+ user_ids_not_in_cache,
+ remote_results,
+ ) = await self.store.get_user_devices_from_cache(query_list)
+ for user_id, devices in remote_results.items():
+ user_devices = results.setdefault(user_id, {})
+ for device_id, device in devices.items():
+ keys = device.get("keys", None)
+ device_display_name = device.get("device_display_name", None)
+ if keys:
+ result = dict(keys)
+ unsigned = result.setdefault("unsigned", {})
+ if device_display_name:
+ unsigned["device_display_name"] = device_display_name
+ user_devices[device_id] = result
+
+ # check for missing cross-signing keys.
+ for user_id in remote_queries.keys():
+ cached_cross_master = user_id in cross_signing_keys["master_keys"]
+ cached_cross_selfsigning = (
+ user_id in cross_signing_keys["self_signing_keys"]
+ )
- try:
- remote_result = await self.federation.query_client_keys(
- destination, {"device_keys": destination_query}, timeout=timeout
- )
+ # check if we are missing only one of cross-signing master or
+ # self-signing key, but the other one is cached.
+ # as we need both, this will issue a federation request.
+ # if we don't have any of the keys, either the user doesn't have
+ # cross-signing set up, or the cached device list
+ # is not (yet) updated.
+ if cached_cross_master ^ cached_cross_selfsigning:
+ user_ids_not_in_cache.add(user_id)
+
+ # add those users to the list to fetch over federation.
+ for user_id in user_ids_not_in_cache:
+ domain = get_domain_from_id(user_id)
+ r = remote_queries_not_in_cache.setdefault(domain, {})
+ r[user_id] = remote_queries[user_id]
+
+ # Now fetch any devices that we don't have in our cache
+ @trace
+ async def do_remote_query(destination):
+ """This is called when we are querying the device list of a user on
+ a remote homeserver and their device list is not in the device list
+ cache. If we share a room with this user and we're not querying for
+ specific user we will update the cache with their device list.
+ """
+
+ destination_query = remote_queries_not_in_cache[destination]
+
+ # We first consider whether we wish to update the device list cache with
+ # the users device list. We want to track a user's devices when the
+ # authenticated user shares a room with the queried user and the query
+ # has not specified a particular device.
+ # If we update the cache for the queried user we remove them from further
+ # queries. We use the more efficient batched query_client_keys for all
+ # remaining users
+ user_ids_updated = []
+ for (user_id, device_list) in destination_query.items():
+ if user_id in user_ids_updated:
+ continue
+
+ if device_list:
+ continue
+
+ room_ids = await self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ continue
+
+ # We've decided we're sharing a room with this user and should
+ # probably be tracking their device lists. However, we haven't
+ # done an initial sync on the device list so we do it now.
+ try:
+ if self._is_master:
+ user_devices = await self.device_handler.device_list_updater.user_device_resync(
+ user_id
+ )
+ else:
+ user_devices = await self._user_device_resync_client(
+ user_id=user_id
+ )
+
+ user_devices = user_devices["devices"]
+ user_results = results.setdefault(user_id, {})
+ for device in user_devices:
+ user_results[device["device_id"]] = device["keys"]
+ user_ids_updated.append(user_id)
+ except Exception as e:
+ failures[destination] = _exception_to_failure(e)
+
+ if len(destination_query) == len(user_ids_updated):
+ # We've updated all the users in the query and we do not need to
+ # make any further remote calls.
+ return
+
+ # Remove all the users from the query which we have updated
+ for user_id in user_ids_updated:
+ destination_query.pop(user_id)
- for user_id, keys in remote_result["device_keys"].items():
- if user_id in destination_query:
- results[user_id] = keys
+ try:
+ remote_result = await self.federation.query_client_keys(
+ destination, {"device_keys": destination_query}, timeout=timeout
+ )
- if "master_keys" in remote_result:
- for user_id, key in remote_result["master_keys"].items():
+ for user_id, keys in remote_result["device_keys"].items():
if user_id in destination_query:
- cross_signing_keys["master_keys"][user_id] = key
+ results[user_id] = keys
- if "self_signing_keys" in remote_result:
- for user_id, key in remote_result["self_signing_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["self_signing_keys"][user_id] = key
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
- except Exception as e:
- failure = _exception_to_failure(e)
- failures[destination] = failure
- set_tag("error", True)
- set_tag("reason", failure)
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
- await make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(do_remote_query, destination)
- for destination in remote_queries_not_in_cache
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
- )
+ except Exception as e:
+ failure = _exception_to_failure(e)
+ failures[destination] = failure
+ set_tag("error", True)
+ set_tag("reason", failure)
+
+ await make_deferred_yieldable(
+ defer.gatherResults(
+ [
+ run_in_background(do_remote_query, destination)
+ for destination in remote_queries_not_in_cache
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+ )
- ret = {"device_keys": results, "failures": failures}
+ ret = {"device_keys": results, "failures": failures}
- ret.update(cross_signing_keys)
+ ret.update(cross_signing_keys)
- return ret
+ return ret
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
diff --git a/synapse/rest/client/v2_alpha/keys.py b/synapse/rest/client/v2_alpha/keys.py
index a57ccbb5e5..4a28f2c072 100644
--- a/synapse/rest/client/v2_alpha/keys.py
+++ b/synapse/rest/client/v2_alpha/keys.py
@@ -160,9 +160,12 @@ class KeyQueryServlet(RestServlet):
async def on_POST(self, request):
requester = await self.auth.get_user_by_req(request, allow_guest=True)
user_id = requester.user.to_string()
+ device_id = requester.device_id
timeout = parse_integer(request, "timeout", 10 * 1000)
body = parse_json_object_from_request(request)
- result = await self.e2e_keys_handler.query_devices(body, timeout, user_id)
+ result = await self.e2e_keys_handler.query_devices(
+ body, timeout, user_id, device_id
+ )
return 200, result
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 61a00130b8..e0a24824cc 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -257,7 +257,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
self.get_success(self.handler.upload_signing_keys_for_user(local_user, keys2))
devices = self.get_success(
- self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
+ self.handler.query_devices(
+ {"device_keys": {local_user: []}}, 0, local_user, "device123"
+ )
)
self.assertDictEqual(devices["master_keys"], {local_user: keys2["master_key"]})
@@ -357,7 +359,9 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
device_key_1["signatures"][local_user]["ed25519:abc"] = "base64+signature"
device_key_2["signatures"][local_user]["ed25519:def"] = "base64+signature"
devices = self.get_success(
- self.handler.query_devices({"device_keys": {local_user: []}}, 0, local_user)
+ self.handler.query_devices(
+ {"device_keys": {local_user: []}}, 0, local_user, "device123"
+ )
)
del devices["device_keys"][local_user]["abc"]["unsigned"]
del devices["device_keys"][local_user]["def"]["unsigned"]
@@ -591,7 +595,10 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
# fetch the signed keys/devices and make sure that the signatures are there
ret = self.get_success(
self.handler.query_devices(
- {"device_keys": {local_user: [], other_user: []}}, 0, local_user
+ {"device_keys": {local_user: [], other_user: []}},
+ 0,
+ local_user,
+ "device123",
)
)
|