summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2023-10-30 12:03:36 +0000
committerRichard van der Hoff <richard@matrix.org>2023-10-30 12:03:36 +0000
commit6dbad839980deb2aee07c1df3d19b80a0e178512 (patch)
treecdf79f75ec422108d2ffc92da86d11115e220e94
parentImprove tracing for `claim_one_time_keys` (diff)
downloadsynapse-6dbad839980deb2aee07c1df3d19b80a0e178512.tar.xz
Implement MSC4072: return result for all /keys/claims
-rw-r--r--synapse/config/experimental.py7
-rw-r--r--synapse/handlers/e2e_keys.py93
-rw-r--r--tests/handlers/test_e2e_keys.py70
3 files changed, 154 insertions, 16 deletions
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 9f830e7094..568efd641e 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -15,7 +15,6 @@
 import enum
 from typing import TYPE_CHECKING, Any, Optional
 
-import attr
 import attr.validators
 
 from synapse.api.errors import LimitExceededError
@@ -419,3 +418,9 @@ class ExperimentalConfig(Config):
         self.msc4028_push_encrypted_events = experimental.get(
             "msc4028_push_encrypted_events", False
         )
+
+        # MSC4072: Return an empty dict from /keys/claim for unknown devices or those
+        # with exhausted OTKs
+        self.msc4072_empty_dict_for_exhausted_devices = experimental.get(
+            "msc4072_empty_dict_for_exhausted_devices", False
+        )
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 7d44127ebf..bb628fdb00 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -592,6 +592,12 @@ class E2eKeysHandler:
             for user_id, device_id, algorithm, count in local_query
         ]
 
+        # prepopulate the response to make sure that all queried users/devices are
+        # included, even if the user/device is unknown or has run out of OTKs
+        if self.config.experimental.msc4072_empty_dict_for_exhausted_devices:
+            for user_id, device_id, _, _ in local_query:
+                result_dict.setdefault(user_id, {}).setdefault(device_id, {})
+
         otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
         update_result_dict(otk_results)
 
@@ -669,6 +675,25 @@ class E2eKeysHandler:
         timeout: Optional[int],
         always_include_fallback_keys: bool,
     ) -> JsonDict:
+        """
+        Handle a /keys/claim request.
+
+        Handles requests for local users with a db lookup, and makes federation
+        requests for remote users.
+
+        Args:
+            query: map from user ID, to map from device ID, to map from algorithm name
+                to number of keys needed
+                (``{user_id: {device_id: {algorithm: number_of keys}}}``)
+
+            user: The user id of the requesting user
+
+            timeout: number of milliseconds to wait for the response from remote servers.
+                ``config.federation.client_timeout_ms`` by default.
+
+            always_include_fallback_keys: True to always include fallback keys, even
+                for devices which still have one-time keys.
+        """
         local_query: List[Tuple[str, str, str, int]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, Dict[str, int]]]] = {}
 
@@ -707,9 +732,18 @@ class E2eKeysHandler:
                 remote_result = await self.federation.claim_client_keys(
                     user, destination, device_keys, timeout=timeout
                 )
-                for user_id, keys in remote_result["one_time_keys"].items():
-                    if user_id in device_keys:
-                        json_result[user_id] = keys
+                try:
+                    destination_result = filter_remote_claimed_keys(
+                        device_keys,
+                        remote_result,
+                        self.config.experimental.msc4072_empty_dict_for_exhausted_devices,
+                    )
+                except Exception as e:
+                    logger.warning(
+                        f"Error parsing /keys/claim response from server {destination}",
+                        e,
+                    )
+                    raise
 
             except Exception as e:
                 failure = _exception_to_failure(e)
@@ -717,6 +751,11 @@ class E2eKeysHandler:
                 set_tag("error", True)
                 set_tag("reason", str(failure))
 
+            else:
+                # only populate json_result once we know there will not be an entry in
+                # failures for this destination.
+                json_result.update(destination_result)
+
         await make_deferred_yieldable(
             defer.gatherResults(
                 [
@@ -1632,3 +1671,51 @@ class SigningKeyEduUpdater:
                 device_ids = device_ids + new_device_ids
 
             await self._device_handler.notify_device_update(user_id, device_ids)
+
+
+def filter_remote_claimed_keys(
+    destination_query: Dict[str, Dict[str, Dict[str, int]]],
+    remote_response: JsonDict,
+    msc4072_empty_dict_for_exhausted_devices: bool,
+) -> JsonDict:
+    """
+    Process the response from a federation /keys/claim request
+
+    Checks that there are no redundant entries, and that all the entries that
+    should be there are present.
+
+    Args:
+        destination_query: user->device->key map that was sent in the request to
+           this server
+        remote_response: response from the remote server
+        msc4072_empty_dict_for_exhausted_devices: true to include an entry in the
+           result for every queried device
+
+    Returns:
+        user->device->key map to be merged into the results
+    """
+    remote_otks = remote_response["one_time_keys"]
+
+    destination_result: JsonDict = {}
+
+    if msc4072_empty_dict_for_exhausted_devices:
+        # We need to make sure there is an entry in destination_result for
+        # every queried (user, device) even if the remote server did not
+        # populate it; so we iterate the query and populate
+        # destination_result based on the federation result.
+        for user_id, user_query in destination_query.items():
+            remote_user_result = remote_otks.get(user_id, {})
+            destination_user_result = destination_result[user_id] = {}
+            for device_id in user_query.keys():
+                destination_user_result[device_id] = remote_user_result.get(
+                    device_id, {}
+                )
+    else:
+        # We need to make sure that remote servers do not poison the
+        # result with data for users which do not belong to it, so we only
+        # copy data for users that were queried.
+        for user_id, keys in remote_otks.items():
+            if user_id in destination_query:
+                destination_result[user_id] = keys
+
+    return destination_result
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index c5556f2844..8a105c5712 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -144,35 +144,81 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             SynapseError,
         )
 
-    def test_claim_one_time_key(self) -> None:
-        local_user = "@boris:" + self.hs.hostname
+    @parameterized.expand([(True,), (False,)])
+    def test_claim_one_time_key(self, msc4072: bool) -> None:
+        self.hs.config.experimental.msc4072_empty_dict_for_exhausted_devices = msc4072
+
+        local_known_user = "@boris:" + self.hs.hostname
         device_id = "xyz"
-        keys = {"alg1:k1": "key1"}
+        local_unknown_user = "@charlie:" + self.hs.hostname
 
+        remote_known_user = "@dave:xyz"
+        remote_unknown_user = "@errol:xyz"
+
+        # upload a key for the local user
         res = self.get_success(
             self.handler.upload_keys_for_user(
-                local_user, device_id, {"one_time_keys": keys}
+                local_known_user, device_id, {"one_time_keys": {"alg1:k1": "key1"}}
             )
         )
         self.assertDictEqual(
             res, {"one_time_key_counts": {"alg1": 1, "signed_curve25519": 0}}
         )
 
+        # mock out the response for remote users. We pretend that the remote server
+        # hasn't heard of MSC4072 and returns an incomplete result. (Even once
+        # MSC4072 is stable, we still need to handle incomplete results.)
+        #
+        # we also include a spurious result to check it gets filtered out.
+        self.hs.get_federation_client().claim_client_keys = mock.AsyncMock(  # type: ignore[method-assign]
+            return_value={
+                "one_time_keys": {
+                    remote_known_user: {"ghi": {"alg1": "keykey"}},
+                    "@other:xyz": {"zzz": {"alg1": "dodgykey"}},
+                }
+            }
+        )
+
         res2 = self.get_success(
             self.handler.claim_one_time_keys(
-                {local_user: {device_id: {"alg1": 1}}},
+                {
+                    local_known_user: {device_id: {"alg1": 1}, "abc": {"alg2": 1}},
+                    local_unknown_user: {"def": {"alg1": 1}},
+                    remote_known_user: {"ghi": {"alg1": 1}, "jkl": {"alg1": 1}},
+                    remote_unknown_user: {"mno": {"alg1": 1}},
+                },
                 self.requester,
                 timeout=None,
                 always_include_fallback_keys=False,
             )
         )
-        self.assertEqual(
-            res2,
-            {
-                "failures": {},
-                "one_time_keys": {local_user: {device_id: {"alg1:k1": "key1"}}},
-            },
-        )
+
+        if msc4072:
+            # empty result for each unknown device
+            self.assertEqual(
+                res2,
+                {
+                    "failures": {},
+                    "one_time_keys": {
+                        local_known_user: {device_id: {"alg1:k1": "key1"}, "abc": {}},
+                        local_unknown_user: {"def": {}},
+                        remote_known_user: {"ghi": {"alg1": "keykey"}, "jkl": {}},
+                        remote_unknown_user: {"mno": {}},
+                    },
+                },
+            )
+        else:
+            # only known devices
+            self.assertEqual(
+                res2,
+                {
+                    "failures": {},
+                    "one_time_keys": {
+                        local_known_user: {device_id: {"alg1:k1": "key1"}},
+                        remote_known_user: {"ghi": {"alg1": "keykey"}},
+                    },
+                },
+            )
 
     def test_fallback_key(self) -> None:
         local_user = "@boris:" + self.hs.hostname