summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py61
1 files changed, 43 insertions, 18 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 69310d9035..86cd4af9bd 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -154,17 +154,21 @@ class Keyring:
 
         if key_fetchers is None:
             key_fetchers = (
+                # Fetch keys from the database.
                 StoreKeyFetcher(hs),
+                # Fetch keys from a configured Perspectives server.
                 PerspectivesKeyFetcher(hs),
+                # Fetch keys from the origin server directly.
                 ServerKeyFetcher(hs),
             )
         self._key_fetchers = key_fetchers
 
-        self._server_queue: BatchingQueue[
+        self._fetch_keys_queue: BatchingQueue[
             _FetchKeyRequest, Dict[str, Dict[str, FetchKeyResult]]
         ] = BatchingQueue(
             "keyring_server",
             clock=hs.get_clock(),
+            # The method called to fetch each key
             process_batch_callback=self._inner_fetch_key_requests,
         )
 
@@ -287,7 +291,7 @@ class Keyring:
                 minimum_valid_until_ts=verify_request.minimum_valid_until_ts,
                 key_ids=list(key_ids_to_find),
             )
-            found_keys_by_server = await self._server_queue.add_to_queue(
+            found_keys_by_server = await self._fetch_keys_queue.add_to_queue(
                 key_request, key=verify_request.server_name
             )
 
@@ -352,7 +356,17 @@ class Keyring:
     async def _inner_fetch_key_requests(
         self, requests: List[_FetchKeyRequest]
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """Processing function for the queue of `_FetchKeyRequest`."""
+        """Processing function for the queue of `_FetchKeyRequest`.
+
+        Takes a list of key fetch requests, de-duplicates them and then carries out
+        each request by invoking self._inner_fetch_key_request.
+
+        Args:
+            requests: A list of requests for homeserver verify keys.
+
+        Returns:
+            {server name: {key id: fetch key result}}
+        """
 
         logger.debug("Starting fetch for %s", requests)
 
@@ -397,8 +411,23 @@ class Keyring:
     async def _inner_fetch_key_request(
         self, verify_request: _FetchKeyRequest
     ) -> Dict[str, FetchKeyResult]:
-        """Attempt to fetch the given key by calling each key fetcher one by
-        one.
+        """Attempt to fetch the given key by calling each key fetcher one by one.
+
+        If a key is found, check whether its `valid_until_ts` attribute satisfies the
+        `minimum_valid_until_ts` attribute of the `verify_request`. If it does, we
+        refrain from asking subsequent fetchers for that key.
+
+        Even if the above check fails, we still return the found key - the caller may
+        still find the invalid key result useful. In this case, we continue to ask
+        subsequent fetchers for the invalid key, in case they return a valid result
+        for it. This can happen when fetching a stale key result from the database,
+        before querying the origin server for an up-to-date result.
+
+        Args:
+            verify_request: The request for a verify key. Can include multiple key IDs.
+
+        Returns:
+            A map of {key_id: the key fetch result}.
         """
         logger.debug("Starting fetch for %s", verify_request)
 
@@ -420,26 +449,22 @@ class Keyring:
                 if not key:
                     continue
 
-                # If we already have a result for the given key ID we keep the
+                # If we already have a result for the given key ID, we keep the
                 # one with the highest `valid_until_ts`.
                 existing_key = found_keys.get(key_id)
-                if existing_key:
-                    if key.valid_until_ts <= existing_key.valid_until_ts:
-                        continue
+                if existing_key and existing_key.valid_until_ts > key.valid_until_ts:
+                    continue
+
+                # Check if this key's expiry timestamp is valid for the verify request.
+                if key.valid_until_ts >= verify_request.minimum_valid_until_ts:
+                    # Stop looking for this key from subsequent fetchers.
+                    missing_key_ids.discard(key_id)
 
-                # We always store the returned key even if it doesn't the
+                # We always store the returned key even if it doesn't meet the
                 # `minimum_valid_until_ts` requirement, as some verification
                 # requests may still be able to be satisfied by it.
-                #
-                # We still keep looking for the key from other fetchers in that
-                # case though.
                 found_keys[key_id] = key
 
-                if key.valid_until_ts < verify_request.minimum_valid_until_ts:
-                    continue
-
-                missing_key_ids.discard(key_id)
-
         return found_keys