summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2020-04-17 12:07:19 +0100
committerAndrew Morgan <andrew@amorgan.xyz>2020-04-17 12:08:09 +0100
commitcb56a51adaca22a24b0bd92c145a24b14f20076c (patch)
treecd263aa75c6f723f519d2945e040294b18b8fa3d /synapse/handlers/e2e_keys.py
parentlint (diff)
downloadsynapse-cb56a51adaca22a24b0bd92c145a24b14f20076c.tar.xz
Factor key retrieval out into a separate function
Diffstat (limited to '')
-rw-r--r--synapse/handlers/e2e_keys.py104
1 files changed, 61 insertions, 43 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index f80f0188c7..493fcb4d9d 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -16,6 +16,7 @@
 # limitations under the License.
 
 import logging
+from typing import Dict, Optional
 
 from six import iteritems
 
@@ -962,7 +963,7 @@ class E2eKeysHandler(object):
 
     @defer.inlineCallbacks
     def _get_e2e_cross_signing_verify_key(
-        self, user_id: str, desired_key_type: str, from_user_id: str = None
+        self, user_id: str, key_type: str, from_user_id: str = None
     ):
         """Fetch or request the given cross-signing public key.
 
@@ -972,7 +973,7 @@ class E2eKeysHandler(object):
 
         Args:
             user_id: the user whose key should be fetched
-            desired_key_type: the type of key to fetch
+            key_type: the type of key to fetch
             from_user_id: the user that we are fetching the keys for.
                 This affects what signatures are fetched.
 
@@ -986,7 +987,7 @@ class E2eKeysHandler(object):
         """
         user = UserID.from_string(user_id)
         key = yield self.store.get_e2e_cross_signing_key(
-            user_id, desired_key_type, from_user_id
+            user_id, key_type, from_user_id
         )
 
         # If we still can't find the key, and we're looking for keys of another user
@@ -996,64 +997,81 @@ class E2eKeysHandler(object):
         # cross-sign a remote user, but does not share any rooms with them yet.
         # Thus, we would not have their key list yet. We fetch the key here and
         # store it just in case.
-        supported_remote_key_types = ["master", "self_signing"]
         if (
             key is None
             and not self.is_mine(user)
             # We only get "master" and "self_signing" keys from remote servers
-            and desired_key_type in supported_remote_key_types
+            and key_type in ["master", "self_signing"]
         ):
-            remote_result = None
-            try:
-                remote_result = yield self.federation.query_user_devices(
-                    user.domain, user_id
-                )
-            except Exception as e:
-                logger.warning(
-                    "Unable to query %s for cross-signing keys of user %s: %s %s",
-                    user.domain,
-                    user_id,
-                    type(e),
-                    e,
-                )
-
-            if remote_result is not None:
-                # Process each of the retrieved cross-signing keys
-                for key_type in supported_remote_key_types:
-                    key_content = remote_result.get(key_type + "_key")
-                    if not key_content:
-                        continue
-
-                    # If this is the desired key type, return it
-                    if key_type == desired_key_type:
-                        key = key_content
-
-                    # At the same time, store this key in the db for
-                    # subsequent queries
-                    yield self.store.set_e2e_cross_signing_key(
-                        user_id, key_type, key_content
-                    )
+            key = yield self._retrieve_cross_signing_keys_for_remote_user(
+                user, key_type
+            )
 
         if key is None:
-            logger.debug("No %s key found for %s", desired_key_type, user_id)
-            raise NotFoundError("No %s key found for %s" % (desired_key_type, user_id))
+            logger.debug("No %s key found for %s", key_type, user_id)
+            raise NotFoundError("No %s key found for %s" % (key_type, user_id))
 
         try:
             key_id, verify_key = get_verify_key_from_cross_signing_key(key)
         except ValueError as e:
             logger.debug(
-                "Invalid %s key retrieved: %s - %s %s",
-                desired_key_type,
-                key,
-                type(e),
-                e,
+                "Invalid %s key retrieved: %s - %s %s", key_type, key, type(e), e,
             )
             raise SynapseError(
-                502, "Invalid %s key retrieved from remote server", desired_key_type
+                502, "Invalid %s key retrieved from remote server", key_type
             )
 
         return key, key_id, verify_key
 
+    @defer.inlineCallbacks
+    def _retrieve_cross_signing_keys_for_remote_user(
+        self, user: UserID, desired_key_type: str,
+    ) -> Optional[Dict]:
+        """Queries cross-signing keys for a remote user and saves them to the database
+
+        Only the key specified by `key_type` will be returned, while all retrieved keys
+        will be saved regardless
+
+        Args:
+            user: The user to query remote keys for
+            desired_key_type: The type of key to receive. One of "master", "self_signing"
+
+        Returns:
+            The retrieved key content, or None if the key could not be retrieved
+        """
+        try:
+            remote_result = yield self.federation.query_user_devices(
+                user.domain, user.to_string()
+            )
+        except Exception as e:
+            logger.warning(
+                "Unable to query %s for cross-signing keys of user %s: %s %s",
+                user.domain,
+                user.to_string(),
+                type(e),
+                e,
+            )
+            return None
+
+        # Process each of the retrieved cross-signing keys
+        key = None
+        for key_type in ["master", "self_signing"]:
+            key_content = remote_result.get(key_type + "_key")
+            if not key_content:
+                continue
+
+            # If this is the desired key type, return it
+            if key_type == desired_key_type:
+                key = key_content
+
+            # At the same time, store this key in the db for
+            # subsequent queries
+            yield self.store.set_e2e_cross_signing_key(
+                user.to_string(), key_type, key_content
+            )
+
+        return key
+
 
 def _check_cross_signing_key(key, user_id, key_type, signing_key=None):
     """Check a cross-signing key uploaded by a user.  Performs some basic sanity