summary refs log tree commit diff
path: root/tests/crypto/test_keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/crypto/test_keyring.py')
-rw-r--r--tests/crypto/test_keyring.py135
1 files changed, 108 insertions, 27 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 3933ad4347..096401938d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,6 +19,7 @@ from mock import Mock
 import canonicaljson
 import signedjson.key
 import signedjson.sign
+from signedjson.key import get_verify_key
 
 from twisted.internet import defer
 
@@ -137,7 +138,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 context_11.request = "11"
 
                 res_deferreds = kr.verify_json_objects_for_server(
-                    [("server10", json1), ("server11", {})]
+                    [("server10", json1, 0), ("server11", {}, 0)]
                 )
 
                 # the unsigned json should be rejected pretty quickly
@@ -174,7 +175,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 self.http_client.post_json.return_value = defer.Deferred()
 
                 res_deferreds_2 = kr.verify_json_objects_for_server(
-                    [("server10", json1)]
+                    [("server10", json1, 0)]
                 )
                 res_deferreds_2[0].addBoth(self.check_context, None)
                 yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
@@ -197,31 +198,108 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         kr = keyring.Keyring(self.hs)
 
         key1 = signedjson.key.generate_signing_key(1)
-        key1_id = "%s:%s" % (key1.alg, key1.version)
-
         r = self.hs.datastore.store_server_verify_keys(
             "server9",
             time.time() * 1000,
-            [
-                (
-                    "server9",
-                    key1_id,
-                    FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
-                ),
-            ],
+            [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
         )
         self.get_success(r)
+
         json1 = {}
         signedjson.sign.sign_json(json1, "server9", key1)
 
         # should fail immediately on an unsigned object
-        d = _verify_json_for_server(kr, "server9", {})
+        d = _verify_json_for_server(kr, "server9", {}, 0)
         self.failureResultOf(d, SynapseError)
 
-        d = _verify_json_for_server(kr, "server9", json1)
-        self.assertFalse(d.called)
+        # should suceed on a signed object
+        d = _verify_json_for_server(kr, "server9", json1, 500)
+        # self.assertFalse(d.called)
         self.get_success(d)
 
+    def test_verify_json_dedupes_key_requests(self):
+        """Two requests for the same key should be deduped."""
+        key1 = signedjson.key.generate_signing_key(1)
+
+        def get_keys(keys_to_fetch):
+            # there should only be one request object (with the max validity)
+            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+
+            return defer.succeed(
+                {
+                    "server1": {
+                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+                    }
+                }
+            )
+
+        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher.get_keys = Mock(side_effect=get_keys)
+        kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
+
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server1", key1)
+
+        # the first request should succeed; the second should fail because the key
+        # has expired
+        results = kr.verify_json_objects_for_server(
+            [("server1", json1, 500), ("server1", json1, 1500)]
+        )
+        self.assertEqual(len(results), 2)
+        self.get_success(results[0])
+        e = self.get_failure(results[1], SynapseError).value
+        self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+        self.assertEqual(e.code, 401)
+
+        # there should have been a single call to the fetcher
+        mock_fetcher.get_keys.assert_called_once()
+
+    def test_verify_json_falls_back_to_other_fetchers(self):
+        """If the first fetcher cannot provide a recent enough key, we fall back"""
+        key1 = signedjson.key.generate_signing_key(1)
+
+        def get_keys1(keys_to_fetch):
+            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+            return defer.succeed(
+                {
+                    "server1": {
+                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
+                    }
+                }
+            )
+
+        def get_keys2(keys_to_fetch):
+            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+            return defer.succeed(
+                {
+                    "server1": {
+                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+                    }
+                }
+            )
+
+        mock_fetcher1 = keyring.KeyFetcher()
+        mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
+        mock_fetcher2 = keyring.KeyFetcher()
+        mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
+        kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
+
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server1", key1)
+
+        results = kr.verify_json_objects_for_server(
+            [("server1", json1, 1200), ("server1", json1, 1500)]
+        )
+        self.assertEqual(len(results), 2)
+        self.get_success(results[0])
+        e = self.get_failure(results[1], SynapseError).value
+        self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+        self.assertEqual(e.code, 401)
+
+        # there should have been a single call to each fetcher
+        mock_fetcher1.get_keys.assert_called_once()
+        mock_fetcher2.get_keys.assert_called_once()
+
 
 class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
@@ -260,8 +338,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.http_client.get_json.side_effect = get_json
 
-        server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
-        keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
+        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
         k = keys[SERVER_NAME][testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
@@ -288,9 +366,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         # change the server name: it should cause a rejection
         response["server_name"] = "OTHER_SERVER"
-        self.get_failure(
-            fetcher.get_keys(server_name_and_key_ids), KeyLookupError
-        )
+        self.get_failure(fetcher.get_keys(keys_to_fetch), KeyLookupError)
 
 
 class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
@@ -342,8 +418,8 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.http_client.post_json.side_effect = post_json
 
-        server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
-        keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
+        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
         self.assertIn(SERVER_NAME, keys)
         k = keys[SERVER_NAME][testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
@@ -401,7 +477,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         def get_key_from_perspectives(response):
             fetcher = PerspectivesKeyFetcher(self.hs)
-            server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
+            keys_to_fetch = {SERVER_NAME: {"key1": 0}}
 
             def post_json(destination, path, data, **kwargs):
                 self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -410,9 +486,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
             self.http_client.post_json.side_effect = post_json
 
-            return self.get_success(
-                fetcher.get_keys(server_name_and_key_ids)
-            )
+            return self.get_success(fetcher.get_keys(keys_to_fetch))
 
         # start with a valid response so we can check we are testing the right thing
         response = build_response()
@@ -435,6 +509,11 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
 
 
+def get_key_id(key):
+    """Get the matrix ID tag for a given SigningKey or VerifyKey"""
+    return "%s:%s" % (key.alg, key.version)
+
+
 @defer.inlineCallbacks
 def run_in_context(f, *args, **kwargs):
     with LoggingContext("testctx") as ctx:
@@ -445,14 +524,16 @@ def run_in_context(f, *args, **kwargs):
     defer.returnValue(rv)
 
 
-def _verify_json_for_server(keyring, server_name, json_object):
+def _verify_json_for_server(keyring, server_name, json_object, validity_time):
     """thin wrapper around verify_json_for_server which makes sure it is wrapped
     with the patched defer.inlineCallbacks.
     """
 
     @defer.inlineCallbacks
     def v():
-        rv1 = yield keyring.verify_json_for_server(server_name, json_object)
+        rv1 = yield keyring.verify_json_for_server(
+            server_name, json_object, validity_time
+        )
         defer.returnValue(rv1)
 
     return run_in_context(v)