diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ce7525e29c..ee48f9e546 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -15,15 +15,22 @@
# limitations under the License.
from typing import Optional
+from unittest import mock
from twisted.test.proto_helpers import MemoryReactor
+from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import NotFoundError, SynapseError
+from synapse.appservice import ApplicationService
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
from synapse.server import HomeServer
+from synapse.storage.databases.main.appservice import _make_exclusive_regex
+from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
+from tests.test_utils import make_awaitable
+from tests.unittest import override_config
user1 = "@boris:aaa"
user2 = "@theresa:bbb"
@@ -31,7 +38,12 @@ user2 = "@theresa:bbb"
class DeviceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- hs = self.setup_test_homeserver("server", federation_http_client=None)
+ self.appservice_api = mock.Mock()
+ hs = self.setup_test_homeserver(
+ "server",
+ federation_http_client=None,
+ application_service_api=self.appservice_api,
+ )
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
@@ -265,6 +277,127 @@ class DeviceTestCase(unittest.HomeserverTestCase):
)
self.reactor.advance(1000)
+ @override_config({"experimental_features": {"msc3984_appservice_key_query": True}})
+ def test_on_federation_query_user_devices_appservice(self) -> None:
+ """Test that querying of appservices for keys overrides responses from the database."""
+ local_user = "@boris:" + self.hs.hostname
+ device_1 = "abc"
+ device_2 = "def"
+ device_3 = "ghi"
+
+ # There are 3 devices:
+ #
+ # 1. One which is uploaded to the homeserver.
+ # 2. One which is uploaded to the homeserver, but a newer copy is returned
+ # by the appservice.
+ # 3. One which is only returned by the appservice.
+ device_key_1: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_1,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:abc": "base64+ed25519+key",
+ "curve25519:abc": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:abc": "base64+signature"}},
+ }
+ device_key_2a: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_2,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:def": "base64+ed25519+key",
+ "curve25519:def": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:def": "base64+signature"}},
+ }
+
+ device_key_2b: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_2,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ # The device ID is the same (above), but the keys are different.
+ "keys": {
+ "ed25519:xyz": "base64+ed25519+key",
+ "curve25519:xyz": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:xyz": "base64+signature"}},
+ }
+ device_key_3: JsonDict = {
+ "user_id": local_user,
+ "device_id": device_3,
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2,
+ ],
+ "keys": {
+ "ed25519:jkl": "base64+ed25519+key",
+ "curve25519:jkl": "base64+curve25519+key",
+ },
+ "signatures": {local_user: {"ed25519:jkl": "base64+signature"}},
+ }
+
+ # Upload keys for devices 1 & 2a.
+ e2e_keys_handler = self.hs.get_e2e_keys_handler()
+ self.get_success(
+ e2e_keys_handler.upload_keys_for_user(
+ local_user, device_1, {"device_keys": device_key_1}
+ )
+ )
+ self.get_success(
+ e2e_keys_handler.upload_keys_for_user(
+ local_user, device_2, {"device_keys": device_key_2a}
+ )
+ )
+
+ # Inject an appservice interested in this user.
+ appservice = ApplicationService(
+ token="i_am_an_app_service",
+ id="1234",
+ namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]},
+ # Note: this user does not have to match the regex above
+ sender="@as_main:test",
+ )
+ self.hs.get_datastores().main.services_cache = [appservice]
+ self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex(
+ [appservice]
+ )
+
+ # Setup a response.
+ self.appservice_api.query_keys.return_value = make_awaitable(
+ {
+ "device_keys": {
+ local_user: {device_2: device_key_2b, device_3: device_key_3}
+ }
+ }
+ )
+
+ # Request all devices.
+ res = self.get_success(
+ self.handler.on_federation_query_user_devices(local_user)
+ )
+ self.assertIn("devices", res)
+ res_devices = res["devices"]
+ for device in res_devices:
+ device["keys"].pop("unsigned", None)
+ self.assertEqual(
+ res_devices,
+ [
+ {"device_id": device_1, "keys": device_key_1},
+ {"device_id": device_2, "keys": device_key_2b},
+ {"device_id": device_3, "keys": device_key_3},
+ ],
+ )
+
class DehydrationTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
|