diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d0fb2fc7dc..60c11e3d21 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -201,95 +201,19 @@ class E2eKeysHandler:
r[user_id] = remote_queries[user_id]
# Now fetch any devices that we don't have in our cache
- @trace
- async def do_remote_query(destination: str) -> None:
- """This is called when we are querying the device list of a user on
- a remote homeserver and their device list is not in the device list
- cache. If we share a room with this user and we're not querying for
- specific user we will update the cache with their device list.
- """
-
- destination_query = remote_queries_not_in_cache[destination]
-
- # We first consider whether we wish to update the device list cache with
- # the users device list. We want to track a user's devices when the
- # authenticated user shares a room with the queried user and the query
- # has not specified a particular device.
- # If we update the cache for the queried user we remove them from further
- # queries. We use the more efficient batched query_client_keys for all
- # remaining users
- user_ids_updated = []
- for (user_id, device_list) in destination_query.items():
- if user_id in user_ids_updated:
- continue
-
- if device_list:
- continue
-
- room_ids = await self.store.get_rooms_for_user(user_id)
- if not room_ids:
- continue
-
- # We've decided we're sharing a room with this user and should
- # probably be tracking their device lists. However, we haven't
- # done an initial sync on the device list so we do it now.
- try:
- if self._is_master:
- user_devices = await self.device_handler.device_list_updater.user_device_resync(
- user_id
- )
- else:
- user_devices = await self._user_device_resync_client(
- user_id=user_id
- )
-
- user_devices = user_devices["devices"]
- user_results = results.setdefault(user_id, {})
- for device in user_devices:
- user_results[device["device_id"]] = device["keys"]
- user_ids_updated.append(user_id)
- except Exception as e:
- failures[destination] = _exception_to_failure(e)
-
- if len(destination_query) == len(user_ids_updated):
- # We've updated all the users in the query and we do not need to
- # make any further remote calls.
- return
-
- # Remove all the users from the query which we have updated
- for user_id in user_ids_updated:
- destination_query.pop(user_id)
-
- try:
- remote_result = await self.federation.query_client_keys(
- destination, {"device_keys": destination_query}, timeout=timeout
- )
-
- for user_id, keys in remote_result["device_keys"].items():
- if user_id in destination_query:
- results[user_id] = keys
-
- if "master_keys" in remote_result:
- for user_id, key in remote_result["master_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["master_keys"][user_id] = key
-
- if "self_signing_keys" in remote_result:
- for user_id, key in remote_result["self_signing_keys"].items():
- if user_id in destination_query:
- cross_signing_keys["self_signing_keys"][user_id] = key
-
- except Exception as e:
- failure = _exception_to_failure(e)
- failures[destination] = failure
- set_tag("error", True)
- set_tag("reason", failure)
-
await make_deferred_yieldable(
defer.gatherResults(
[
- run_in_background(do_remote_query, destination)
- for destination in remote_queries_not_in_cache
+ run_in_background(
+ self._query_devices_for_destination,
+ results,
+ cross_signing_keys,
+ failures,
+ destination,
+ queries,
+ timeout,
+ )
+ for destination, queries in remote_queries_not_in_cache.items()
],
consumeErrors=True,
).addErrback(unwrapFirstError)
@@ -301,6 +225,121 @@ class E2eKeysHandler:
return ret
+ @trace
+ async def _query_devices_for_destination(
+ self,
+ results: JsonDict,
+ cross_signing_keys: JsonDict,
+ failures: Dict[str, JsonDict],
+ destination: str,
+ destination_query: Dict[str, Iterable[str]],
+ timeout: int,
+ ) -> None:
+ """This is called when we are querying the device list of a user on
+ a remote homeserver and their device list is not in the device list
+ cache. If we share a room with this user and we're not querying for
+ specific user we will update the cache with their device list.
+
+ Args:
+ results: A map from user ID to their device keys, which gets
+ updated with the newly fetched keys.
+ cross_signing_keys: Map from user ID to their cross signing keys,
+ which gets updated with the newly fetched keys.
+ failures: Map of destinations to failures that have occurred while
+ attempting to fetch keys.
+ destination: The remote server to query
+ destination_query: The query dict of devices to query the remote
+ server for.
+ timeout: The timeout for remote HTTP requests.
+ """
+
+ # We first consider whether we wish to update the device list cache with
+ # the users device list. We want to track a user's devices when the
+ # authenticated user shares a room with the queried user and the query
+ # has not specified a particular device.
+ # If we update the cache for the queried user we remove them from further
+ # queries. We use the more efficient batched query_client_keys for all
+ # remaining users
+ user_ids_updated = []
+ for (user_id, device_list) in destination_query.items():
+ if user_id in user_ids_updated:
+ continue
+
+ if device_list:
+ continue
+
+ room_ids = await self.store.get_rooms_for_user(user_id)
+ if not room_ids:
+ continue
+
+ # We've decided we're sharing a room with this user and should
+ # probably be tracking their device lists. However, we haven't
+ # done an initial sync on the device list so we do it now.
+ try:
+ if self._is_master:
+ resync_results = await self.device_handler.device_list_updater.user_device_resync(
+ user_id
+ )
+ else:
+ resync_results = await self._user_device_resync_client(
+ user_id=user_id
+ )
+
+ # Add the device keys to the results.
+ user_devices = resync_results["devices"]
+ user_results = results.setdefault(user_id, {})
+ for device in user_devices:
+ user_results[device["device_id"]] = device["keys"]
+ user_ids_updated.append(user_id)
+
+ # Add any cross signing keys to the results.
+ master_key = resync_results.get("master_key")
+ self_signing_key = resync_results.get("self_signing_key")
+
+ if master_key:
+ cross_signing_keys["master_keys"][user_id] = master_key
+
+ if self_signing_key:
+ cross_signing_keys["self_signing_keys"][user_id] = self_signing_key
+ except Exception as e:
+ failures[destination] = _exception_to_failure(e)
+
+ if len(destination_query) == len(user_ids_updated):
+ # We've updated all the users in the query and we do not need to
+ # make any further remote calls.
+ return
+
+ # Remove all the users from the query which we have updated
+ for user_id in user_ids_updated:
+ destination_query.pop(user_id)
+
+ try:
+ remote_result = await self.federation.query_client_keys(
+ destination, {"device_keys": destination_query}, timeout=timeout
+ )
+
+ for user_id, keys in remote_result["device_keys"].items():
+ if user_id in destination_query:
+ results[user_id] = keys
+
+ if "master_keys" in remote_result:
+ for user_id, key in remote_result["master_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["master_keys"][user_id] = key
+
+ if "self_signing_keys" in remote_result:
+ for user_id, key in remote_result["self_signing_keys"].items():
+ if user_id in destination_query:
+ cross_signing_keys["self_signing_keys"][user_id] = key
+
+ except Exception as e:
+ failure = _exception_to_failure(e)
+ failures[destination] = failure
+ set_tag("error", True)
+ set_tag("reason", failure)
+
+ return
+
async def get_cross_signing_keys_from_cache(
self, query: Iterable[str], from_user_id: Optional[str]
) -> Dict[str, Dict[str, dict]]:
|