summary refs log tree commit diff
path: root/synapse/storage/databases/main/devices.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/devices.py')
-rw-r--r--synapse/storage/databases/main/devices.py51
1 files changed, 49 insertions, 2 deletions
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d9df437e51..e4162f846b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
     cast,
 )
 
+from canonicaljson import encode_canonical_json
 from typing_extensions import Literal
 
 from synapse.api.constants import EduTypes
@@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         )
 
     def _store_dehydrated_device_txn(
-        self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
+        self,
+        txn: LoggingTransaction,
+        user_id: str,
+        device_id: str,
+        device_data: str,
+        time: int,
+        keys: Optional[JsonDict] = None,
     ) -> Optional[str]:
+        # TODO: make keys non-optional once support for msc2697 is dropped
+        if keys:
+            device_keys = keys.get("device_keys", None)
+            if device_keys:
+                # Type ignore - this function is defined on EndToEndKeyStore which we do
+                # have access to due to hs.get_datastore() "magic"
+                self._set_e2e_device_keys_txn(  # type: ignore[attr-defined]
+                    txn, user_id, device_id, time, device_keys
+                )
+
+            one_time_keys = keys.get("one_time_keys", None)
+            if one_time_keys:
+                key_list = []
+                for key_id, key_obj in one_time_keys.items():
+                    algorithm, key_id = key_id.split(":")
+                    key_list.append(
+                        (
+                            algorithm,
+                            key_id,
+                            encode_canonical_json(key_obj).decode("ascii"),
+                        )
+                    )
+                self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
+
+            fallback_keys = keys.get("fallback_keys", None)
+            if fallback_keys:
+                self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
+
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn,
             table="dehydrated_devices",
@@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             keyvalues={"user_id": user_id},
             values={"device_id": device_id, "device_data": device_data},
         )
+
         return old_device_id
 
     async def store_dehydrated_device(
-        self, user_id: str, device_id: str, device_data: JsonDict
+        self,
+        user_id: str,
+        device_id: str,
+        device_data: JsonDict,
+        time_now: int,
+        keys: Optional[dict] = None,
     ) -> Optional[str]:
         """Store a dehydrated device for a user.
 
@@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
             user_id: the user that we are storing the device for
             device_id: the ID of the dehydrated device
             device_data: the dehydrated device information
+            time_now: current time at the request in milliseconds
+            keys: keys for the dehydrated device
+
         Returns:
             device id of the user's previous dehydrated device, if any
         """
+
         return await self.db_pool.runInteraction(
             "store_dehydrated_device_txn",
             self._store_dehydrated_device_txn,
             user_id,
             device_id,
             json_encoder.encode(device_data),
+            time_now,
+            keys,
         )
 
     async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: