summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2023-07-24 08:23:19 -0700
committerGitHub <noreply@github.com>2023-07-24 08:23:19 -0700
commit641ff9ef7eaa7f1a632b983f4d36bb28dc23484d (patch)
tree3d61b97c9cbedd0c804c04de32463e57513cabc5 /synapse
parentFix broken Arch Linux package link (#15981) (diff)
downloadsynapse-641ff9ef7eaa7f1a632b983f4d36bb28dc23484d.tar.xz
Support MSC3814: Dehydrated Devices (#15929)

Signed-off-by: Nicolas Werner <n.werner@famedly.com>
Co-authored-by: Nicolas Werner <n.werner@famedly.com>
Co-authored-by: Nicolas Werner <89468146+nico-famedly@users.noreply.github.com>
Co-authored-by: Hubert Chathi <hubert@uhoreg.ca>
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/experimental.py21
-rw-r--r--synapse/handlers/device.py4
-rw-r--r--synapse/handlers/devicemessage.py108
-rw-r--r--synapse/rest/client/devices.py232
4 files changed, 356 insertions, 9 deletions
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 0970f22a75..1695ed8ca3 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -247,6 +247,27 @@ class ExperimentalConfig(Config):
         # MSC3026 (busy presence state)
         self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
 
+        # MSC2697 (device dehydration)
+        # Enabled by default since this option was added after adding the feature.
+        # It is not recommended that both MSC2697 and MSC3814 both be enabled at
+        # once.
+        self.msc2697_enabled: bool = experimental.get("msc2697_enabled", True)
+
+        # MSC3814 (dehydrated devices with SSSS)
+        # This is an alternative method to achieve the same goals as MSC2697.
+        # It is not recommended that both MSC2697 and MSC3814 both be enabled at
+        # once.
+        self.msc3814_enabled: bool = experimental.get("msc3814_enabled", False)
+
+        if self.msc2697_enabled and self.msc3814_enabled:
+            raise ConfigError(
+                "MSC2697 and MSC3814 should not both be enabled.",
+                (
+                    "experimental_features",
+                    "msc3814_enabled",
+                ),
+            )
+
         # MSC3244 (room version capabilities)
         self.msc3244_enabled: bool = experimental.get("msc3244_enabled", True)
 
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index d73d9dca08..f3a713f5fa 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -653,6 +653,7 @@ class DeviceHandler(DeviceWorkerHandler):
     async def store_dehydrated_device(
         self,
         user_id: str,
+        device_id: Optional[str],
         device_data: JsonDict,
         initial_device_display_name: Optional[str] = None,
     ) -> str:
@@ -661,6 +662,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
         Args:
             user_id: the user that we are storing the device for
+            device_id: device id supplied by client
             device_data: the dehydrated device information
             initial_device_display_name: The display name to use for the device
         Returns:
@@ -668,7 +670,7 @@ class DeviceHandler(DeviceWorkerHandler):
         """
         device_id = await self.check_device_registered(
             user_id,
-            None,
+            device_id,
             initial_device_display_name,
         )
         old_device_id = await self.store.store_dehydrated_device(
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 3caf9b31cc..15e94a03cb 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -13,10 +13,11 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict
+from http import HTTPStatus
+from typing import TYPE_CHECKING, Any, Dict, Optional
 
 from synapse.api.constants import EduTypes, EventContentFields, ToDeviceEventTypes
-from synapse.api.errors import SynapseError
+from synapse.api.errors import Codes, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.logging.context import run_in_background
 from synapse.logging.opentracing import (
@@ -48,6 +49,9 @@ class DeviceMessageHandler:
         self.store = hs.get_datastores().main
         self.notifier = hs.get_notifier()
         self.is_mine = hs.is_mine
+        if hs.config.experimental.msc3814_enabled:
+            self.event_sources = hs.get_event_sources()
+            self.device_handler = hs.get_device_handler()
 
         # We only need to poke the federation sender explicitly if its on the
         # same instance. Other federation sender instances will get notified by
@@ -303,3 +307,103 @@ class DeviceMessageHandler:
                 # Enqueue a new federation transaction to send the new
                 # device messages to each remote destination.
                 self.federation_sender.send_device_messages(destination)
+
+    async def get_events_for_dehydrated_device(
+        self,
+        requester: Requester,
+        device_id: str,
+        since_token: Optional[str],
+        limit: int,
+    ) -> JsonDict:
+        """Fetches up to `limit` events sent to `device_id` starting from `since_token`
+        and returns the new since token. If there are no more messages, returns an empty
+        array.
+
+        Args:
+            requester: the user requesting the messages
+            device_id: ID of the dehydrated device
+            since_token: stream id to start from when fetching messages
+            limit: the number of messages to fetch
+        Returns:
+            A dict containing the to-device messages, as well as a token that the client
+            can provide in the next call to fetch the next batch of messages
+        """
+
+        user_id = requester.user.to_string()
+
+        # only allow fetching messages for the dehydrated device id currently associated
+        # with the user
+        dehydrated_device = await self.device_handler.get_dehydrated_device(user_id)
+        if dehydrated_device is None:
+            raise SynapseError(
+                HTTPStatus.FORBIDDEN,
+                "No dehydrated device exists",
+                Codes.FORBIDDEN,
+            )
+
+        dehydrated_device_id, _ = dehydrated_device
+        if device_id != dehydrated_device_id:
+            raise SynapseError(
+                HTTPStatus.FORBIDDEN,
+                "You may only fetch messages for your dehydrated device",
+                Codes.FORBIDDEN,
+            )
+
+        since_stream_id = 0
+        if since_token:
+            if not since_token.startswith("d"):
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "from parameter %r has an invalid format" % (since_token,),
+                    errcode=Codes.INVALID_PARAM,
+                )
+
+            try:
+                since_stream_id = int(since_token[1:])
+            except Exception:
+                raise SynapseError(
+                    HTTPStatus.BAD_REQUEST,
+                    "from parameter %r has an invalid format" % (since_token,),
+                    errcode=Codes.INVALID_PARAM,
+                )
+
+            # if we have a since token, delete any to-device messages before that token
+            # (since we now know that the device has received them)
+            deleted = await self.store.delete_messages_for_device(
+                user_id, device_id, since_stream_id
+            )
+            logger.debug(
+                "Deleted %d to-device messages up to %d for user_id %s device_id %s",
+                deleted,
+                since_stream_id,
+                user_id,
+                device_id,
+            )
+
+        to_token = self.event_sources.get_current_token().to_device_key
+
+        messages, stream_id = await self.store.get_messages_for_device(
+            user_id, device_id, since_stream_id, to_token, limit
+        )
+
+        for message in messages:
+            # Remove the message id before sending to client
+            message_id = message.pop("message_id", None)
+            if message_id:
+                set_tag(SynapseTags.TO_DEVICE_EDU_ID, message_id)
+
+        logger.debug(
+            "Returning %d to-device messages between %d and %d (current token: %d) for "
+            "dehydrated device %s, user_id %s",
+            len(messages),
+            since_stream_id,
+            stream_id,
+            to_token,
+            device_id,
+            user_id,
+        )
+
+        return {
+            "events": messages,
+            "next_batch": f"d{stream_id}",
+        }
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 38dff9703f..690d2ec406 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -14,19 +14,22 @@
 # limitations under the License.
 
 import logging
+from http import HTTPStatus
 from typing import TYPE_CHECKING, List, Optional, Tuple
 
 from pydantic import Extra, StrictStr
 
 from synapse.api import errors
-from synapse.api.errors import NotFoundError, UnrecognizedRequestError
+from synapse.api.errors import NotFoundError, SynapseError, UnrecognizedRequestError
 from synapse.handlers.device import DeviceHandler
 from synapse.http.server import HttpServer
 from synapse.http.servlet import (
     RestServlet,
     parse_and_validate_json_object_from_request,
+    parse_integer,
 )
 from synapse.http.site import SynapseRequest
+from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
 from synapse.rest.client._base import client_patterns, interactive_auth_handler
 from synapse.rest.client.models import AuthenticationData
 from synapse.rest.models import RequestBodyModel
@@ -229,6 +232,8 @@ class DehydratedDeviceDataModel(RequestBodyModel):
 class DehydratedDeviceServlet(RestServlet):
     """Retrieve or store a dehydrated device.
 
+    Implements either MSC2697 or MSC3814.
+
     GET /org.matrix.msc2697.v2/dehydrated_device
 
     HTTP/1.1 200 OK
@@ -261,9 +266,7 @@ class DehydratedDeviceServlet(RestServlet):
 
     """
 
-    PATTERNS = client_patterns("/org.matrix.msc2697.v2/dehydrated_device$", releases=())
-
-    def __init__(self, hs: "HomeServer"):
+    def __init__(self, hs: "HomeServer", msc2697: bool = True):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
@@ -271,6 +274,13 @@ class DehydratedDeviceServlet(RestServlet):
         assert isinstance(handler, DeviceHandler)
         self.device_handler = handler
 
+        self.PATTERNS = client_patterns(
+            "/org.matrix.msc2697.v2/dehydrated_device$"
+            if msc2697
+            else "/org.matrix.msc3814.v1/dehydrated_device$",
+            releases=(),
+        )
+
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
         dehydrated_device = await self.device_handler.get_dehydrated_device(
@@ -293,6 +303,7 @@ class DehydratedDeviceServlet(RestServlet):
 
         device_id = await self.device_handler.store_dehydrated_device(
             requester.user.to_string(),
+            None,
             submission.device_data.dict(),
             submission.initial_device_display_name,
         )
@@ -347,6 +358,210 @@ class ClaimDehydratedDeviceServlet(RestServlet):
         return 200, result
 
 
+class DehydratedDeviceEventsServlet(RestServlet):
+    PATTERNS = client_patterns(
+        "/org.matrix.msc3814.v1/dehydrated_device/(?P<device_id>[^/]*)/events$",
+        releases=(),
+    )
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self.message_handler = hs.get_device_message_handler()
+        self.auth = hs.get_auth()
+        self.store = hs.get_datastores().main
+
+    class PostBody(RequestBodyModel):
+        next_batch: Optional[StrictStr]
+
+    async def on_POST(
+        self, request: SynapseRequest, device_id: str
+    ) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+
+        next_batch = parse_and_validate_json_object_from_request(
+            request, self.PostBody
+        ).next_batch
+        limit = parse_integer(request, "limit", 100)
+
+        msgs = await self.message_handler.get_events_for_dehydrated_device(
+            requester=requester,
+            device_id=device_id,
+            since_token=next_batch,
+            limit=limit,
+        )
+
+        return 200, msgs
+
+
+class DehydratedDeviceV2Servlet(RestServlet):
+    """Upload, retrieve, or delete a dehydrated device.
+
+    GET /org.matrix.msc3814.v1/dehydrated_device
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+
+    {
+      "device_id": "dehydrated_device_id",
+      "device_data": {
+        "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+        "account": "dehydrated_device"
+      }
+    }
+
+    PUT /org.matrix.msc3814.v1/dehydrated_device
+    Content-Type: application/json
+
+    {
+        "device_id": "dehydrated_device_id",
+        "device_data": {
+            "algorithm": "org.matrix.msc2697.v1.dehydration.v1.olm",
+            "account": "dehydrated_device"
+        },
+        "device_keys": {
+            "user_id": "<user_id>",
+            "device_id": "<device_id>",
+            "valid_until_ts": <millisecond_timestamp>,
+            "algorithms": [
+                "m.olm.curve25519-aes-sha2",
+            ]
+            "keys": {
+                "<algorithm>:<device_id>": "<key_base64>",
+            },
+            "signatures:" {
+                "<user_id>" {
+                    "<algorithm>:<device_id>": "<signature_base64>"
+                }
+            }
+        },
+        "fallback_keys": {
+            "<algorithm>:<device_id>": "<key_base64>",
+            "signed_<algorithm>:<device_id>": {
+                "fallback": true,
+                "key": "<key_base64>",
+                "signatures": {
+                    "<user_id>": {
+                        "<algorithm>:<device_id>": "<key_base64>"
+                    }
+                }
+            }
+        }
+        "one_time_keys": {
+            "<algorithm>:<key_id>": "<key_base64>"
+        },
+
+    }
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+
+    {
+      "device_id": "dehydrated_device_id"
+    }
+
+    DELETE /org.matrix.msc3814.v1/dehydrated_device
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+
+    {
+      "device_id": "dehydrated_device_id",
+    }
+    """
+
+    PATTERNS = [
+        *client_patterns("/org.matrix.msc3814.v1/dehydrated_device$", releases=()),
+    ]
+
+    def __init__(self, hs: "HomeServer"):
+        super().__init__()
+        self.hs = hs
+        self.auth = hs.get_auth()
+        handler = hs.get_device_handler()
+        assert isinstance(handler, DeviceHandler)
+        self.e2e_keys_handler = hs.get_e2e_keys_handler()
+        self.device_handler = handler
+
+        if hs.config.worker.worker_app is None:
+            # if main process
+            self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
+        else:
+            # then a worker
+            self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
+
+    async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+
+        dehydrated_device = await self.device_handler.get_dehydrated_device(
+            requester.user.to_string()
+        )
+
+        if dehydrated_device is not None:
+            (device_id, device_data) = dehydrated_device
+            result = {"device_id": device_id, "device_data": device_data}
+            return 200, result
+        else:
+            raise errors.NotFoundError("No dehydrated device available")
+
+    async def on_DELETE(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        requester = await self.auth.get_user_by_req(request)
+
+        dehydrated_device = await self.device_handler.get_dehydrated_device(
+            requester.user.to_string()
+        )
+
+        if dehydrated_device is not None:
+            (device_id, device_data) = dehydrated_device
+
+            result = await self.device_handler.rehydrate_device(
+                requester.user.to_string(),
+                self.auth.get_access_token_from_request(request),
+                device_id,
+            )
+
+            result = {"device_id": device_id}
+
+            return 200, result
+        else:
+            raise errors.NotFoundError("No dehydrated device available")
+
+    class PutBody(RequestBodyModel):
+        device_data: DehydratedDeviceDataModel
+        device_id: StrictStr
+        initial_device_display_name: Optional[StrictStr]
+
+        class Config:
+            extra = Extra.allow
+
+    async def on_PUT(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
+        submission = parse_and_validate_json_object_from_request(request, self.PutBody)
+        requester = await self.auth.get_user_by_req(request)
+        user_id = requester.user.to_string()
+
+        device_info = submission.dict()
+        if "device_keys" not in device_info.keys():
+            raise SynapseError(
+                HTTPStatus.BAD_REQUEST,
+                "Device key(s) not found, these must be provided.",
+            )
+
+        # TODO: Those two operations, creating a device and storing the
+        # device's keys should be atomic.
+        device_id = await self.device_handler.store_dehydrated_device(
+            requester.user.to_string(),
+            submission.device_id,
+            submission.device_data.dict(),
+            submission.initial_device_display_name,
+        )
+
+        # TODO: Do we need to do something with the result here?
+        await self.key_uploader(
+            user_id=user_id, device_id=submission.device_id, keys=submission.dict()
+        )
+
+        return 200, {"device_id": device_id}
+
+
 def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     if (
         hs.config.worker.worker_app is None
@@ -354,7 +569,12 @@ def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None:
     ):
         DeleteDevicesRestServlet(hs).register(http_server)
     DevicesRestServlet(hs).register(http_server)
+
     if hs.config.worker.worker_app is None:
         DeviceRestServlet(hs).register(http_server)
-        DehydratedDeviceServlet(hs).register(http_server)
-        ClaimDehydratedDeviceServlet(hs).register(http_server)
+        if hs.config.experimental.msc2697_enabled:
+            DehydratedDeviceServlet(hs, msc2697=True).register(http_server)
+            ClaimDehydratedDeviceServlet(hs).register(http_server)
+        if hs.config.experimental.msc3814_enabled:
+            DehydratedDeviceV2Servlet(hs).register(http_server)
+            DehydratedDeviceEventsServlet(hs).register(http_server)