summary refs log tree commit diff
diff options
context:
space:
mode:
authorHubert Chathi <hubert@uhoreg.ca>2020-08-17 18:13:11 -0400
committerHubert Chathi <hubert@uhoreg.ca>2020-08-17 18:13:11 -0400
commitc7c8f2822d252019376f73e48107f35a23baebb3 (patch)
tree238c4fbf92396b2f0a90eba390d6b0b052916de1
parentMerge remote-tracking branch 'origin/develop' into dehydration (diff)
downloadsynapse-c7c8f2822d252019376f73e48107f35a23baebb3.tar.xz
newer version of dehydration proposal, add doc improvements and other minor fixes
-rw-r--r--changelog.d/7955.feature2
-rw-r--r--synapse/handlers/device.py59
-rw-r--r--synapse/rest/client/v1/login.py70
-rw-r--r--synapse/storage/databases/main/devices.py67
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py5
5 files changed, 171 insertions, 32 deletions
diff --git a/changelog.d/7955.feature b/changelog.d/7955.feature
index d6d04619e5..7d726046fe 100644
--- a/changelog.d/7955.feature
+++ b/changelog.d/7955.feature
@@ -1 +1 @@
-Add support for device dehydration.
+Add support for device dehydration. (MSC2697)
diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py
index 8995be8446..e1fd39356c 100644
--- a/synapse/handlers/device.py
+++ b/synapse/handlers/device.py
@@ -494,9 +494,19 @@ class DeviceHandler(DeviceWorkerHandler):
     async def store_dehydrated_device(
         self,
         user_id: str,
-        device_data: str,
+        device_data: JsonDict,
         initial_device_display_name: Optional[str] = None,
     ) -> str:
+        """Store a dehydrated device for a user.  If the user had a previous
+        dehydrated device, it is removed.
+
+        Args:
+            user_id: the user that we are storing the device for
+            device_data: the dehydrated device information
+            initial_device_display_name: The display name to use for the device
+        Returns:
+            device id of the dehydrated device
+        """
         device_id = await self.check_device_registered(
             user_id, None, initial_device_display_name,
         )
@@ -507,17 +517,43 @@ class DeviceHandler(DeviceWorkerHandler):
             await self.delete_device(user_id, old_device_id)
         return device_id
 
-    async def get_dehydrated_device(self, user_id: str) -> Tuple[str, str]:
+    async def get_dehydrated_device(self, user_id: str) -> Tuple[str, JsonDict]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
         return await self.store.get_dehydrated_device(user_id)
 
-    async def get_dehydration_token(
+    async def create_dehydration_token(
         self, user_id: str, device_id: str, login_submission: JsonDict
     ) -> str:
+        """Create a token for a client to fulfill a dehydration request.
+
+        Args:
+            user_id: the user that we are creating the token for
+            device_id: the device ID for the dehydrated device.  This is to
+                ensure that the device still exists when the user tells us
+                they want to use the dehydrated device.
+            login_submission: the contents of the login request.
+        Returns:
+            the dehydration token
+        """
         return await self.store.create_dehydration_token(
-            user_id, device_id, json.dumps(login_submission)
+            user_id, device_id, login_submission
         )
 
     async def rehydrate_device(self, token: str) -> dict:
+        """Process a rehydration request from the user.
+
+        Args:
+            token: the dehydration token
+        Returns:
+            the login result, including the user's access token and device ID
+        """
         # FIXME: if can't find token, return 404
         token_info = await self.store.clear_dehydration_token(token, True)
 
@@ -538,7 +574,7 @@ class DeviceHandler(DeviceWorkerHandler):
             )
 
             return {
-                "user_id": token_info.get("user_id"),
+                "user_id": token_info["user_id"],
                 "access_token": access_token,
                 "home_server": self.hs.hostname,
                 "device_id": device_id,
@@ -546,7 +582,7 @@ class DeviceHandler(DeviceWorkerHandler):
 
         else:
             # create device and access token from original login submission
-            login_submission = token_info.get("login_submission")
+            login_submission = token_info["login_submission"]
             device_id = login_submission.get("device_id")
             initial_display_name = login_submission.get("initial_device_display_name")
             device_id, access_token = await registration_handler.register_device(
@@ -554,17 +590,24 @@ class DeviceHandler(DeviceWorkerHandler):
             )
 
             return {
-                "user_id": token.info.get("user_id"),
+                "user_id": token.info["user_id"],
                 "access_token": access_token,
                 "home_server": self.hs.hostname,
                 "device_id": device_id,
             }
 
     async def cancel_rehydrate(self, token: str) -> dict:
+        """Cancel a rehydration request from the user and complete the user's login.
+
+        Args:
+            token: the dehydration token
+        Returns:
+            the login result, including the user's access token and device ID
+        """
         # FIXME: if can't find token, return 404
         token_info = await self.store.clear_dehydration_token(token, False)
         # create device and access token from original login submission
-        login_submission = token_info.get("login_submission")
+        login_submission = token_info["login_submission"]
         device_id = login_submission.get("device_id")
         initial_display_name = login_submission.get("initial_device_display_name")
         device_id, access_token = await self.registration_handler.register_device(
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 9e6748d4d1..68fece986b 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -341,12 +341,14 @@ class LoginRestServlet(RestServlet):
             user_id = canonical_uid
 
         if login_submission.get("org.matrix.msc2697.restore_device"):
+            # user requested to rehydrate a device, so check if there they have
+            # a dehydrated device, and if so, allow them to try to rehydrate it
             (
                 device_id,
                 dehydrated_device,
             ) = await self.device_handler.get_dehydrated_device(user_id)
             if dehydrated_device:
-                token = await self.device_handler.get_dehydration_token(
+                token = await self.device_handler.create_dehydration_token(
                     user_id, device_id, login_submission
                 )
                 result = {
@@ -424,33 +426,75 @@ class LoginRestServlet(RestServlet):
 
 
 class RestoreDeviceServlet(RestServlet):
+    """Complete a rehydration request, either by letting the client use the
+    dehydrated device, or by creating a new device for the user.
+
+    POST /org.matrix.msc2697/restore_device
+    Content-Type: application/json
+
+    {
+      "rehydrate": true,
+      "dehydration_token": "an_opaque_token"
+    }
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+
+    { // same format as the result from a /login request
+      "user_id": "@alice:example.org",
+      "device_id": "dehydrated_device",
+      "access_token": "another_opaque_token"
+    }
+
+    """
+
     PATTERNS = client_patterns("/org.matrix.msc2697/restore_device")
 
     def __init__(self, hs):
         super(RestoreDeviceServlet, self).__init__()
         self.hs = hs
         self.device_handler = hs.get_device_handler()
+        self._well_known_builder = WellKnownBuilder(hs)
 
     async def on_POST(self, request: SynapseRequest):
         submission = parse_json_object_from_request(request)
 
         if submission.get("rehydrate"):
-            return (
-                200,
-                await self.device_handler.rehydrate_device(
-                    submission.get("dehydration_token")
-                ),
+            result = await self.device_handler.rehydrate_device(
+                submission["dehydration_token"]
             )
         else:
-            return (
-                200,
-                await self.device_handler.cancel_rehydrate(
-                    submission.get("dehydration_token")
-                ),
+            result = await self.device_handler.cancel_rehydrate(
+                submission["dehydration_token"]
             )
+        well_known_data = self._well_known_builder.get_well_known()
+        if well_known_data:
+            result["well_known"] = well_known_data
+        return (200, result)
 
 
 class StoreDeviceServlet(RestServlet):
+    """Store a dehydrated device.
+
+    POST /org.matrix.msc2697/device/dehydrate
+    Content-Type: application/json
+
+    {
+      "device_data": {
+        "algorithm": "m.dehydration.v1.olm",
+        "account": "dehydrated_device"
+      }
+    }
+
+    HTTP/1.1 200 OK
+    Content-Type: application/json
+
+    {
+      "device_id": "dehydrated_device_id"
+    }
+
+    """
+
     PATTERNS = client_patterns("/org.matrix.msc2697/device/dehydrate")
 
     def __init__(self, hs):
@@ -464,7 +508,9 @@ class StoreDeviceServlet(RestServlet):
         requester = await self.auth.get_user_by_req(request)
 
         device_id = await self.device_handler.store_dehydrated_device(
-            requester.user.to_string(), submission.get("device_data")
+            requester.user.to_string(),
+            submission["device_data"],
+            submission.get("initial_device_display_name", None)
         )
         return 200, {"device_id": device_id}
 
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index eb9d772d6e..d6c6f0ac34 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -35,7 +35,7 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_tuple_comparison_clause,
 )
-from synapse.types import Collection, get_verify_key_from_cross_signing_key
+from synapse.types import Collection, JsonDict, get_verify_key_from_cross_signing_key
 from synapse.util.caches.descriptors import (
     Cache,
     cached,
@@ -728,7 +728,15 @@ class DeviceWorkerStore(SQLBaseStore):
             _mark_remote_user_device_list_as_unsubscribed_txn,
         )
 
-    async def get_dehydrated_device(self, user_id: str) -> Tuple[str, str]:
+    async def get_dehydrated_device(self, user_id: str) -> Tuple[str, JsonDict]:
+        """Retrieve the information for a dehydrated device.
+
+        Args:
+            user_id: the user whose dehydrated device we are looking for
+        Returns:
+            a tuple whose first item is the device ID, and the second item is
+            the dehydrated device information
+        """
         # FIXME: make sure device ID still exists in devices table
         row = await self.db_pool.simple_select_one(
             table="dehydrated_devices",
@@ -736,7 +744,7 @@ class DeviceWorkerStore(SQLBaseStore):
             retcols=["device_id", "device_data"],
             allow_none=True,
         )
-        return (row["device_id"], row["device_data"]) if row else (None, None)
+        return (row["device_id"], json.loads(row["device_data"])) if row else (None, None)
 
     def _store_dehydrated_device_txn(
         self, txn, user_id: str, device_id: str, device_data: str
@@ -768,19 +776,39 @@ class DeviceWorkerStore(SQLBaseStore):
         return old_device_id
 
     async def store_dehydrated_device(
-        self, user_id: str, device_id: str, device_data: str
+        self, user_id: str, device_id: str, device_data: JsonDict
     ) -> Optional[str]:
+        """Store a dehydrated device for a user.
+
+        Args:
+            user_id: the user that we are storing the device for
+            device_data: the dehydrated device information
+            initial_device_display_name: The display name to use for the device
+        Returns:
+            device id of the user's previous dehydrated device, if any
+        """
         return await self.db_pool.runInteraction(
             "store_dehydrated_device_txn",
             self._store_dehydrated_device_txn,
             user_id,
             device_id,
-            device_data,
+            json.dumps(device_data),
         )
 
     async def create_dehydration_token(
-        self, user_id: str, device_id: str, login_submission: str
+        self, user_id: str, device_id: str, login_submission: JsonDict
     ) -> str:
+        """Create a token for a client to fulfill a dehydration request.
+
+        Args:
+            user_id: the user that we are creating the token for
+            device_id: the device ID for the dehydrated device.  This is to
+                ensure that the device still exists when the user tells us
+                they want to use the dehydrated device.
+            login_submission: the contents of the login request.
+        Returns:
+            the dehydration token
+        """
         # FIXME: expire any old tokens
 
         attempts = 0
@@ -794,7 +822,7 @@ class DeviceWorkerStore(SQLBaseStore):
                         "token": token,
                         "user_id": user_id,
                         "device_id": device_id,
-                        "login_submission": login_submission,
+                        "login_submission": json.dumps(login_submission),
                         "creation_time": self.hs.get_clock().time_msec(),
                     },
                     desc="create_dehydration_token",
@@ -814,16 +842,18 @@ class DeviceWorkerStore(SQLBaseStore):
         self.db_pool.simple_delete_one_txn(
             txn, "dehydration_token", {"token": token},
         )
+        token_info["login_submission"] = json.loads(token_info["login_submission"])
 
         if dehydrate:
-            device = self.db_pool.simple_select_one_txn(
+            device_id = self.db_pool.simple_select_one_onecol_txn(
                 txn,
                 "dehydrated_devices",
-                {"user_id": token_info["user_id"]},
-                ["device_id", "device_data"],
+                keyvalues={"user_id": token_info["user_id"]},
+                retcol="device_id",
                 allow_none=True,
             )
-            if device and device["device_id"] == token_info["device_id"]:
+            token_info["dehydrated"] = False
+            if device_id == token_info["device_id"]:
                 count = self.db_pool.simple_delete_txn(
                     txn,
                     "dehydrated_devices",
@@ -838,6 +868,21 @@ class DeviceWorkerStore(SQLBaseStore):
         return token_info
 
     async def clear_dehydration_token(self, token: str, dehydrate: bool) -> dict:
+        """Use a dehydration token.  If the client wishes to use the dehydrated
+        device, it will also remove the dehydrated device.
+
+        Args:
+            token: the dehydration token
+            dehydrate: whether the client wishes to use the dehydrated device
+        Returns:
+            A dict giving the information related to the token.  It will have
+            the following properties:
+            - user_id: the user associated from the token
+            - device_id: the ID of the dehydrated device
+            - login_submission: the original submission to /login
+            - dehydrated: (only present if the "dehydrate" parameter is True).
+              Whether the dehydrated device can be used by the client.
+        """
         return await self.db_pool.runInteraction(
             "get_users_whose_devices_changed",
             self._clear_dehydration_token_txn,
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 40354b8304..23f04a4887 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -641,6 +641,11 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
             self._invalidate_cache_and_stream(
                 txn, self.count_e2e_one_time_keys, (user_id, device_id)
             )
+            self.db_pool.simple_delete_txn(
+                txn,
+                table="dehydrated_devices",
+                keyvalues={"user_id": user_id, "device_id": device_id},
+            )
 
         return self.db_pool.runInteraction(
             "delete_e2e_keys_by_device", delete_e2e_keys_by_device_txn