summary refs log tree commit diff
path: root/tests/crypto
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-06-02 16:37:59 +0100
committerGitHub <noreply@github.com>2021-06-02 16:37:59 +0100
commitfc3d2dc269a79e0404d0a9867e5042354d59147f (patch)
tree53ac2672942797a76a2f673040b0535c318d6c43 /tests/crypto
parentDo not show invite-only rooms in spaces summary (unless joined/invited). (#10... (diff)
downloadsynapse-fc3d2dc269a79e0404d0a9867e5042354d59147f.tar.xz
Rewrite the KeyRing (#10035)
Diffstat (limited to 'tests/crypto')
-rw-r--r--tests/crypto/test_keyring.py170
1 files changed, 86 insertions, 84 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 2775dfd880..745c295d3b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -12,6 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import time
+from typing import Dict, List
 from unittest.mock import Mock
 
 import attr
@@ -21,7 +22,6 @@ import signedjson.sign
 from nacl.signing import SigningKey
 from signedjson.key import encode_verify_key_base64, get_verify_key
 
-from twisted.internet import defer
 from twisted.internet.defer import Deferred, ensureDeferred
 
 from synapse.api.errors import SynapseError
@@ -92,23 +92,23 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # deferred completes.
         first_lookup_deferred = Deferred()
 
-        async def first_lookup_fetch(keys_to_fetch):
-            self.assertEquals(current_context().request.id, "context_11")
-            self.assertEqual(keys_to_fetch, {"server10": {get_key_id(key1): 0}})
+        async def first_lookup_fetch(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            # self.assertEquals(current_context().request.id, "context_11")
+            self.assertEqual(server_name, "server10")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 0)
 
             await make_deferred_yieldable(first_lookup_deferred)
-            return {
-                "server10": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
-                }
-            }
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
 
         mock_fetcher.get_keys.side_effect = first_lookup_fetch
 
         async def first_lookup():
             with LoggingContext("context_11", request=FakeRequest("context_11")):
                 res_deferreds = kr.verify_json_objects_for_server(
-                    [("server10", json1, 0, "test10"), ("server11", {}, 0, "test11")]
+                    [("server10", json1, 0), ("server11", {}, 0)]
                 )
 
                 # the unsigned json should be rejected pretty quickly
@@ -126,18 +126,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
         d0 = ensureDeferred(first_lookup())
 
+        self.pump()
+
         mock_fetcher.get_keys.assert_called_once()
 
         # a second request for a server with outstanding requests
         # should block rather than start a second call
 
-        async def second_lookup_fetch(keys_to_fetch):
-            self.assertEquals(current_context().request.id, "context_12")
-            return {
-                "server10": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)
-                }
-            }
+        async def second_lookup_fetch(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            # self.assertEquals(current_context().request.id, "context_12")
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 100)}
 
         mock_fetcher.get_keys.reset_mock()
         mock_fetcher.get_keys.side_effect = second_lookup_fetch
@@ -146,7 +146,13 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         async def second_lookup():
             with LoggingContext("context_12", request=FakeRequest("context_12")):
                 res_deferreds_2 = kr.verify_json_objects_for_server(
-                    [("server10", json1, 0, "test")]
+                    [
+                        (
+                            "server10",
+                            json1,
+                            0,
+                        )
+                    ]
                 )
                 res_deferreds_2[0].addBoth(self.check_context, None)
                 second_lookup_state[0] = 1
@@ -183,11 +189,11 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         signedjson.sign.sign_json(json1, "server9", key1)
 
         # should fail immediately on an unsigned object
-        d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
+        d = kr.verify_json_for_server("server9", {}, 0)
         self.get_failure(d, SynapseError)
 
         # should succeed on a signed object
-        d = _verify_json_for_server(kr, "server9", json1, 500, "test signed")
+        d = kr.verify_json_for_server("server9", json1, 500)
         # self.assertFalse(d.called)
         self.get_success(d)
 
@@ -214,24 +220,24 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         signedjson.sign.sign_json(json1, "server9", key1)
 
         # should fail immediately on an unsigned object
-        d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned")
+        d = kr.verify_json_for_server("server9", {}, 0)
         self.get_failure(d, SynapseError)
 
         # should fail on a signed object with a non-zero minimum_valid_until_ms,
         # as it tries to refetch the keys and fails.
-        d = _verify_json_for_server(
-            kr, "server9", json1, 500, "test signed non-zero min"
-        )
+        d = kr.verify_json_for_server("server9", json1, 500)
         self.get_failure(d, SynapseError)
 
         # We expect the keyring tried to refetch the key once.
         mock_fetcher.get_keys.assert_called_once_with(
-            {"server9": {get_key_id(key1): 500}}
+            "server9", [get_key_id(key1)], 500
         )
 
         # should succeed on a signed object with a 0 minimum_valid_until_ms
-        d = _verify_json_for_server(
-            kr, "server9", json1, 0, "test signed with zero min"
+        d = kr.verify_json_for_server(
+            "server9",
+            json1,
+            0,
         )
         self.get_success(d)
 
@@ -239,15 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """Two requests for the same key should be deduped."""
         key1 = signedjson.key.generate_signing_key(1)
 
-        async def get_keys(keys_to_fetch):
+        async def get_keys(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
             # there should only be one request object (with the max validity)
-            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+            self.assertEqual(server_name, "server1")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 1500)
 
-            return {
-                "server1": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                }
-            }
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
 
         mock_fetcher = Mock()
         mock_fetcher.get_keys = Mock(side_effect=get_keys)
@@ -259,7 +265,14 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         # the first request should succeed; the second should fail because the key
         # has expired
         results = kr.verify_json_objects_for_server(
-            [("server1", json1, 500, "test1"), ("server1", json1, 1500, "test2")]
+            [
+                (
+                    "server1",
+                    json1,
+                    500,
+                ),
+                ("server1", json1, 1500),
+            ]
         )
         self.assertEqual(len(results), 2)
         self.get_success(results[0])
@@ -274,19 +287,21 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """If the first fetcher cannot provide a recent enough key, we fall back"""
         key1 = signedjson.key.generate_signing_key(1)
 
-        async def get_keys1(keys_to_fetch):
-            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return {
-                "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
-            }
-
-        async def get_keys2(keys_to_fetch):
-            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return {
-                "server1": {
-                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                }
-            }
+        async def get_keys1(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            self.assertEqual(server_name, "server1")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 1500)
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
+
+        async def get_keys2(
+            server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+        ) -> Dict[str, FetchKeyResult]:
+            self.assertEqual(server_name, "server1")
+            self.assertEqual(key_ids, [get_key_id(key1)])
+            self.assertEqual(minimum_valid_until_ts, 1500)
+            return {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)}
 
         mock_fetcher1 = Mock()
         mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
@@ -298,7 +313,18 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         signedjson.sign.sign_json(json1, "server1", key1)
 
         results = kr.verify_json_objects_for_server(
-            [("server1", json1, 1200, "test1"), ("server1", json1, 1500, "test2")]
+            [
+                (
+                    "server1",
+                    json1,
+                    1200,
+                ),
+                (
+                    "server1",
+                    json1,
+                    1500,
+                ),
+            ]
         )
         self.assertEqual(len(results), 2)
         self.get_success(results[0])
@@ -349,9 +375,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.http_client.get_json.side_effect = get_json
 
-        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
-        k = keys[SERVER_NAME][testverifykey_id]
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
+        k = keys[testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
         self.assertEqual(k.verify_key.alg, "ed25519")
@@ -378,7 +403,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         # change the server name: the result should be ignored
         response["server_name"] = "OTHER_SERVER"
 
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
         self.assertEqual(keys, {})
 
 
@@ -465,10 +490,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
 
-        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
-        self.assertIn(SERVER_NAME, keys)
-        k = keys[SERVER_NAME][testverifykey_id]
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
+        self.assertIn(testverifykey_id, keys)
+        k = keys[testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
         self.assertEqual(k.verify_key.alg, "ed25519")
@@ -515,10 +539,9 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
 
-        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
-        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
-        self.assertIn(SERVER_NAME, keys)
-        k = keys[SERVER_NAME][testverifykey_id]
+        keys = self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
+        self.assertIn(testverifykey_id, keys)
+        k = keys[testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
         self.assertEqual(k.verify_key.alg, "ed25519")
@@ -559,14 +582,13 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         def get_key_from_perspectives(response):
             fetcher = PerspectivesKeyFetcher(self.hs)
-            keys_to_fetch = {SERVER_NAME: {"key1": 0}}
             self.expect_outgoing_key_query(SERVER_NAME, "key1", response)
-            return self.get_success(fetcher.get_keys(keys_to_fetch))
+            return self.get_success(fetcher.get_keys(SERVER_NAME, ["key1"], 0))
 
         # start with a valid response so we can check we are testing the right thing
         response = build_response()
         keys = get_key_from_perspectives(response)
-        k = keys[SERVER_NAME][testverifykey_id]
+        k = keys[testverifykey_id]
         self.assertEqual(k.verify_key, testverifykey)
 
         # remove the perspectives server's signature
@@ -585,23 +607,3 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 def get_key_id(key):
     """Get the matrix ID tag for a given SigningKey or VerifyKey"""
     return "%s:%s" % (key.alg, key.version)
-
-
-@defer.inlineCallbacks
-def run_in_context(f, *args, **kwargs):
-    with LoggingContext("testctx"):
-        rv = yield f(*args, **kwargs)
-    return rv
-
-
-def _verify_json_for_server(kr, *args):
-    """thin wrapper around verify_json_for_server which makes sure it is wrapped
-    with the patched defer.inlineCallbacks.
-    """
-
-    @defer.inlineCallbacks
-    def v():
-        rv1 = yield kr.verify_json_for_server(*args)
-        return rv1
-
-    return run_in_context(v)