summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15423.bugfix1
-rw-r--r--synapse/crypto/keyring.py26
-rw-r--r--synapse/storage/databases/main/keys.py6
-rw-r--r--tests/crypto/test_keyring.py4
-rw-r--r--tests/storage/test_keys.py18
-rw-r--r--tests/unittest.py16
6 files changed, 43 insertions, 28 deletions
diff --git a/changelog.d/15423.bugfix b/changelog.d/15423.bugfix
new file mode 100644
index 0000000000..dfb60ddd2f
--- /dev/null
+++ b/changelog.d/15423.bugfix
@@ -0,0 +1 @@
+Improve robustness when handling a perspective key response by deduplicating received server keys.
\ No newline at end of file
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index d710607c63..d2f99dc2ac 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -721,7 +721,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         )
 
         keys: Dict[str, Dict[str, FetchKeyResult]] = {}
-        added_keys: List[Tuple[str, str, FetchKeyResult]] = []
+        added_keys: Dict[Tuple[str, str], FetchKeyResult] = {}
 
         time_now_ms = self.clock.time_msec()
 
@@ -752,9 +752,27 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 # we continue to process the rest of the response
                 continue
 
-            added_keys.extend(
-                (server_name, key_id, key) for key_id, key in processed_response.items()
-            )
+            for key_id, key in processed_response.items():
+                dict_key = (server_name, key_id)
+                if dict_key in added_keys:
+                    already_present_key = added_keys[dict_key]
+                    logger.warning(
+                        "Duplicate server keys for %s (%s) from perspective %s (%r, %r)",
+                        server_name,
+                        key_id,
+                        perspective_name,
+                        already_present_key,
+                        key,
+                    )
+
+                    if already_present_key.valid_until_ts > key.valid_until_ts:
+                        # Favour the entry with the largest valid_until_ts,
+                        # as `old_verify_keys` are also collected from this
+                        # response.
+                        continue
+
+                added_keys[dict_key] = key
+
             keys.setdefault(server_name, {}).update(processed_response)
 
         await self.store.store_server_verify_keys(
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 0a19f607bd..89c37a4eb5 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -15,7 +15,7 @@
 
 import itertools
 import logging
-from typing import Any, Dict, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
 
 from signedjson.key import decode_verify_key_bytes
 
@@ -95,7 +95,7 @@ class KeyStore(SQLBaseStore):
         self,
         from_server: str,
         ts_added_ms: int,
-        verify_keys: Iterable[Tuple[str, str, FetchKeyResult]],
+        verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
     ) -> None:
         """Stores NACL verification keys for remote servers.
         Args:
@@ -108,7 +108,7 @@ class KeyStore(SQLBaseStore):
         key_values = []
         value_values = []
         invalidations = []
-        for server_name, key_id, fetch_result in verify_keys:
+        for (server_name, key_id), fetch_result in verify_keys.items():
             key_values.append((server_name, key_id))
             value_values.append(
                 (
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 1b9696748f..66102ab934 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -193,7 +193,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         r = self.hs.get_datastores().main.store_server_verify_keys(
             "server9",
             int(time.time() * 1000),
-            [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
+            {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), 1000)},
         )
         self.get_success(r)
 
@@ -291,7 +291,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
             # None is not a valid value in FetchKeyResult, but we're abusing this
             # API to insert null values into the database. The nulls get converted
             # to 0 when fetched in KeyStore.get_server_verify_keys.
-            [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), None))],  # type: ignore[arg-type]
+            {("server9", get_key_id(key1)): FetchKeyResult(get_verify_key(key1), None)},  # type: ignore[arg-type]
         )
         self.get_success(r)
 
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index ba68171ad7..5901d80f26 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -46,10 +46,10 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
             store.store_server_verify_keys(
                 "from_server",
                 10,
-                [
-                    ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
-                    ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
-                ],
+                {
+                    ("server1", key_id_1): FetchKeyResult(KEY_1, 100),
+                    ("server1", key_id_2): FetchKeyResult(KEY_2, 200),
+                },
             )
         )
 
@@ -90,10 +90,10 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
             store.store_server_verify_keys(
                 "from_server",
                 0,
-                [
-                    ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
-                    ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
-                ],
+                {
+                    ("srv1", key_id_1): FetchKeyResult(KEY_1, 100),
+                    ("srv1", key_id_2): FetchKeyResult(KEY_2, 200),
+                },
             )
         )
 
@@ -119,7 +119,7 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
             signedjson.key.generate_signing_key("key2")
         )
         d = store.store_server_verify_keys(
-            "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
+            "from_server", 10, {("srv1", key_id_2): FetchKeyResult(new_key_2, 300)}
         )
         self.get_success(d)
 
diff --git a/tests/unittest.py b/tests/unittest.py
index 8a16fd3665..93fee1c0e6 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -793,16 +793,12 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
             hs.get_datastores().main.store_server_verify_keys(
                 from_server=self.OTHER_SERVER_NAME,
                 ts_added_ms=clock.time_msec(),
-                verify_keys=[
-                    (
-                        self.OTHER_SERVER_NAME,
-                        verify_key_id,
-                        FetchKeyResult(
-                            verify_key=verify_key,
-                            valid_until_ts=clock.time_msec() + 10000,
-                        ),
-                    )
-                ],
+                verify_keys={
+                    (self.OTHER_SERVER_NAME, verify_key_id): FetchKeyResult(
+                        verify_key=verify_key,
+                        valid_until_ts=clock.time_msec() + 10000,
+                    ),
+                },
             )
         )