summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11379.bugfix1
-rw-r--r--synapse/crypto/keyring.py69
-rw-r--r--tests/crypto/test_keyring.py56
3 files changed, 95 insertions, 31 deletions
diff --git a/changelog.d/11379.bugfix b/changelog.d/11379.bugfix
new file mode 100644
index 0000000000..a49d4eb776
--- /dev/null
+++ b/changelog.d/11379.bugfix
@@ -0,0 +1 @@
+Fix an issue introduced in v1.47.0 which prevented servers re-joining rooms they had previously left, if their signing keys were replaced.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f641ab7ef5..4cda439ad9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -1,5 +1,4 @@
-# Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2017, 2018 New Vector Ltd
+# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -120,16 +119,6 @@ class VerifyJsonRequest:
             key_ids=key_ids,
         )
 
-    def to_fetch_key_request(self) -> "_FetchKeyRequest":
-        """Create a key fetch request for all keys needed to satisfy the
-        verification request.
-        """
-        return _FetchKeyRequest(
-            server_name=self.server_name,
-            minimum_valid_until_ts=self.minimum_valid_until_ts,
-            key_ids=self.key_ids,
-        )
-
 
 class KeyLookupError(ValueError):
     pass
@@ -179,8 +168,22 @@ class Keyring:
             clock=hs.get_clock(),
             process_batch_callback=self._inner_fetch_key_requests,
         )
-        self.verify_key = get_verify_key(hs.signing_key)
-        self.hostname = hs.hostname
+
+        self._hostname = hs.hostname
+
+        # build a FetchKeyResult for each of our own keys, to shortcircuit the
+        # fetcher.
+        self._local_verify_keys: Dict[str, FetchKeyResult] = {}
+        for key_id, key in hs.config.key.old_signing_keys.items():
+            self._local_verify_keys[key_id] = FetchKeyResult(
+                verify_key=key, valid_until_ts=key.expired_ts
+            )
+
+        vk = get_verify_key(hs.signing_key)
+        self._local_verify_keys[f"{vk.alg}:{vk.version}"] = FetchKeyResult(
+            verify_key=vk,
+            valid_until_ts=2 ** 63,  # fake future timestamp
+        )
 
     async def verify_json_for_server(
         self,
@@ -267,22 +270,32 @@ class Keyring:
                 Codes.UNAUTHORIZED,
             )
 
-        # If we are the originating server don't fetch verify key for self over federation
-        if verify_request.server_name == self.hostname:
-            await self._process_json(self.verify_key, verify_request)
-            return
+        found_keys: Dict[str, FetchKeyResult] = {}
 
-        # Add the keys we need to verify to the queue for retrieval. We queue
-        # up requests for the same server so we don't end up with many in flight
-        # requests for the same keys.
-        key_request = verify_request.to_fetch_key_request()
-        found_keys_by_server = await self._server_queue.add_to_queue(
-            key_request, key=verify_request.server_name
-        )
+        # If we are the originating server, short-circuit the key-fetch for any keys
+        # we already have
+        if verify_request.server_name == self._hostname:
+            for key_id in verify_request.key_ids:
+                if key_id in self._local_verify_keys:
+                    found_keys[key_id] = self._local_verify_keys[key_id]
+
+        key_ids_to_find = set(verify_request.key_ids) - found_keys.keys()
+        if key_ids_to_find:
+            # Add the keys we need to verify to the queue for retrieval. We queue
+            # up requests for the same server so we don't end up with many in flight
+            # requests for the same keys.
+            key_request = _FetchKeyRequest(
+                server_name=verify_request.server_name,
+                minimum_valid_until_ts=verify_request.minimum_valid_until_ts,
+                key_ids=list(key_ids_to_find),
+            )
+            found_keys_by_server = await self._server_queue.add_to_queue(
+                key_request, key=verify_request.server_name
+            )
 
-        # Since we batch up requests the returned set of keys may contain keys
-        # from other servers, so we pull out only the ones we care about.s
-        found_keys = found_keys_by_server.get(verify_request.server_name, {})
+            # Since we batch up requests the returned set of keys may contain keys
+            # from other servers, so we pull out only the ones we care about.
+            found_keys.update(found_keys_by_server.get(verify_request.server_name, {}))
 
         # Verify each signature we got valid keys for, raising if we can't
         # verify any of them.
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index cbecc1c20f..4d1e154578 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -1,4 +1,4 @@
-# Copyright 2017 New Vector Ltd
+# Copyright 2017-2021 The Matrix.org Foundation C.I.C
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -40,7 +40,7 @@ from synapse.storage.keys import FetchKeyResult
 
 from tests import unittest
 from tests.test_utils import make_awaitable
-from tests.unittest import logcontext_clean
+from tests.unittest import logcontext_clean, override_config
 
 
 class MockPerspectiveServer:
@@ -197,7 +197,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # self.assertFalse(d.called)
         self.get_success(d)
 
-    def test_verify_for_server_locally(self):
+    def test_verify_for_local_server(self):
         """Ensure that locally signed JSON can be verified without fetching keys
         over federation
         """
@@ -209,6 +209,56 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
         self.get_success(d)
 
+    OLD_KEY = signedjson.key.generate_signing_key("old")
+
+    @override_config(
+        {
+            "old_signing_keys": {
+                f"{OLD_KEY.alg}:{OLD_KEY.version}": {
+                    "key": encode_verify_key_base64(OLD_KEY.verify_key),
+                    "expired_ts": 1000,
+                }
+            }
+        }
+    )
+    def test_verify_for_local_server_old_key(self):
+        """Can also use keys in old_signing_keys for verification"""
+        json1 = {}
+        signedjson.sign.sign_json(json1, self.hs.hostname, self.OLD_KEY)
+
+        kr = keyring.Keyring(self.hs)
+        d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
+        self.get_success(d)
+
+    def test_verify_for_local_server_unknown_key(self):
+        """Local keys that we no longer have should be fetched via the fetcher"""
+
+        # the key we'll sign things with (nb, not known to the Keyring)
+        key2 = signedjson.key.generate_signing_key("2")
+
+        # set up a mock fetcher which will return the key
+        async def get_keys(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            self.assertEqual(server_name, self.hs.hostname)
+            self.assertEqual(key_ids, [get_key_id(key2)])
+
+            return {get_key_id(key2): FetchKeyResult(get_verify_key(key2), 1200)}
+
+        mock_fetcher = Mock()
+        mock_fetcher.get_keys = Mock(side_effect=get_keys)
+        kr = keyring.Keyring(
+            self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
+        )
+
+        # sign the json
+        json1 = {}
+        signedjson.sign.sign_json(json1, self.hs.hostname, key2)
+
+        # ... and check we can verify it.
+        d = kr.verify_json_for_server(self.hs.hostname, json1, 0)
+        self.get_success(d)
+
     def test_verify_json_for_server_with_null_valid_until_ms(self):
         """Tests that we correctly handle key requests for keys we've stored
         with a null `ts_valid_until_ms`