diff --git a/changelog.d/5321.bugfix b/changelog.d/5321.bugfix
new file mode 100644
index 0000000000..943a61956d
--- /dev/null
+++ b/changelog.d/5321.bugfix
@@ -0,0 +1 @@
+Ensure that we have an up-to-date copy of the signing key when validating incoming federation requests.
diff --git a/changelog.d/5328.misc b/changelog.d/5328.misc
new file mode 100644
index 0000000000..e1b9dc58a3
--- /dev/null
+++ b/changelog.d/5328.misc
@@ -0,0 +1 @@
+The base classes for the v1 and v2_alpha REST APIs have been unified.
diff --git a/changelog.d/5332.misc b/changelog.d/5332.misc
new file mode 100644
index 0000000000..dcfac4eac9
--- /dev/null
+++ b/changelog.d/5332.misc
@@ -0,0 +1 @@
+Improve docstrings on MatrixFederationClient.
diff --git a/changelog.d/5333.bugfix b/changelog.d/5333.bugfix
new file mode 100644
index 0000000000..cb05a6dd63
--- /dev/null
+++ b/changelog.d/5333.bugfix
@@ -0,0 +1 @@
+Fix various problems which made the signing-key notary server time out for some requests.
\ No newline at end of file
diff --git a/changelog.d/5334.bugfix b/changelog.d/5334.bugfix
new file mode 100644
index 0000000000..ed141e0918
--- /dev/null
+++ b/changelog.d/5334.bugfix
@@ -0,0 +1 @@
+Fix bug which would make certain operations (such as room joins) block for 20 minutes while attemoting to fetch verification keys.
diff --git a/changelog.d/5335.bugfix b/changelog.d/5335.bugfix
new file mode 100644
index 0000000000..7318cbe35e
--- /dev/null
+++ b/changelog.d/5335.bugfix
@@ -0,0 +1 @@
+Fix a bug where we could rapidly mark a server as unreachable even though it was only down for a few minutes.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index b2f4cea536..0fd15287e7 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
+from collections import defaultdict
import six
from six import raise_from
@@ -45,6 +46,7 @@ from synapse.api.errors import (
)
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext, unwrapFirstError
+from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.logcontext import (
LoggingContext,
PreserveLoggingContext,
@@ -70,6 +72,9 @@ class VerifyKeyRequest(object):
json_object(dict): The JSON object to verify.
+ minimum_valid_until_ts (int): time at which we require the signing key to
+ be valid. (0 implies we don't care)
+
deferred(Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
@@ -82,7 +87,8 @@ class VerifyKeyRequest(object):
server_name = attr.ib()
key_ids = attr.ib()
json_object = attr.ib()
- deferred = attr.ib()
+ minimum_valid_until_ts = attr.ib()
+ deferred = attr.ib(default=attr.Factory(defer.Deferred))
class KeyLookupError(ValueError):
@@ -90,14 +96,16 @@ class KeyLookupError(ValueError):
class Keyring(object):
- def __init__(self, hs):
+ def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock()
- self._key_fetchers = (
- StoreKeyFetcher(hs),
- PerspectivesKeyFetcher(hs),
- ServerKeyFetcher(hs),
- )
+ if key_fetchers is None:
+ key_fetchers = (
+ StoreKeyFetcher(hs),
+ PerspectivesKeyFetcher(hs),
+ ServerKeyFetcher(hs),
+ )
+ self._key_fetchers = key_fetchers
# map from server name to Deferred. Has an entry for each server with
# an ongoing key download; the Deferred completes once the download
@@ -106,9 +114,25 @@ class Keyring(object):
# These are regular, logcontext-agnostic Deferreds.
self.key_downloads = {}
- def verify_json_for_server(self, server_name, json_object):
+ def verify_json_for_server(self, server_name, json_object, validity_time):
+ """Verify that a JSON object has been signed by a given server
+
+ Args:
+ server_name (str): name of the server which must have signed this object
+
+ json_object (dict): object to be checked
+
+ validity_time (int): timestamp at which we require the signing key to
+ be valid. (0 implies we don't care)
+
+ Returns:
+ Deferred[None]: completes if the the object was correctly signed, otherwise
+ errbacks with an error
+ """
+ req = server_name, json_object, validity_time
+
return logcontext.make_deferred_yieldable(
- self.verify_json_objects_for_server([(server_name, json_object)])[0]
+ self.verify_json_objects_for_server((req,))[0]
)
def verify_json_objects_for_server(self, server_and_json):
@@ -116,10 +140,12 @@ class Keyring(object):
necessary.
Args:
- server_and_json (list): List of pairs of (server_name, json_object)
+ server_and_json (iterable[Tuple[str, dict, int]):
+ Iterable of triplets of (server_name, json_object, validity_time)
+ validity_time is a timestamp at which the signing key must be valid.
Returns:
- List<Deferred>: for each input pair, a deferred indicating success
+ List<Deferred[None]>: for each input triplet, a deferred indicating success
or failure to verify each json object's signature for the given
server_name. The deferreds run their callbacks in the sentinel
logcontext.
@@ -128,12 +154,12 @@ class Keyring(object):
verify_requests = []
handle = preserve_fn(_handle_key_deferred)
- def process(server_name, json_object):
+ def process(server_name, json_object, validity_time):
"""Process an entry in the request list
- Given a (server_name, json_object) pair from the request list,
- adds a key request to verify_requests, and returns a deferred which will
- complete or fail (in the sentinel context) when verification completes.
+ Given a (server_name, json_object, validity_time) triplet from the request
+ list, adds a key request to verify_requests, and returns a deferred which
+ will complete or fail (in the sentinel context) when verification completes.
"""
key_ids = signature_ids(json_object, server_name)
@@ -144,11 +170,16 @@ class Keyring(object):
)
)
- logger.debug("Verifying for %s with key_ids %s", server_name, key_ids)
+ logger.debug(
+ "Verifying for %s with key_ids %s, min_validity %i",
+ server_name,
+ key_ids,
+ validity_time,
+ )
# add the key request to the queue, but don't start it off yet.
verify_request = VerifyKeyRequest(
- server_name, key_ids, json_object, defer.Deferred()
+ server_name, key_ids, json_object, validity_time
)
verify_requests.append(verify_request)
@@ -160,8 +191,8 @@ class Keyring(object):
return handle(verify_request)
results = [
- process(server_name, json_object)
- for server_name, json_object in server_and_json
+ process(server_name, json_object, validity_time)
+ for server_name, json_object, validity_time in server_and_json
]
if verify_requests:
@@ -298,8 +329,12 @@ class Keyring(object):
verify_request.deferred.errback(
SynapseError(
401,
- "No key for %s with id %s"
- % (verify_request.server_name, verify_request.key_ids),
+ "No key for %s with ids in %s (min_validity %i)"
+ % (
+ verify_request.server_name,
+ verify_request.key_ids,
+ verify_request.minimum_valid_until_ts,
+ ),
Codes.UNAUTHORIZED,
)
)
@@ -323,18 +358,28 @@ class Keyring(object):
Args:
fetcher (KeyFetcher): fetcher to use to fetch the keys
remaining_requests (set[VerifyKeyRequest]): outstanding key requests.
- Any successfully-completed requests will be reomved from the list.
+ Any successfully-completed requests will be removed from the list.
"""
- # dict[str, set(str)]: keys to fetch for each server
- missing_keys = {}
+ # dict[str, dict[str, int]]: keys to fetch.
+ # server_name -> key_id -> min_valid_ts
+ missing_keys = defaultdict(dict)
+
for verify_request in remaining_requests:
# any completed requests should already have been removed
assert not verify_request.deferred.called
- missing_keys.setdefault(verify_request.server_name, set()).update(
- verify_request.key_ids
- )
+ keys_for_server = missing_keys[verify_request.server_name]
- results = yield fetcher.get_keys(missing_keys.items())
+ for key_id in verify_request.key_ids:
+ # If we have several requests for the same key, then we only need to
+ # request that key once, but we should do so with the greatest
+ # min_valid_until_ts of the requests, so that we can satisfy all of
+ # the requests.
+ keys_for_server[key_id] = max(
+ keys_for_server.get(key_id, -1),
+ verify_request.minimum_valid_until_ts
+ )
+
+ results = yield fetcher.get_keys(missing_keys)
completed = list()
for verify_request in remaining_requests:
@@ -344,25 +389,34 @@ class Keyring(object):
# complete this VerifyKeyRequest.
result_keys = results.get(server_name, {})
for key_id in verify_request.key_ids:
- key = result_keys.get(key_id)
- if key:
- with PreserveLoggingContext():
- verify_request.deferred.callback(
- (server_name, key_id, key.verify_key)
- )
- completed.append(verify_request)
- break
+ fetch_key_result = result_keys.get(key_id)
+ if not fetch_key_result:
+ # we didn't get a result for this key
+ continue
+
+ if (
+ fetch_key_result.valid_until_ts
+ < verify_request.minimum_valid_until_ts
+ ):
+ # key was not valid at this point
+ continue
+
+ with PreserveLoggingContext():
+ verify_request.deferred.callback(
+ (server_name, key_id, fetch_key_result.verify_key)
+ )
+ completed.append(verify_request)
+ break
remaining_requests.difference_update(completed)
class KeyFetcher(object):
- def get_keys(self, server_name_and_key_ids):
+ def get_keys(self, keys_to_fetch):
"""
Args:
- server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
- list of (server_name, iterable[key_id]) tuples to fetch keys for
- Note that the iterables may be iterated more than once.
+ keys_to_fetch (dict[str, dict[str, int]]):
+ the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns:
Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
@@ -378,13 +432,15 @@ class StoreKeyFetcher(KeyFetcher):
self.store = hs.get_datastore()
@defer.inlineCallbacks
- def get_keys(self, server_name_and_key_ids):
+ def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
+
keys_to_fetch = (
(server_name, key_id)
- for server_name, key_ids in server_name_and_key_ids
- for key_id in key_ids
+ for server_name, keys_for_server in keys_to_fetch.items()
+ for key_id in keys_for_server.keys()
)
+
res = yield self.store.get_server_verify_keys(keys_to_fetch)
keys = {}
for (server_name, key_id), key in res.items():
@@ -508,14 +564,14 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.perspective_servers = self.config.perspectives
@defer.inlineCallbacks
- def get_keys(self, server_name_and_key_ids):
+ def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
- server_name_and_key_ids, perspective_name, perspective_keys
+ keys_to_fetch, perspective_name, perspective_keys
)
defer.returnValue(result)
except KeyLookupError as e:
@@ -549,13 +605,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
@defer.inlineCallbacks
def get_server_verify_key_v2_indirect(
- self, server_names_and_key_ids, perspective_name, perspective_keys
+ self, keys_to_fetch, perspective_name, perspective_keys
):
"""
Args:
- server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
- list of (server_name, iterable[key_id]) tuples to fetch keys for
+ keys_to_fetch (dict[str, dict[str, int]]):
+ the keys to be fetched. server_name -> key_id -> min_valid_ts
+
perspective_name (str): name of the notary server to query for the keys
+
perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
notary server
@@ -569,12 +627,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""
logger.info(
"Requesting keys %s from notary server %s",
- server_names_and_key_ids,
+ keys_to_fetch.items(),
perspective_name,
)
- # TODO(mark): Set the minimum_valid_until_ts to that needed by
- # the events being validated or the current time if validating
- # an incoming request.
+
try:
query_response = yield self.client.post_json(
destination=perspective_name,
@@ -582,12 +638,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
data={
u"server_keys": {
server_name: {
- key_id: {u"minimum_valid_until_ts": 0} for key_id in key_ids
+ key_id: {u"minimum_valid_until_ts": min_valid_ts}
+ for key_id, min_valid_ts in server_keys.items()
}
- for server_name, key_ids in server_names_and_key_ids
+ for server_name, server_keys in keys_to_fetch.items()
}
},
- long_retries=True,
)
except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(KeyLookupError("Failed to connect to remote server"), e)
@@ -693,34 +749,54 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock()
self.client = hs.get_http_client()
- @defer.inlineCallbacks
- def get_keys(self, server_name_and_key_ids):
- """see KeyFetcher.get_keys"""
- results = yield logcontext.make_deferred_yieldable(
- defer.gatherResults(
- [
- run_in_background(
- self.get_server_verify_key_v2_direct, server_name, key_ids
- )
- for server_name, key_ids in server_name_and_key_ids
- ],
- consumeErrors=True,
- ).addErrback(unwrapFirstError)
- )
+ def get_keys(self, keys_to_fetch):
+ """
+ Args:
+ keys_to_fetch (dict[str, iterable[str]]):
+ the keys to be fetched. server_name -> key_ids
- merged = {}
- for result in results:
- merged.update(result)
+ Returns:
+ Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+ map from server_name -> key_id -> FetchKeyResult
+ """
+
+ results = {}
- defer.returnValue(
- {server_name: keys for server_name, keys in merged.items() if keys}
+ @defer.inlineCallbacks
+ def get_key(key_to_fetch_item):
+ server_name, key_ids = key_to_fetch_item
+ try:
+ keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+ results[server_name] = keys
+ except KeyLookupError as e:
+ logger.warning(
+ "Error looking up keys %s from %s: %s", key_ids, server_name, e
+ )
+ except Exception:
+ logger.exception("Error getting keys %s from %s", key_ids, server_name)
+
+ return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
+ lambda _: results
)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ """
+
+ Args:
+ server_name (str):
+ key_ids (iterable[str]):
+
+ Returns:
+ Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+
+ Raises:
+ KeyLookupError if there was a problem making the lookup
+ """
keys = {} # type: dict[str, FetchKeyResult]
for requested_key_id in key_ids:
+ # we may have found this key as a side-effect of asking for another.
if requested_key_id in keys:
continue
@@ -731,6 +807,19 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
ignore_backoff=True,
+
+ # we only give the remote server 10s to respond. It should be an
+ # easy request to handle, so if it doesn't reply within 10s, it's
+ # probably not going to.
+ #
+ # Furthermore, when we are acting as a notary server, we cannot
+ # wait all day for all of the origin servers, as the requesting
+ # server will otherwise time out before we can respond.
+ #
+ # (Note that get_json may make 4 attempts, so this can still take
+ # almost 45 seconds to fetch the headers, plus up to another 60s to
+ # read the response).
+ timeout=10000,
)
except (NotRetryingDestination, RequestSendFailed) as e:
raise_from(KeyLookupError("Failed to connect to remote server"), e)
@@ -755,7 +844,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
)
keys.update(response_keys)
- defer.returnValue({server_name: keys})
+ defer.returnValue(keys)
@defer.inlineCallbacks
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index cffa831d80..4b38f7c759 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -265,7 +265,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
]
more_deferreds = keyring.verify_json_objects_for_server([
- (p.sender_domain, p.redacted_pdu_json)
+ (p.sender_domain, p.redacted_pdu_json, 0)
for p in pdus_to_check_sender
])
@@ -298,7 +298,7 @@ def _check_sigs_on_pdus(keyring, room_version, pdus):
]
more_deferreds = keyring.verify_json_objects_for_server([
- (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json)
+ (get_domain_from_id(p.pdu.event_id), p.redacted_pdu_json, 0)
for p in pdus_to_check_event_id
])
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index d0efc4e0d3..0db8858cf1 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -94,6 +94,7 @@ class NoAuthenticationError(AuthenticationError):
class Authenticator(object):
def __init__(self, hs):
+ self._clock = hs.get_clock()
self.keyring = hs.get_keyring()
self.server_name = hs.hostname
self.store = hs.get_datastore()
@@ -102,6 +103,7 @@ class Authenticator(object):
# A method just so we can pass 'self' as the authenticator to the Servlets
@defer.inlineCallbacks
def authenticate_request(self, request, content):
+ now = self._clock.time_msec()
json_request = {
"method": request.method.decode('ascii'),
"uri": request.uri.decode('ascii'),
@@ -138,7 +140,7 @@ class Authenticator(object):
401, "Missing Authorization headers", Codes.UNAUTHORIZED,
)
- yield self.keyring.verify_json_for_server(origin, json_request)
+ yield self.keyring.verify_json_for_server(origin, json_request, now)
logger.info("Request from %s", origin)
request.authenticated_entity = origin
diff --git a/synapse/groups/attestations.py b/synapse/groups/attestations.py
index 786149be65..fa6b641ee1 100644
--- a/synapse/groups/attestations.py
+++ b/synapse/groups/attestations.py
@@ -97,10 +97,11 @@ class GroupAttestationSigning(object):
# TODO: We also want to check that *new* attestations that people give
# us to store are valid for at least a little while.
- if valid_until_ms < self.clock.time_msec():
+ now = self.clock.time_msec()
+ if valid_until_ms < now:
raise SynapseError(400, "Attestation expired")
- yield self.keyring.verify_json_for_server(server_name, attestation)
+ yield self.keyring.verify_json_for_server(server_name, attestation, now)
def create_attestation(self, group_id, user_id):
"""Create an attestation for the group_id and user_id with default
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 8197619a78..663ea72a7a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -285,7 +285,24 @@ class MatrixFederationHttpClient(object):
request (MatrixFederationRequest): details of request to be sent
timeout (int|None): number of milliseconds to wait for the response headers
- (including connecting to the server). 60s by default.
+ (including connecting to the server), *for each attempt*.
+ 60s by default.
+
+ long_retries (bool): whether to use the long retry algorithm.
+
+ The regular retry algorithm makes 4 attempts, with intervals
+ [0.5s, 1s, 2s].
+
+ The long retry algorithm makes 11 attempts, with intervals
+ [4s, 16s, 60s, 60s, ...]
+
+ Both algorithms add -20%/+40% jitter to the retry intervals.
+
+ Note that the above intervals are *in addition* to the time spent
+ waiting for the request to complete (up to `timeout` ms).
+
+ NB: the long retry algorithm takes over 20 minutes to complete, with
+ a default timeout of 60s!
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
@@ -566,10 +583,14 @@ class MatrixFederationHttpClient(object):
the request body. This will be encoded as JSON.
json_data_callback (callable): A callable returning the dict to
use as the request body.
- long_retries (bool): A boolean that indicates whether we should
- retry for a short or long time.
- timeout(int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout.
+
+ long_retries (bool): whether to use the long retry algorithm. See
+ docs on _send_request for details.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
backoff_on_404 (bool): True if we should count a 404 response as
@@ -627,15 +648,22 @@ class MatrixFederationHttpClient(object):
Args:
destination (str): The remote server to send the HTTP request
to.
+
path (str): The HTTP path.
+
data (dict): A dict containing the data that will be used as
the request body. This will be encoded as JSON.
- long_retries (bool): A boolean that indicates whether we should
- retry for a short or long time.
- timeout(int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout.
+
+ long_retries (bool): whether to use the long retry algorithm. See
+ docs on _send_request for details.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
+
args (dict): query params
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
@@ -686,14 +714,19 @@ class MatrixFederationHttpClient(object):
Args:
destination (str): The remote server to send the HTTP request
to.
+
path (str): The HTTP path.
+
args (dict|None): A dictionary used to create query strings, defaults to
None.
- timeout (int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout and that the request will
- be retried.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data
and try the request anyway.
+
try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
@@ -742,12 +775,18 @@ class MatrixFederationHttpClient(object):
destination (str): The remote server to send the HTTP request
to.
path (str): The HTTP path.
- long_retries (bool): A boolean that indicates whether we should
- retry for a short or long time.
- timeout(int): How long to try (in ms) the destination for before
- giving up. None indicates no timeout.
+
+ long_retries (bool): whether to use the long retry algorithm. See
+ docs on _send_request for details.
+
+ timeout (int|None): number of milliseconds to wait for the response headers
+ (including connecting to the server), *for each attempt*.
+ self._default_timeout (60s) by default.
+
ignore_backoff (bool): true to ignore the historical backoff data and
try the request anyway.
+
+ args (dict): query params
Returns:
Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
diff --git a/synapse/rest/client/v1/voip.py b/synapse/rest/client/v1/voip.py
index 0975df84cf..6381049210 100644
--- a/synapse/rest/client/v1/voip.py
+++ b/synapse/rest/client/v1/voip.py
@@ -29,6 +29,7 @@ class VoipRestServlet(RestServlet):
def __init__(self, hs):
super(VoipRestServlet, self).__init__()
self.hs = hs
+ self.auth = hs.get_auth()
@defer.inlineCallbacks
def on_GET(self, request):
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 21c3c807b9..8a730bbc35 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -20,7 +20,7 @@ from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from synapse.api.errors import Codes, SynapseError
-from synapse.crypto.keyring import KeyLookupError, ServerKeyFetcher
+from synapse.crypto.keyring import ServerKeyFetcher
from synapse.http.server import respond_with_json_bytes, wrap_json_request_handler
from synapse.http.servlet import parse_integer, parse_json_object_from_request
@@ -215,15 +215,7 @@ class RemoteKey(Resource):
json_results.add(bytes(result["key_json"]))
if cache_misses and query_remote_on_cache_miss:
- for server_name, key_ids in cache_misses.items():
- try:
- yield self.fetcher.get_server_verify_key_v2_direct(
- server_name, key_ids
- )
- except KeyLookupError as e:
- logger.info("Failed to fetch key: %s", e)
- except Exception:
- logger.exception("Failed to get key for %r", server_name)
+ yield self.fetcher.get_keys(cache_misses)
yield self.query_keys(
request, query, query_remote_on_cache_miss=False
)
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
index 26cce7d197..f6dfa77d8f 100644
--- a/synapse/util/retryutils.py
+++ b/synapse/util/retryutils.py
@@ -46,8 +46,7 @@ class NotRetryingDestination(Exception):
@defer.inlineCallbacks
-def get_retry_limiter(destination, clock, store, ignore_backoff=False,
- **kwargs):
+def get_retry_limiter(destination, clock, store, ignore_backoff=False, **kwargs):
"""For a given destination check if we have previously failed to
send a request there and are waiting before retrying the destination.
If we are not ready to retry the destination, this will raise a
@@ -60,8 +59,7 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
clock (synapse.util.clock): timing source
store (synapse.storage.transactions.TransactionStore): datastore
ignore_backoff (bool): true to ignore the historical backoff data and
- try the request anyway. We will still update the next
- retry_interval on success/failure.
+ try the request anyway. We will still reset the retry_interval on success.
Example usage:
@@ -75,13 +73,12 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
"""
retry_last_ts, retry_interval = (0, 0)
- retry_timings = yield store.get_destination_retry_timings(
- destination
- )
+ retry_timings = yield store.get_destination_retry_timings(destination)
if retry_timings:
retry_last_ts, retry_interval = (
- retry_timings["retry_last_ts"], retry_timings["retry_interval"]
+ retry_timings["retry_last_ts"],
+ retry_timings["retry_interval"],
)
now = int(clock.time_msec())
@@ -93,22 +90,31 @@ def get_retry_limiter(destination, clock, store, ignore_backoff=False,
destination=destination,
)
+ # if we are ignoring the backoff data, we should also not increment the backoff
+ # when we get another failure - otherwise a server can very quickly reach the
+ # maximum backoff even though it might only have been down briefly
+ backoff_on_failure = not ignore_backoff
+
defer.returnValue(
RetryDestinationLimiter(
- destination,
- clock,
- store,
- retry_interval,
- **kwargs
+ destination, clock, store, retry_interval, backoff_on_failure, **kwargs
)
)
class RetryDestinationLimiter(object):
- def __init__(self, destination, clock, store, retry_interval,
- min_retry_interval=10 * 60 * 1000,
- max_retry_interval=24 * 60 * 60 * 1000,
- multiplier_retry_interval=5, backoff_on_404=False):
+ def __init__(
+ self,
+ destination,
+ clock,
+ store,
+ retry_interval,
+ min_retry_interval=10 * 60 * 1000,
+ max_retry_interval=24 * 60 * 60 * 1000,
+ multiplier_retry_interval=5,
+ backoff_on_404=False,
+ backoff_on_failure=True,
+ ):
"""Marks the destination as "down" if an exception is thrown in the
context, except for CodeMessageException with code < 500.
@@ -128,6 +134,9 @@ class RetryDestinationLimiter(object):
multiplier_retry_interval (int): The multiplier to use to increase
the retry interval after a failed request.
backoff_on_404 (bool): Back off if we get a 404
+
+ backoff_on_failure (bool): set to False if we should not increase the
+ retry interval on a failure.
"""
self.clock = clock
self.store = store
@@ -138,6 +147,7 @@ class RetryDestinationLimiter(object):
self.max_retry_interval = max_retry_interval
self.multiplier_retry_interval = multiplier_retry_interval
self.backoff_on_404 = backoff_on_404
+ self.backoff_on_failure = backoff_on_failure
def __enter__(self):
pass
@@ -173,10 +183,13 @@ class RetryDestinationLimiter(object):
if not self.retry_interval:
return
- logger.debug("Connection to %s was successful; clearing backoff",
- self.destination)
+ logger.debug(
+ "Connection to %s was successful; clearing backoff", self.destination
+ )
retry_last_ts = 0
self.retry_interval = 0
+ elif not self.backoff_on_failure:
+ return
else:
# We couldn't connect.
if self.retry_interval:
@@ -190,7 +203,10 @@ class RetryDestinationLimiter(object):
logger.info(
"Connection to %s was unsuccessful (%s(%s)); backoff now %i",
- self.destination, exc_type, exc_val, self.retry_interval
+ self.destination,
+ exc_type,
+ exc_val,
+ self.retry_interval,
)
retry_last_ts = int(self.clock.time_msec())
@@ -201,9 +217,7 @@ class RetryDestinationLimiter(object):
self.destination, retry_last_ts, self.retry_interval
)
except Exception:
- logger.exception(
- "Failed to store destination_retry_timings",
- )
+ logger.exception("Failed to store destination_retry_timings")
# we deliberately do this in the background.
synapse.util.logcontext.run_in_background(store_retry_timings)
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index 3933ad4347..4cff7e36c8 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -19,16 +19,13 @@ from mock import Mock
import canonicaljson
import signedjson.key
import signedjson.sign
+from signedjson.key import get_verify_key
from twisted.internet import defer
from synapse.api.errors import SynapseError
from synapse.crypto import keyring
-from synapse.crypto.keyring import (
- KeyLookupError,
- PerspectivesKeyFetcher,
- ServerKeyFetcher,
-)
+from synapse.crypto.keyring import PerspectivesKeyFetcher, ServerKeyFetcher
from synapse.storage.keys import FetchKeyResult
from synapse.util import logcontext
from synapse.util.logcontext import LoggingContext
@@ -137,7 +134,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
context_11.request = "11"
res_deferreds = kr.verify_json_objects_for_server(
- [("server10", json1), ("server11", {})]
+ [("server10", json1, 0), ("server11", {}, 0)]
)
# the unsigned json should be rejected pretty quickly
@@ -174,7 +171,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.return_value = defer.Deferred()
res_deferreds_2 = kr.verify_json_objects_for_server(
- [("server10", json1)]
+ [("server10", json1, 0)]
)
res_deferreds_2[0].addBoth(self.check_context, None)
yield logcontext.make_deferred_yieldable(res_deferreds_2[0])
@@ -197,31 +194,108 @@ class KeyringTestCase(unittest.HomeserverTestCase):
kr = keyring.Keyring(self.hs)
key1 = signedjson.key.generate_signing_key(1)
- key1_id = "%s:%s" % (key1.alg, key1.version)
-
r = self.hs.datastore.store_server_verify_keys(
"server9",
time.time() * 1000,
- [
- (
- "server9",
- key1_id,
- FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
- ),
- ],
+ [("server9", get_key_id(key1), FetchKeyResult(get_verify_key(key1), 1000))],
)
self.get_success(r)
+
json1 = {}
signedjson.sign.sign_json(json1, "server9", key1)
# should fail immediately on an unsigned object
- d = _verify_json_for_server(kr, "server9", {})
+ d = _verify_json_for_server(kr, "server9", {}, 0)
self.failureResultOf(d, SynapseError)
- d = _verify_json_for_server(kr, "server9", json1)
- self.assertFalse(d.called)
+ # should suceed on a signed object
+ d = _verify_json_for_server(kr, "server9", json1, 500)
+ # self.assertFalse(d.called)
self.get_success(d)
+ def test_verify_json_dedupes_key_requests(self):
+ """Two requests for the same key should be deduped."""
+ key1 = signedjson.key.generate_signing_key(1)
+
+ def get_keys(keys_to_fetch):
+ # there should only be one request object (with the max validity)
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+ }
+ }
+ )
+
+ mock_fetcher = keyring.KeyFetcher()
+ mock_fetcher.get_keys = Mock(side_effect=get_keys)
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher,))
+
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server1", key1)
+
+ # the first request should succeed; the second should fail because the key
+ # has expired
+ results = kr.verify_json_objects_for_server(
+ [("server1", json1, 500), ("server1", json1, 1500)]
+ )
+ self.assertEqual(len(results), 2)
+ self.get_success(results[0])
+ e = self.get_failure(results[1], SynapseError).value
+ self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+ self.assertEqual(e.code, 401)
+
+ # there should have been a single call to the fetcher
+ mock_fetcher.get_keys.assert_called_once()
+
+ def test_verify_json_falls_back_to_other_fetchers(self):
+ """If the first fetcher cannot provide a recent enough key, we fall back"""
+ key1 = signedjson.key.generate_signing_key(1)
+
+ def get_keys1(keys_to_fetch):
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
+ }
+ }
+ )
+
+ def get_keys2(keys_to_fetch):
+ self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
+ return defer.succeed(
+ {
+ "server1": {
+ get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
+ }
+ }
+ )
+
+ mock_fetcher1 = keyring.KeyFetcher()
+ mock_fetcher1.get_keys = Mock(side_effect=get_keys1)
+ mock_fetcher2 = keyring.KeyFetcher()
+ mock_fetcher2.get_keys = Mock(side_effect=get_keys2)
+ kr = keyring.Keyring(self.hs, key_fetchers=(mock_fetcher1, mock_fetcher2))
+
+ json1 = {}
+ signedjson.sign.sign_json(json1, "server1", key1)
+
+ results = kr.verify_json_objects_for_server(
+ [("server1", json1, 1200), ("server1", json1, 1500)]
+ )
+ self.assertEqual(len(results), 2)
+ self.get_success(results[0])
+ e = self.get_failure(results[1], SynapseError).value
+ self.assertEqual(e.errcode, "M_UNAUTHORIZED")
+ self.assertEqual(e.code, 401)
+
+ # there should have been a single call to each fetcher
+ mock_fetcher1.get_keys.assert_called_once()
+ mock_fetcher2.get_keys.assert_called_once()
+
class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor, clock):
@@ -260,8 +334,8 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.get_json.side_effect = get_json
- server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
self.assertEqual(k.verify_key, testverifykey)
@@ -286,11 +360,11 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
)
- # change the server name: it should cause a rejection
+ # change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
- self.get_failure(
- fetcher.get_keys(server_name_and_key_ids), KeyLookupError
- )
+
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
+ self.assertEqual(keys, {})
class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
@@ -342,8 +416,8 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json
- server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
- keys = self.get_success(fetcher.get_keys(server_name_and_key_ids))
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
+ keys = self.get_success(fetcher.get_keys(keys_to_fetch))
self.assertIn(SERVER_NAME, keys)
k = keys[SERVER_NAME][testverifykey_id]
self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
@@ -401,7 +475,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
def get_key_from_perspectives(response):
fetcher = PerspectivesKeyFetcher(self.hs)
- server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
+ keys_to_fetch = {SERVER_NAME: {"key1": 0}}
def post_json(destination, path, data, **kwargs):
self.assertEqual(destination, self.mock_perspective_server.server_name)
@@ -410,9 +484,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.http_client.post_json.side_effect = post_json
- return self.get_success(
- fetcher.get_keys(server_name_and_key_ids)
- )
+ return self.get_success(fetcher.get_keys(keys_to_fetch))
# start with a valid response so we can check we are testing the right thing
response = build_response()
@@ -435,6 +507,11 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig")
+def get_key_id(key):
+ """Get the matrix ID tag for a given SigningKey or VerifyKey"""
+ return "%s:%s" % (key.alg, key.version)
+
+
@defer.inlineCallbacks
def run_in_context(f, *args, **kwargs):
with LoggingContext("testctx") as ctx:
@@ -445,14 +522,16 @@ def run_in_context(f, *args, **kwargs):
defer.returnValue(rv)
-def _verify_json_for_server(keyring, server_name, json_object):
+def _verify_json_for_server(keyring, server_name, json_object, validity_time):
"""thin wrapper around verify_json_for_server which makes sure it is wrapped
with the patched defer.inlineCallbacks.
"""
@defer.inlineCallbacks
def v():
- rv1 = yield keyring.verify_json_for_server(server_name, json_object)
+ rv1 = yield keyring.verify_json_for_server(
+ server_name, json_object, validity_time
+ )
defer.returnValue(rv1)
return run_in_context(v)
|