diff options
author | Shay <hillerys@element.io> | 2023-08-08 12:04:46 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-08 12:04:46 -0700 |
commit | 0328b56468fe12c4d86ef636b60964527a510160 (patch) | |
tree | 2cff86a1e4518c90f234d4010419c049082a147f /synapse | |
parent | Fixup changelog (diff) | |
download | synapse-0328b56468fe12c4d86ef636b60964527a510160.tar.xz |
Support MSC3814: Dehydrated Devices Part 2 (#16010)
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/handlers/device.py | 14 | ||||
-rw-r--r-- | synapse/handlers/devicemessage.py | 13 | ||||
-rw-r--r-- | synapse/rest/client/devices.py | 16 | ||||
-rw-r--r-- | synapse/storage/databases/main/devices.py | 51 | ||||
-rw-r--r-- | synapse/storage/databases/main/end_to_end_keys.py | 170 |
5 files changed, 174 insertions, 90 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index b7bf70a72d..5ae427d52c 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -385,6 +385,7 @@ class DeviceHandler(DeviceWorkerHandler): self.federation_sender = hs.get_federation_sender() self._account_data_handler = hs.get_account_data_handler() self._storage_controllers = hs.get_storage_controllers() + self.db_pool = hs.get_datastores().main.db_pool self.device_list_updater = DeviceListUpdater(hs, self) @@ -656,15 +657,17 @@ class DeviceHandler(DeviceWorkerHandler): device_id: Optional[str], device_data: JsonDict, initial_device_display_name: Optional[str] = None, + keys_for_device: Optional[JsonDict] = None, ) -> str: - """Store a dehydrated device for a user. If the user had a previous - dehydrated device, it is removed. + """Store a dehydrated device for a user, optionally storing the keys associated with + it as well. If the user had a previous dehydrated device, it is removed. Args: user_id: the user that we are storing the device for device_id: device id supplied by client device_data: the dehydrated device information initial_device_display_name: The display name to use for the device + keys_for_device: keys for the dehydrated device Returns: device id of the dehydrated device """ @@ -673,11 +676,16 @@ class DeviceHandler(DeviceWorkerHandler): device_id, initial_device_display_name, ) + + time_now = self.clock.time_msec() + old_device_id = await self.store.store_dehydrated_device( - user_id, device_id, device_data + user_id, device_id, device_data, time_now, keys_for_device ) + if old_device_id is not None: await self.delete_devices(user_id, [old_device_id]) + return device_id async def rehydrate_device( diff --git a/synapse/handlers/devicemessage.py b/synapse/handlers/devicemessage.py index 15e94a03cb..17ff8821d9 100644 --- a/synapse/handlers/devicemessage.py +++ b/synapse/handlers/devicemessage.py @@ -367,19 +367,6 @@ class DeviceMessageHandler: errcode=Codes.INVALID_PARAM, ) - # if we have a since token, delete any to-device messages before that token - # (since we now know that the device has received them) - deleted = await self.store.delete_messages_for_device( - user_id, device_id, since_stream_id - ) - logger.debug( - "Deleted %d to-device messages up to %d for user_id %s device_id %s", - deleted, - since_stream_id, - user_id, - device_id, - ) - to_token = self.event_sources.get_current_token().to_device_key messages, stream_id = await self.store.get_messages_for_device( diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py index 51f17f80da..925f037743 100644 --- a/synapse/rest/client/devices.py +++ b/synapse/rest/client/devices.py @@ -29,7 +29,6 @@ from synapse.http.servlet import ( parse_integer, ) from synapse.http.site import SynapseRequest -from synapse.replication.http.devices import ReplicationUploadKeysForUserRestServlet from synapse.rest.client._base import client_patterns, interactive_auth_handler from synapse.rest.client.models import AuthenticationData from synapse.rest.models import RequestBodyModel @@ -480,13 +479,6 @@ class DehydratedDeviceV2Servlet(RestServlet): self.e2e_keys_handler = hs.get_e2e_keys_handler() self.device_handler = handler - if hs.config.worker.worker_app is None: - # if main process - self.key_uploader = self.e2e_keys_handler.upload_keys_for_user - else: - # then a worker - self.key_uploader = ReplicationUploadKeysForUserRestServlet.make_client(hs) - async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: requester = await self.auth.get_user_by_req(request) @@ -549,18 +541,12 @@ class DehydratedDeviceV2Servlet(RestServlet): "Device key(s) not found, these must be provided.", ) - # TODO: Those two operations, creating a device and storing the - # device's keys should be atomic. device_id = await self.device_handler.store_dehydrated_device( requester.user.to_string(), submission.device_id, submission.device_data.dict(), submission.initial_device_display_name, - ) - - # TODO: Do we need to do something with the result here? - await self.key_uploader( - user_id=user_id, device_id=submission.device_id, keys=submission.dict() + device_info, ) return 200, {"device_id": device_id} diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py index d9df437e51..e4162f846b 100644 --- a/synapse/storage/databases/main/devices.py +++ b/synapse/storage/databases/main/devices.py @@ -28,6 +28,7 @@ from typing import ( cast, ) +from canonicaljson import encode_canonical_json from typing_extensions import Literal from synapse.api.constants import EduTypes @@ -1188,8 +1189,42 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): ) def _store_dehydrated_device_txn( - self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + device_data: str, + time: int, + keys: Optional[JsonDict] = None, ) -> Optional[str]: + # TODO: make keys non-optional once support for msc2697 is dropped + if keys: + device_keys = keys.get("device_keys", None) + if device_keys: + # Type ignore - this function is defined on EndToEndKeyStore which we do + # have access to due to hs.get_datastore() "magic" + self._set_e2e_device_keys_txn( # type: ignore[attr-defined] + txn, user_id, device_id, time, device_keys + ) + + one_time_keys = keys.get("one_time_keys", None) + if one_time_keys: + key_list = [] + for key_id, key_obj in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append( + ( + algorithm, + key_id, + encode_canonical_json(key_obj).decode("ascii"), + ) + ) + self._add_e2e_one_time_keys_txn(txn, user_id, device_id, time, key_list) + + fallback_keys = keys.get("fallback_keys", None) + if fallback_keys: + self._set_e2e_fallback_keys_txn(txn, user_id, device_id, fallback_keys) + old_device_id = self.db_pool.simple_select_one_onecol_txn( txn, table="dehydrated_devices", @@ -1203,10 +1238,16 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): keyvalues={"user_id": user_id}, values={"device_id": device_id, "device_data": device_data}, ) + return old_device_id async def store_dehydrated_device( - self, user_id: str, device_id: str, device_data: JsonDict + self, + user_id: str, + device_id: str, + device_data: JsonDict, + time_now: int, + keys: Optional[dict] = None, ) -> Optional[str]: """Store a dehydrated device for a user. @@ -1214,15 +1255,21 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore): user_id: the user that we are storing the device for device_id: the ID of the dehydrated device device_data: the dehydrated device information + time_now: current time at the request in milliseconds + keys: keys for the dehydrated device + Returns: device id of the user's previous dehydrated device, if any """ + return await self.db_pool.runInteraction( "store_dehydrated_device_txn", self._store_dehydrated_device_txn, user_id, device_id, json_encoder.encode(device_data), + time_now, + keys, ) async def remove_dehydrated_device(self, user_id: str, device_id: str) -> bool: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index 91ae9c457d..b49dea577c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -522,36 +522,57 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker new_keys: keys to add - each a tuple of (algorithm, key_id, key json) """ - def _add_e2e_one_time_keys(txn: LoggingTransaction) -> None: - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("new_keys", str(new_keys)) - # We are protected from race between lookup and insertion due to - # a unique constraint. If there is a race of two calls to - # `add_e2e_one_time_keys` then they'll conflict and we will only - # insert one set. - self.db_pool.simple_insert_many_txn( - txn, - table="e2e_one_time_keys_json", - keys=( - "user_id", - "device_id", - "algorithm", - "key_id", - "ts_added_ms", - "key_json", - ), - values=[ - (user_id, device_id, algorithm, key_id, time_now, json_bytes) - for algorithm, key_id, json_bytes in new_keys - ], - ) - self._invalidate_cache_and_stream( - txn, self.count_e2e_one_time_keys, (user_id, device_id) - ) - await self.db_pool.runInteraction( - "add_e2e_one_time_keys_insert", _add_e2e_one_time_keys + "add_e2e_one_time_keys_insert", + self._add_e2e_one_time_keys_txn, + user_id, + device_id, + time_now, + new_keys, + ) + + def _add_e2e_one_time_keys_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + time_now: int, + new_keys: Iterable[Tuple[str, str, str]], + ) -> None: + """Insert some new one time keys for a device. Errors if any of the keys already exist. + + Args: + user_id: id of user to get keys for + device_id: id of device to get keys for + time_now: insertion time to record (ms since epoch) + new_keys: keys to add - each a tuple of (algorithm, key_id, key json) - note + that the key JSON must be in canonical JSON form + """ + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("new_keys", str(new_keys)) + # We are protected from race between lookup and insertion due to + # a unique constraint. If there is a race of two calls to + # `add_e2e_one_time_keys` then they'll conflict and we will only + # insert one set. + self.db_pool.simple_insert_many_txn( + txn, + table="e2e_one_time_keys_json", + keys=( + "user_id", + "device_id", + "algorithm", + "key_id", + "ts_added_ms", + "key_json", + ), + values=[ + (user_id, device_id, algorithm, key_id, time_now, json_bytes) + for algorithm, key_id, json_bytes in new_keys + ], + ) + self._invalidate_cache_and_stream( + txn, self.count_e2e_one_time_keys, (user_id, device_id) ) @cached(max_entries=10000) @@ -723,6 +744,14 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker device_id: str, fallback_keys: JsonDict, ) -> None: + """Set the user's e2e fallback keys. + + Args: + user_id: the user whose keys are being set + device_id: the device whose keys are being set + fallback_keys: the keys to set. This is a map from key ID (which is + of the form "algorithm:id") to key data. + """ # fallback_keys will usually only have one item in it, so using a for # loop (as opposed to calling simple_upsert_many_txn) won't be too bad # FIXME: make sure that only one key per algorithm is uploaded @@ -1304,42 +1333,69 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore): ) -> bool: """Stores device keys for a device. Returns whether there was a change or the keys were already in the database. + + Args: + user_id: user_id of the user to store keys for + device_id: device_id of the device to store keys for + time_now: time at the request to store the keys + device_keys: the keys to store """ - def _set_e2e_device_keys_txn(txn: LoggingTransaction) -> bool: - set_tag("user_id", user_id) - set_tag("device_id", device_id) - set_tag("time_now", time_now) - set_tag("device_keys", str(device_keys)) + return await self.db_pool.runInteraction( + "set_e2e_device_keys", + self._set_e2e_device_keys_txn, + user_id, + device_id, + time_now, + device_keys, + ) - old_key_json = self.db_pool.simple_select_one_onecol_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - retcol="key_json", - allow_none=True, - ) + def _set_e2e_device_keys_txn( + self, + txn: LoggingTransaction, + user_id: str, + device_id: str, + time_now: int, + device_keys: JsonDict, + ) -> bool: + """Stores device keys for a device. Returns whether there was a change + or the keys were already in the database. - # In py3 we need old_key_json to match new_key_json type. The DB - # returns unicode while encode_canonical_json returns bytes. - new_key_json = encode_canonical_json(device_keys).decode("utf-8") + Args: + user_id: user_id of the user to store keys for + device_id: device_id of the device to store keys for + time_now: time at the request to store the keys + device_keys: the keys to store + """ + set_tag("user_id", user_id) + set_tag("device_id", device_id) + set_tag("time_now", time_now) + set_tag("device_keys", str(device_keys)) + + old_key_json = self.db_pool.simple_select_one_onecol_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + retcol="key_json", + allow_none=True, + ) - if old_key_json == new_key_json: - log_kv({"Message": "Device key already stored."}) - return False + # In py3 we need old_key_json to match new_key_json type. The DB + # returns unicode while encode_canonical_json returns bytes. + new_key_json = encode_canonical_json(device_keys).decode("utf-8") - self.db_pool.simple_upsert_txn( - txn, - table="e2e_device_keys_json", - keyvalues={"user_id": user_id, "device_id": device_id}, - values={"ts_added_ms": time_now, "key_json": new_key_json}, - ) - log_kv({"message": "Device keys stored."}) - return True + if old_key_json == new_key_json: + log_kv({"Message": "Device key already stored."}) + return False - return await self.db_pool.runInteraction( - "set_e2e_device_keys", _set_e2e_device_keys_txn + self.db_pool.simple_upsert_txn( + txn, + table="e2e_device_keys_json", + keyvalues={"user_id": user_id, "device_id": device_id}, + values={"ts_added_ms": time_now, "key_json": new_key_json}, ) + log_kv({"message": "Device keys stored."}) + return True async def delete_e2e_keys_by_device(self, user_id: str, device_id: str) -> None: def delete_e2e_keys_by_device_txn(txn: LoggingTransaction) -> None: |