diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 66215af2b8..647ee09279 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -17,15 +17,18 @@
from typing import Optional
from unittest import mock
+from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import RoomEncryptionAlgorithms
from synapse.api.errors import NotFoundError, SynapseError
from synapse.appservice import ApplicationService
from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
+from synapse.rest import admin
+from synapse.rest.client import devices, login, register
from synapse.server import HomeServer
from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
@@ -399,11 +402,19 @@ class DeviceTestCase(unittest.HomeserverTestCase):
class DehydrationTestCase(unittest.HomeserverTestCase):
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ login.register_servlets,
+ register.register_servlets,
+ devices.register_servlets,
+ ]
+
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
hs = self.setup_test_homeserver("server")
handler = hs.get_device_handler()
assert isinstance(handler, DeviceHandler)
self.handler = handler
+ self.message_handler = hs.get_device_message_handler()
self.registration = hs.get_registration_handler()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
@@ -418,6 +429,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
stored_dehydrated_device_id = self.get_success(
self.handler.store_dehydrated_device(
user_id=user_id,
+ device_id=None,
device_data={"device_data": {"foo": "bar"}},
initial_device_display_name="dehydrated device",
)
@@ -481,3 +493,88 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
self.assertIsNone(ret)
+
+ @unittest.override_config(
+ {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+ )
+ def test_dehydrate_v2_and_fetch_events(self) -> None:
+ user_id = "@boris:server"
+
+ self.get_success(self.store.register_user(user_id, "foobar"))
+
+ # First check if we can store and fetch a dehydrated device
+ stored_dehydrated_device_id = self.get_success(
+ self.handler.store_dehydrated_device(
+ user_id=user_id,
+ device_id=None,
+ device_data={"device_data": {"foo": "bar"}},
+ initial_device_display_name="dehydrated device",
+ )
+ )
+
+ device_info = self.get_success(
+ self.handler.get_dehydrated_device(user_id=user_id)
+ )
+ assert device_info is not None
+ retrieved_device_id, device_data = device_info
+ self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+ self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+ # Create a new login for the user
+ device_id, access_token, _expiration_time, _refresh_token = self.get_success(
+ self.registration.register_device(
+ user_id=user_id,
+ device_id=None,
+ initial_display_name="new device",
+ )
+ )
+
+ requester = create_requester(user_id, device_id=device_id)
+
+ # Fetching messages for a non-existing device should return an error
+ self.get_failure(
+ self.message_handler.get_events_for_dehydrated_device(
+ requester=requester,
+ device_id="not the right device ID",
+ since_token=None,
+ limit=10,
+ ),
+ SynapseError,
+ )
+
+ # Send a message to the dehydrated device
+ ensureDeferred(
+ self.message_handler.send_device_message(
+ requester=requester,
+ message_type="test.message",
+ messages={user_id: {stored_dehydrated_device_id: {"body": "foo"}}},
+ )
+ )
+ self.pump()
+
+ # Fetch the message of the dehydrated device
+ res = self.get_success(
+ self.message_handler.get_events_for_dehydrated_device(
+ requester=requester,
+ device_id=stored_dehydrated_device_id,
+ since_token=None,
+ limit=10,
+ )
+ )
+
+ self.assertTrue(len(res["next_batch"]) > 1)
+ self.assertEqual(len(res["events"]), 1)
+ self.assertEqual(res["events"][0]["content"]["body"], "foo")
+
+ # Fetch the message of the dehydrated device again, which should return nothing
+ # and delete the old messages
+ res = self.get_success(
+ self.message_handler.get_events_for_dehydrated_device(
+ requester=requester,
+ device_id=stored_dehydrated_device_id,
+ since_token=res["next_batch"],
+ limit=10,
+ )
+ )
+ self.assertTrue(len(res["next_batch"]) > 1)
+ self.assertEqual(len(res["events"]), 0)
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index d80eea17d3..b7d420cfec 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -13,12 +13,14 @@
# limitations under the License.
from http import HTTPStatus
+from twisted.internet.defer import ensureDeferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.errors import NotFoundError
from synapse.rest import admin, devices, room, sync
-from synapse.rest.client import account, login, register
+from synapse.rest.client import account, keys, login, register
from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
from synapse.util import Clock
from tests import unittest
@@ -208,8 +210,13 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
login.register_servlets,
register.register_servlets,
devices.register_servlets,
+ keys.register_servlets,
]
+ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+ self.registration = hs.get_registration_handler()
+ self.message_handler = hs.get_device_message_handler()
+
def test_PUT(self) -> None:
"""Sanity-check that we can PUT a dehydrated device.
@@ -226,7 +233,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
"device_data": {
"algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
"account": "dehydrated_device",
- }
+ },
+ "device_keys": {
+ "user_id": "@alice:test",
+ "device_id": "device1",
+ "valid_until_ts": "80",
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ ],
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures": {
+ "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
+ },
+ },
},
access_token=token,
shorthand=False,
@@ -234,3 +255,128 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
device_id = channel.json_body.get("device_id")
self.assertIsInstance(device_id, str)
+
+ @unittest.override_config(
+ {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+ )
+ def test_dehydrate_msc3814(self) -> None:
+ user = self.register_user("mikey", "pass")
+ token = self.login(user, "pass", device_id="device1")
+ content: JsonDict = {
+ "device_data": {
+ "algorithm": "m.dehydration.v1.olm",
+ },
+ "device_id": "device1",
+ "initial_device_display_name": "foo bar",
+ "device_keys": {
+ "user_id": "@mikey:test",
+ "device_id": "device1",
+ "valid_until_ts": "80",
+ "algorithms": [
+ "m.olm.curve25519-aes-sha2",
+ ],
+ "keys": {
+ "<algorithm>:<device_id>": "<key_base64>",
+ },
+ "signatures": {
+ "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
+ },
+ },
+ }
+ channel = self.make_request(
+ "PUT",
+ "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+ content=content,
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ device_id = channel.json_body.get("device_id")
+ assert device_id is not None
+ self.assertIsInstance(device_id, str)
+ self.assertEqual("device1", device_id)
+
+ # test that we can now GET the dehydrated device info
+ channel = self.make_request(
+ "GET",
+ "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ returned_device_id = channel.json_body.get("device_id")
+ self.assertEqual(returned_device_id, device_id)
+ device_data = channel.json_body.get("device_data")
+ expected_device_data = {
+ "algorithm": "m.dehydration.v1.olm",
+ }
+ self.assertEqual(device_data, expected_device_data)
+
+ # create another device for the user
+ (
+ new_device_id,
+ _,
+ _,
+ _,
+ ) = self.get_success(
+ self.registration.register_device(
+ user_id=user,
+ device_id=None,
+ initial_display_name="new device",
+ )
+ )
+ requester = create_requester(user, device_id=new_device_id)
+
+ # Send a message to the dehydrated device
+ ensureDeferred(
+ self.message_handler.send_device_message(
+ requester=requester,
+ message_type="test.message",
+ messages={user: {device_id: {"body": "test_message"}}},
+ )
+ )
+ self.pump()
+
+ # make sure we can fetch the message with our dehydrated device id
+ channel = self.make_request(
+ "POST",
+ f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+ content={},
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ expected_content = {"body": "test_message"}
+ self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+ next_batch_token = channel.json_body.get("next_batch")
+
+ # fetch messages again and make sure that the message was deleted and we are returned an
+ # empty array
+ content = {"next_batch": next_batch_token}
+ channel = self.make_request(
+ "POST",
+ f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+ content=content,
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+ self.assertEqual(channel.json_body["events"], [])
+
+ # make sure we can delete the dehydrated device
+ channel = self.make_request(
+ "DELETE",
+ "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 200)
+
+ # ...and after deleting it is no longer available
+ channel = self.make_request(
+ "GET",
+ "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+ access_token=token,
+ shorthand=False,
+ )
+ self.assertEqual(channel.code, 404)
|