diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 66215af2b8..647ee09279 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -17,15 +17,18 @@
from typing import Optional
from unittest import mock
+from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import NotFoundError, SynapseError
from synapse.appservice import ApplicationService
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
+from synapse.rest import admin
+from synapse.rest.client import devices, login, register
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
@@ -399,11 +402,19 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ devices.register_servlets,
+ ]
+
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server")
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
+ self.message_handler = hs.get_device_message_handler()
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
@@ -418,6 +429,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
stored_dehydrated_device_id = self.get_success(
self.handler.store_dehydrated_device(
user_id=user_id,
+ device_id=None,
device_data={"device_data": {"foo": "bar"}},
initial_device_display_name="dehydrated device",
)
@@ -481,3 +493,88 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
self.assertIsNone(ret)
+
+ @unittest.override_config(
+ {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+ )
+ def test_dehydrate_v2_and_fetch_events(self) -> None:
+ user_id = "@boris:server"
+
+ self.get_success(self.store.register_user(user_id, "foobar"))
+
+ # First check if we can store and fetch a dehydrated device
+ stored_dehydrated_device_id = self.get_success(
+ self.handler.store_dehydrated_device(
+ user_id=user_id,
+ device_id=None,
+ device_data={"device_data": {"foo": "bar"}},
+ initial_device_display_name="dehydrated device",
+ )
+ )
+
+ device_info = self.get_success(
+ self.handler.get_dehydrated_device(user_id=user_id)
+ )
+ assert device_info is not None
+ retrieved_device_id, device_data = device_info
+ self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+ self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+ # Create a new login for the user
+ device_id, access_token, _expiration_time, _refresh_token = self.get_success(
+ self.registration.register_device(
+ user_id=user_id,
+ device_id=None,
+ initial_display_name="new device",
+ )
+ )
+
+ requester = create_requester(user_id, device_id=device_id)
+
+ # Fetching messages for a non-existing device should return an error
+ self.get_failure(
+ self.message_handler.get_events_for_dehydrated_device(
+ requester=requester,
+ device_id="not the right device ID",
+ since_token=None,
+ limit=10,
+ ),
+ SynapseError,
+ )
+
+ # Send a message to the dehydrated device
+ ensureDeferred(
+ self.message_handler.send_device_message(
+ requester=requester,
+ message_type="test.message",
+ messages={user_id: {stored_dehydrated_device_id: {"body": "foo"}}},
+ )
+ )
+ self.pump()
+
+ # Fetch the message of the dehydrated device
+ res = self.get_success(
+ self.message_handler.get_events_for_dehydrated_device(
+ requester=requester,
+ device_id=stored_dehydrated_device_id,
+ since_token=None,
+ limit=10,
+ )
+ )
+
+ self.assertTrue(len(res["next_batch"]) > 1)
+ self.assertEqual(len(res["events"]), 1)
+ self.assertEqual(res["events"][0]["content"]["body"], "foo")
+
+ # Fetch the message of the dehydrated device again, which should return nothing
+ # and delete the old messages
+ res = self.get_success(
+ self.message_handler.get_events_for_dehydrated_device(
+ requester=requester,
+ device_id=stored_dehydrated_device_id,
+ since_token=res["next_batch"],
+ limit=10,
+ )
+ )
+ self.assertTrue(len(res["next_batch"]) > 1)
+ self.assertEqual(len(res["events"]), 0)
|