summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/client/test_devices.py77
1 files changed, 74 insertions, 3 deletions
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index 3cf29c10ea..60099f8c59 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError
 from synapse.rest import admin, devices, room, sync
 from synapse.rest.client import account, keys, login, register
 from synapse.server import HomeServer
-from synapse.types import JsonDict, create_requester
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -282,6 +282,17 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
                     "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
                 },
             },
+            "fallback_keys": {
+                "alg1:device1": "f4llb4ckk3y",
+                "signed_<algorithm>:<device_id>": {
+                    "fallback": "true",
+                    "key": "f4llb4ckk3y",
+                    "signatures": {
+                        "<user_id>": {"<algorithm>:<device_id>": "<key_base64>"}
+                    },
+                },
+            },
+            "one_time_keys": {"alg1:k1": "0net1m3k3y"},
         }
         channel = self.make_request(
             "PUT",
@@ -312,6 +323,55 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         }
         self.assertEqual(device_data, expected_device_data)
 
+        # test that the keys are correctly uploaded
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    user: ["device1"],
+                },
+            },
+            token,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body["device_keys"][user][device_id]["keys"],
+            content["device_keys"]["keys"],
+        )
+        # first claim should return the onetime key we uploaded
+        res = self.get_success(
+            self.hs.get_e2e_keys_handler().claim_one_time_keys(
+                {user: {device_id: {"alg1": 1}}},
+                UserID.from_string(user),
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        self.assertEqual(
+            res,
+            {
+                "failures": {},
+                "one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}},
+            },
+        )
+        # second claim should return fallback key
+        res2 = self.get_success(
+            self.hs.get_e2e_keys_handler().claim_one_time_keys(
+                {user: {device_id: {"alg1": 1}}},
+                UserID.from_string(user),
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        self.assertEqual(
+            res2,
+            {
+                "failures": {},
+                "one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}},
+            },
+        )
+
         # create another device for the user
         (
             new_device_id,
@@ -348,10 +408,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         expected_content = {"body": "test_message"}
         self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+
+        # fetch messages again and make sure that the message was not deleted
+        channel = self.make_request(
+            "POST",
+            f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+            content={},
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
         next_batch_token = channel.json_body.get("next_batch")
 
-        # fetch messages again and make sure that the message was deleted and we are returned an
-        # empty array
+        # make sure fetching messages with next batch token works - there are no unfetched
+        # messages so we should receive an empty array
         content = {"next_batch": next_batch_token}
         channel = self.make_request(
             "POST",