summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13683.bugfix1
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py41
2 files changed, 22 insertions, 20 deletions
diff --git a/changelog.d/13683.bugfix b/changelog.d/13683.bugfix
new file mode 100644
index 0000000000..538534fec1
--- /dev/null
+++ b/changelog.d/13683.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug which meant that keys for unwhitelisted servers were not returned by `/_matrix/key/v2/query`.
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f597157581..7f8ad29566 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -135,13 +135,6 @@ class RemoteKey(DirectServeJsonResource):
 
         store_queries = []
         for server_name, key_ids in query.items():
-            if (
-                self.federation_domain_whitelist is not None
-                and server_name not in self.federation_domain_whitelist
-            ):
-                logger.debug("Federation denied with %s", server_name)
-                continue
-
             if not key_ids:
                 key_ids = (None,)
             for key_id in key_ids:
@@ -153,21 +146,28 @@ class RemoteKey(DirectServeJsonResource):
 
         time_now_ms = self.clock.time_msec()
 
-        # Note that the value is unused.
+        # Map server_name->key_id->int. Note that the value of the init is unused.
+        # XXX: why don't we just use a set?
         cache_misses: Dict[str, Dict[str, int]] = {}
         for (server_name, key_id, _), key_results in cached.items():
             results = [(result["ts_added_ms"], result) for result in key_results]
 
-            if not results and key_id is not None:
-                cache_misses.setdefault(server_name, {})[key_id] = 0
+            if key_id is None:
+                # all keys were requested. Just return what we have without worrying
+                # about validity
+                for _, result in results:
+                    # Cast to bytes since postgresql returns a memoryview.
+                    json_results.add(bytes(result["key_json"]))
                 continue
 
-            if key_id is not None:
+            miss = False
+            if not results:
+                miss = True
+            else:
                 ts_added_ms, most_recent_result = max(results)
                 ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
                 req_key = query.get(server_name, {}).get(key_id, {})
                 req_valid_until = req_key.get("minimum_valid_until_ts")
-                miss = False
                 if req_valid_until is not None:
                     if ts_valid_until_ms < req_valid_until:
                         logger.debug(
@@ -211,19 +211,20 @@ class RemoteKey(DirectServeJsonResource):
                         ts_valid_until_ms,
                         time_now_ms,
                     )
-
-                if miss:
-                    cache_misses.setdefault(server_name, {})[key_id] = 0
                 # Cast to bytes since postgresql returns a memoryview.
                 json_results.add(bytes(most_recent_result["key_json"]))
-            else:
-                for _, result in results:
-                    # Cast to bytes since postgresql returns a memoryview.
-                    json_results.add(bytes(result["key_json"]))
+
+            if miss and query_remote_on_cache_miss:
+                # only bother attempting to fetch keys from servers on our whitelist
+                if (
+                    self.federation_domain_whitelist is None
+                    or server_name in self.federation_domain_whitelist
+                ):
+                    cache_misses.setdefault(server_name, {})[key_id] = 0
 
         # If there is a cache miss, request the missing keys, then recurse (and
         # ensure the result is sent).
-        if cache_misses and query_remote_on_cache_miss:
+        if cache_misses:
             await yieldable_gather_results(
                 lambda t: self.fetcher.get_keys(*t),
                 (