summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py99
1 files changed, 60 insertions, 39 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f641ab7ef5..993b04099e 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -120,16 +119,6 @@ class VerifyJsonRequest:
             key_ids=key_ids,
         )
 
-    def to_fetch_key_request(self) -> "_FetchKeyRequest":
-        """Create a key fetch request for all keys needed to satisfy the
-        verification request.
-        """
-        return _FetchKeyRequest(
-            server_name=self.server_name,
-            minimum_valid_until_ts=self.minimum_valid_until_ts,
-            key_ids=self.key_ids,
-        )
-
 
 class KeyLookupError(ValueError):
     pass
@@ -179,8 +168,22 @@ class Keyring:
             clock=hs.get_clock(),
             process_batch_callback=self._inner_fetch_key_requests,
         )
-        self.verify_key = get_verify_key(hs.signing_key)
-        self.hostname = hs.hostname
+
+        self._hostname = hs.hostname
+
+        # build a FetchKeyResult for each of our own keys, to shortcircuit the
+        # fetcher.
+        self._local_verify_keys: Dict[str, FetchKeyResult] = {}
+        for key_id, key in hs.config.key.old_signing_keys.items():
+            self._local_verify_keys[key_id] = FetchKeyResult(
+                verify_key=key, valid_until_ts=key.expired_ts
+            )
+
+        vk = get_verify_key(hs.signing_key)
+        self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult(
+            verify_key=vk,
+            valid_until_ts=2 ** 63,  # fake future timestamp
+        )
 
     async def verify_json_for_server(
         self,
@@ -267,22 +270,32 @@ class Keyring:
                 Codes.UNAUTHORIZED,
             )
 
-        # If we are the originating server don't fetch verify key for self over federation
-        if verify_request.server_name == self.hostname:
-            await self._process_json(self.verify_key, verify_request)
-            return
+        found_keys: Dict[str, FetchKeyResult] = {}
 
-        # Add the keys we need to verify to the queue for retrieval. We queue
-        # up requests for the same server so we don't end up with many in flight
-        # requests for the same keys.
-        key_request = verify_request.to_fetch_key_request()
-        found_keys_by_server = await self._server_queue.add_to_queue(
-            key_request, key=verify_request.server_name
-        )
+        # If we are the originating server, short-circuit the key-fetch for any keys
+        # we already have
+        if verify_request.server_name == self._hostname:
+            for key_id in verify_request.key_ids:
+                if key_id in self._local_verify_keys:
+                    found_keys[key_id] = self._local_verify_keys[key_id]
+
+        key_ids_to_find = set(verify_request.key_ids) - found_keys.keys()
+        if key_ids_to_find:
+            # Add the keys we need to verify to the queue for retrieval. We queue
+            # up requests for the same server so we don't end up with many in flight
+            # requests for the same keys.
+            key_request = _FetchKeyRequest(
+                server_name=verify_request.server_name,
+                minimum_valid_until_ts=verify_request.minimum_valid_until_ts,
+                key_ids=list(key_ids_to_find),
+            )
+            found_keys_by_server = await self._server_queue.add_to_queue(
+                key_request, key=verify_request.server_name
+            )
 
-        # Since we batch up requests the returned set of keys may contain keys
-        # from other servers, so we pull out only the ones we care about.s
-        found_keys = found_keys_by_server.get(verify_request.server_name, {})
+            # Since we batch up requests the returned set of keys may contain keys
+            # from other servers, so we pull out only the ones we care about.
+            found_keys.update(found_keys_by_server.get(verify_request.server_name, {}))
 
         # Verify each signature we got valid keys for, raising if we can't
         # verify any of them.
@@ -654,21 +667,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             perspective_name,
         )
 
+        request: JsonDict = {}
+        for queue_value in keys_to_fetch:
+            # there may be multiple requests for each server, so we have to merge
+            # them intelligently.
+            request_for_server = {
+                key_id: {
+                    "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+                }
+                for key_id in queue_value.key_ids
+            }
+            request.setdefault(queue_value.server_name, {}).update(request_for_server)
+
+        logger.debug("Request to notary server %s: %s", perspective_name, request)
+
         try:
             query_response = await self.client.post_json(
                 destination=perspective_name,
                 path="/_matrix/key/v2/query",
-                data={
-                    "server_keys": {
-                        queue_value.server_name: {
-                            key_id: {
-                                "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
-                            }
-                            for key_id in queue_value.key_ids
-                        }
-                        for queue_value in keys_to_fetch
-                    }
-                },
+                data={"server_keys": request},
             )
         except (NotRetryingDestination, RequestSendFailed) as e:
             # these both have str() representations which we can't really improve upon
@@ -676,6 +693,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         except HttpResponseException as e:
             raise KeyLookupError("Remote server returned an error: %s" % (e,))
 
+        logger.debug(
+            "Response from notary server %s: %s", perspective_name, query_response
+        )
+
         keys: Dict[str, Dict[str, FetchKeyResult]] = {}
         added_keys: List[Tuple[str, str, FetchKeyResult]] = []