summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16010.misc1
-rw-r--r--synapse/handlers/device.py14
-rw-r--r--synapse/handlers/devicemessage.py13
-rw-r--r--synapse/rest/client/devices.py16
-rw-r--r--synapse/storage/databases/main/devices.py51
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py170
-rw-r--r--tests/handlers/test_device.py9
-rw-r--r--tests/rest/client/test_devices.py77
8 files changed, 254 insertions, 97 deletions
diff --git a/changelog.d/16010.misc b/changelog.d/16010.misc
new file mode 100644
index 0000000000..1e1a148069
--- /dev/null
+++ b/changelog.d/16010.misc
@@ -0,0 +1 @@
+Update dehydrated devices implementation.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index b7bf70a72d..5ae427d52c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -385,6 +385,7 @@ class DeviceHandler(DeviceWorkerHandler):
         self.federation_sender = hs.get_federation_sender()
         self._account_data_handler = hs.get_account_data_handler()
         self._storage_controllers = hs.get_storage_controllers()
+        self.db_pool = hs.get_datastores().main.db_pool
 
         self.device_list_updater = DeviceListUpdater(hs, self)
 
@@ -656,15 +657,17 @@ class DeviceHandler(DeviceWorkerHandler):
         device_id: Optional[str],
         device_data: JsonDict,
         initial_device_display_name: Optional[str] = None,
+        keys_for_device: Optional[JsonDict] = None,
     ) -> str:
-        """Store a dehydrated device for a user.  If the user had a previous
-        dehydrated device, it is removed.
+        """Store a dehydrated device for a user, optionally storing the keys associated with
+        it as well.  If the user had a previous dehydrated device, it is removed.
 
         Args:
             user_id: the user that we are storing the device for
             device_id: device id supplied by client
             device_data: the dehydrated device information
             initial_device_display_name: The display name to use for the device
+            keys_for_device: keys for the dehydrated device
         Returns:
             device id of the dehydrated device
         """
@@ -673,11 +676,16 @@ class DeviceHandler(DeviceWorkerHandler):
             device_id,
             initial_device_display_name,
         )
+
+        time_now = self.clock.time_msec()
+
         old_device_id = await self.store.store_dehydrated_device(
-            user_id, device_id, device_data
+            user_id, device_id, device_data, time_now, keys_for_device
         )
+
         if old_device_id is not None:
             await self.delete_devices(user_id, [old_device_id])
+
         return device_id
 
     async def rehydrate_device(
diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py
index 15e94a03cb..17ff8821d9 100644
--- a/synapse/handlers/devicemessage.py
+++ b/synapse/handlers/devicemessage.py
@@ -367,19 +367,6 @@ class DeviceMessageHandler:
                     errcode=Codes.INVALID_PARAM,
                 )
 
-            # if we have a since token, delete any to-device messages before that token
-            # (since we now know that the device has received them)
-            deleted = await self.store.delete_messages_for_device(
-                user_id, device_id, since_stream_id
-            )
-            logger.debug(
-                "Deleted %d to-device messages up to %d for user_id %s device_id %s",
-                deleted,
-                since_stream_id,
-                user_id,
-                device_id,
-            )
-
         to_token = self.event_sources.get_current_token().to_device_key
 
         messages, stream_id = await self.store.get_messages_for_device(
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 51f17f80da..925f037743 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -29,7 +29,6 @@ from synapse.http.servlet import (
     parse_integer,
 )
 from synapse.http.site import SynapseRequest
-from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet
 from synapse.rest.client._base import client_patterns, interactive_auth_handler
 from synapse.rest.client.models import AuthenticationData
 from synapse.rest.models import RequestBodyModel
@@ -480,13 +479,6 @@ class DehydratedDeviceV2Servlet(RestServlet):
         self.e2e_keys_handler = hs.get_e2e_keys_handler()
         self.device_handler = handler
 
-        if hs.config.worker.worker_app is None:
-            # if main process
-            self.key_uploader = self.e2e_keys_handler.upload_keys_for_user
-        else:
-            # then a worker
-            self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs)
-
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         requester = await self.auth.get_user_by_req(request)
 
@@ -549,18 +541,12 @@ class DehydratedDeviceV2Servlet(RestServlet):
                 "Device key(s) not found, these must be provided.",
             )
 
-        # TODO: Those two operations, creating a device and storing the
-        # device's keys should be atomic.
         device_id = await self.device_handler.store_dehydrated_device(
             requester.user.to_string(),
             submission.device_id,
             submission.device_data.dict(),
             submission.initial_device_display_name,
-        )
-
-        # TODO: Do we need to do something with the result here?
-        await self.key_uploader(
-            user_id=user_id, device_id=submission.device_id, keys=submission.dict()
+            device_info,
         )
 
         return 200, {"device_id": device_id}
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d9df437e51..e4162f846b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
     cast,
 )
 
+from canonicaljson import encode_canonical_json
 from typing_extensions import Literal
 
 from synapse.api.constants import EduTypes
@@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         )
 
     def _store_dehydrated_device_txn(
-        self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        device_data: str,
+        time: int,
+        keys: Optional[JsonDict] = None,
     ) -> Optional[str]:
+        # TODO: make keys non-optional once support for msc2697 is dropped
+        if keys:
+            device_keys = keys.get("device_keys", None)
+            if device_keys:
+                # Type ignore - this function is defined on EndToEndKeyStore which we do
+                # have access to due to hs.get_datastore() "magic"
+                self._set_e2e_device_keys_txn(  # type: ignore[attr-defined]
+                    txn, user_id, device_id, time, device_keys
+                )
+
+            one_time_keys = keys.get("one_time_keys", None)
+            if one_time_keys:
+                key_list = []
+                for key_id, key_obj in one_time_keys.items():
+                    algorithm, key_id = key_id.split(":")
+                    key_list.append(
+                        (
+                            algorithm,
+                            key_id,
+                            encode_canonical_json(key_obj).decode("ascii"),
+                        )
+                    )
+                self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
+
+            fallback_keys = keys.get("fallback_keys", None)
+            if fallback_keys:
+                self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
+
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="dehydrated_devices",
@@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             keyvalues={"user_id": user_id},
             values={"device_id": device_id, "device_data": device_data},
         )
+
         return old_device_id
 
     async def store_dehydrated_device(
-        self, user_id: str, device_id: str, device_data: JsonDict
+        self,
+        user_id: str,
+        device_id: str,
+        device_data: JsonDict,
+        time_now: int,
+        keys: Optional[dict] = None,
     ) -> Optional[str]:
         """Store a dehydrated device for a user.
 
@@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             user_id: the user that we are storing the device for
             device_id: the ID of the dehydrated device
             device_data: the dehydrated device information
+            time_now: current time at the request in milliseconds
+            keys: keys for the dehydrated device
+
         Returns:
             device id of the user's previous dehydrated device, if any
         """
+
         return await self.db_pool.runInteraction(
             "store_dehydrated_device_txn",
             self._store_dehydrated_device_txn,
             user_id,
             device_id,
             json_encoder.encode(device_data),
+            time_now,
+            keys,
         )
 
     async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 91ae9c457d..b49dea577c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
         """
 
-        def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("new_keys", str(new_keys))
-            # We are protected from race between lookup and insertion due to
-            # a unique constraint. If there is a race of two calls to
-            # `add_e2e_one_time_keys` then they'll conflict and we will only
-            # insert one set.
-            self.db_pool.simple_insert_many_txn(
-                txn,
-                table="e2e_one_time_keys_json",
-                keys=(
-                    "user_id",
-                    "device_id",
-                    "algorithm",
-                    "key_id",
-                    "ts_added_ms",
-                    "key_json",
-                ),
-                values=[
-                    (user_id, device_id, algorithm, key_id, time_now, json_bytes)
-                    for algorithm, key_id, json_bytes in new_keys
-                ],
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
-
         await self.db_pool.runInteraction(
-            "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
+            "add_e2e_one_time_keys_insert",
+            self._add_e2e_one_time_keys_txn,
+            user_id,
+            device_id,
+            time_now,
+            new_keys,
+        )
+
+    def _add_e2e_one_time_keys_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        time_now: int,
+        new_keys: Iterable[Tuple[str, str, str]],
+    ) -> None:
+        """Insert some new one time keys for a device. Errors if any of the keys already exist.
+
+        Args:
+             user_id: id of user to get keys for
+             device_id: id of device to get keys for
+             time_now: insertion time to record (ms since epoch)
+             new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note
+             that the key JSON must be in canonical JSON form
+        """
+        set_tag("user_id", user_id)
+        set_tag("device_id", device_id)
+        set_tag("new_keys", str(new_keys))
+        # We are protected from race between lookup and insertion due to
+        # a unique constraint. If there is a race of two calls to
+        # `add_e2e_one_time_keys` then they'll conflict and we will only
+        # insert one set.
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="e2e_one_time_keys_json",
+            keys=(
+                "user_id",
+                "device_id",
+                "algorithm",
+                "key_id",
+                "ts_added_ms",
+                "key_json",
+            ),
+            values=[
+                (user_id, device_id, algorithm, key_id, time_now, json_bytes)
+                for algorithm, key_id, json_bytes in new_keys
+            ],
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.count_e2e_one_time_keys, (user_id, device_id)
         )
 
     @cached(max_entries=10000)
@@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         device_id: str,
         fallback_keys: JsonDict,
     ) -> None:
+        """Set the user's e2e fallback keys.
+
+        Args:
+            user_id: the user whose keys are being set
+            device_id: the device whose keys are being set
+            fallback_keys: the keys to set.  This is a map from key ID (which is
+                    of the form "algorithm:id") to key data.
+        """
         # fallback_keys will usually only have one item in it, so using a for
         # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
         # FIXME: make sure that only one key per algorithm is uploaded
@@ -1304,42 +1333,69 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     ) -> bool:
         """Stores device keys for a device. Returns whether there was a change
         or the keys were already in the database.
+
+            Args:
+                user_id: user_id of the user to store keys for
+                device_id: device_id of the device to store keys for
+                time_now: time at the request to store the keys
+                device_keys: the keys to store
         """
 
-        def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("time_now", time_now)
-            set_tag("device_keys", str(device_keys))
+        return await self.db_pool.runInteraction(
+            "set_e2e_device_keys",
+            self._set_e2e_device_keys_txn,
+            user_id,
+            device_id,
+            time_now,
+            device_keys,
+        )
 
-            old_key_json = self.db_pool.simple_select_one_onecol_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                retcol="key_json",
-                allow_none=True,
-            )
+    def _set_e2e_device_keys_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        time_now: int,
+        device_keys: JsonDict,
+    ) -> bool:
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
 
-            # In py3 we need old_key_json to match new_key_json type. The DB
-            # returns unicode while encode_canonical_json returns bytes.
-            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+        Args:
+             user_id: user_id of the user to store keys for
+             device_id: device_id of the device to store keys for
+             time_now: time at the request to store the keys
+             device_keys: the keys to store
+        """
+        set_tag("user_id", user_id)
+        set_tag("device_id", device_id)
+        set_tag("time_now", time_now)
+        set_tag("device_keys", str(device_keys))
+
+        old_key_json = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="e2e_device_keys_json",
+            keyvalues={"user_id": user_id, "device_id": device_id},
+            retcol="key_json",
+            allow_none=True,
+        )
 
-            if old_key_json == new_key_json:
-                log_kv({"Message": "Device key already stored."})
-                return False
+        # In py3 we need old_key_json to match new_key_json type. The DB
+        # returns unicode while encode_canonical_json returns bytes.
+        new_key_json = encode_canonical_json(device_keys).decode("utf-8")
 
-            self.db_pool.simple_upsert_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                values={"ts_added_ms": time_now, "key_json": new_key_json},
-            )
-            log_kv({"message": "Device keys stored."})
-            return True
+        if old_key_json == new_key_json:
+            log_kv({"Message": "Device key already stored."})
+            return False
 
-        return await self.db_pool.runInteraction(
-            "set_e2e_device_keys", _set_e2e_device_keys_txn
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="e2e_device_keys_json",
+            keyvalues={"user_id": user_id, "device_id": device_id},
+            values={"ts_added_ms": time_now, "key_json": new_key_json},
         )
+        log_kv({"message": "Device keys stored."})
+        return True
 
     async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
         def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None:
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 647ee09279..e1e58fa6e6 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -566,15 +566,16 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(res["events"]), 1)
         self.assertEqual(res["events"][0]["content"]["body"], "foo")
 
-        # Fetch the message of the dehydrated device again, which should return nothing
-        # and delete the old messages
+        # Fetch the message of the dehydrated device again, which should return
+        # the same message as it has not been deleted
         res = self.get_success(
             self.message_handler.get_events_for_dehydrated_device(
                 requester=requester,
                 device_id=stored_dehydrated_device_id,
-                since_token=res["next_batch"],
+                since_token=None,
                 limit=10,
             )
         )
         self.assertTrue(len(res["next_batch"]) > 1)
-        self.assertEqual(len(res["events"]), 0)
+        self.assertEqual(len(res["events"]), 1)
+        self.assertEqual(res["events"][0]["content"]["body"], "foo")
diff --git a/tests/rest/client/test_devices.py b/tests/rest/client/test_devices.py
index 3cf29c10ea..60099f8c59 100644
--- a/tests/rest/client/test_devices.py
+++ b/tests/rest/client/test_devices.py
@@ -20,7 +20,7 @@ from synapse.api.errors import NotFoundError
 from synapse.rest import admin, devices, room, sync
 from synapse.rest.client import account, keys, login, register
 from synapse.server import HomeServer
-from synapse.types import JsonDict, create_requester
+from synapse.types import JsonDict, UserID, create_requester
 from synapse.util import Clock
 
 from tests import unittest
@@ -282,6 +282,17 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
                     "<user_id>": {"<algorithm>:<device_id>": "<signature_base64>"}
                 },
             },
+            "fallback_keys": {
+                "alg1:device1": "f4llb4ckk3y",
+                "signed_<algorithm>:<device_id>": {
+                    "fallback": "true",
+                    "key": "f4llb4ckk3y",
+                    "signatures": {
+                        "<user_id>": {"<algorithm>:<device_id>": "<key_base64>"}
+                    },
+                },
+            },
+            "one_time_keys": {"alg1:k1": "0net1m3k3y"},
         }
         channel = self.make_request(
             "PUT",
@@ -312,6 +323,55 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         }
         self.assertEqual(device_data, expected_device_data)
 
+        # test that the keys are correctly uploaded
+        channel = self.make_request(
+            "POST",
+            "/_matrix/client/r0/keys/query",
+            {
+                "device_keys": {
+                    user: ["device1"],
+                },
+            },
+            token,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(
+            channel.json_body["device_keys"][user][device_id]["keys"],
+            content["device_keys"]["keys"],
+        )
+        # first claim should return the onetime key we uploaded
+        res = self.get_success(
+            self.hs.get_e2e_keys_handler().claim_one_time_keys(
+                {user: {device_id: {"alg1": 1}}},
+                UserID.from_string(user),
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        self.assertEqual(
+            res,
+            {
+                "failures": {},
+                "one_time_keys": {user: {device_id: {"alg1:k1": "0net1m3k3y"}}},
+            },
+        )
+        # second claim should return fallback key
+        res2 = self.get_success(
+            self.hs.get_e2e_keys_handler().claim_one_time_keys(
+                {user: {device_id: {"alg1": 1}}},
+                UserID.from_string(user),
+                timeout=None,
+                always_include_fallback_keys=False,
+            )
+        )
+        self.assertEqual(
+            res2,
+            {
+                "failures": {},
+                "one_time_keys": {user: {device_id: {"alg1:device1": "f4llb4ckk3y"}}},
+            },
+        )
+
         # create another device for the user
         (
             new_device_id,
@@ -348,10 +408,21 @@ class DehydratedDeviceTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200)
         expected_content = {"body": "test_message"}
         self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
+
+        # fetch messages again and make sure that the message was not deleted
+        channel = self.make_request(
+            "POST",
+            f"_matrix/client/unstable/org.matrix.msc3814.v1/dehydrated_device/{device_id}/events",
+            content={},
+            access_token=token,
+            shorthand=False,
+        )
+        self.assertEqual(channel.code, 200)
+        self.assertEqual(channel.json_body["events"][0]["content"], expected_content)
         next_batch_token = channel.json_body.get("next_batch")
 
-        # fetch messages again and make sure that the message was deleted and we are returned an
-        # empty array
+        # make sure fetching messages with next batch token works - there are no unfetched
+        # messages so we should receive an empty array
         content = {"next_batch": next_batch_token}
         channel = self.make_request(
             "POST",