diff options
-rw-r--r-- | changelog.d/13683.bugfix | 1 | ||||
-rw-r--r-- | synapse/rest/key/v2/remote_key_resource.py | 41 |
2 files changed, 22 insertions, 20 deletions
diff --git a/changelog.d/13683.bugfix b/changelog.d/13683.bugfix new file mode 100644 index 0000000000..538534fec1 --- /dev/null +++ b/changelog.d/13683.bugfix @@ -0,0 +1 @@ +Fix a long-standing bug which meant that keys for unwhitelisted servers were not returned by `/_matrix/key/v2/query`. diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py index f597157581..7f8ad29566 100644 --- a/synapse/rest/key/v2/remote_key_resource.py +++ b/synapse/rest/key/v2/remote_key_resource.py @@ -135,13 +135,6 @@ class RemoteKey(DirectServeJsonResource): store_queries = [] for server_name, key_ids in query.items(): - if ( - self.federation_domain_whitelist is not None - and server_name not in self.federation_domain_whitelist - ): - logger.debug("Federation denied with %s", server_name) - continue - if not key_ids: key_ids = (None,) for key_id in key_ids: @@ -153,21 +146,28 @@ class RemoteKey(DirectServeJsonResource): time_now_ms = self.clock.time_msec() - # Note that the value is unused. + # Map server_name->key_id->int. Note that the value of the init is unused. + # XXX: why don't we just use a set? cache_misses: Dict[str, Dict[str, int]] = {} for (server_name, key_id, _), key_results in cached.items(): results = [(result["ts_added_ms"], result) for result in key_results] - if not results and key_id is not None: - cache_misses.setdefault(server_name, {})[key_id] = 0 + if key_id is None: + # all keys were requested. Just return what we have without worrying + # about validity + for _, result in results: + # Cast to bytes since postgresql returns a memoryview. + json_results.add(bytes(result["key_json"])) continue - if key_id is not None: + miss = False + if not results: + miss = True + else: ts_added_ms, most_recent_result = max(results) ts_valid_until_ms = most_recent_result["ts_valid_until_ms"] req_key = query.get(server_name, {}).get(key_id, {}) req_valid_until = req_key.get("minimum_valid_until_ts") - miss = False if req_valid_until is not None: if ts_valid_until_ms < req_valid_until: logger.debug( @@ -211,19 +211,20 @@ class RemoteKey(DirectServeJsonResource): ts_valid_until_ms, time_now_ms, ) - - if miss: - cache_misses.setdefault(server_name, {})[key_id] = 0 # Cast to bytes since postgresql returns a memoryview. json_results.add(bytes(most_recent_result["key_json"])) - else: - for _, result in results: - # Cast to bytes since postgresql returns a memoryview. - json_results.add(bytes(result["key_json"])) + + if miss and query_remote_on_cache_miss: + # only bother attempting to fetch keys from servers on our whitelist + if ( + self.federation_domain_whitelist is None + or server_name in self.federation_domain_whitelist + ): + cache_misses.setdefault(server_name, {})[key_id] = 0 # If there is a cache miss, request the missing keys, then recurse (and # ensure the result is sent). - if cache_misses and query_remote_on_cache_miss: + if cache_misses: await yieldable_gather_results( lambda t: self.fetcher.get_keys(*t), ( |