summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2023-08-08 12:04:46 -0700
committerGitHub <noreply@github.com>2023-08-08 12:04:46 -0700
commit0328b56468fe12c4d86ef636b60964527a510160 (patch)
tree2cff86a1e4518c90f234d4010419c049082a147f /synapse
parentFixup changelog (diff)
downloadsynapse-0328b56468fe12c4d86ef636b60964527a510160.tar.xz
Support MSC3814: Dehydrated Devices Part 2 (#16010)
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/device.py14
-rw-r--r--synapse/handlers/devicemessage.py13
-rw-r--r--synapse/rest/client/devices.py16
-rw-r--r--synapse/storage/databases/main/devices.py51
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py170
5 files changed, 174 insertions, 90 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b7bf70a72d..5ae427d52c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -385,6 +385,7 @@ class DeviceHandler(DeviceWorkerHandler):
         self.federation_sender = hs.get_federation_sender()
         self._account_data_handler = hs.get_account_data_handler()
         self._storage_controllers = hs.get_storage_controllers()
+        self.db_pool = hs.get_datastores().main.db_pool
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -656,15 +657,17 @@ class DeviceHandler(DeviceWorkerHandler):
         device_id: Optional[str],
         device_data: JsonDict,
         initial_device_display_name: Optional[str] = None,
+        keys_for_device: Optional[JsonDict] = None,
     ) -> str:
-        """Store a dehydrated device for a user.  If the user had a previous
-        dehydrated device, it is removed.
+        """Store a dehydrated device for a user, optionally storing the keys associated with
+        it as well.  If the user had a previous dehydrated device, it is removed.
 
         Args:
             user_id: the user that we are storing the device for
             device_id: device id supplied by client
             device_data: the dehydrated device information
             initial_device_display_name: The display name to use for the device
+            keys_for_device: keys for the dehydrated device
         Returns:
             device id of the dehydrated device
         """
@@ -673,11 +676,16 @@ class DeviceHandler(DeviceWorkerHandler):
             device_id,
             initial_device_display_name,
         )
+
+        time_now = self.clock.time_msec()
+
         old_device_id = await self.store.store_dehydrated_device(
-            user_id, device_id, device_data
+            user_id, device_id, device_data, time_now, keys_for_device
         )
+
         if old_device_id is not None:
             await self.delete_devices(user_id, [old_device_id])
+
         return device_id
 
     async def rehydrate_device(
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 15e94a03cb..17ff8821d9 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -367,19 +367,6 @@ class DeviceMessageHandler:
                     errcode=Codes.INVALID_PARAM,
                 )
 
-            # if we have a since token, delete any to-device messages before that token
-            # (since we now know that the device has received them)
-            deleted = await self.store.delete_messages_for_device(
-                user_id, device_id, since_stream_id
-            )
-            logger.debug(
-                "Deleted %d to-device messages up to %d for user_id %s device_id %s",
-                deleted,
-                since_stream_id,
-                user_id,
-                device_id,
-            )
-
         to_token = self.event_sources.get_current_token().to_device_key
 
         messages, stream_id = await self.store.get_messages_for_device(
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 51f17f80da..925f037743 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -29,7 +29,6 @@ from synapse.http.servlet import (
     parse_integer,
 )
 from synapse.http.site import SynapseRequest
-from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
 from synapse.rest.client._base import client_patterns, interactive_auth_handler
 from synapse.rest.client.models import AuthenticationData
 from synapse.rest.models import RequestBodyModel
@@ -480,13 +479,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
         self.device_handler = handler
 
-        if hs.config.worker.worker_app is None:
-            # if main process
-            self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
-        else:
-            # then a worker
-            self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
-
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
@@ -549,18 +541,12 @@ class DehydratedDeviceV2Servlet(RestServlet):
                 "Device key(s) not found, these must be provided.",
             )
 
-        # TODO: Those two operations, creating a device and storing the
-        # device's keys should be atomic.
         device_id = await self.device_handler.store_dehydrated_device(
             requester.user.to_string(),
             submission.device_id,
             submission.device_data.dict(),
             submission.initial_device_display_name,
-        )
-
-        # TODO: Do we need to do something with the result here?
-        await self.key_uploader(
-            user_id=user_id, device_id=submission.device_id, keys=submission.dict()
+            device_info,
         )
 
         return 200, {"device_id": device_id}
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d9df437e51..e4162f846b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
     cast,
 )
 
+from canonicaljson import encode_canonical_json
 from typing_extensions import Literal
 
 from synapse.api.constants import EduTypes
@@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         )
 
     def _store_dehydrated_device_txn(
-        self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        device_data: str,
+        time: int,
+        keys: Optional[JsonDict] = None,
     ) -> Optional[str]:
+        # TODO: make keys non-optional once support for msc2697 is dropped
+        if keys:
+            device_keys = keys.get("device_keys", None)
+            if device_keys:
+                # Type ignore - this function is defined on EndToEndKeyStore which we do
+                # have access to due to hs.get_datastore() "magic"
+                self._set_e2e_device_keys_txn(  # type: ignore[attr-defined]
+                    txn, user_id, device_id, time, device_keys
+                )
+
+            one_time_keys = keys.get("one_time_keys", None)
+            if one_time_keys:
+                key_list = []
+                for key_id, key_obj in one_time_keys.items():
+                    algorithm, key_id = key_id.split(":")
+                    key_list.append(
+                        (
+                            algorithm,
+                            key_id,
+                            encode_canonical_json(key_obj).decode("ascii"),
+                        )
+                    )
+                self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
+
+            fallback_keys = keys.get("fallback_keys", None)
+            if fallback_keys:
+                self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
+
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="dehydrated_devices",
@@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             keyvalues={"user_id": user_id},
             values={"device_id": device_id, "device_data": device_data},
         )
+
         return old_device_id
 
     async def store_dehydrated_device(
-        self, user_id: str, device_id: str, device_data: JsonDict
+        self,
+        user_id: str,
+        device_id: str,
+        device_data: JsonDict,
+        time_now: int,
+        keys: Optional[dict] = None,
     ) -> Optional[str]:
         """Store a dehydrated device for a user.
 
@@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             user_id: the user that we are storing the device for
             device_id: the ID of the dehydrated device
             device_data: the dehydrated device information
+            time_now: current time at the request in milliseconds
+            keys: keys for the dehydrated device
+
         Returns:
             device id of the user's previous dehydrated device, if any
         """
+
         return await self.db_pool.runInteraction(
             "store_dehydrated_device_txn",
             self._store_dehydrated_device_txn,
             user_id,
             device_id,
             json_encoder.encode(device_data),
+            time_now,
+            keys,
         )
 
     async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 91ae9c457d..b49dea577c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
         """
 
-        def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("new_keys", str(new_keys))
-            # We are protected from race between lookup and insertion due to
-            # a unique constraint. If there is a race of two calls to
-            # `add_e2e_one_time_keys` then they'll conflict and we will only
-            # insert one set.
-            self.db_pool.simple_insert_many_txn(
-                txn,
-                table="e2e_one_time_keys_json",
-                keys=(
-                    "user_id",
-                    "device_id",
-                    "algorithm",
-                    "key_id",
-                    "ts_added_ms",
-                    "key_json",
-                ),
-                values=[
-                    (user_id, device_id, algorithm, key_id, time_now, json_bytes)
-                    for algorithm, key_id, json_bytes in new_keys
-                ],
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
-
         await self.db_pool.runInteraction(
-            "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
+            "add_e2e_one_time_keys_insert",
+            self._add_e2e_one_time_keys_txn,
+            user_id,
+            device_id,
+            time_now,
+            new_keys,
+        )
+
+    def _add_e2e_one_time_keys_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        time_now: int,
+        new_keys: Iterable[Tuple[str, str, str]],
+    ) -> None:
+        """Insert some new one time keys for a device. Errors if any of the keys already exist.
+
+        Args:
+             user_id: id of user to get keys for
+             device_id: id of device to get keys for
+             time_now: insertion time to record (ms since epoch)
+             new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note
+             that the key JSON must be in canonical JSON form
+        """
+        set_tag("user_id", user_id)
+        set_tag("device_id", device_id)
+        set_tag("new_keys", str(new_keys))
+        # We are protected from race between lookup and insertion due to
+        # a unique constraint. If there is a race of two calls to
+        # `add_e2e_one_time_keys` then they'll conflict and we will only
+        # insert one set.
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="e2e_one_time_keys_json",
+            keys=(
+                "user_id",
+                "device_id",
+                "algorithm",
+                "key_id",
+                "ts_added_ms",
+                "key_json",
+            ),
+            values=[
+                (user_id, device_id, algorithm, key_id, time_now, json_bytes)
+                for algorithm, key_id, json_bytes in new_keys
+            ],
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.count_e2e_one_time_keys, (user_id, device_id)
         )
 
     @cached(max_entries=10000)
@@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         device_id: str,
         fallback_keys: JsonDict,
     ) -> None:
+        """Set the user's e2e fallback keys.
+
+        Args:
+            user_id: the user whose keys are being set
+            device_id: the device whose keys are being set
+            fallback_keys: the keys to set.  This is a map from key ID (which is
+                    of the form "algorithm:id") to key data.
+        """
         # fallback_keys will usually only have one item in it, so using a for
         # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
         # FIXME: make sure that only one key per algorithm is uploaded
@@ -1304,42 +1333,69 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     ) -> bool:
         """Stores device keys for a device. Returns whether there was a change
         or the keys were already in the database.
+
+            Args:
+                user_id: user_id of the user to store keys for
+                device_id: device_id of the device to store keys for
+                time_now: time at the request to store the keys
+                device_keys: the keys to store
         """
 
-        def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("time_now", time_now)
-            set_tag("device_keys", str(device_keys))
+        return await self.db_pool.runInteraction(
+            "set_e2e_device_keys",
+            self._set_e2e_device_keys_txn,
+            user_id,
+            device_id,
+            time_now,
+            device_keys,
+        )
 
-            old_key_json = self.db_pool.simple_select_one_onecol_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                retcol="key_json",
-                allow_none=True,
-            )
+    def _set_e2e_device_keys_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        time_now: int,
+        device_keys: JsonDict,
+    ) -> bool:
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
 
-            # In py3 we need old_key_json to match new_key_json type. The DB
-            # returns unicode while encode_canonical_json returns bytes.
-            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+        Args:
+             user_id: user_id of the user to store keys for
+             device_id: device_id of the device to store keys for
+             time_now: time at the request to store the keys
+             device_keys: the keys to store
+        """
+        set_tag("user_id", user_id)
+        set_tag("device_id", device_id)
+        set_tag("time_now", time_now)
+        set_tag("device_keys", str(device_keys))
+
+        old_key_json = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="e2e_device_keys_json",
+            keyvalues={"user_id": user_id, "device_id": device_id},
+            retcol="key_json",
+            allow_none=True,
+        )
 
-            if old_key_json == new_key_json:
-                log_kv({"Message": "Device key already stored."})
-                return False
+        # In py3 we need old_key_json to match new_key_json type. The DB
+        # returns unicode while encode_canonical_json returns bytes.
+        new_key_json = encode_canonical_json(device_keys).decode("utf-8")
 
-            self.db_pool.simple_upsert_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                values={"ts_added_ms": time_now, "key_json": new_key_json},
-            )
-            log_kv({"message": "Device keys stored."})
-            return True
+        if old_key_json == new_key_json:
+            log_kv({"Message": "Device key already stored."})
+            return False
 
-        return await self.db_pool.runInteraction(
-            "set_e2e_device_keys", _set_e2e_device_keys_txn
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="e2e_device_keys_json",
+            keyvalues={"user_id": user_id, "device_id": device_id},
+            values={"ts_added_ms": time_now, "key_json": new_key_json},
         )
+        log_kv({"message": "Device keys stored."})
+        return True
 
     async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
         def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: