summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2023-10-30 12:03:36 +0000
committerRichard van der Hoff <richard@matrix.org>2023-10-30 12:03:36 +0000
commit6dbad839980deb2aee07c1df3d19b80a0e178512 (patch)
treecdf79f75ec422108d2ffc92da86d11115e220e94 /synapse/handlers
parentImprove tracing for `claim_one_time_keys` (diff)
downloadsynapse-6dbad839980deb2aee07c1df3d19b80a0e178512.tar.xz
Implement MSC4072: return result for all /keys/claims
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/e2e_keys.py93
1 files changed, 90 insertions, 3 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 7d44127ebf..bb628fdb00 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -592,6 +592,12 @@ class E2eKeysHandler:
             for user_id, device_id, algorithm, count in local_query
         ]
 
+        # prepopulate the response to make sure that all queried users/devices are
+        # included, even if the user/device is unknown or has run out of OTKs
+        if self.config.experimental.msc4072_empty_dict_for_exhausted_devices:
+            for user_id, device_id, _, _ in local_query:
+                result_dict.setdefault(user_id, {}).setdefault(device_id, {})
+
         otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
         update_result_dict(otk_results)
 
@@ -669,6 +675,25 @@ class E2eKeysHandler:
         timeout: Optional[int],
         always_include_fallback_keys: bool,
     ) -> JsonDict:
+        """
+        Handle a /keys/claim request.
+
+        Handles requests for local users with a db lookup, and makes federation
+        requests for remote users.
+
+        Args:
+            query: map from user ID, to map from device ID, to map from algorithm name
+                to number of keys needed
+                (``{user_id: {device_id: {algorithm: number_of keys}}}``)
+
+            user: The user id of the requesting user
+
+            timeout: number of milliseconds to wait for the response from remote servers.
+                ``config.federation.client_timeout_ms`` by default.
+
+            always_include_fallback_keys: True to always include fallback keys, even
+                for devices which still have one-time keys.
+        """
         local_query: List[Tuple[str, str, str, int]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
 
@@ -707,9 +732,18 @@ class E2eKeysHandler:
                 remote_result = await self.federation.claim_client_keys(
                     user, destination, device_keys, timeout=timeout
                 )
-                for user_id, keys in remote_result["one_time_keys"].items():
-                    if user_id in device_keys:
-                        json_result[user_id] = keys
+                try:
+                    destination_result = filter_remote_claimed_keys(
+                        device_keys,
+                        remote_result,
+                        self.config.experimental.msc4072_empty_dict_for_exhausted_devices,
+                    )
+                except Exception as e:
+                    logger.warning(
+                        f"Error parsing /keys/claim response from server {destination}",
+                        e,
+                    )
+                    raise
 
             except Exception as e:
                 failure = _exception_to_failure(e)
@@ -717,6 +751,11 @@ class E2eKeysHandler:
                 set_tag("error", True)
                 set_tag("reason", str(failure))
 
+            else:
+                # only populate json_result once we know there will not be an entry in
+                # failures for this destination.
+                json_result.update(destination_result)
+
         await make_deferred_yieldable(
             defer.gatherResults(
                 [
@@ -1632,3 +1671,51 @@ class SigningKeyEduUpdater:
                 device_ids = device_ids + new_device_ids
 
             await self._device_handler.notify_device_update(user_id, device_ids)
+
+
+def filter_remote_claimed_keys(
+    destination_query: Dict[str, Dict[str, Dict[str, int]]],
+    remote_response: JsonDict,
+    msc4072_empty_dict_for_exhausted_devices: bool,
+) -> JsonDict:
+    """
+    Process the response from a federation /keys/claim request
+
+    Checks that there are no redundant entries, and that all the entries that
+    should be there are present.
+
+    Args:
+        destination_query: user->device->key map that was sent in the request to
+           this server
+        remote_response: response from the remote server
+        msc4072_empty_dict_for_exhausted_devices: true to include an entry in the
+           result for every queried device
+
+    Returns:
+        user->device->key map to be merged into the results
+    """
+    remote_otks = remote_response["one_time_keys"]
+
+    destination_result: JsonDict = {}
+
+    if msc4072_empty_dict_for_exhausted_devices:
+        # We need to make sure there is an entry in destination_result for
+        # every queried (user, device) even if the remote server did not
+        # populate it; so we iterate the query and populate
+        # destination_result based on the federation result.
+        for user_id, user_query in destination_query.items():
+            remote_user_result = remote_otks.get(user_id, {})
+            destination_user_result = destination_result[user_id] = {}
+            for device_id in user_query.keys():
+                destination_user_result[device_id] = remote_user_result.get(
+                    device_id, {}
+                )
+    else:
+        # We need to make sure that remote servers do not poison the
+        # result with data for users which do not belong to it, so we only
+        # copy data for users that were queried.
+        for user_id, keys in remote_otks.items():
+            if user_id in destination_query:
+                destination_result[user_id] = keys
+
+    return destination_result