summary refs log tree commit diff
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2023-07-24 08:23:19 -0700
committerGitHub <noreply@github.com>2023-07-24 08:23:19 -0700
commit641ff9ef7eaa7f1a632b983f4d36bb28dc23484d (patch)
tree3d61b97c9cbedd0c804c04de32463e57513cabc5
parentFix broken Arch Linux package link (#15981) (diff)
downloadsynapse-641ff9ef7eaa7f1a632b983f4d36bb28dc23484d.tar.xz
Support MSC3814: Dehydrated Devices (#15929)

Signed-off-by: Nicolas Werner <n.werner@famedly.com>
Co-authored-by: Nicolas Werner <n.werner@famedly.com>
Co-authored-by: Nicolas Werner <89468146+nico-famedly@users.noreply.github.com>
Co-authored-by: Hubert Chathi <hubert@uhoreg.ca>
Diffstat (limited to '')
-rw-r--r--changelog.d/15929.feature1
-rw-r--r--synapse/config/experimental.py21
-rw-r--r--synapse/handlers/device.py4
-rw-r--r--synapse/handlers/devicemessage.py108
-rw-r--r--synapse/rest/client/devices.py232
-rw-r--r--tests/handlers/test_device.py99
-rw-r--r--tests/rest/client/test_devices.py150
7 files changed, 603 insertions, 12 deletions
diff --git a/changelog.d/15929.feature b/changelog.d/15929.feature
new file mode 100644
index 0000000000..c3aaeae66e
--- /dev/null
+++ b/changelog.d/15929.feature
@@ -0,0 +1 @@
+Implement [MSC3814](https://github.com/matrix-org/matrix-spec-proposals/pull/3814), dehydrated devices v2/shrivelled sessions and move [MSC2697](https://github.com/matrix-org/matrix-spec-proposals/pull/2697) behind a config flag. Contributed by Nico from Famedly and H-Shay.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 0970f22a75..1695ed8ca3 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -247,6 +247,27 @@ class ExperimentalConfig(Config):
         # MSC3026 (busy presence state)
         self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
 
+        # MSC2697 (device dehydration)
+        # Enabled by default since this option was added after adding the feature.
+        # It is not recommended that both MSC2697 and MSC3814 both be enabled at
+        # once.
+        self.msc2697_enabled: bool = experimental.get("msc2697_enabled", True)
+
+        # MSC3814 (dehydrated devices with SSSS)
+        # This is an alternative method to achieve the same goals as MSC2697.
+        # It is not recommended that both MSC2697 and MSC3814 both be enabled at
+        # once.
+        self.msc3814_enabled: bool = experimental.get("msc3814_enabled", False)
+
+        if self.msc2697_enabled and self.msc3814_enabled:
+            raise ConfigError(
+                "MSC2697 and MSC3814 should not both be enabled.",
+                (
+                    "experimental_features",
+                    "msc3814_enabled",
+                ),
+            )
+
         # MSC3244 (room version capabilities)
         self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d73d9dca08..f3a713f5fa 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -653,6 +653,7 @@ class DeviceHandler(DeviceWorkerHandler):
     async def store_dehydrated_device(
         self,
         user_id: str,
+        device_id: Optional[str],
         device_data: JsonDict,
         initial_device_display_name: Optional[str] = None,
     ) -> str:
@@ -661,6 +662,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
         Args:
             user_id: the user that we are storing the device for
+            device_id: device id supplied by client
             device_data: the dehydrated device information
             initial_device_display_name: The display name to use for the device
         Returns:
@@ -668,7 +670,7 @@ class DeviceHandler(DeviceWorkerHandler):
         """
         device_id = await self.check_device_registered(
             user_id,
-            None,
+            device_id,
             initial_device_display_name,
         )
         old_device_id = await self.store.store_dehydrated_device(
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 3caf9b31cc..15e94a03cb 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -13,10 +13,11 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict
+from http import HTTPStatus
+from typing import TYPE_CHECKING, Any, Dict, Optional
 
 from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
@@ -48,6 +49,9 @@ class DeviceMessageHandler:
         self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.is_mine = hs.is_mine
+        if hs.config.experimental.msc3814_enabled:
+            self.event_sources = hs.get_event_sources()
+            self.device_handler = hs.get_device_handler()
 
         # We only need to poke the federation sender explicitly if its on the
         # same instance. Other federation sender instances will get notified by
@@ -303,3 +307,103 @@ class DeviceMessageHandler:
                 # Enqueue a new federation transaction to send the new
                 # device messages to each remote destination.
                 self.federation_sender.send_device_messages(destination)
+
+    async def get_events_for_dehydrated_device(
+        self,
+        requester: Requester,
+        device_id: str,
+        since_token: Optional[str],
+        limit: int,
+    ) -> JsonDict:
+        """Fetches up to `limit` events sent to `device_id` starting from `since_token`
+        and returns the new since token. If there are no more messages, returns an empty
+        array.
+
+        Args:
+            requester: the user requesting the messages
+            device_id: ID of the dehydrated device
+            since_token: stream id to start from when fetching messages
+            limit: the number of messages to fetch
+        Returns:
+            A dict containing the to-device messages, as well as a token that the client
+            can provide in the next call to fetch the next batch of messages
+        """
+
+        user_id = requester.user.to_string()
+
+        # only allow fetching messages for the dehydrated device id currently associated
+        # with the user
+        dehydrated_device = await self.device_handler.get_dehydrated_device(user_id)
+        if dehydrated_device is None:
+            raise SynapseError(
+                HTTPStatus.FORBIDDEN,
+                "No dehydrated device exists",
+                Codes.FORBIDDEN,
+            )
+
+        dehydrated_device_id, _ = dehydrated_device
+        if device_id != dehydrated_device_id:
+            raise SynapseError(
+                HTTPStatus.FORBIDDEN,
+                "You may only fetch messages for your dehydrated device",
+                Codes.FORBIDDEN,
+            )
+
+        since_stream_id = 0
+        if since_token:
+            if not since_token.startswith("d"):
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "from parameter %r has an invalid format" % (since_token,),
+                    errcode=Codes.INVALID_PARAM,
+                )
+
+            try:
+                since_stream_id = int(since_token[1:])
+            except Exception:
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "from parameter %r has an invalid format" % (since_token,),
+                    errcode=Codes.INVALID_PARAM,
+                )
+
+            # if we have a since token, delete any to-device messages before that token
+            # (since we now know that the device has received them)
+            deleted = await self.store.delete_messages_for_device(
+                user_id, device_id, since_stream_id
+            )
+            logger.debug(
+                "Deleted %d to-device messages up to %d for user_id %s device_id %s",
+                deleted,
+                since_stream_id,
+                user_id,
+                device_id,
+            )
+
+        to_token = self.event_sources.get_current_token().to_device_key
+
+        messages, stream_id = await self.store.get_messages_for_device(
+            user_id, device_id, since_stream_id, to_token, limit
+        )
+
+        for message in messages:
+            # Remove the message id before sending to client
+            message_id = message.pop("message_id", None)
+            if message_id:
+                set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id)
+
+        logger.debug(
+            "Returning %d to-device messages between %d and %d (current token: %d) for "
+            "dehydrated device %s, user_id %s",
+            len(messages),
+            since_stream_id,
+            stream_id,
+            to_token,
+            device_id,
+            user_id,
+        )
+
+        return {
+            "events": messages,
+            "next_batch": f"d{stream_id}",
+        }
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 38dff9703f..690d2ec406 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -14,19 +14,22 @@
 # limitations under the License.
 
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from pydantic import Extra, StrictStr
 
 from synapse.api import errors
-from synapse.api.errors import NotFoundError, UnrecognizedRequestError
+from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError
 from synapse.handlers.device import DeviceHandler
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
     parse_and_validate_json_object_from_request,
+    parse_integer,
 )
 from synapse.http.site import SynapseRequest
+from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
 from synapse.rest.client._base import client_patterns, interactive_auth_handler
 from synapse.rest.client.models import AuthenticationData
 from synapse.rest.models import RequestBodyModel
@@ -229,6 +232,8 @@ class DehydratedDeviceDataModel(RequestBodyModel):
 class DehydratedDeviceServlet(RestServlet):
     """Retrieve or store a dehydrated device.
 
+    Implements either MSC2697 or MSC3814.
+
     GET /org.matrix.msc2697.v2/dehydrated_device
 
     HTTP/1.1 200 OK
@@ -261,9 +266,7 @@ class DehydratedDeviceServlet(RestServlet):
 
     """
 
-    PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device$", releases=())
-
-    def __init__(self, hs: "HomeServer"):
+    def __init__(self, hs: "HomeServer", msc2697: bool = True):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -271,6 +274,13 @@ class DehydratedDeviceServlet(RestServlet):
         assert isinstance(handler, DeviceHandler)
         self.device_handler = handler
 
+        self.PATTERNS = client_patterns(
+            "/org.matrix.msc2697.v2/dehydrated_device$"
+            if msc2697
+            else "/org.matrix.msc3814.v1/dehydrated_device$",
+            releases=(),
+        )
+
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         dehydrated_device = await self.device_handler.get_dehydrated_device(
@@ -293,6 +303,7 @@ class DehydratedDeviceServlet(RestServlet):
 
         device_id = await self.device_handler.store_dehydrated_device(
             requester.user.to_string(),
+            None,
             submission.device_data.dict(),
             submission.initial_device_display_name,
         )
@@ -347,6 +358,210 @@ class ClaimDehydratedDeviceServlet(RestServlet):
         return 200, result
 
 
+class DehydratedDeviceEventsServlet(RestServlet):
+    PATTERNS = client_patterns(
+        "/org.matrix.msc3814.v1/dehydrated_device/(?P<device_id>[^/]*)/events$",
+        releases=(),
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self.message_handler = hs.get_device_message_handler()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastores().main
+
+    class PostBody(RequestBodyModel):
+        next_batch: Optional[StrictStr]
+
+    async def on_POST(
+        self, request: SynapseRequest, device_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+
+        next_batch = parse_and_validate_json_object_from_request(
+            request, self.PostBody
+        ).next_batch
+        limit = parse_integer(request, "limit", 100)
+
+        msgs = await self.message_handler.get_events_for_dehydrated_device(
+            requester=requester,
+            device_id=device_id,
+            since_token=next_batch,
+            limit=limit,
+        )
+
+        return 200, msgs
+
+
+class DehydratedDeviceV2Servlet(RestServlet):
+    """Upload, retrieve, or delete a dehydrated device.
+
+    GET /org.matrix.msc3814.v1/dehydrated_device
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+
+    {
+      "device_id": "dehydrated_device_id",
+      "device_data": {
+        "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+        "account": "dehydrated_device"
+      }
+    }
+
+    PUT /org.matrix.msc3814.v1/dehydrated_device
+    Content-Type: application/json
+
+    {
+        "device_id": "dehydrated_device_id",
+        "device_data": {
+            "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+            "account": "dehydrated_device"
+        },
+        "device_keys": {
+            "user_id": "<user_id>",
+            "device_id": "<device_id>",
+            "valid_until_ts": <millisecond_timestamp>,
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+            ]
+            "keys": {
+                "<algorithm>:<device_id>": "<key_base64>",
+            },
+            "signatures:" {
+                "<user_id>" {
+                    "<algorithm>:<device_id>": "<signature_base64>"
+                }
+            }
+        },
+        "fallback_keys": {
+            "<algorithm>:<device_id>": "<key_base64>",
+            "signed_<algorithm>:<device_id>": {
+                "fallback": true,
+                "key": "<key_base64>",
+                "signatures": {
+                    "<user_id>": {
+                        "<algorithm>:<device_id>": "<key_base64>"
+                    }
+                }
+            }
+        }
+        "one_time_keys": {
+            "<algorithm>:<key_id>": "<key_base64>"
+        },
+
+    }
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+
+    {
+      "device_id": "dehydrated_device_id"
+    }
+
+    DELETE /org.matrix.msc3814.v1/dehydrated_device
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+
+    {
+      "device_id": "dehydrated_device_id",
+    }
+    """
+
+    PATTERNS = [
+        *client_patterns("/org.matrix.msc3814.v1/dehydrated_device$", releases=()),
+    ]
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
+        self.device_handler = handler
+
+        if hs.config.worker.worker_app is None:
+            # if main process
+            self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
+        else:
+            # then a worker
+            self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
+
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+
+        dehydrated_device = await self.device_handler.get_dehydrated_device(
+            requester.user.to_string()
+        )
+
+        if dehydrated_device is not None:
+            (device_id, device_data) = dehydrated_device
+            result = {"device_id": device_id, "device_data": device_data}
+            return 200, result
+        else:
+            raise errors.NotFoundError("No dehydrated device available")
+
+    async def on_DELETE(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+
+        dehydrated_device = await self.device_handler.get_dehydrated_device(
+            requester.user.to_string()
+        )
+
+        if dehydrated_device is not None:
+            (device_id, device_data) = dehydrated_device
+
+            result = await self.device_handler.rehydrate_device(
+                requester.user.to_string(),
+                self.auth.get_access_token_from_request(request),
+                device_id,
+            )
+
+            result = {"device_id": device_id}
+
+            return 200, result
+        else:
+            raise errors.NotFoundError("No dehydrated device available")
+
+    class PutBody(RequestBodyModel):
+        device_data: DehydratedDeviceDataModel
+        device_id: StrictStr
+        initial_device_display_name: Optional[StrictStr]
+
+        class Config:
+            extra = Extra.allow
+
+    async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        submission = parse_and_validate_json_object_from_request(request, self.PutBody)
+        requester = await self.auth.get_user_by_req(request)
+        user_id = requester.user.to_string()
+
+        device_info = submission.dict()
+        if "device_keys" not in device_info.keys():
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "Device key(s) not found, these must be provided.",
+            )
+
+        # TODO: Those two operations, creating a device and storing the
+        # device's keys should be atomic.
+        device_id = await self.device_handler.store_dehydrated_device(
+            requester.user.to_string(),
+            submission.device_id,
+            submission.device_data.dict(),
+            submission.initial_device_display_name,
+        )
+
+        # TODO: Do we need to do something with the result here?
+        await self.key_uploader(
+            user_id=user_id, device_id=submission.device_id, keys=submission.dict()
+        )
+
+        return 200, {"device_id": device_id}
+
+
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     if (
         hs.config.worker.worker_app is None
@@ -354,7 +569,12 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ):
         DeleteDevicesRestServlet(hs).register(http_server)
     DevicesRestServlet(hs).register(http_server)
+
     if hs.config.worker.worker_app is None:
         DeviceRestServlet(hs).register(http_server)
-        DehydratedDeviceServlet(hs).register(http_server)
-        ClaimDehydratedDeviceServlet(hs).register(http_server)
+        if hs.config.experimental.msc2697_enabled:
+            DehydratedDeviceServlet(hs, msc2697=True).register(http_server)
+            ClaimDehydratedDeviceServlet(hs).register(http_server)
+        if hs.config.experimental.msc3814_enabled:
+            DehydratedDeviceV2Servlet(hs).register(http_server)
+            DehydratedDeviceEventsServlet(hs).register(http_server)
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 66215af2b8..647ee09279 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -17,15 +17,18 @@
 from typing import Optional
 from unittest import mock
 
+from twisted.internet.defer import ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import RoomEncryptionAlgorithms
 from synapse.api.errors import NotFoundError, SynapseError
 from synapse.appservice import ApplicationService
 from synapse.handlers.device import MAX_DEVICE_DISPLAY_NAME_LEN, DeviceHandler
+from synapse.rest import admin
+from synapse.rest.client import devices, login, register
 from synapse.server import HomeServer
 from synapse.storage.databases.main.appservice import _make_exclusive_regex
-from synapse.types import JsonDict
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -399,11 +402,19 @@ class DeviceTestCase(unittest.HomeserverTestCase):
 
 
 class DehydrationTestCase(unittest.HomeserverTestCase):
+    servlets = [
+        admin.register_servlets_for_client_rest_resource,
+        login.register_servlets,
+        register.register_servlets,
+        devices.register_servlets,
+    ]
+
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver("server")
         handler = hs.get_device_handler()
         assert isinstance(handler, DeviceHandler)
         self.handler = handler
+        self.message_handler = hs.get_device_message_handler()
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
@@ -418,6 +429,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         stored_dehydrated_device_id = self.get_success(
             self.handler.store_dehydrated_device(
                 user_id=user_id,
+                device_id=None,
                 device_data={"device_data": {"foo": "bar"}},
                 initial_device_display_name="dehydrated device",
             )
@@ -481,3 +493,88 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         ret = self.get_success(self.handler.get_dehydrated_device(user_id=user_id))
 
         self.assertIsNone(ret)
+
+    @unittest.override_config(
+        {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+    )
+    def test_dehydrate_v2_and_fetch_events(self) -> None:
+        user_id = "@boris:server"
+
+        self.get_success(self.store.register_user(user_id, "foobar"))
+
+        # First check if we can store and fetch a dehydrated device
+        stored_dehydrated_device_id = self.get_success(
+            self.handler.store_dehydrated_device(
+                user_id=user_id,
+                device_id=None,
+                device_data={"device_data": {"foo": "bar"}},
+                initial_device_display_name="dehydrated device",
+            )
+        )
+
+        device_info = self.get_success(
+            self.handler.get_dehydrated_device(user_id=user_id)
+        )
+        assert device_info is not None
+        retrieved_device_id, device_data = device_info
+        self.assertEqual(retrieved_device_id, stored_dehydrated_device_id)
+        self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
+
+        # Create a new login for the user
+        device_id, access_token, _expiration_time, _refresh_token = self.get_success(
+            self.registration.register_device(
+                user_id=user_id,
+                device_id=None,
+                initial_display_name="new device",
+            )
+        )
+
+        requester = create_requester(user_id, device_id=device_id)
+
+        # Fetching messages for a non-existing device should return an error
+        self.get_failure(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id="not the right device ID",
+                since_token=None,
+                limit=10,
+            ),
+            SynapseError,
+        )
+
+        # Send a message to the dehydrated device
+        ensureDeferred(
+            self.message_handler.send_device_message(
+                requester=requester,
+                message_type="test.message",
+                messages={user_id: {stored_dehydrated_device_id: {"body": "foo"}}},
+            )
+        )
+        self.pump()
+
+        # Fetch the message of the dehydrated device
+        res = self.get_success(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id=stored_dehydrated_device_id,
+                since_token=None,
+                limit=10,
+            )
+        )
+
+        self.assertTrue(len(res["next_batch"]) > 1)
+        self.assertEqual(len(res["events"]), 1)
+        self.assertEqual(res["events"][0]["content"]["body"], "foo")
+
+        # Fetch the message of the dehydrated device again, which should return nothing
+        # and delete the old messages
+        res = self.get_success(
+            self.message_handler.get_events_for_dehydrated_device(
+                requester=requester,
+                device_id=stored_dehydrated_device_id,
+                since_token=res["next_batch"],
+                limit=10,
+            )
+        )
+        self.assertTrue(len(res["next_batch"]) > 1)
+        self.assertEqual(len(res["events"]), 0)
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index d80eea17d3..b7d420cfec 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -13,12 +13,14 @@
 # limitations under the License.
 from http import HTTPStatus
 
+from twisted.internet.defer import ensureDeferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.errors import NotFoundError
 from synapse.rest import admin, devices, room, sync
-from synapse.rest.client import account, login, register
+from synapse.rest.client import account, keys, login, register
 from synapse.server import HomeServer
+from synapse.types import JsonDict, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -208,8 +210,13 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         login.register_servlets,
         register.register_servlets,
         devices.register_servlets,
+        keys.register_servlets,
     ]
 
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.registration = hs.get_registration_handler()
+        self.message_handler = hs.get_device_message_handler()
+
     def test_PUT(self) -> None:
         """Sanity-check that we can PUT a dehydrated device.
 
@@ -226,7 +233,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
                 "device_data": {
                     "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
                     "account": "dehydrated_device",
-                }
+                },
+                "device_keys": {
+                    "user_id": "@alice:test",
+                    "device_id": "device1",
+                    "valid_until_ts": "80",
+                    "algorithms": [
+                        "m.olm.curve25519-aes-sha2",
+                    ],
+                    "keys": {
+                        "<algorithm>:<device_id>": "<key_base64>",
+                    },
+                    "signatures": {
+                        "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
+                    },
+                },
             },
             access_token=token,
             shorthand=False,
@@ -234,3 +255,128 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, HTTPStatus.OK, channel.json_body)
         device_id = channel.json_body.get("device_id")
         self.assertIsInstance(device_id, str)
+
+    @unittest.override_config(
+        {"experimental_features": {"msc2697_enabled": False, "msc3814_enabled": True}}
+    )
+    def test_dehydrate_msc3814(self) -> None:
+        user = self.register_user("mikey", "pass")
+        token = self.login(user, "pass", device_id="device1")
+        content: JsonDict = {
+            "device_data": {
+                "algorithm": "m.dehydration.v1.olm",
+            },
+            "device_id": "device1",
+            "initial_device_display_name": "foo bar",
+            "device_keys": {
+                "user_id": "@mikey:test",
+                "device_id": "device1",
+                "valid_until_ts": "80",
+                "algorithms": [
+                    "m.olm.curve25519-aes-sha2",
+                ],
+                "keys": {
+                    "<algorithm>:<device_id>": "<key_base64>",
+                },
+                "signatures": {
+                    "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
+                },
+            },
+        }
+        channel = self.make_request(
+            "PUT",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            content=content,
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        device_id = channel.json_body.get("device_id")
+        assert device_id is not None
+        self.assertIsInstance(device_id, str)
+        self.assertEqual("device1", device_id)
+
+        # test that we can now GET the dehydrated device info
+        channel = self.make_request(
+            "GET",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        returned_device_id = channel.json_body.get("device_id")
+        self.assertEqual(returned_device_id, device_id)
+        device_data = channel.json_body.get("device_data")
+        expected_device_data = {
+            "algorithm": "m.dehydration.v1.olm",
+        }
+        self.assertEqual(device_data, expected_device_data)
+
+        # create another device for the user
+        (
+            new_device_id,
+            _,
+            _,
+            _,
+        ) = self.get_success(
+            self.registration.register_device(
+                user_id=user,
+                device_id=None,
+                initial_display_name="new device",
+            )
+        )
+        requester = create_requester(user, device_id=new_device_id)
+
+        # Send a message to the dehydrated device
+        ensureDeferred(
+            self.message_handler.send_device_message(
+                requester=requester,
+                message_type="test.message",
+                messages={user: {device_id: {"body": "test_message"}}},
+            )
+        )
+        self.pump()
+
+        # make sure we can fetch the message with our dehydrated device id
+        channel = self.make_request(
+            "POST",
+            f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+            content={},
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        expected_content = {"body": "test_message"}
+        self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+        next_batch_token = channel.json_body.get("next_batch")
+
+        # fetch messages again and make sure that the message was deleted and we are returned an
+        # empty array
+        content = {"next_batch": next_batch_token}
+        channel = self.make_request(
+            "POST",
+            f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+            content=content,
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["events"], [])
+
+        # make sure we can delete the dehydrated device
+        channel = self.make_request(
+            "DELETE",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+
+        # ...and after deleting it is no longer available
+        channel = self.make_request(
+            "GET",
+            "_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device",
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 404)