summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16288.bugfix1
-rw-r--r--synapse/handlers/device.py7
-rw-r--r--synapse/storage/databases/main/registration.py20
-rw-r--r--tests/handlers/test_device.py10
4 files changed, 34 insertions, 4 deletions
diff --git a/changelog.d/16288.bugfix b/changelog.d/16288.bugfix
new file mode 100644
index 0000000000..f08d10d1f3
--- /dev/null
+++ b/changelog.d/16288.bugfix
@@ -0,0 +1 @@
+Fix bug introduced in Synapse 1.49.0 when using dehydrated devices ([MSC2697](https://github.com/matrix-org/matrix-spec-proposals/pull/2697)) and refresh tokens. Contributed by Hanadi.
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index e2ae3da67e..0d3d5ebc86 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -758,12 +758,13 @@ class DeviceHandler(DeviceWorkerHandler):
 
         # If the dehydrated device was successfully deleted (the device ID
         # matched the stored dehydrated device), then modify the access
-        # token to use the dehydrated device's ID and copy the old device
-        # display name to the dehydrated device, and destroy the old device
-        # ID
+        # token and refresh token to use the dehydrated device's ID and
+        # copy the old device display name to the dehydrated device,
+        # and destroy the old device ID
         old_device_id = await self.store.set_device_for_access_token(
             access_token, device_id
         )
+        await self.store.set_device_for_refresh_token(user_id, old_device_id, device_id)
         old_device = await self.store.get_device(user_id, old_device_id)
         if old_device is None:
             raise errors.NotFoundError()
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index 7e85b73e8e..e34156dc55 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -2312,6 +2312,26 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         return next_id
 
+    async def set_device_for_refresh_token(
+        self, user_id: str, old_device_id: str, device_id: str
+    ) -> None:
+        """Moves refresh tokens from old device to current device
+
+        Args:
+            user_id: The user of the devices.
+            old_device_id: The old device.
+            device_id: The new device ID.
+        Returns:
+            None
+        """
+
+        await self.db_pool.simple_update(
+            "refresh_tokens",
+            keyvalues={"user_id": user_id, "device_id": old_device_id},
+            updatevalues={"device_id": device_id},
+            desc="set_device_for_refresh_token",
+        )
+
     def _set_device_for_access_token_txn(
         self, txn: LoggingTransaction, token: str, device_id: str
     ) -> str:
diff --git a/tests/handlers/test_device.py b/tests/handlers/test_device.py
index 79d327499b..d4ed068357 100644
--- a/tests/handlers/test_device.py
+++ b/tests/handlers/test_device.py
@@ -461,6 +461,7 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         self.message_handler = hs.get_device_message_handler()
         self.registration = hs.get_registration_handler()
         self.auth = hs.get_auth()
+        self.auth_handler = hs.get_auth_handler()
         self.store = hs.get_datastores().main
         return hs
 
@@ -487,11 +488,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
         self.assertEqual(device_data, {"device_data": {"foo": "bar"}})
 
         # Create a new login for the user and dehydrated the device
-        device_id, access_token, _expiration_time, _refresh_token = self.get_success(
+        device_id, access_token, _expiration_time, refresh_token = self.get_success(
             self.registration.register_device(
                 user_id=user_id,
                 device_id=None,
                 initial_display_name="new device",
+                should_issue_refresh_token=True,
             )
         )
 
@@ -522,6 +524,12 @@ class DehydrationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(user_info.device_id, retrieved_device_id)
 
+        # make sure the user device has the refresh token
+        assert refresh_token is not None
+        self.get_success(
+            self.auth_handler.refresh_token(refresh_token, 5 * 60 * 1000, 5 * 60 * 1000)
+        )
+
         # make sure the device has the display name that was set from the login
         res = self.get_success(self.handler.get_device(user_id, retrieved_device_id))