summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11234.bugfix1
-rw-r--r--synapse/handlers/e2e_keys.py211
-rw-r--r--tests/handlers/test_e2e_keys.py151
3 files changed, 277 insertions, 86 deletions
diff --git a/changelog.d/11234.bugfix b/changelog.d/11234.bugfix
new file mode 100644
index 0000000000..c0c02a58f6
--- /dev/null
+++ b/changelog.d/11234.bugfix
@@ -0,0 +1 @@
+Fix long-standing bug where cross signing keys were not included in the response to `/r0/keys/query` the first time a remote user was queried.
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index d0fb2fc7dc..60c11e3d21 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -201,95 +201,19 @@ class E2eKeysHandler:
                     r[user_id] = remote_queries[user_id]
 
             # Now fetch any devices that we don't have in our cache
-            @trace
-            async def do_remote_query(destination: str) -> None:
-                """This is called when we are querying the device list of a user on
-                a remote homeserver and their device list is not in the device list
-                cache. If we share a room with this user and we're not querying for
-                specific user we will update the cache with their device list.
-                """
-
-                destination_query = remote_queries_not_in_cache[destination]
-
-                # We first consider whether we wish to update the device list cache with
-                # the users device list. We want to track a user's devices when the
-                # authenticated user shares a room with the queried user and the query
-                # has not specified a particular device.
-                # If we update the cache for the queried user we remove them from further
-                # queries. We use the more efficient batched query_client_keys for all
-                # remaining users
-                user_ids_updated = []
-                for (user_id, device_list) in destination_query.items():
-                    if user_id in user_ids_updated:
-                        continue
-
-                    if device_list:
-                        continue
-
-                    room_ids = await self.store.get_rooms_for_user(user_id)
-                    if not room_ids:
-                        continue
-
-                    # We've decided we're sharing a room with this user and should
-                    # probably be tracking their device lists. However, we haven't
-                    # done an initial sync on the device list so we do it now.
-                    try:
-                        if self._is_master:
-                            user_devices = await self.device_handler.device_list_updater.user_device_resync(
-                                user_id
-                            )
-                        else:
-                            user_devices = await self._user_device_resync_client(
-                                user_id=user_id
-                            )
-
-                        user_devices = user_devices["devices"]
-                        user_results = results.setdefault(user_id, {})
-                        for device in user_devices:
-                            user_results[device["device_id"]] = device["keys"]
-                        user_ids_updated.append(user_id)
-                    except Exception as e:
-                        failures[destination] = _exception_to_failure(e)
-
-                if len(destination_query) == len(user_ids_updated):
-                    # We've updated all the users in the query and we do not need to
-                    # make any further remote calls.
-                    return
-
-                # Remove all the users from the query which we have updated
-                for user_id in user_ids_updated:
-                    destination_query.pop(user_id)
-
-                try:
-                    remote_result = await self.federation.query_client_keys(
-                        destination, {"device_keys": destination_query}, timeout=timeout
-                    )
-
-                    for user_id, keys in remote_result["device_keys"].items():
-                        if user_id in destination_query:
-                            results[user_id] = keys
-
-                    if "master_keys" in remote_result:
-                        for user_id, key in remote_result["master_keys"].items():
-                            if user_id in destination_query:
-                                cross_signing_keys["master_keys"][user_id] = key
-
-                    if "self_signing_keys" in remote_result:
-                        for user_id, key in remote_result["self_signing_keys"].items():
-                            if user_id in destination_query:
-                                cross_signing_keys["self_signing_keys"][user_id] = key
-
-                except Exception as e:
-                    failure = _exception_to_failure(e)
-                    failures[destination] = failure
-                    set_tag("error", True)
-                    set_tag("reason", failure)
-
             await make_deferred_yieldable(
                 defer.gatherResults(
                     [
-                        run_in_background(do_remote_query, destination)
-                        for destination in remote_queries_not_in_cache
+                        run_in_background(
+                            self._query_devices_for_destination,
+                            results,
+                            cross_signing_keys,
+                            failures,
+                            destination,
+                            queries,
+                            timeout,
+                        )
+                        for destination, queries in remote_queries_not_in_cache.items()
                     ],
                     consumeErrors=True,
                 ).addErrback(unwrapFirstError)
@@ -301,6 +225,121 @@ class E2eKeysHandler:
 
             return ret
 
+    @trace
+    async def _query_devices_for_destination(
+        self,
+        results: JsonDict,
+        cross_signing_keys: JsonDict,
+        failures: Dict[str, JsonDict],
+        destination: str,
+        destination_query: Dict[str, Iterable[str]],
+        timeout: int,
+    ) -> None:
+        """This is called when we are querying the device list of a user on
+        a remote homeserver and their device list is not in the device list
+        cache. If we share a room with this user and we're not querying for
+        specific user we will update the cache with their device list.
+
+        Args:
+            results: A map from user ID to their device keys, which gets
+                updated with the newly fetched keys.
+            cross_signing_keys: Map from user ID to their cross signing keys,
+                which gets updated with the newly fetched keys.
+            failures: Map of destinations to failures that have occurred while
+                attempting to fetch keys.
+            destination: The remote server to query
+            destination_query: The query dict of devices to query the remote
+                server for.
+            timeout: The timeout for remote HTTP requests.
+        """
+
+        # We first consider whether we wish to update the device list cache with
+        # the users device list. We want to track a user's devices when the
+        # authenticated user shares a room with the queried user and the query
+        # has not specified a particular device.
+        # If we update the cache for the queried user we remove them from further
+        # queries. We use the more efficient batched query_client_keys for all
+        # remaining users
+        user_ids_updated = []
+        for (user_id, device_list) in destination_query.items():
+            if user_id in user_ids_updated:
+                continue
+
+            if device_list:
+                continue
+
+            room_ids = await self.store.get_rooms_for_user(user_id)
+            if not room_ids:
+                continue
+
+            # We've decided we're sharing a room with this user and should
+            # probably be tracking their device lists. However, we haven't
+            # done an initial sync on the device list so we do it now.
+            try:
+                if self._is_master:
+                    resync_results = await self.device_handler.device_list_updater.user_device_resync(
+                        user_id
+                    )
+                else:
+                    resync_results = await self._user_device_resync_client(
+                        user_id=user_id
+                    )
+
+                # Add the device keys to the results.
+                user_devices = resync_results["devices"]
+                user_results = results.setdefault(user_id, {})
+                for device in user_devices:
+                    user_results[device["device_id"]] = device["keys"]
+                user_ids_updated.append(user_id)
+
+                # Add any cross signing keys to the results.
+                master_key = resync_results.get("master_key")
+                self_signing_key = resync_results.get("self_signing_key")
+
+                if master_key:
+                    cross_signing_keys["master_keys"][user_id] = master_key
+
+                if self_signing_key:
+                    cross_signing_keys["self_signing_keys"][user_id] = self_signing_key
+            except Exception as e:
+                failures[destination] = _exception_to_failure(e)
+
+        if len(destination_query) == len(user_ids_updated):
+            # We've updated all the users in the query and we do not need to
+            # make any further remote calls.
+            return
+
+        # Remove all the users from the query which we have updated
+        for user_id in user_ids_updated:
+            destination_query.pop(user_id)
+
+        try:
+            remote_result = await self.federation.query_client_keys(
+                destination, {"device_keys": destination_query}, timeout=timeout
+            )
+
+            for user_id, keys in remote_result["device_keys"].items():
+                if user_id in destination_query:
+                    results[user_id] = keys
+
+            if "master_keys" in remote_result:
+                for user_id, key in remote_result["master_keys"].items():
+                    if user_id in destination_query:
+                        cross_signing_keys["master_keys"][user_id] = key
+
+            if "self_signing_keys" in remote_result:
+                for user_id, key in remote_result["self_signing_keys"].items():
+                    if user_id in destination_query:
+                        cross_signing_keys["self_signing_keys"][user_id] = key
+
+        except Exception as e:
+            failure = _exception_to_failure(e)
+            failures[destination] = failure
+            set_tag("error", True)
+            set_tag("reason", failure)
+
+        return
+
     async def get_cross_signing_keys_from_cache(
         self, query: Iterable[str], from_user_id: Optional[str]
     ) -> Dict[str, Dict[str, dict]]:
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 39e7b1ab25..0c3b86fda9 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -17,6 +17,8 @@ from unittest import mock
 
 from signedjson import key as key, sign as sign
 
+from twisted.internet import defer
+
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import Codes, SynapseError
 
@@ -630,3 +632,152 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase):
             ],
             other_master_key["signatures"][local_user]["ed25519:" + usersigning_pubkey],
         )
+
+    def test_query_devices_remote_no_sync(self):
+        """Tests that querying keys for a remote user that we don't share a room
+        with returns the cross signing keys correctly.
+        """
+
+        remote_user_id = "@test:other"
+        local_user_id = "@test:test"
+
+        remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+        remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+        self.hs.get_federation_client().query_client_keys = mock.Mock(
+            return_value=defer.succeed(
+                {
+                    "device_keys": {remote_user_id: {}},
+                    "master_keys": {
+                        remote_user_id: {
+                            "user_id": remote_user_id,
+                            "usage": ["master"],
+                            "keys": {"ed25519:" + remote_master_key: remote_master_key},
+                        },
+                    },
+                    "self_signing_keys": {
+                        remote_user_id: {
+                            "user_id": remote_user_id,
+                            "usage": ["self_signing"],
+                            "keys": {
+                                "ed25519:"
+                                + remote_self_signing_key: remote_self_signing_key
+                            },
+                        }
+                    },
+                }
+            )
+        )
+
+        e2e_handler = self.hs.get_e2e_keys_handler()
+
+        query_result = self.get_success(
+            e2e_handler.query_devices(
+                {
+                    "device_keys": {remote_user_id: []},
+                },
+                timeout=10,
+                from_user_id=local_user_id,
+                from_device_id="some_device_id",
+            )
+        )
+
+        self.assertEqual(query_result["failures"], {})
+        self.assertEqual(
+            query_result["master_keys"],
+            {
+                remote_user_id: {
+                    "user_id": remote_user_id,
+                    "usage": ["master"],
+                    "keys": {"ed25519:" + remote_master_key: remote_master_key},
+                },
+            },
+        )
+        self.assertEqual(
+            query_result["self_signing_keys"],
+            {
+                remote_user_id: {
+                    "user_id": remote_user_id,
+                    "usage": ["self_signing"],
+                    "keys": {
+                        "ed25519:" + remote_self_signing_key: remote_self_signing_key
+                    },
+                }
+            },
+        )
+
+    def test_query_devices_remote_sync(self):
+        """Tests that querying keys for a remote user that we share a room with,
+        but haven't yet fetched the keys for, returns the cross signing keys
+        correctly.
+        """
+
+        remote_user_id = "@test:other"
+        local_user_id = "@test:test"
+
+        self.store.get_rooms_for_user = mock.Mock(
+            return_value=defer.succeed({"some_room_id"})
+        )
+
+        remote_master_key = "85T7JXPFBAySB/jwby4S3lBPTqY3+Zg53nYuGmu1ggY"
+        remote_self_signing_key = "QeIiFEjluPBtI7WQdG365QKZcFs9kqmHir6RBD0//nQ"
+
+        self.hs.get_federation_client().query_user_devices = mock.Mock(
+            return_value=defer.succeed(
+                {
+                    "user_id": remote_user_id,
+                    "stream_id": 1,
+                    "devices": [],
+                    "master_key": {
+                        "user_id": remote_user_id,
+                        "usage": ["master"],
+                        "keys": {"ed25519:" + remote_master_key: remote_master_key},
+                    },
+                    "self_signing_key": {
+                        "user_id": remote_user_id,
+                        "usage": ["self_signing"],
+                        "keys": {
+                            "ed25519:"
+                            + remote_self_signing_key: remote_self_signing_key
+                        },
+                    },
+                }
+            )
+        )
+
+        e2e_handler = self.hs.get_e2e_keys_handler()
+
+        query_result = self.get_success(
+            e2e_handler.query_devices(
+                {
+                    "device_keys": {remote_user_id: []},
+                },
+                timeout=10,
+                from_user_id=local_user_id,
+                from_device_id="some_device_id",
+            )
+        )
+
+        self.assertEqual(query_result["failures"], {})
+        self.assertEqual(
+            query_result["master_keys"],
+            {
+                remote_user_id: {
+                    "user_id": remote_user_id,
+                    "usage": ["master"],
+                    "keys": {"ed25519:" + remote_master_key: remote_master_key},
+                }
+            },
+        )
+        self.assertEqual(
+            query_result["self_signing_keys"],
+            {
+                remote_user_id: {
+                    "user_id": remote_user_id,
+                    "usage": ["self_signing"],
+                    "keys": {
+                        "ed25519:" + remote_self_signing_key: remote_self_signing_key
+                    },
+                }
+            },
+        )