diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index d9df437e51..e4162f846b 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -28,6 +28,7 @@ from typing import (
cast,
)
+from canonicaljson import encode_canonical_json
from typing_extensions import Literal
from synapse.api.constants import EduTypes
@@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
)
def _store_dehydrated_device_txn(
- self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
+ self,
+ txn: LoggingTransaction,
+ user_id: str,
+ device_id: str,
+ device_data: str,
+ time: int,
+ keys: Optional[JsonDict] = None,
) -> Optional[str]:
+ # TODO: make keys non-optional once support for msc2697 is dropped
+ if keys:
+ device_keys = keys.get("device_keys", None)
+ if device_keys:
+ # Type ignore - this function is defined on EndToEndKeyStore which we do
+ # have access to due to hs.get_datastore() "magic"
+ self._set_e2e_device_keys_txn( # type: ignore[attr-defined]
+ txn, user_id, device_id, time, device_keys
+ )
+
+ one_time_keys = keys.get("one_time_keys", None)
+ if one_time_keys:
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append(
+ (
+ algorithm,
+ key_id,
+ encode_canonical_json(key_obj).decode("ascii"),
+ )
+ )
+ self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list)
+
+ fallback_keys = keys.get("fallback_keys", None)
+ if fallback_keys:
+ self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys)
+
old_device_id = self.db_pool.simple_select_one_onecol_txn(
txn,
table="dehydrated_devices",
@@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
keyvalues={"user_id": user_id},
values={"device_id": device_id, "device_data": device_data},
)
+
return old_device_id
async def store_dehydrated_device(
- self, user_id: str, device_id: str, device_data: JsonDict
+ self,
+ user_id: str,
+ device_id: str,
+ device_data: JsonDict,
+ time_now: int,
+ keys: Optional[dict] = None,
) -> Optional[str]:
"""Store a dehydrated device for a user.
@@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
user_id: the user that we are storing the device for
device_id: the ID of the dehydrated device
device_data: the dehydrated device information
+ time_now: current time at the request in milliseconds
+ keys: keys for the dehydrated device
+
Returns:
device id of the user's previous dehydrated device, if any
"""
+
return await self.db_pool.runInteraction(
"store_dehydrated_device_txn",
self._store_dehydrated_device_txn,
user_id,
device_id,
json_encoder.encode(device_data),
+ time_now,
+ keys,
)
async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool:
|