summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/device.py7
-rw-r--r--synapse/storage/databases/main/registration.py20
2 files changed, 24 insertions, 3 deletions
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index e2ae3da67e..0d3d5ebc86 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -758,12 +758,13 @@ class DeviceHandler(DeviceWorkerHandler):
 
         # If the dehydrated device was successfully deleted (the device ID
         # matched the stored dehydrated device), then modify the access
-        # token to use the dehydrated device's ID and copy the old device
-        # display name to the dehydrated device, and destroy the old device
-        # ID
+        # token and refresh token to use the dehydrated device's ID and
+        # copy the old device display name to the dehydrated device,
+        # and destroy the old device ID
         old_device_id = await self.store.set_device_for_access_token(
             access_token, device_id
         )
+        await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id)
         old_device = await self.store.get_device(user_id, old_device_id)
         if old_device is None:
             raise errors.NotFoundError()
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 7e85b73e8e..e34156dc55 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -2312,6 +2312,26 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         return next_id
 
+    async def set_device_for_refresh_token(
+        self, user_id: str, old_device_id: str, device_id: str
+    ) -> None:
+        """Moves refresh tokens from old device to current device
+
+        Args:
+            user_id: The user of the devices.
+            old_device_id: The old device.
+            device_id: The new device ID.
+        Returns:
+            None
+        """
+
+        await self.db_pool.simple_update(
+            "refresh_tokens",
+            keyvalues={"user_id": user_id, "device_id": old_device_id},
+            updatevalues={"device_id": device_id},
+            desc="set_device_for_refresh_token",
+        )
+
     def _set_device_for_access_token_txn(
         self, txn: LoggingTransaction, token: str, device_id: str
     ) -> str: