summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/crypto/keyring.py69
1 files changed, 41 insertions, 28 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f641ab7ef5..4cda439ad9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -120,16 +119,6 @@ class VerifyJsonRequest:
             key_ids=key_ids,
         )
 
-    def to_fetch_key_request(self) -> "_FetchKeyRequest":
-        """Create a key fetch request for all keys needed to satisfy the
-        verification request.
-        """
-        return _FetchKeyRequest(
-            server_name=self.server_name,
-            minimum_valid_until_ts=self.minimum_valid_until_ts,
-            key_ids=self.key_ids,
-        )
-
 
 class KeyLookupError(ValueError):
     pass
@@ -179,8 +168,22 @@ class Keyring:
             clock=hs.get_clock(),
             process_batch_callback=self._inner_fetch_key_requests,
         )
-        self.verify_key = get_verify_key(hs.signing_key)
-        self.hostname = hs.hostname
+
+        self._hostname = hs.hostname
+
+        # build a FetchKeyResult for each of our own keys, to shortcircuit the
+        # fetcher.
+        self._local_verify_keys: Dict[str, FetchKeyResult] = {}
+        for key_id, key in hs.config.key.old_signing_keys.items():
+            self._local_verify_keys[key_id] = FetchKeyResult(
+                verify_key=key, valid_until_ts=key.expired_ts
+            )
+
+        vk = get_verify_key(hs.signing_key)
+        self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult(
+            verify_key=vk,
+            valid_until_ts=2 ** 63,  # fake future timestamp
+        )
 
     async def verify_json_for_server(
         self,
@@ -267,22 +270,32 @@ class Keyring:
                 Codes.UNAUTHORIZED,
             )
 
-        # If we are the originating server don't fetch verify key for self over federation
-        if verify_request.server_name == self.hostname:
-            await self._process_json(self.verify_key, verify_request)
-            return
+        found_keys: Dict[str, FetchKeyResult] = {}
 
-        # Add the keys we need to verify to the queue for retrieval. We queue
-        # up requests for the same server so we don't end up with many in flight
-        # requests for the same keys.
-        key_request = verify_request.to_fetch_key_request()
-        found_keys_by_server = await self._server_queue.add_to_queue(
-            key_request, key=verify_request.server_name
-        )
+        # If we are the originating server, short-circuit the key-fetch for any keys
+        # we already have
+        if verify_request.server_name == self._hostname:
+            for key_id in verify_request.key_ids:
+                if key_id in self._local_verify_keys:
+                    found_keys[key_id] = self._local_verify_keys[key_id]
+
+        key_ids_to_find = set(verify_request.key_ids) - found_keys.keys()
+        if key_ids_to_find:
+            # Add the keys we need to verify to the queue for retrieval. We queue
+            # up requests for the same server so we don't end up with many in flight
+            # requests for the same keys.
+            key_request = _FetchKeyRequest(
+                server_name=verify_request.server_name,
+                minimum_valid_until_ts=verify_request.minimum_valid_until_ts,
+                key_ids=list(key_ids_to_find),
+            )
+            found_keys_by_server = await self._server_queue.add_to_queue(
+                key_request, key=verify_request.server_name
+            )
 
-        # Since we batch up requests the returned set of keys may contain keys
-        # from other servers, so we pull out only the ones we care about.s
-        found_keys = found_keys_by_server.get(verify_request.server_name, {})
+            # Since we batch up requests the returned set of keys may contain keys
+            # from other servers, so we pull out only the ones we care about.
+            found_keys.update(found_keys_by_server.get(verify_request.server_name, {}))
 
         # Verify each signature we got valid keys for, raising if we can't
         # verify any of them.