diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f641ab7ef5..4cda439ad9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -120,16 +119,6 @@ class VerifyJsonRequest:
key_ids=key_ids,
)
- def to_fetch_key_request(self) -> "_FetchKeyRequest":
- """Create a key fetch request for all keys needed to satisfy the
- verification request.
- """
- return _FetchKeyRequest(
- server_name=self.server_name,
- minimum_valid_until_ts=self.minimum_valid_until_ts,
- key_ids=self.key_ids,
- )
-
class KeyLookupError(ValueError):
pass
@@ -179,8 +168,22 @@ class Keyring:
clock=hs.get_clock(),
process_batch_callback=self._inner_fetch_key_requests,
)
- self.verify_key = get_verify_key(hs.signing_key)
- self.hostname = hs.hostname
+
+ self._hostname = hs.hostname
+
+ # build a FetchKeyResult for each of our own keys, to shortcircuit the
+ # fetcher.
+ self._local_verify_keys: Dict[str, FetchKeyResult] = {}
+ for key_id, key in hs.config.key.old_signing_keys.items():
+ self._local_verify_keys[key_id] = FetchKeyResult(
+ verify_key=key, valid_until_ts=key.expired_ts
+ )
+
+ vk = get_verify_key(hs.signing_key)
+ self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult(
+ verify_key=vk,
+ valid_until_ts=2 ** 63, # fake future timestamp
+ )
async def verify_json_for_server(
self,
@@ -267,22 +270,32 @@ class Keyring:
Codes.UNAUTHORIZED,
)
- # If we are the originating server don't fetch verify key for self over federation
- if verify_request.server_name == self.hostname:
- await self._process_json(self.verify_key, verify_request)
- return
+ found_keys: Dict[str, FetchKeyResult] = {}
- # Add the keys we need to verify to the queue for retrieval. We queue
- # up requests for the same server so we don't end up with many in flight
- # requests for the same keys.
- key_request = verify_request.to_fetch_key_request()
- found_keys_by_server = await self._server_queue.add_to_queue(
- key_request, key=verify_request.server_name
- )
+ # If we are the originating server, short-circuit the key-fetch for any keys
+ # we already have
+ if verify_request.server_name == self._hostname:
+ for key_id in verify_request.key_ids:
+ if key_id in self._local_verify_keys:
+ found_keys[key_id] = self._local_verify_keys[key_id]
+
+ key_ids_to_find = set(verify_request.key_ids) - found_keys.keys()
+ if key_ids_to_find:
+ # Add the keys we need to verify to the queue for retrieval. We queue
+ # up requests for the same server so we don't end up with many in flight
+ # requests for the same keys.
+ key_request = _FetchKeyRequest(
+ server_name=verify_request.server_name,
+ minimum_valid_until_ts=verify_request.minimum_valid_until_ts,
+ key_ids=list(key_ids_to_find),
+ )
+ found_keys_by_server = await self._server_queue.add_to_queue(
+ key_request, key=verify_request.server_name
+ )
- # Since we batch up requests the returned set of keys may contain keys
- # from other servers, so we pull out only the ones we care about.s
- found_keys = found_keys_by_server.get(verify_request.server_name, {})
+ # Since we batch up requests the returned set of keys may contain keys
+ # from other servers, so we pull out only the ones we care about.
+ found_keys.update(found_keys_by_server.get(verify_request.server_name, {}))
# Verify each signature we got valid keys for, raising if we can't
# verify any of them.
|