diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d710607c63..d2f99dc2ac 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -721,7 +721,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
keys: Dict[str, Dict[str, FetchKeyResult]] = {}
- added_keys: List[Tuple[str, str, FetchKeyResult]] = []
+ added_keys: Dict[Tuple[str, str], FetchKeyResult] = {}
time_now_ms = self.clock.time_msec()
@@ -752,9 +752,27 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
# we continue to process the rest of the response
continue
- added_keys.extend(
- (server_name, key_id, key) for key_id, key in processed_response.items()
- )
+ for key_id, key in processed_response.items():
+ dict_key = (server_name, key_id)
+ if dict_key in added_keys:
+ already_present_key = added_keys[dict_key]
+ logger.warning(
+ "Duplicate server keys for %s (%s) from perspective %s (%r, %r)",
+ server_name,
+ key_id,
+ perspective_name,
+ already_present_key,
+ key,
+ )
+
+ if already_present_key.valid_until_ts > key.valid_until_ts:
+ # Favour the entry with the largest valid_until_ts,
+ # as `old_verify_keys` are also collected from this
+ # response.
+ continue
+
+ added_keys[dict_key] = key
+
keys.setdefault(server_name, {}).update(processed_response)
await self.store.store_server_verify_keys(
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 0a19f607bd..89c37a4eb5 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -15,7 +15,7 @@
import itertools
import logging
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
from signedjson.key import decode_verify_key_bytes
@@ -95,7 +95,7 @@ class KeyStore(SQLBaseStore):
self,
from_server: str,
ts_added_ms: int,
- verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+ verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
) -> None:
"""Stores NACL verification keys for remote servers.
Args:
@@ -108,7 +108,7 @@ class KeyStore(SQLBaseStore):
key_values = []
value_values = []
invalidations = []
- for server_name, key_id, fetch_result in verify_keys:
+ for (server_name, key_id), fetch_result in verify_keys.items():
key_values.append((server_name, key_id))
value_values.append(
(
|