summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15539.misc1
-rw-r--r--synapse/handlers/device.py28
-rw-r--r--tests/handlers/test_device.py135
3 files changed, 163 insertions, 1 deletions
diff --git a/changelog.d/15539.misc b/changelog.d/15539.misc
new file mode 100644
index 0000000000..e5af5dee5c
--- /dev/null
+++ b/changelog.d/15539.misc
@@ -0,0 +1 @@
+Proxy `/user/devices` federation queries to application services for [MSC3984](https://github.com/matrix-org/matrix-spec-proposals/pull/3984).
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b9d3b7fbc6..5d12a39e26 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -75,10 +75,14 @@ class DeviceWorkerHandler:
         self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.state = hs.get_state_handler()
+        self._appservice_handler = hs.get_application_service_handler()
         self._state_storage = hs.get_storage_controllers().state
         self._auth_handler = hs.get_auth_handler()
         self.server_name = hs.hostname
         self._msc3852_enabled = hs.config.experimental.msc3852_enabled
+        self._query_appservices_for_keys = (
+            hs.config.experimental.msc3984_appservice_key_query
+        )
 
         self.device_list_updater = DeviceListWorkerUpdater(hs)
 
@@ -328,6 +332,30 @@ class DeviceWorkerHandler:
             user_id, "self_signing"
         )
 
+        # Check if the application services have any results.
+        if self._query_appservices_for_keys:
+            # Query the appservice for all devices for this user.
+            query: Dict[str, Optional[List[str]]] = {user_id: None}
+
+            # Query the appservices for any keys.
+            appservice_results = await self._appservice_handler.query_keys(query)
+
+            # Merge results, overriding anything from the database.
+            appservice_devices = appservice_results.get("device_keys", {}).get(
+                user_id, {}
+            )
+
+            # Filter the database results to only those devices that the appservice has
+            # *not* responded with.
+            devices = [d for d in devices if d["device_id"] not in appservice_devices]
+            # Append the appservice response by wrapping each result in another dictionary.
+            devices.extend(
+                {"device_id": device_id, "keys": device}
+                for device_id, device in appservice_devices.items()
+            )
+
+            # TODO Handle cross-signing keys.
+
         return {
             "user_id": user_id,
             "stream_id": stream_id,
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ce7525e29c..ee48f9e546 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -15,15 +15,22 @@
 # limitations under the License.
 
 from typing import Optional
+from unittest import mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
+from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import NotFoundError, SynapseError
+from synapse.appservice import ApplicationService
 from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
 from synapse.server import HomeServer
+from synapse.storage.databases.main.appservice import _make_exclusive_regex
+from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
+from tests.test_utils import make_awaitable
+from tests.unittest import override_config
 
 user1 = "@boris:aaa"
 user2 = "@theresa:bbb"
@@ -31,7 +38,12 @@ user2 = "@theresa:bbb"
 
 class DeviceTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        hs = self.setup_test_homeserver("server", federation_http_client=None)
+        self.appservice_api = mock.Mock()
+        hs = self.setup_test_homeserver(
+            "server",
+            federation_http_client=None,
+            application_service_api=self.appservice_api,
+        )
         handler = hs.get_device_handler()
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
@@ -265,6 +277,127 @@ class DeviceTestCase(unittest.HomeserverTestCase):
             )
             self.reactor.advance(1000)
 
+    @override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
+    def test_on_federation_query_user_devices_appservice(self) -> None:
+        """Test that querying of appservices for keys overrides responses from the database."""
+        local_user = "@boris:" + self.hs.hostname
+        device_1 = "abc"
+        device_2 = "def"
+        device_3 = "ghi"
+
+        # There are 3 devices:
+        #
+        # 1. One which is uploaded to the homeserver.
+        # 2. One which is uploaded to the homeserver, but a newer copy is returned
+        #     by the appservice.
+        # 3. One which is only returned by the appservice.
+        device_key_1: JsonDict = {
+            "user_id": local_user,
+            "device_id": device_1,
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+            ],
+            "keys": {
+                "ed25519:abc": "base64+ed25519+key",
+                "curve25519:abc": "base64+curve25519+key",
+            },
+            "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
+        }
+        device_key_2a: JsonDict = {
+            "user_id": local_user,
+            "device_id": device_2,
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+            ],
+            "keys": {
+                "ed25519:def": "base64+ed25519+key",
+                "curve25519:def": "base64+curve25519+key",
+            },
+            "signatures": {local_user: {"ed25519:def": "base64+signature"}},
+        }
+
+        device_key_2b: JsonDict = {
+            "user_id": local_user,
+            "device_id": device_2,
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+            ],
+            # The device ID is the same (above), but the keys are different.
+            "keys": {
+                "ed25519:xyz": "base64+ed25519+key",
+                "curve25519:xyz": "base64+curve25519+key",
+            },
+            "signatures": {local_user: {"ed25519:xyz": "base64+signature"}},
+        }
+        device_key_3: JsonDict = {
+            "user_id": local_user,
+            "device_id": device_3,
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+                RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+            ],
+            "keys": {
+                "ed25519:jkl": "base64+ed25519+key",
+                "curve25519:jkl": "base64+curve25519+key",
+            },
+            "signatures": {local_user: {"ed25519:jkl": "base64+signature"}},
+        }
+
+        # Upload keys for devices 1 & 2a.
+        e2e_keys_handler = self.hs.get_e2e_keys_handler()
+        self.get_success(
+            e2e_keys_handler.upload_keys_for_user(
+                local_user, device_1, {"device_keys": device_key_1}
+            )
+        )
+        self.get_success(
+            e2e_keys_handler.upload_keys_for_user(
+                local_user, device_2, {"device_keys": device_key_2a}
+            )
+        )
+
+        # Inject an appservice interested in this user.
+        appservice = ApplicationService(
+            token="i_am_an_app_service",
+            id="1234",
+            namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
+            # Note: this user does not have to match the regex above
+            sender="@as_main:test",
+        )
+        self.hs.get_datastores().main.services_cache = [appservice]
+        self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+            [appservice]
+        )
+
+        # Setup a response.
+        self.appservice_api.query_keys.return_value = make_awaitable(
+            {
+                "device_keys": {
+                    local_user: {device_2: device_key_2b, device_3: device_key_3}
+                }
+            }
+        )
+
+        # Request all devices.
+        res = self.get_success(
+            self.handler.on_federation_query_user_devices(local_user)
+        )
+        self.assertIn("devices", res)
+        res_devices = res["devices"]
+        for device in res_devices:
+            device["keys"].pop("unsigned", None)
+        self.assertEqual(
+            res_devices,
+            [
+                {"device_id": device_1, "keys": device_key_1},
+                {"device_id": device_2, "keys": device_key_2b},
+                {"device_id": device_3, "keys": device_key_3},
+            ],
+        )
+
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: