summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/5321.bugfix1
-rw-r--r--synapse/crypto/keyring.py167
-rw-r--r--synapse/federation/federation_base.py4
-rw-r--r--synapse/federation/transport/server.py4
-rw-r--r--synapse/groups/attestations.py5
-rw-r--r--tests/crypto/test_keyring.py135
6 files changed, 228 insertions, 88 deletions
diff --git a/changelog.d/5321.bugfix b/changelog.d/5321.bugfix
new file mode 100644
index 0000000000..943a61956d
--- /dev/null
+++ b/changelog.d/5321.bugfix
@@ -0,0 +1 @@
+Ensure that we have an up-to-date copy of the signing key when validating incoming federation requests.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index b2f4cea536..cdec06c88e 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from collections import defaultdict
 
 import six
 from six import raise_from
@@ -70,6 +71,9 @@ class VerifyKeyRequest(object):
 
         json_object(dict): The JSON object to verify.
 
+        minimum_valid_until_ts (int): time at which we require the signing key to
+            be valid. (0 implies we don't care)
+
         deferred(Deferred[str, str, nacl.signing.VerifyKey]):
             A deferred (server_name, key_id, verify_key) tuple that resolves when
             a verify key has been fetched. The deferreds' callbacks are run with no
@@ -82,7 +86,8 @@ class VerifyKeyRequest(object):
     server_name = attr.ib()
     key_ids = attr.ib()
     json_object = attr.ib()
-    deferred = attr.ib()
+    minimum_valid_until_ts = attr.ib()
+    deferred = attr.ib(default=attr.Factory(defer.Deferred))
 
 
 class KeyLookupError(ValueError):
@@ -90,14 +95,16 @@ class KeyLookupError(ValueError):
 
 
 class Keyring(object):
-    def __init__(self, hs):
+    def __init__(self, hs, key_fetchers=None):
         self.clock = hs.get_clock()
 
-        self._key_fetchers = (
-            StoreKeyFetcher(hs),
-            PerspectivesKeyFetcher(hs),
-            ServerKeyFetcher(hs),
-        )
+        if key_fetchers is None:
+            key_fetchers = (
+                StoreKeyFetcher(hs),
+                PerspectivesKeyFetcher(hs),
+                ServerKeyFetcher(hs),
+            )
+        self._key_fetchers = key_fetchers
 
         # map from server name to Deferred. Has an entry for each server with
         # an ongoing key download; the Deferred completes once the download
@@ -106,9 +113,25 @@ class Keyring(object):
         # These are regular, logcontext-agnostic Deferreds.
         self.key_downloads = {}
 
-    def verify_json_for_server(self, server_name, json_object):
+    def verify_json_for_server(self, server_name, json_object, validity_time):
+        """Verify that a JSON object has been signed by a given server
+
+        Args:
+            server_name (str): name of the server which must have signed this object
+
+            json_object (dict): object to be checked
+
+            validity_time (int): timestamp at which we require the signing key to
+                be valid. (0 implies we don't care)
+
+        Returns:
+            Deferred[None]: completes if the the object was correctly signed, otherwise
+                errbacks with an error
+        """
+        req = server_name, json_object, validity_time
+
         return logcontext.make_deferred_yieldable(
-            self.verify_json_objects_for_server([(server_name, json_object)])[0]
+            self.verify_json_objects_for_server((req,))[0]
         )
 
     def verify_json_objects_for_server(self, server_and_json):
@@ -116,10 +139,12 @@ class Keyring(object):
         necessary.
 
         Args:
-            server_and_json (list): List of pairs of (server_name, json_object)
+            server_and_json (iterable[Tuple[str, dict, int]):
+                Iterable of triplets of (server_name, json_object, validity_time)
+                validity_time is a timestamp at which the signing key must be valid.
 
         Returns:
-            List<Deferred>: for each input pair, a deferred indicating success
+            List<Deferred[None]>: for each input triplet, a deferred indicating success
                 or failure to verify each json object's signature for the given
                 server_name. The deferreds run their callbacks in the sentinel
                 logcontext.
@@ -128,12 +153,12 @@ class Keyring(object):
         verify_requests = []
         handle = preserve_fn(_handle_key_deferred)
 
-        def process(server_name, json_object):
+        def process(server_name, json_object, validity_time):
             """Process an entry in the request list
 
-            Given a (server_name, json_object) pair from the request list,
-            adds a key request to verify_requests, and returns a deferred which will
-            complete or fail (in the sentinel context) when verification completes.
+            Given a (server_name, json_object, validity_time) triplet from the request
+            list, adds a key request to verify_requests, and returns a deferred which
+            will complete or fail (in the sentinel context) when verification completes.
             """
             key_ids = signature_ids(json_object, server_name)
 
@@ -148,7 +173,7 @@ class Keyring(object):
 
             # add the key request to the queue, but don't start it off yet.
             verify_request = VerifyKeyRequest(
-                server_name, key_ids, json_object, defer.Deferred()
+                server_name, key_ids, json_object, validity_time
             )
             verify_requests.append(verify_request)
 
@@ -160,8 +185,8 @@ class Keyring(object):
             return handle(verify_request)
 
         results = [
-            process(server_name, json_object)
-            for server_name, json_object in server_and_json
+            process(server_name, json_object, validity_time)
+            for server_name, json_object, validity_time in server_and_json
         ]
 
         if verify_requests:
@@ -298,8 +323,12 @@ class Keyring(object):
                         verify_request.deferred.errback(
                             SynapseError(
                                 401,
-                                "No key for %s with id %s"
-                                % (verify_request.server_name, verify_request.key_ids),
+                                "No key for %s with ids in %s (min_validity %i)"
+                                % (
+                                    verify_request.server_name,
+                                    verify_request.key_ids,
+                                    verify_request.minimum_valid_until_ts,
+                                ),
                                 Codes.UNAUTHORIZED,
                             )
                         )
@@ -323,18 +352,28 @@ class Keyring(object):
         Args:
             fetcher (KeyFetcher): fetcher to use to fetch the keys
             remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
-                Any successfully-completed requests will be reomved from the list.
+                Any successfully-completed requests will be removed from the list.
         """
-        # dict[str, set(str)]: keys to fetch for each server
-        missing_keys = {}
+        # dict[str, dict[str, int]]: keys to fetch.
+        # server_name -> key_id -> min_valid_ts
+        missing_keys = defaultdict(dict)
+
         for verify_request in remaining_requests:
             # any completed requests should already have been removed
             assert not verify_request.deferred.called
-            missing_keys.setdefault(verify_request.server_name, set()).update(
-                verify_request.key_ids
-            )
+            keys_for_server = missing_keys[verify_request.server_name]
 
-        results = yield fetcher.get_keys(missing_keys.items())
+            for key_id in verify_request.key_ids:
+                # If we have several requests for the same key, then we only need to
+                # request that key once, but we should do so with the greatest
+                # min_valid_until_ts of the requests, so that we can satisfy all of
+                # the requests.
+                keys_for_server[key_id] = max(
+                    keys_for_server.get(key_id, -1),
+                    verify_request.minimum_valid_until_ts
+                )
+
+        results = yield fetcher.get_keys(missing_keys)
 
         completed = list()
         for verify_request in remaining_requests:
@@ -344,25 +383,34 @@ class Keyring(object):
             # complete this VerifyKeyRequest.
             result_keys = results.get(server_name, {})
             for key_id in verify_request.key_ids:
-                key = result_keys.get(key_id)
-                if key:
-                    with PreserveLoggingContext():
-                        verify_request.deferred.callback(
-                            (server_name, key_id, key.verify_key)
-                        )
-                    completed.append(verify_request)
-                    break
+                fetch_key_result = result_keys.get(key_id)
+                if not fetch_key_result:
+                    # we didn't get a result for this key
+                    continue
+
+                if (
+                    fetch_key_result.valid_until_ts
+                    < verify_request.minimum_valid_until_ts
+                ):
+                    # key was not valid at this point
+                    continue
+
+                with PreserveLoggingContext():
+                    verify_request.deferred.callback(
+                        (server_name, key_id, fetch_key_result.verify_key)
+                    )
+                completed.append(verify_request)
+                break
 
         remaining_requests.difference_update(completed)
 
 
 class KeyFetcher(object):
-    def get_keys(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
         """
         Args:
-            server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
-                list of (server_name, iterable[key_id]) tuples to fetch keys for
-                Note that the iterables may be iterated more than once.
+            keys_to_fetch (dict[str, dict[str, int]]):
+                the keys to be fetched. server_name -> key_id -> min_valid_ts
 
         Returns:
             Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
@@ -378,13 +426,15 @@ class StoreKeyFetcher(KeyFetcher):
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
-    def get_keys(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
         """see KeyFetcher.get_keys"""
+
         keys_to_fetch = (
             (server_name, key_id)
-            for server_name, key_ids in server_name_and_key_ids
-            for key_id in key_ids
+            for server_name, keys_for_server in keys_to_fetch.items()
+            for key_id in keys_for_server.keys()
         )
+
         res = yield self.store.get_server_verify_keys(keys_to_fetch)
         keys = {}
         for (server_name, key_id), key in res.items():
@@ -508,14 +558,14 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         self.perspective_servers = self.config.perspectives
 
     @defer.inlineCallbacks
-    def get_keys(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
         """see KeyFetcher.get_keys"""
 
         @defer.inlineCallbacks
         def get_key(perspective_name, perspective_keys):
             try:
                 result = yield self.get_server_verify_key_v2_indirect(
-                    server_name_and_key_ids, perspective_name, perspective_keys
+                    keys_to_fetch, perspective_name, perspective_keys
                 )
                 defer.returnValue(result)
             except KeyLookupError as e:
@@ -549,13 +599,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_indirect(
-        self, server_names_and_key_ids, perspective_name, perspective_keys
+        self, keys_to_fetch, perspective_name, perspective_keys
     ):
         """
         Args:
-            server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
-                list of (server_name, iterable[key_id]) tuples to fetch keys for
+            keys_to_fetch (dict[str, dict[str, int]]):
+                the keys to be fetched. server_name -> key_id -> min_valid_ts
+
             perspective_name (str): name of the notary server to query for the keys
+
             perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
                 notary server
 
@@ -569,12 +621,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         """
         logger.info(
             "Requesting keys %s from notary server %s",
-            server_names_and_key_ids,
+            keys_to_fetch.items(),
             perspective_name,
         )
-        # TODO(mark): Set the minimum_valid_until_ts to that needed by
-        # the events being validated or the current time if validating
-        # an incoming request.
+
         try:
             query_response = yield self.client.post_json(
                 destination=perspective_name,
@@ -582,9 +632,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 data={
                     u"server_keys": {
                         server_name: {
-                            key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
+                            key_id: {u"minimum_valid_until_ts": min_valid_ts}
+                            for key_id, min_valid_ts in server_keys.items()
                         }
-                        for server_name, key_ids in server_names_and_key_ids
+                        for server_name, server_keys in keys_to_fetch.items()
                     }
                 },
                 long_retries=True,
@@ -694,15 +745,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         self.client = hs.get_http_client()
 
     @defer.inlineCallbacks
-    def get_keys(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
         """see KeyFetcher.get_keys"""
+        # TODO make this more resilient
         results = yield logcontext.make_deferred_yieldable(
             defer.gatherResults(
                 [
                     run_in_background(
-                        self.get_server_verify_key_v2_direct, server_name, key_ids
+                        self.get_server_verify_key_v2_direct,
+                        server_name,
+                        server_keys.keys(),
                     )
-                    for server_name, key_ids in server_name_and_key_ids
+                    for server_name, server_keys in keys_to_fetch.items()
                 ],
                 consumeErrors=True,
             ).addErrback(unwrapFirstError)
@@ -721,6 +775,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         keys = {}  # type: dict[str, FetchKeyResult]
 
         for requested_key_id in key_ids:
+            # we may have found this key as a side-effect of asking for another.
             if requested_key_id in keys:
                 continue
 
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index cffa831d80..4b38f7c759 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -265,7 +265,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
     ]
 
     more_deferreds = keyring.verify_json_objects_for_server([
-        (p.sender_domain, p.redacted_pdu_json)
+        (p.sender_domain, p.redacted_pdu_json, 0)
         for p in pdus_to_check_sender
     ])
 
@@ -298,7 +298,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
         ]
 
         more_deferreds = keyring.verify_json_objects_for_server([
-            (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
+            (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
             for p in pdus_to_check_event_id
         ])
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index d0efc4e0d3..0db8858cf1 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):
 
 class Authenticator(object):
     def __init__(self, hs):
+        self._clock = hs.get_clock()
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
@@ -102,6 +103,7 @@ class Authenticator(object):
     # A method just so we can pass 'self' as the authenticator to the Servlets
     @defer.inlineCallbacks
     def authenticate_request(self, request, content):
+        now = self._clock.time_msec()
         json_request = {
             "method": request.method.decode('ascii'),
             "uri": request.uri.decode('ascii'),
@@ -138,7 +140,7 @@ class Authenticator(object):
                 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
             )
 
-        yield self.keyring.verify_json_for_server(origin, json_request)
+        yield self.keyring.verify_json_for_server(origin, json_request, now)
 
         logger.info("Request from %s", origin)
         request.authenticated_entity = origin
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 786149be65..fa6b641ee1 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -97,10 +97,11 @@ class GroupAttestationSigning(object):
 
         # TODO: We also want to check that *new* attestations that people give
         # us to store are valid for at least a little while.
-        if valid_until_ms < self.clock.time_msec():
+        now = self.clock.time_msec()
+        if valid_until_ms < now:
             raise SynapseError(400, "Attestation expired")
 
-        yield self.keyring.verify_json_for_server(server_name, attestation)
+        yield self.keyring.verify_json_for_server(server_name, attestation, now)
 
     def create_attestation(self, group_id, user_id):
         """Create an attestation for the group_id and user_id with default
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 3933ad4347..096401938d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,6 +19,7 @@ from mock import Mock
 import canonicaljson
 import signedjson.key
 import signedjson.sign
+from signedjson.key import get_verify_key
 
 from twisted.internet import defer
 
@@ -137,7 +138,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 context_11.request = "11"
 
                 res_deferreds = kr.verify_json_objects_for_server(
-                    [("server10", json1), ("server11", {})]
+                    [("server10", json1, 0), ("server11", {}, 0)]
                 )
 
                 # the unsigned json should be rejected pretty quickly
@@ -174,7 +175,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 self.http_client.post_json.return_value = defer.Deferred()
 
                 res_deferreds_2 = kr.verify_json_objects_for_server(
-                    [("server10", json1)]
+                    [("server10", json1, 0)]
                 )
                 res_deferreds_2[0].addBoth(self.check_context, None)
                 yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
@@ -197,31 +198,108 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         kr = keyring.Keyring(self.hs)
 
         key1 = signedjson.key.generate_signing_key(1)
-        key1_id = "%s:%s" % (key1.alg, key1.version)
-
         r = self.hs.datastore.store_server_verify_keys(
             "server9",
             time.time() * 1000,
-            [
-                (
-                    "server9",
-                    key1_id,
-                    FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
-                ),
-            ],
+            [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
         )
         self.get_success(r)
+
         json1 = {}
         signedjson.sign.sign_json(json1, "server9", key1)
 
         # should fail immediately on an unsigned object
-        d = _verify_json_for_server(kr, "server9", {})
+        d = _verify_json_for_server(kr, "server9", {}, 0)
         self.failureResultOf(d, SynapseError)
 
-        d = _verify_json_for_server(kr, "server9", json1)
-        self.assertFalse(d.called)
+        # should suceed on a signed object
+        d = _verify_json_for_server(kr, "server9", json1, 500)
+        # self.assertFalse(d.called)
         self.get_success(d)
 
+    def test_verify_json_dedupes_key_requests(self):
+        """Two requests for the same key should be deduped."""
+        key1 = signedjson.key.generate_signing_key(1)
+
+        def get_keys(keys_to_fetch):
+            # there should only be one request object (with the max validity)
+            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+
+            return defer.succeed(
+                {
+                    "server1": {
+                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+                    }
+                }
+            )
+
+        mock_fetcher = keyring.KeyFetcher()
+        mock_fetcher.get_keys = Mock(side_effect=get_keys)
+        kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
+
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server1", key1)
+
+        # the first request should succeed; the second should fail because the key
+        # has expired
+        results = kr.verify_json_objects_for_server(
+            [("server1", json1, 500), ("server1", json1, 1500)]
+        )
+        self.assertEqual(len(results), 2)
+        self.get_success(results[0])
+        e = self.get_failure(results[1], SynapseError).value
+        self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+        self.assertEqual(e.code, 401)
+
+        # there should have been a single call to the fetcher
+        mock_fetcher.get_keys.assert_called_once()
+
+    def test_verify_json_falls_back_to_other_fetchers(self):
+        """If the first fetcher cannot provide a recent enough key, we fall back"""
+        key1 = signedjson.key.generate_signing_key(1)
+
+        def get_keys1(keys_to_fetch):
+            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+            return defer.succeed(
+                {
+                    "server1": {
+                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
+                    }
+                }
+            )
+
+        def get_keys2(keys_to_fetch):
+            self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+            return defer.succeed(
+                {
+                    "server1": {
+                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+                    }
+                }
+            )
+
+        mock_fetcher1 = keyring.KeyFetcher()
+        mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
+        mock_fetcher2 = keyring.KeyFetcher()
+        mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
+        kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
+
+        json1 = {}
+        signedjson.sign.sign_json(json1, "server1", key1)
+
+        results = kr.verify_json_objects_for_server(
+            [("server1", json1, 1200), ("server1", json1, 1500)]
+        )
+        self.assertEqual(len(results), 2)
+        self.get_success(results[0])
+        e = self.get_failure(results[1], SynapseError).value
+        self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+        self.assertEqual(e.code, 401)
+
+        # there should have been a single call to each fetcher
+        mock_fetcher1.get_keys.assert_called_once()
+        mock_fetcher2.get_keys.assert_called_once()
+
 
 class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
     def make_homeserver(self, reactor, clock):
@@ -260,8 +338,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.http_client.get_json.side_effect = get_json
 
-        server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
-        keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
+        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
         k = keys[SERVER_NAME][testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
         self.assertEqual(k.verify_key, testverifykey)
@@ -288,9 +366,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         # change the server name: it should cause a rejection
         response["server_name"] = "OTHER_SERVER"
-        self.get_failure(
-            fetcher.get_keys(server_name_and_key_ids), KeyLookupError
-        )
+        self.get_failure(fetcher.get_keys(keys_to_fetch), KeyLookupError)
 
 
 class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
@@ -342,8 +418,8 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         self.http_client.post_json.side_effect = post_json
 
-        server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
-        keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
+        keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+        keys = self.get_success(fetcher.get_keys(keys_to_fetch))
         self.assertIn(SERVER_NAME, keys)
         k = keys[SERVER_NAME][testverifykey_id]
         self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
@@ -401,7 +477,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
         def get_key_from_perspectives(response):
             fetcher = PerspectivesKeyFetcher(self.hs)
-            server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
+            keys_to_fetch = {SERVER_NAME: {"key1": 0}}
 
             def post_json(destination, path, data, **kwargs):
                 self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -410,9 +486,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
 
             self.http_client.post_json.side_effect = post_json
 
-            return self.get_success(
-                fetcher.get_keys(server_name_and_key_ids)
-            )
+            return self.get_success(fetcher.get_keys(keys_to_fetch))
 
         # start with a valid response so we can check we are testing the right thing
         response = build_response()
@@ -435,6 +509,11 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
 
 
+def get_key_id(key):
+    """Get the matrix ID tag for a given SigningKey or VerifyKey"""
+    return "%s:%s" % (key.alg, key.version)
+
+
 @defer.inlineCallbacks
 def run_in_context(f, *args, **kwargs):
     with LoggingContext("testctx") as ctx:
@@ -445,14 +524,16 @@ def run_in_context(f, *args, **kwargs):
     defer.returnValue(rv)
 
 
-def _verify_json_for_server(keyring, server_name, json_object):
+def _verify_json_for_server(keyring, server_name, json_object, validity_time):
     """thin wrapper around verify_json_for_server which makes sure it is wrapped
     with the patched defer.inlineCallbacks.
     """
 
     @defer.inlineCallbacks
     def v():
-        rv1 = yield keyring.verify_json_for_server(server_name, json_object)
+        rv1 = yield keyring.verify_json_for_server(
+            server_name, json_object, validity_time
+        )
         defer.returnValue(rv1)
 
     return run_in_context(v)