diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index f3a713f5fa..b7bf70a72d 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -722,6 +722,22 @@ class DeviceHandler(DeviceWorkerHandler):
return {"success": True}
+ async def delete_dehydrated_device(self, user_id: str, device_id: str) -> None:
+ """
+ Delete a stored dehydrated device.
+
+ Args:
+ user_id: the user_id to delete the device from
+ device_id: id of the dehydrated device to delete
+ """
+ success = await self.store.remove_dehydrated_device(user_id, device_id)
+
+ if not success:
+ raise errors.NotFoundError()
+
+ await self.delete_devices(user_id, [device_id])
+ await self.store.delete_e2e_keys_by_device(user_id=user_id, device_id=device_id)
+
@wrap_as_background_process("_handle_new_device_update_async")
async def _handle_new_device_update_async(self) -> None:
"""Called when we have a new local device list update that we need to
diff --git a/synapse/rest/client/devices.py b/synapse/rest/client/devices.py
index 690d2ec406..dd3f7fd666 100644
--- a/synapse/rest/client/devices.py
+++ b/synapse/rest/client/devices.py
@@ -513,10 +513,8 @@ class DehydratedDeviceV2Servlet(RestServlet):
if dehydrated_device is not None:
(device_id, device_data) = dehydrated_device
- result = await self.device_handler.rehydrate_device(
- requester.user.to_string(),
- self.auth.get_access_token_from_request(request),
- device_id,
+ await self.device_handler.delete_dehydrated_device(
+ requester.user.to_string(), device_id
)
result = {"device_id": device_id}
@@ -538,6 +536,14 @@ class DehydratedDeviceV2Servlet(RestServlet):
requester = await self.auth.get_user_by_req(request)
user_id = requester.user.to_string()
+ old_dehydrated_device = await self.device_handler.get_dehydrated_device(user_id)
+
+ # if an old device exists, delete it before creating a new one
+ if old_dehydrated_device:
+ await self.device_handler.delete_dehydrated_device(
+ user_id, old_dehydrated_device[0]
+ )
+
device_info = submission.dict()
if "device_keys" not in device_info.keys():
raise SynapseError(
|