summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2023-08-08 12:04:46 -0700
committerGitHub <noreply@github.com>2023-08-08 12:04:46 -0700
commit0328b56468fe12c4d86ef636b60964527a510160 (patch)
tree2cff86a1e4518c90f234d4010419c049082a147f /tests
parentFixup changelog (diff)
downloadsynapse-0328b56468fe12c4d86ef636b60964527a510160.tar.xz
Support MSC3814: Dehydrated Devices Part 2 (#16010)
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_device.py9
-rw-r--r--tests/rest/client/test_devices.py77
2 files changed, 79 insertions, 7 deletions
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 647ee09279..e1e58fa6e6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -566,15 +566,16 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(res["events"]), 1)
         self.assertEqual(res["events"][0]["content"]["body"], "foo")
 
-        # Fetch the message of the dehydrated device again, which should return nothing
-        # and delete the old messages
+        # Fetch the message of the dehydrated device again, which should return
+        # the same message as it has not been deleted
         res = self.get_success(
             self.message_handler.get_events_for_dehydrated_device(
                 requester=requester,
                 device_id=stored_dehydrated_device_id,
-                since_token=res["next_batch"],
+                since_token=None,
                 limit=10,
             )
         )
         self.assertTrue(len(res["next_batch"]) > 1)
-        self.assertEqual(len(res["events"]), 0)
+        self.assertEqual(len(res["events"]), 1)
+        self.assertEqual(res["events"][0]["content"]["body"], "foo")
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index 3cf29c10ea..60099f8c59 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError
 from synapse.rest import admin, devices, room, sync
 from synapse.rest.client import account, keys, login, register
 from synapse.server import HomeServer
-from synapse.types import JsonDict, create_requester
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -282,6 +282,17 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
                     "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
                 },
             },
+            "fallback_keys": {
+                "alg1:device1": "f4llb4ckk3y",
+                "signed_<algorithm>:<device_id>": {
+                    "fallback": "true",
+                    "key": "f4llb4ckk3y",
+                    "signatures": {
+                        "<user_id>": {"<algorithm>:<device_id>": "<key_base64>"}
+                    },
+                },
+            },
+            "one_time_keys": {"alg1:k1": "0net1m3k3y"},
         }
         channel = self.make_request(
             "PUT",
@@ -312,6 +323,55 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         }
         self.assertEqual(device_data, expected_device_data)
 
+        # test that the keys are correctly uploaded
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    user: ["device1"],
+                },
+            },
+            token,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body["device_keys"][user][device_id]["keys"],
+            content["device_keys"]["keys"],
+        )
+        # first claim should return the onetime key we uploaded
+        res = self.get_success(
+            self.hs.get_e2e_keys_handler().claim_one_time_keys(
+                {user: {device_id: {"alg1": 1}}},
+                UserID.from_string(user),
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        self.assertEqual(
+            res,
+            {
+                "failures": {},
+                "one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}},
+            },
+        )
+        # second claim should return fallback key
+        res2 = self.get_success(
+            self.hs.get_e2e_keys_handler().claim_one_time_keys(
+                {user: {device_id: {"alg1": 1}}},
+                UserID.from_string(user),
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        self.assertEqual(
+            res2,
+            {
+                "failures": {},
+                "one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}},
+            },
+        )
+
         # create another device for the user
         (
             new_device_id,
@@ -348,10 +408,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         expected_content = {"body": "test_message"}
         self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+
+        # fetch messages again and make sure that the message was not deleted
+        channel = self.make_request(
+            "POST",
+            f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+            content={},
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
         next_batch_token = channel.json_body.get("next_batch")
 
-        # fetch messages again and make sure that the message was deleted and we are returned an
-        # empty array
+        # make sure fetching messages with next batch token works - there are no unfetched
+        # messages so we should receive an empty array
         content = {"next_batch": next_batch_token}
         channel = self.make_request(
             "POST",