summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/handlers/test_device.py99
-rw-r--r--tests/rest/client/test_devices.py150
2 files changed, 246 insertions, 3 deletions
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 66215af2b8..647ee09279 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -17,15 +17,18 @@
 from typing import Optional
 from unittest import mock
 
+from twisted.internet.defer import ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
+from synapse.rest import admin
+from synapse.rest.client import devices, login, register
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -399,11 +402,19 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        register.register_servlets,
+        devices.register_servlets,
+    ]
+
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver("server")
         handler = hs.get_device_handler()
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
+        self.message_handler = hs.get_device_message_handler()
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
@@ -418,6 +429,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         stored_dehydrated_device_id = self.get_success(
             self.handler.store_dehydrated_device(
                 user_id=user_id,
+                device_id=None,
                 device_data={"device_data": {"foo": "bar"}},
                 initial_device_display_name="dehydrated device",
             )
@@ -481,3 +493,88 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
 
         self.assertIsNone(ret)
+
+    @unittest.override_config(
+        {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+    )
+    def test_dehydrate_v2_and_fetch_events(self) -> None:
+        user_id = "@boris:server"
+
+        self.get_success(self.store.register_user(user_id, "foobar"))
+
+        # First check if we can store and fetch a dehydrated device
+        stored_dehydrated_device_id = self.get_success(
+            self.handler.store_dehydrated_device(
+                user_id=user_id,
+                device_id=None,
+                device_data={"device_data": {"foo": "bar"}},
+                initial_device_display_name="dehydrated device",
+            )
+        )
+
+        device_info = self.get_success(
+            self.handler.get_dehydrated_device(user_id=user_id)
+        )
+        assert device_info is not None
+        retrieved_device_id, device_data = device_info
+        self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+        self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+        # Create a new login for the user
+        device_id, access_token, _expiration_time, _refresh_token = self.get_success(
+            self.registration.register_device(
+                user_id=user_id,
+                device_id=None,
+                initial_display_name="new device",
+            )
+        )
+
+        requester = create_requester(user_id, device_id=device_id)
+
+        # Fetching messages for a non-existing device should return an error
+        self.get_failure(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id="not the right device ID",
+                since_token=None,
+                limit=10,
+            ),
+            SynapseError,
+        )
+
+        # Send a message to the dehydrated device
+        ensureDeferred(
+            self.message_handler.send_device_message(
+                requester=requester,
+                message_type="test.message",
+                messages={user_id: {stored_dehydrated_device_id: {"body": "foo"}}},
+            )
+        )
+        self.pump()
+
+        # Fetch the message of the dehydrated device
+        res = self.get_success(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id=stored_dehydrated_device_id,
+                since_token=None,
+                limit=10,
+            )
+        )
+
+        self.assertTrue(len(res["next_batch"]) > 1)
+        self.assertEqual(len(res["events"]), 1)
+        self.assertEqual(res["events"][0]["content"]["body"], "foo")
+
+        # Fetch the message of the dehydrated device again, which should return nothing
+        # and delete the old messages
+        res = self.get_success(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id=stored_dehydrated_device_id,
+                since_token=res["next_batch"],
+                limit=10,
+            )
+        )
+        self.assertTrue(len(res["next_batch"]) > 1)
+        self.assertEqual(len(res["events"]), 0)
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index d80eea17d3..b7d420cfec 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -13,12 +13,14 @@
 # limitations under the License.
 from http import HTTPStatus
 
+from twisted.internet.defer import ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import NotFoundError
 from synapse.rest import admin, devices, room, sync
-from synapse.rest.client import account, login, register
+from synapse.rest.client import account, keys, login, register
 from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -208,8 +210,13 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
         register.register_servlets,
         devices.register_servlets,
+        keys.register_servlets,
     ]
 
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.registration = hs.get_registration_handler()
+        self.message_handler = hs.get_device_message_handler()
+
     def test_PUT(self) -> None:
         """Sanity-check that we can PUT a dehydrated device.
 
@@ -226,7 +233,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
                 "device_data": {
                     "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
                     "account": "dehydrated_device",
-                }
+                },
+                "device_keys": {
+                    "user_id": "@alice:test",
+                    "device_id": "device1",
+                    "valid_until_ts": "80",
+                    "algorithms": [
+                        "m.olm.curve25519-aes-sha2",
+                    ],
+                    "keys": {
+                        "<algorithm>:<device_id>": "<key_base64>",
+                    },
+                    "signatures": {
+                        "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
+                    },
+                },
             },
             access_token=token,
             shorthand=False,
@@ -234,3 +255,128 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
         device_id = channel.json_body.get("device_id")
         self.assertIsInstance(device_id, str)
+
+    @unittest.override_config(
+        {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+    )
+    def test_dehydrate_msc3814(self) -> None:
+        user = self.register_user("mikey", "pass")
+        token = self.login(user, "pass", device_id="device1")
+        content: JsonDict = {
+            "device_data": {
+                "algorithm": "m.dehydration.v1.olm",
+            },
+            "device_id": "device1",
+            "initial_device_display_name": "foo bar",
+            "device_keys": {
+                "user_id": "@mikey:test",
+                "device_id": "device1",
+                "valid_until_ts": "80",
+                "algorithms": [
+                    "m.olm.curve25519-aes-sha2",
+                ],
+                "keys": {
+                    "<algorithm>:<device_id>": "<key_base64>",
+                },
+                "signatures": {
+                    "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
+                },
+            },
+        }
+        channel = self.make_request(
+            "PUT",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            content=content,
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        device_id = channel.json_body.get("device_id")
+        assert device_id is not None
+        self.assertIsInstance(device_id, str)
+        self.assertEqual("device1", device_id)
+
+        # test that we can now GET the dehydrated device info
+        channel = self.make_request(
+            "GET",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        returned_device_id = channel.json_body.get("device_id")
+        self.assertEqual(returned_device_id, device_id)
+        device_data = channel.json_body.get("device_data")
+        expected_device_data = {
+            "algorithm": "m.dehydration.v1.olm",
+        }
+        self.assertEqual(device_data, expected_device_data)
+
+        # create another device for the user
+        (
+            new_device_id,
+            _,
+            _,
+            _,
+        ) = self.get_success(
+            self.registration.register_device(
+                user_id=user,
+                device_id=None,
+                initial_display_name="new device",
+            )
+        )
+        requester = create_requester(user, device_id=new_device_id)
+
+        # Send a message to the dehydrated device
+        ensureDeferred(
+            self.message_handler.send_device_message(
+                requester=requester,
+                message_type="test.message",
+                messages={user: {device_id: {"body": "test_message"}}},
+            )
+        )
+        self.pump()
+
+        # make sure we can fetch the message with our dehydrated device id
+        channel = self.make_request(
+            "POST",
+            f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+            content={},
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        expected_content = {"body": "test_message"}
+        self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+        next_batch_token = channel.json_body.get("next_batch")
+
+        # fetch messages again and make sure that the message was deleted and we are returned an
+        # empty array
+        content = {"next_batch": next_batch_token}
+        channel = self.make_request(
+            "POST",
+            f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+            content=content,
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["events"], [])
+
+        # make sure we can delete the dehydrated device
+        channel = self.make_request(
+            "DELETE",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # ...and after deleting it is no longer available
+        channel = self.make_request(
+            "GET",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 404)