summary refs log tree commit diff
path: root/tests/handlers/test_device.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_device.py')
-rw-r--r--tests/handlers/test_device.py102
1 files changed, 99 insertions, 3 deletions
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index ee48f9e546..647ee09279 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -17,15 +17,18 @@
 from typing import Optional
 from unittest import mock
 
+from twisted.internet.defer import ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
+from synapse.rest import admin
+from synapse.rest.client import devices, login, register
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -41,7 +44,6 @@ class DeviceTestCase(unittest.HomeserverTestCase):
         self.appservice_api = mock.Mock()
         hs = self.setup_test_homeserver(
             "server",
-            federation_http_client=None,
             application_service_api=self.appservice_api,
         )
         handler = hs.get_device_handler()
@@ -400,11 +402,19 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        register.register_servlets,
+        devices.register_servlets,
+    ]
+
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        hs = self.setup_test_homeserver("server", federation_http_client=None)
+        hs = self.setup_test_homeserver("server")
         handler = hs.get_device_handler()
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
+        self.message_handler = hs.get_device_message_handler()
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
@@ -419,6 +429,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         stored_dehydrated_device_id = self.get_success(
             self.handler.store_dehydrated_device(
                 user_id=user_id,
+                device_id=None,
                 device_data={"device_data": {"foo": "bar"}},
                 initial_device_display_name="dehydrated device",
             )
@@ -482,3 +493,88 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
 
         self.assertIsNone(ret)
+
+    @unittest.override_config(
+        {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+    )
+    def test_dehydrate_v2_and_fetch_events(self) -> None:
+        user_id = "@boris:server"
+
+        self.get_success(self.store.register_user(user_id, "foobar"))
+
+        # First check if we can store and fetch a dehydrated device
+        stored_dehydrated_device_id = self.get_success(
+            self.handler.store_dehydrated_device(
+                user_id=user_id,
+                device_id=None,
+                device_data={"device_data": {"foo": "bar"}},
+                initial_device_display_name="dehydrated device",
+            )
+        )
+
+        device_info = self.get_success(
+            self.handler.get_dehydrated_device(user_id=user_id)
+        )
+        assert device_info is not None
+        retrieved_device_id, device_data = device_info
+        self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+        self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+        # Create a new login for the user
+        device_id, access_token, _expiration_time, _refresh_token = self.get_success(
+            self.registration.register_device(
+                user_id=user_id,
+                device_id=None,
+                initial_display_name="new device",
+            )
+        )
+
+        requester = create_requester(user_id, device_id=device_id)
+
+        # Fetching messages for a non-existing device should return an error
+        self.get_failure(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id="not the right device ID",
+                since_token=None,
+                limit=10,
+            ),
+            SynapseError,
+        )
+
+        # Send a message to the dehydrated device
+        ensureDeferred(
+            self.message_handler.send_device_message(
+                requester=requester,
+                message_type="test.message",
+                messages={user_id: {stored_dehydrated_device_id: {"body": "foo"}}},
+            )
+        )
+        self.pump()
+
+        # Fetch the message of the dehydrated device
+        res = self.get_success(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id=stored_dehydrated_device_id,
+                since_token=None,
+                limit=10,
+            )
+        )
+
+        self.assertTrue(len(res["next_batch"]) > 1)
+        self.assertEqual(len(res["events"]), 1)
+        self.assertEqual(res["events"][0]["content"]["body"], "foo")
+
+        # Fetch the message of the dehydrated device again, which should return nothing
+        # and delete the old messages
+        res = self.get_success(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id=stored_dehydrated_device_id,
+                since_token=res["next_batch"],
+                limit=10,
+            )
+        )
+        self.assertTrue(len(res["next_batch"]) > 1)
+        self.assertEqual(len(res["events"]), 0)