summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/devices.py51
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py170
2 files changed, 162 insertions, 59 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d9df437e51..e4162f846b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
     cast,
 )
 
+from canonicaljson import encode_canonical_json
 from typing_extensions import Literal
 
 from synapse.api.constants import EduTypes
@@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         )
 
     def _store_dehydrated_device_txn(
-        self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        device_data: str,
+        time: int,
+        keys: Optional[JsonDict] = None,
     ) -> Optional[str]:
+        # TODO: make keys non-optional once support for msc2697 is dropped
+        if keys:
+            device_keys = keys.get("device_keys", None)
+            if device_keys:
+                # Type ignore - this function is defined on EndToEndKeyStore which we do
+                # have access to due to hs.get_datastore() "magic"
+                self._set_e2e_device_keys_txn(  # type: ignore[attr-defined]
+                    txn, user_id, device_id, time, device_keys
+                )
+
+            one_time_keys = keys.get("one_time_keys", None)
+            if one_time_keys:
+                key_list = []
+                for key_id, key_obj in one_time_keys.items():
+                    algorithm, key_id = key_id.split(":")
+                    key_list.append(
+                        (
+                            algorithm,
+                            key_id,
+                            encode_canonical_json(key_obj).decode("ascii"),
+                        )
+                    )
+                self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
+
+            fallback_keys = keys.get("fallback_keys", None)
+            if fallback_keys:
+                self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
+
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="dehydrated_devices",
@@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             keyvalues={"user_id": user_id},
             values={"device_id": device_id, "device_data": device_data},
         )
+
         return old_device_id
 
     async def store_dehydrated_device(
-        self, user_id: str, device_id: str, device_data: JsonDict
+        self,
+        user_id: str,
+        device_id: str,
+        device_data: JsonDict,
+        time_now: int,
+        keys: Optional[dict] = None,
     ) -> Optional[str]:
         """Store a dehydrated device for a user.
 
@@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             user_id: the user that we are storing the device for
             device_id: the ID of the dehydrated device
             device_data: the dehydrated device information
+            time_now: current time at the request in milliseconds
+            keys: keys for the dehydrated device
+
         Returns:
             device id of the user's previous dehydrated device, if any
         """
+
         return await self.db_pool.runInteraction(
             "store_dehydrated_device_txn",
             self._store_dehydrated_device_txn,
             user_id,
             device_id,
             json_encoder.encode(device_data),
+            time_now,
+            keys,
         )
 
     async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 91ae9c457d..b49dea577c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
             new_keys: keys to add - each a tuple of (algorithm, key_id, key json)
         """
 
-        def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None:
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("new_keys", str(new_keys))
-            # We are protected from race between lookup and insertion due to
-            # a unique constraint. If there is a race of two calls to
-            # `add_e2e_one_time_keys` then they'll conflict and we will only
-            # insert one set.
-            self.db_pool.simple_insert_many_txn(
-                txn,
-                table="e2e_one_time_keys_json",
-                keys=(
-                    "user_id",
-                    "device_id",
-                    "algorithm",
-                    "key_id",
-                    "ts_added_ms",
-                    "key_json",
-                ),
-                values=[
-                    (user_id, device_id, algorithm, key_id, time_now, json_bytes)
-                    for algorithm, key_id, json_bytes in new_keys
-                ],
-            )
-            self._invalidate_cache_and_stream(
-                txn, self.count_e2e_one_time_keys, (user_id, device_id)
-            )
-
         await self.db_pool.runInteraction(
-            "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys
+            "add_e2e_one_time_keys_insert",
+            self._add_e2e_one_time_keys_txn,
+            user_id,
+            device_id,
+            time_now,
+            new_keys,
+        )
+
+    def _add_e2e_one_time_keys_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        time_now: int,
+        new_keys: Iterable[Tuple[str, str, str]],
+    ) -> None:
+        """Insert some new one time keys for a device. Errors if any of the keys already exist.
+
+        Args:
+             user_id: id of user to get keys for
+             device_id: id of device to get keys for
+             time_now: insertion time to record (ms since epoch)
+             new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note
+             that the key JSON must be in canonical JSON form
+        """
+        set_tag("user_id", user_id)
+        set_tag("device_id", device_id)
+        set_tag("new_keys", str(new_keys))
+        # We are protected from race between lookup and insertion due to
+        # a unique constraint. If there is a race of two calls to
+        # `add_e2e_one_time_keys` then they'll conflict and we will only
+        # insert one set.
+        self.db_pool.simple_insert_many_txn(
+            txn,
+            table="e2e_one_time_keys_json",
+            keys=(
+                "user_id",
+                "device_id",
+                "algorithm",
+                "key_id",
+                "ts_added_ms",
+                "key_json",
+            ),
+            values=[
+                (user_id, device_id, algorithm, key_id, time_now, json_bytes)
+                for algorithm, key_id, json_bytes in new_keys
+            ],
+        )
+        self._invalidate_cache_and_stream(
+            txn, self.count_e2e_one_time_keys, (user_id, device_id)
         )
 
     @cached(max_entries=10000)
@@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
         device_id: str,
         fallback_keys: JsonDict,
     ) -> None:
+        """Set the user's e2e fallback keys.
+
+        Args:
+            user_id: the user whose keys are being set
+            device_id: the device whose keys are being set
+            fallback_keys: the keys to set.  This is a map from key ID (which is
+                    of the form "algorithm:id") to key data.
+        """
         # fallback_keys will usually only have one item in it, so using a for
         # loop (as opposed to calling simple_upsert_many_txn) won't be too bad
         # FIXME: make sure that only one key per algorithm is uploaded
@@ -1304,42 +1333,69 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
     ) -> bool:
         """Stores device keys for a device. Returns whether there was a change
         or the keys were already in the database.
+
+            Args:
+                user_id: user_id of the user to store keys for
+                device_id: device_id of the device to store keys for
+                time_now: time at the request to store the keys
+                device_keys: the keys to store
         """
 
-        def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool:
-            set_tag("user_id", user_id)
-            set_tag("device_id", device_id)
-            set_tag("time_now", time_now)
-            set_tag("device_keys", str(device_keys))
+        return await self.db_pool.runInteraction(
+            "set_e2e_device_keys",
+            self._set_e2e_device_keys_txn,
+            user_id,
+            device_id,
+            time_now,
+            device_keys,
+        )
 
-            old_key_json = self.db_pool.simple_select_one_onecol_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                retcol="key_json",
-                allow_none=True,
-            )
+    def _set_e2e_device_keys_txn(
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        time_now: int,
+        device_keys: JsonDict,
+    ) -> bool:
+        """Stores device keys for a device. Returns whether there was a change
+        or the keys were already in the database.
 
-            # In py3 we need old_key_json to match new_key_json type. The DB
-            # returns unicode while encode_canonical_json returns bytes.
-            new_key_json = encode_canonical_json(device_keys).decode("utf-8")
+        Args:
+             user_id: user_id of the user to store keys for
+             device_id: device_id of the device to store keys for
+             time_now: time at the request to store the keys
+             device_keys: the keys to store
+        """
+        set_tag("user_id", user_id)
+        set_tag("device_id", device_id)
+        set_tag("time_now", time_now)
+        set_tag("device_keys", str(device_keys))
+
+        old_key_json = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="e2e_device_keys_json",
+            keyvalues={"user_id": user_id, "device_id": device_id},
+            retcol="key_json",
+            allow_none=True,
+        )
 
-            if old_key_json == new_key_json:
-                log_kv({"Message": "Device key already stored."})
-                return False
+        # In py3 we need old_key_json to match new_key_json type. The DB
+        # returns unicode while encode_canonical_json returns bytes.
+        new_key_json = encode_canonical_json(device_keys).decode("utf-8")
 
-            self.db_pool.simple_upsert_txn(
-                txn,
-                table="e2e_device_keys_json",
-                keyvalues={"user_id": user_id, "device_id": device_id},
-                values={"ts_added_ms": time_now, "key_json": new_key_json},
-            )
-            log_kv({"message": "Device keys stored."})
-            return True
+        if old_key_json == new_key_json:
+            log_kv({"Message": "Device key already stored."})
+            return False
 
-        return await self.db_pool.runInteraction(
-            "set_e2e_device_keys", _set_e2e_device_keys_txn
+        self.db_pool.simple_upsert_txn(
+            txn,
+            table="e2e_device_keys_json",
+            keyvalues={"user_id": user_id, "device_id": device_id},
+            values={"ts_added_ms": time_now, "key_json": new_key_json},
         )
+        log_kv({"message": "Device keys stored."})
+        return True
 
     async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None:
         def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: