summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7289.bugfix1
-rw-r--r--changelog.d/7311.feature1
-rw-r--r--docs/workers.md2
-rw-r--r--synapse/app/generic_worker.py6
-rw-r--r--synapse/federation/transport/client.py49
-rw-r--r--synapse/handlers/e2e_keys.py148
-rw-r--r--synapse/rest/client/v2_alpha/account_data.py8
7 files changed, 197 insertions, 18 deletions
diff --git a/changelog.d/7289.bugfix b/changelog.d/7289.bugfix
new file mode 100644
index 0000000000..84699e50a9
--- /dev/null
+++ b/changelog.d/7289.bugfix
@@ -0,0 +1 @@
+Fix a bug with cross-signing devices belonging to remote users who did not share a room with any user on the local homeserver.
diff --git a/changelog.d/7311.feature b/changelog.d/7311.feature
new file mode 100644
index 0000000000..c3adc1d6e7
--- /dev/null
+++ b/changelog.d/7311.feature
@@ -0,0 +1 @@
+Add support for handling GET requests for account_data on a worker.
diff --git a/docs/workers.md b/docs/workers.md
index 2ce2259b22..cc0b23197f 100644
--- a/docs/workers.md
+++ b/docs/workers.md
@@ -286,6 +286,8 @@ Additionally, the following REST endpoints can be handled for GET requests:
 
     ^/_matrix/client/(api/v1|r0|unstable)/pushrules/.*$
     ^/_matrix/client/(api/v1|r0|unstable)/groups/.*$
+    ^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/account_data/
+    ^/_matrix/client/(api/v1|r0|unstable)/user/[^/]*/rooms/[^/]*/account_data/
 
 Additionally, the following REST endpoints can be handled, but all requests must
 be routed to the same instance:
diff --git a/synapse/app/generic_worker.py b/synapse/app/generic_worker.py
index dcd0709a02..37afd2f810 100644
--- a/synapse/app/generic_worker.py
+++ b/synapse/app/generic_worker.py
@@ -110,6 +110,10 @@ from synapse.rest.client.v1.voip import VoipRestServlet
 from synapse.rest.client.v2_alpha import groups, sync, user_directory
 from synapse.rest.client.v2_alpha._base import client_patterns
 from synapse.rest.client.v2_alpha.account import ThreepidRestServlet
+from synapse.rest.client.v2_alpha.account_data import (
+    AccountDataServlet,
+    RoomAccountDataServlet,
+)
 from synapse.rest.client.v2_alpha.keys import KeyChangesServlet, KeyQueryServlet
 from synapse.rest.client.v2_alpha.register import RegisterRestServlet
 from synapse.rest.client.versions import VersionsRestServlet
@@ -501,6 +505,8 @@ class GenericWorkerServer(HomeServer):
                     ProfileDisplaynameRestServlet(self).register(resource)
                     ProfileRestServlet(self).register(resource)
                     KeyUploadServlet(self).register(resource)
+                    AccountDataServlet(self).register(resource)
+                    RoomAccountDataServlet(self).register(resource)
 
                     sync.register_servlets(self, resource)
                     events.register_servlets(self, resource)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index dc563538de..383e3fdc8b 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -399,20 +399,30 @@ class TransportLayerClient(object):
             {
               "device_keys": {
                 "<user_id>": ["<device_id>"]
-            } }
+              }
+            }
 
         Response:
             {
               "device_keys": {
                 "<user_id>": {
                   "<device_id>": {...}
-            } } }
+                }
+              },
+              "master_key": {
+                "<user_id>": {...}
+                }
+              },
+              "self_signing_key": {
+                "<user_id>": {...}
+              }
+            }
 
         Args:
             destination(str): The server to query.
             query_content(dict): The user ids to query.
         Returns:
-            A dict containg the device keys.
+            A dict containing device and cross-signing keys.
         """
         path = _create_v1_path("/user/keys/query")
 
@@ -429,14 +439,30 @@ class TransportLayerClient(object):
         Response:
             {
               "stream_id": "...",
-              "devices": [ { ... } ]
+              "devices": [ { ... } ],
+              "master_key": {
+                "user_id": "<user_id>",
+                "usage": [...],
+                "keys": {...},
+                "signatures": {
+                  "<user_id>": {...}
+                }
+              },
+              "self_signing_key": {
+                "user_id": "<user_id>",
+                "usage": [...],
+                "keys": {...},
+                "signatures": {
+                  "<user_id>": {...}
+                }
+              }
             }
 
         Args:
             destination(str): The server to query.
             query_content(dict): The user ids to query.
         Returns:
-            A dict containg the device keys.
+            A dict containing device and cross-signing keys.
         """
         path = _create_v1_path("/user/devices/%s", user_id)
 
@@ -454,8 +480,10 @@ class TransportLayerClient(object):
             {
               "one_time_keys": {
                 "<user_id>": {
-                    "<device_id>": "<algorithm>"
-            } } }
+                  "<device_id>": "<algorithm>"
+                }
+              }
+            }
 
         Response:
             {
@@ -463,13 +491,16 @@ class TransportLayerClient(object):
                 "<user_id>": {
                   "<device_id>": {
                     "<algorithm>:<key_id>": "<key_base64>"
-            } } } }
+                  }
+                }
+              }
+            }
 
         Args:
             destination(str): The server to query.
             query_content(dict): The user ids to query.
         Returns:
-            A dict containg the one-time keys.
+            A dict containing the one-time keys.
         """
 
         path = _create_v1_path("/user/keys/claim")
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 8d7075f2eb..8f1bc0323c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -174,8 +174,8 @@ class E2eKeysHandler(object):
             """This is called when we are querying the device list of a user on
             a remote homeserver and their device list is not in the device list
             cache. If we share a room with this user and we're not querying for
-            specific user we will update the cache
-            with their device list."""
+            specific user we will update the cache with their device list.
+            """
 
             destination_query = remote_queries_not_in_cache[destination]
 
@@ -961,13 +961,19 @@ class E2eKeysHandler(object):
         return signature_list, failures
 
     @defer.inlineCallbacks
-    def _get_e2e_cross_signing_verify_key(self, user_id, key_type, from_user_id=None):
-        """Fetch the cross-signing public key from storage and interpret it.
+    def _get_e2e_cross_signing_verify_key(
+        self, user_id: str, key_type: str, from_user_id: str = None
+    ):
+        """Fetch locally or remotely query for a cross-signing public key.
+
+        First, attempt to fetch the cross-signing public key from storage.
+        If that fails, query the keys from the homeserver they belong to
+        and update our local copy.
 
         Args:
-            user_id (str): the user whose key should be fetched
-            key_type (str): the type of key to fetch
-            from_user_id (str): the user that we are fetching the keys for.
+            user_id: the user whose key should be fetched
+            key_type: the type of key to fetch
+            from_user_id: the user that we are fetching the keys for.
                 This affects what signatures are fetched.
 
         Returns:
@@ -976,16 +982,140 @@ class E2eKeysHandler(object):
 
         Raises:
             NotFoundError: if the key is not found
+            SynapseError: if `user_id` is invalid
         """
+        user = UserID.from_string(user_id)
         key = yield self.store.get_e2e_cross_signing_key(
             user_id, key_type, from_user_id
         )
+
+        if key:
+            # We found a copy of this key in our database. Decode and return it
+            key_id, verify_key = get_verify_key_from_cross_signing_key(key)
+            return key, key_id, verify_key
+
+        # If we couldn't find the key locally, and we're looking for keys of
+        # another user then attempt to fetch the missing key from the remote
+        # user's server.
+        #
+        # We may run into this in possible edge cases where a user tries to
+        # cross-sign a remote user, but does not share any rooms with them yet.
+        # Thus, we would not have their key list yet. We instead fetch the key,
+        # store it and notify clients of new, associated device IDs.
+        if self.is_mine(user) or key_type not in ["master", "self_signing"]:
+            # Note that master and self_signing keys are the only cross-signing keys we
+            # can request over federation
+            raise NotFoundError("No %s key found for %s" % (key_type, user_id))
+
+        (
+            key,
+            key_id,
+            verify_key,
+        ) = yield self._retrieve_cross_signing_keys_for_remote_user(user, key_type)
+
         if key is None:
-            logger.debug("no %s key found for %s", key_type, user_id)
             raise NotFoundError("No %s key found for %s" % (key_type, user_id))
-        key_id, verify_key = get_verify_key_from_cross_signing_key(key)
+
         return key, key_id, verify_key
 
+    @defer.inlineCallbacks
+    def _retrieve_cross_signing_keys_for_remote_user(
+        self, user: UserID, desired_key_type: str,
+    ):
+        """Queries cross-signing keys for a remote user and saves them to the database
+
+        Only the key specified by `key_type` will be returned, while all retrieved keys
+        will be saved regardless
+
+        Args:
+            user: The user to query remote keys for
+            desired_key_type: The type of key to receive. One of "master", "self_signing"
+
+        Returns:
+            Deferred[Tuple[Optional[Dict], Optional[str], Optional[VerifyKey]]]: A tuple
+            of the retrieved key content, the key's ID and the matching VerifyKey.
+            If the key cannot be retrieved, all values in the tuple will instead be None.
+        """
+        try:
+            remote_result = yield self.federation.query_user_devices(
+                user.domain, user.to_string()
+            )
+        except Exception as e:
+            logger.warning(
+                "Unable to query %s for cross-signing keys of user %s: %s %s",
+                user.domain,
+                user.to_string(),
+                type(e),
+                e,
+            )
+            return None, None, None
+
+        # Process each of the retrieved cross-signing keys
+        desired_key = None
+        desired_key_id = None
+        desired_verify_key = None
+        retrieved_device_ids = []
+        for key_type in ["master", "self_signing"]:
+            key_content = remote_result.get(key_type + "_key")
+            if not key_content:
+                continue
+
+            # Ensure these keys belong to the correct user
+            if "user_id" not in key_content:
+                logger.warning(
+                    "Invalid %s key retrieved, missing user_id field: %s",
+                    key_type,
+                    key_content,
+                )
+                continue
+            if user.to_string() != key_content["user_id"]:
+                logger.warning(
+                    "Found %s key of user %s when querying for keys of user %s",
+                    key_type,
+                    key_content["user_id"],
+                    user.to_string(),
+                )
+                continue
+
+            # Validate the key contents
+            try:
+                # verify_key is a VerifyKey from signedjson, which uses
+                # .version to denote the portion of the key ID after the
+                # algorithm and colon, which is the device ID
+                key_id, verify_key = get_verify_key_from_cross_signing_key(key_content)
+            except ValueError as e:
+                logger.warning(
+                    "Invalid %s key retrieved: %s - %s %s",
+                    key_type,
+                    key_content,
+                    type(e),
+                    e,
+                )
+                continue
+
+            # Note down the device ID attached to this key
+            retrieved_device_ids.append(verify_key.version)
+
+            # If this is the desired key type, save it and its ID/VerifyKey
+            if key_type == desired_key_type:
+                desired_key = key_content
+                desired_verify_key = verify_key
+                desired_key_id = key_id
+
+            # At the same time, store this key in the db for subsequent queries
+            yield self.store.set_e2e_cross_signing_key(
+                user.to_string(), key_type, key_content
+            )
+
+        # Notify clients that new devices for this user have been discovered
+        if retrieved_device_ids:
+            # XXX is this necessary?
+            yield self.device_handler.notify_device_update(
+                user.to_string(), retrieved_device_ids
+            )
+
+        return desired_key, desired_key_id, desired_verify_key
+
 
 def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
     """Check a cross-signing key uploaded by a user.  Performs some basic sanity
diff --git a/synapse/rest/client/v2_alpha/account_data.py b/synapse/rest/client/v2_alpha/account_data.py
index 64eb7fec3b..c1d4cd0caf 100644
--- a/synapse/rest/client/v2_alpha/account_data.py
+++ b/synapse/rest/client/v2_alpha/account_data.py
@@ -38,8 +38,12 @@ class AccountDataServlet(RestServlet):
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
+        self._is_worker = hs.config.worker_app is not None
 
     async def on_PUT(self, request, user_id, account_data_type):
+        if self._is_worker:
+            raise Exception("Cannot handle PUT /account_data on worker")
+
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")
@@ -86,8 +90,12 @@ class RoomAccountDataServlet(RestServlet):
         self.auth = hs.get_auth()
         self.store = hs.get_datastore()
         self.notifier = hs.get_notifier()
+        self._is_worker = hs.config.worker_app is not None
 
     async def on_PUT(self, request, user_id, room_id, account_data_type):
+        if self._is_worker:
+            raise Exception("Cannot handle PUT /account_data on worker")
+
         requester = await self.auth.get_user_by_req(request)
         if user_id != requester.user.to_string():
             raise AuthError(403, "Cannot add account data for other users.")