summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/crypto/keyring.py239
-rw-r--r--synapse/federation/federation_base.py4
-rw-r--r--synapse/federation/transport/server.py4
-rw-r--r--synapse/groups/attestations.py5
-rw-r--r--synapse/http/matrixfederationclient.py71
-rw-r--r--synapse/rest/client/v1/voip.py1
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py12
-rw-r--r--synapse/util/retryutils.py60
8 files changed, 267 insertions, 129 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index b2f4cea536..0fd15287e7 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,6 +15,7 @@
 # limitations under the License.
 
 import logging
+from collections import defaultdict
 
 import six
 from six import raise_from
@@ -45,6 +46,7 @@ from synapse.api.errors import (
 )
 from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
 from synapse.util.logcontext import (
     LoggingContext,
     PreserveLoggingContext,
@@ -70,6 +72,9 @@ class VerifyKeyRequest(object):
 
         json_object(dict): The JSON object to verify.
 
+        minimum_valid_until_ts (int): time at which we require the signing key to
+            be valid. (0 implies we don't care)
+
         deferred(Deferred[str, str, nacl.signing.VerifyKey]):
             A deferred (server_name, key_id, verify_key) tuple that resolves when
             a verify key has been fetched. The deferreds' callbacks are run with no
@@ -82,7 +87,8 @@ class VerifyKeyRequest(object):
     server_name = attr.ib()
     key_ids = attr.ib()
     json_object = attr.ib()
-    deferred = attr.ib()
+    minimum_valid_until_ts = attr.ib()
+    deferred = attr.ib(default=attr.Factory(defer.Deferred))
 
 
 class KeyLookupError(ValueError):
@@ -90,14 +96,16 @@ class KeyLookupError(ValueError):
 
 
 class Keyring(object):
-    def __init__(self, hs):
+    def __init__(self, hs, key_fetchers=None):
         self.clock = hs.get_clock()
 
-        self._key_fetchers = (
-            StoreKeyFetcher(hs),
-            PerspectivesKeyFetcher(hs),
-            ServerKeyFetcher(hs),
-        )
+        if key_fetchers is None:
+            key_fetchers = (
+                StoreKeyFetcher(hs),
+                PerspectivesKeyFetcher(hs),
+                ServerKeyFetcher(hs),
+            )
+        self._key_fetchers = key_fetchers
 
         # map from server name to Deferred. Has an entry for each server with
         # an ongoing key download; the Deferred completes once the download
@@ -106,9 +114,25 @@ class Keyring(object):
         # These are regular, logcontext-agnostic Deferreds.
         self.key_downloads = {}
 
-    def verify_json_for_server(self, server_name, json_object):
+    def verify_json_for_server(self, server_name, json_object, validity_time):
+        """Verify that a JSON object has been signed by a given server
+
+        Args:
+            server_name (str): name of the server which must have signed this object
+
+            json_object (dict): object to be checked
+
+            validity_time (int): timestamp at which we require the signing key to
+                be valid. (0 implies we don't care)
+
+        Returns:
+            Deferred[None]: completes if the the object was correctly signed, otherwise
+                errbacks with an error
+        """
+        req = server_name, json_object, validity_time
+
         return logcontext.make_deferred_yieldable(
-            self.verify_json_objects_for_server([(server_name, json_object)])[0]
+            self.verify_json_objects_for_server((req,))[0]
         )
 
     def verify_json_objects_for_server(self, server_and_json):
@@ -116,10 +140,12 @@ class Keyring(object):
         necessary.
 
         Args:
-            server_and_json (list): List of pairs of (server_name, json_object)
+            server_and_json (iterable[Tuple[str, dict, int]):
+                Iterable of triplets of (server_name, json_object, validity_time)
+                validity_time is a timestamp at which the signing key must be valid.
 
         Returns:
-            List<Deferred>: for each input pair, a deferred indicating success
+            List<Deferred[None]>: for each input triplet, a deferred indicating success
                 or failure to verify each json object's signature for the given
                 server_name. The deferreds run their callbacks in the sentinel
                 logcontext.
@@ -128,12 +154,12 @@ class Keyring(object):
         verify_requests = []
         handle = preserve_fn(_handle_key_deferred)
 
-        def process(server_name, json_object):
+        def process(server_name, json_object, validity_time):
             """Process an entry in the request list
 
-            Given a (server_name, json_object) pair from the request list,
-            adds a key request to verify_requests, and returns a deferred which will
-            complete or fail (in the sentinel context) when verification completes.
+            Given a (server_name, json_object, validity_time) triplet from the request
+            list, adds a key request to verify_requests, and returns a deferred which
+            will complete or fail (in the sentinel context) when verification completes.
             """
             key_ids = signature_ids(json_object, server_name)
 
@@ -144,11 +170,16 @@ class Keyring(object):
                     )
                 )
 
-            logger.debug("Verifying for %s with key_ids %s", server_name, key_ids)
+            logger.debug(
+                "Verifying for %s with key_ids %s, min_validity %i",
+                server_name,
+                key_ids,
+                validity_time,
+            )
 
             # add the key request to the queue, but don't start it off yet.
             verify_request = VerifyKeyRequest(
-                server_name, key_ids, json_object, defer.Deferred()
+                server_name, key_ids, json_object, validity_time
             )
             verify_requests.append(verify_request)
 
@@ -160,8 +191,8 @@ class Keyring(object):
             return handle(verify_request)
 
         results = [
-            process(server_name, json_object)
-            for server_name, json_object in server_and_json
+            process(server_name, json_object, validity_time)
+            for server_name, json_object, validity_time in server_and_json
         ]
 
         if verify_requests:
@@ -298,8 +329,12 @@ class Keyring(object):
                         verify_request.deferred.errback(
                             SynapseError(
                                 401,
-                                "No key for %s with id %s"
-                                % (verify_request.server_name, verify_request.key_ids),
+                                "No key for %s with ids in %s (min_validity %i)"
+                                % (
+                                    verify_request.server_name,
+                                    verify_request.key_ids,
+                                    verify_request.minimum_valid_until_ts,
+                                ),
                                 Codes.UNAUTHORIZED,
                             )
                         )
@@ -323,18 +358,28 @@ class Keyring(object):
         Args:
             fetcher (KeyFetcher): fetcher to use to fetch the keys
             remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
-                Any successfully-completed requests will be reomved from the list.
+                Any successfully-completed requests will be removed from the list.
         """
-        # dict[str, set(str)]: keys to fetch for each server
-        missing_keys = {}
+        # dict[str, dict[str, int]]: keys to fetch.
+        # server_name -> key_id -> min_valid_ts
+        missing_keys = defaultdict(dict)
+
         for verify_request in remaining_requests:
             # any completed requests should already have been removed
             assert not verify_request.deferred.called
-            missing_keys.setdefault(verify_request.server_name, set()).update(
-                verify_request.key_ids
-            )
+            keys_for_server = missing_keys[verify_request.server_name]
 
-        results = yield fetcher.get_keys(missing_keys.items())
+            for key_id in verify_request.key_ids:
+                # If we have several requests for the same key, then we only need to
+                # request that key once, but we should do so with the greatest
+                # min_valid_until_ts of the requests, so that we can satisfy all of
+                # the requests.
+                keys_for_server[key_id] = max(
+                    keys_for_server.get(key_id, -1),
+                    verify_request.minimum_valid_until_ts
+                )
+
+        results = yield fetcher.get_keys(missing_keys)
 
         completed = list()
         for verify_request in remaining_requests:
@@ -344,25 +389,34 @@ class Keyring(object):
             # complete this VerifyKeyRequest.
             result_keys = results.get(server_name, {})
             for key_id in verify_request.key_ids:
-                key = result_keys.get(key_id)
-                if key:
-                    with PreserveLoggingContext():
-                        verify_request.deferred.callback(
-                            (server_name, key_id, key.verify_key)
-                        )
-                    completed.append(verify_request)
-                    break
+                fetch_key_result = result_keys.get(key_id)
+                if not fetch_key_result:
+                    # we didn't get a result for this key
+                    continue
+
+                if (
+                    fetch_key_result.valid_until_ts
+                    < verify_request.minimum_valid_until_ts
+                ):
+                    # key was not valid at this point
+                    continue
+
+                with PreserveLoggingContext():
+                    verify_request.deferred.callback(
+                        (server_name, key_id, fetch_key_result.verify_key)
+                    )
+                completed.append(verify_request)
+                break
 
         remaining_requests.difference_update(completed)
 
 
 class KeyFetcher(object):
-    def get_keys(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
         """
         Args:
-            server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
-                list of (server_name, iterable[key_id]) tuples to fetch keys for
-                Note that the iterables may be iterated more than once.
+            keys_to_fetch (dict[str, dict[str, int]]):
+                the keys to be fetched. server_name -> key_id -> min_valid_ts
 
         Returns:
             Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
@@ -378,13 +432,15 @@ class StoreKeyFetcher(KeyFetcher):
         self.store = hs.get_datastore()
 
     @defer.inlineCallbacks
-    def get_keys(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
         """see KeyFetcher.get_keys"""
+
         keys_to_fetch = (
             (server_name, key_id)
-            for server_name, key_ids in server_name_and_key_ids
-            for key_id in key_ids
+            for server_name, keys_for_server in keys_to_fetch.items()
+            for key_id in keys_for_server.keys()
         )
+
         res = yield self.store.get_server_verify_keys(keys_to_fetch)
         keys = {}
         for (server_name, key_id), key in res.items():
@@ -508,14 +564,14 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         self.perspective_servers = self.config.perspectives
 
     @defer.inlineCallbacks
-    def get_keys(self, server_name_and_key_ids):
+    def get_keys(self, keys_to_fetch):
         """see KeyFetcher.get_keys"""
 
         @defer.inlineCallbacks
         def get_key(perspective_name, perspective_keys):
             try:
                 result = yield self.get_server_verify_key_v2_indirect(
-                    server_name_and_key_ids, perspective_name, perspective_keys
+                    keys_to_fetch, perspective_name, perspective_keys
                 )
                 defer.returnValue(result)
             except KeyLookupError as e:
@@ -549,13 +605,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_indirect(
-        self, server_names_and_key_ids, perspective_name, perspective_keys
+        self, keys_to_fetch, perspective_name, perspective_keys
     ):
         """
         Args:
-            server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
-                list of (server_name, iterable[key_id]) tuples to fetch keys for
+            keys_to_fetch (dict[str, dict[str, int]]):
+                the keys to be fetched. server_name -> key_id -> min_valid_ts
+
             perspective_name (str): name of the notary server to query for the keys
+
             perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
                 notary server
 
@@ -569,12 +627,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         """
         logger.info(
             "Requesting keys %s from notary server %s",
-            server_names_and_key_ids,
+            keys_to_fetch.items(),
             perspective_name,
         )
-        # TODO(mark): Set the minimum_valid_until_ts to that needed by
-        # the events being validated or the current time if validating
-        # an incoming request.
+
         try:
             query_response = yield self.client.post_json(
                 destination=perspective_name,
@@ -582,12 +638,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 data={
                     u"server_keys": {
                         server_name: {
-                            key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
+                            key_id: {u"minimum_valid_until_ts": min_valid_ts}
+                            for key_id, min_valid_ts in server_keys.items()
                         }
-                        for server_name, key_ids in server_names_and_key_ids
+                        for server_name, server_keys in keys_to_fetch.items()
                     }
                 },
-                long_retries=True,
             )
         except (NotRetryingDestination, RequestSendFailed) as e:
             raise_from(KeyLookupError("Failed to connect to remote server"), e)
@@ -693,34 +749,54 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         self.clock = hs.get_clock()
         self.client = hs.get_http_client()
 
-    @defer.inlineCallbacks
-    def get_keys(self, server_name_and_key_ids):
-        """see KeyFetcher.get_keys"""
-        results = yield logcontext.make_deferred_yieldable(
-            defer.gatherResults(
-                [
-                    run_in_background(
-                        self.get_server_verify_key_v2_direct, server_name, key_ids
-                    )
-                    for server_name, key_ids in server_name_and_key_ids
-                ],
-                consumeErrors=True,
-            ).addErrback(unwrapFirstError)
-        )
+    def get_keys(self, keys_to_fetch):
+        """
+        Args:
+            keys_to_fetch (dict[str, iterable[str]]):
+                the keys to be fetched. server_name -> key_ids
 
-        merged = {}
-        for result in results:
-            merged.update(result)
+        Returns:
+            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+                map from server_name -> key_id -> FetchKeyResult
+        """
+
+        results = {}
 
-        defer.returnValue(
-            {server_name: keys for server_name, keys in merged.items() if keys}
+        @defer.inlineCallbacks
+        def get_key(key_to_fetch_item):
+            server_name, key_ids = key_to_fetch_item
+            try:
+                keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+                results[server_name] = keys
+            except KeyLookupError as e:
+                logger.warning(
+                    "Error looking up keys %s from %s: %s", key_ids, server_name, e
+                )
+            except Exception:
+                logger.exception("Error getting keys %s from %s", key_ids, server_name)
+
+        return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
+            lambda _: results
         )
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_direct(self, server_name, key_ids):
+        """
+
+        Args:
+            server_name (str):
+            key_ids (iterable[str]):
+
+        Returns:
+            Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+
+        Raises:
+            KeyLookupError if there was a problem making the lookup
+        """
         keys = {}  # type: dict[str, FetchKeyResult]
 
         for requested_key_id in key_ids:
+            # we may have found this key as a side-effect of asking for another.
             if requested_key_id in keys:
                 continue
 
@@ -731,6 +807,19 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
                     path="/_matrix/key/v2/server/"
                     + urllib.parse.quote(requested_key_id),
                     ignore_backoff=True,
+
+                    # we only give the remote server 10s to respond. It should be an
+                    # easy request to handle, so if it doesn't reply within 10s, it's
+                    # probably not going to.
+                    #
+                    # Furthermore, when we are acting as a notary server, we cannot
+                    # wait all day for all of the origin servers, as the requesting
+                    # server will otherwise time out before we can respond.
+                    #
+                    # (Note that get_json may make 4 attempts, so this can still take
+                    # almost 45 seconds to fetch the headers, plus up to another 60s to
+                    # read the response).
+                    timeout=10000,
                 )
             except (NotRetryingDestination, RequestSendFailed) as e:
                 raise_from(KeyLookupError("Failed to connect to remote server"), e)
@@ -755,7 +844,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             )
             keys.update(response_keys)
 
-        defer.returnValue({server_name: keys})
+        defer.returnValue(keys)
 
 
 @defer.inlineCallbacks
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index cffa831d80..4b38f7c759 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -265,7 +265,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
     ]
 
     more_deferreds = keyring.verify_json_objects_for_server([
-        (p.sender_domain, p.redacted_pdu_json)
+        (p.sender_domain, p.redacted_pdu_json, 0)
         for p in pdus_to_check_sender
     ])
 
@@ -298,7 +298,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
         ]
 
         more_deferreds = keyring.verify_json_objects_for_server([
-            (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
+            (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
             for p in pdus_to_check_event_id
         ])
 
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index d0efc4e0d3..0db8858cf1 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):
 
 class Authenticator(object):
     def __init__(self, hs):
+        self._clock = hs.get_clock()
         self.keyring = hs.get_keyring()
         self.server_name = hs.hostname
         self.store = hs.get_datastore()
@@ -102,6 +103,7 @@ class Authenticator(object):
     # A method just so we can pass 'self' as the authenticator to the Servlets
     @defer.inlineCallbacks
     def authenticate_request(self, request, content):
+        now = self._clock.time_msec()
         json_request = {
             "method": request.method.decode('ascii'),
             "uri": request.uri.decode('ascii'),
@@ -138,7 +140,7 @@ class Authenticator(object):
                 401, "Missing Authorization headers", Codes.UNAUTHORIZED,
             )
 
-        yield self.keyring.verify_json_for_server(origin, json_request)
+        yield self.keyring.verify_json_for_server(origin, json_request, now)
 
         logger.info("Request from %s", origin)
         request.authenticated_entity = origin
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 786149be65..fa6b641ee1 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -97,10 +97,11 @@ class GroupAttestationSigning(object):
 
         # TODO: We also want to check that *new* attestations that people give
         # us to store are valid for at least a little while.
-        if valid_until_ms < self.clock.time_msec():
+        now = self.clock.time_msec()
+        if valid_until_ms < now:
             raise SynapseError(400, "Attestation expired")
 
-        yield self.keyring.verify_json_for_server(server_name, attestation)
+        yield self.keyring.verify_json_for_server(server_name, attestation, now)
 
     def create_attestation(self, group_id, user_id):
         """Create an attestation for the group_id and user_id with default
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 8197619a78..663ea72a7a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -285,7 +285,24 @@ class MatrixFederationHttpClient(object):
             request (MatrixFederationRequest): details of request to be sent
 
             timeout (int|None): number of milliseconds to wait for the response headers
-                (including connecting to the server). 60s by default.
+                (including connecting to the server), *for each attempt*.
+                60s by default.
+
+            long_retries (bool): whether to use the long retry algorithm.
+
+                The regular retry algorithm makes 4 attempts, with intervals
+                [0.5s, 1s, 2s].
+
+                The long retry algorithm makes 11 attempts, with intervals
+                [4s, 16s, 60s, 60s, ...]
+
+                Both algorithms add -20%/+40% jitter to the retry intervals.
+
+                Note that the above intervals are *in addition* to the time spent
+                waiting for the request to complete (up to `timeout` ms).
+
+                NB: the long retry algorithm takes over 20 minutes to complete, with
+                a default timeout of 60s!
 
             ignore_backoff (bool): true to ignore the historical backoff data
                 and try the request anyway.
@@ -566,10 +583,14 @@ class MatrixFederationHttpClient(object):
                 the request body. This will be encoded as JSON.
             json_data_callback (callable): A callable returning the dict to
                 use as the request body.
-            long_retries (bool): A boolean that indicates whether we should
-                retry for a short or long time.
-            timeout(int): How long to try (in ms) the destination for before
-                giving up. None indicates no timeout.
+
+            long_retries (bool): whether to use the long retry algorithm. See
+                docs on _send_request for details.
+
+            timeout (int|None): number of milliseconds to wait for the response headers
+                (including connecting to the server), *for each attempt*.
+                self._default_timeout (60s) by default.
+
             ignore_backoff (bool): true to ignore the historical backoff data
                 and try the request anyway.
             backoff_on_404 (bool): True if we should count a 404 response as
@@ -627,15 +648,22 @@ class MatrixFederationHttpClient(object):
         Args:
             destination (str): The remote server to send the HTTP request
                 to.
+
             path (str): The HTTP path.
+
             data (dict): A dict containing the data that will be used as
                 the request body. This will be encoded as JSON.
-            long_retries (bool): A boolean that indicates whether we should
-                retry for a short or long time.
-            timeout(int): How long to try (in ms) the destination for before
-                giving up. None indicates no timeout.
+
+            long_retries (bool): whether to use the long retry algorithm. See
+                docs on _send_request for details.
+
+            timeout (int|None): number of milliseconds to wait for the response headers
+                (including connecting to the server), *for each attempt*.
+                self._default_timeout (60s) by default.
+
             ignore_backoff (bool): true to ignore the historical backoff data and
                 try the request anyway.
+
             args (dict): query params
         Returns:
             Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
@@ -686,14 +714,19 @@ class MatrixFederationHttpClient(object):
         Args:
             destination (str): The remote server to send the HTTP request
                 to.
+
             path (str): The HTTP path.
+
             args (dict|None): A dictionary used to create query strings, defaults to
                 None.
-            timeout (int): How long to try (in ms) the destination for before
-                giving up. None indicates no timeout and that the request will
-                be retried.
+
+            timeout (int|None): number of milliseconds to wait for the response headers
+                (including connecting to the server), *for each attempt*.
+                self._default_timeout (60s) by default.
+
             ignore_backoff (bool): true to ignore the historical backoff data
                 and try the request anyway.
+
             try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
                 response we should try appending a trailing slash to the end of
                 the request. Workaround for #3622 in Synapse <= v0.99.3.
@@ -742,12 +775,18 @@ class MatrixFederationHttpClient(object):
             destination (str): The remote server to send the HTTP request
                 to.
             path (str): The HTTP path.
-            long_retries (bool): A boolean that indicates whether we should
-                retry for a short or long time.
-            timeout(int): How long to try (in ms) the destination for before
-                giving up. None indicates no timeout.
+
+            long_retries (bool): whether to use the long retry algorithm. See
+                docs on _send_request for details.
+
+            timeout (int|None): number of milliseconds to wait for the response headers
+                (including connecting to the server), *for each attempt*.
+                self._default_timeout (60s) by default.
+
             ignore_backoff (bool): true to ignore the historical backoff data and
                 try the request anyway.
+
+            args (dict): query params
         Returns:
             Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
             result will be the decoded JSON body.
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 0975df84cf..6381049210 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -29,6 +29,7 @@ class VoipRestServlet(RestServlet):
     def __init__(self, hs):
         super(VoipRestServlet, self).__init__()
         self.hs = hs
+        self.auth = hs.get_auth()
 
     @defer.inlineCallbacks
     def on_GET(self, request):
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 21c3c807b9..8a730bbc35 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -20,7 +20,7 @@ from twisted.web.resource import Resource
 from twisted.web.server import NOT_DONE_YET
 
 from synapse.api.errors import Codes, SynapseError
-from synapse.crypto.keyring import KeyLookupError, ServerKeyFetcher
+from synapse.crypto.keyring import ServerKeyFetcher
 from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
 
@@ -215,15 +215,7 @@ class RemoteKey(Resource):
                     json_results.add(bytes(result["key_json"]))
 
         if cache_misses and query_remote_on_cache_miss:
-            for server_name, key_ids in cache_misses.items():
-                try:
-                    yield self.fetcher.get_server_verify_key_v2_direct(
-                        server_name, key_ids
-                    )
-                except KeyLookupError as e:
-                    logger.info("Failed to fetch key: %s", e)
-                except Exception:
-                    logger.exception("Failed to get key for %r", server_name)
+            yield self.fetcher.get_keys(cache_misses)
             yield self.query_keys(
                 request, query, query_remote_on_cache_miss=False
             )
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 26cce7d197..f6dfa77d8f 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -46,8 +46,7 @@ class NotRetryingDestination(Exception):
 
 
 @defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, ignore_backoff=False,
-                      **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
     """For a given destination check if we have previously failed to
     send a request there and are waiting before retrying the destination.
     If we are not ready to retry the destination, this will raise a
@@ -60,8 +59,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
         clock (synapse.util.clock): timing source
         store (synapse.storage.transactions.TransactionStore): datastore
         ignore_backoff (bool): true to ignore the historical backoff data and
-            try the request anyway. We will still update the next
-            retry_interval on success/failure.
+            try the request anyway. We will still reset the retry_interval on success.
 
     Example usage:
 
@@ -75,13 +73,12 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
     """
     retry_last_ts, retry_interval = (0, 0)
 
-    retry_timings = yield store.get_destination_retry_timings(
-        destination
-    )
+    retry_timings = yield store.get_destination_retry_timings(destination)
 
     if retry_timings:
         retry_last_ts, retry_interval = (
-            retry_timings["retry_last_ts"], retry_timings["retry_interval"]
+            retry_timings["retry_last_ts"],
+            retry_timings["retry_interval"],
         )
 
         now = int(clock.time_msec())
@@ -93,22 +90,31 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
                 destination=destination,
             )
 
+    # if we are ignoring the backoff data, we should also not increment the backoff
+    # when we get another failure - otherwise a server can very quickly reach the
+    # maximum backoff even though it might only have been down briefly
+    backoff_on_failure = not ignore_backoff
+
     defer.returnValue(
         RetryDestinationLimiter(
-            destination,
-            clock,
-            store,
-            retry_interval,
-            **kwargs
+            destination, clock, store, retry_interval, backoff_on_failure, **kwargs
         )
     )
 
 
 class RetryDestinationLimiter(object):
-    def __init__(self, destination, clock, store, retry_interval,
-                 min_retry_interval=10 * 60 * 1000,
-                 max_retry_interval=24 * 60 * 60 * 1000,
-                 multiplier_retry_interval=5, backoff_on_404=False):
+    def __init__(
+        self,
+        destination,
+        clock,
+        store,
+        retry_interval,
+        min_retry_interval=10 * 60 * 1000,
+        max_retry_interval=24 * 60 * 60 * 1000,
+        multiplier_retry_interval=5,
+        backoff_on_404=False,
+        backoff_on_failure=True,
+    ):
         """Marks the destination as "down" if an exception is thrown in the
         context, except for CodeMessageException with code < 500.
 
@@ -128,6 +134,9 @@ class RetryDestinationLimiter(object):
             multiplier_retry_interval (int): The multiplier to use to increase
                 the retry interval after a failed request.
             backoff_on_404 (bool): Back off if we get a 404
+
+            backoff_on_failure (bool): set to False if we should not increase the
+                retry interval on a failure.
         """
         self.clock = clock
         self.store = store
@@ -138,6 +147,7 @@ class RetryDestinationLimiter(object):
         self.max_retry_interval = max_retry_interval
         self.multiplier_retry_interval = multiplier_retry_interval
         self.backoff_on_404 = backoff_on_404
+        self.backoff_on_failure = backoff_on_failure
 
     def __enter__(self):
         pass
@@ -173,10 +183,13 @@ class RetryDestinationLimiter(object):
             if not self.retry_interval:
                 return
 
-            logger.debug("Connection to %s was successful; clearing backoff",
-                         self.destination)
+            logger.debug(
+                "Connection to %s was successful; clearing backoff", self.destination
+            )
             retry_last_ts = 0
             self.retry_interval = 0
+        elif not self.backoff_on_failure:
+            return
         else:
             # We couldn't connect.
             if self.retry_interval:
@@ -190,7 +203,10 @@ class RetryDestinationLimiter(object):
 
             logger.info(
                 "Connection to %s was unsuccessful (%s(%s)); backoff now %i",
-                self.destination, exc_type, exc_val, self.retry_interval
+                self.destination,
+                exc_type,
+                exc_val,
+                self.retry_interval,
             )
             retry_last_ts = int(self.clock.time_msec())
 
@@ -201,9 +217,7 @@ class RetryDestinationLimiter(object):
                     self.destination, retry_last_ts, self.retry_interval
                 )
             except Exception:
-                logger.exception(
-                    "Failed to store destination_retry_timings",
-                )
+                logger.exception("Failed to store destination_retry_timings")
 
         # we deliberately do this in the background.
         synapse.util.logcontext.run_in_background(store_retry_timings)