diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 3fb99f7125..ea00c830c0 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -22,6 +22,8 @@ from syutil.crypto.signing_key import (
from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
+from synapse.util.retryutils import get_retry_limiter
+
from OpenSSL import crypto
import logging
@@ -87,69 +89,75 @@ class Keyring(object):
return
# Try to fetch the key from the remote server.
- # TODO(markjh): Ratelimit requests to a given server.
-
- (response, tls_certificate) = yield fetch_server_key(
- server_name, self.hs.tls_context_factory
- )
- # Check the response.
-
- x509_certificate_bytes = crypto.dump_certificate(
- crypto.FILETYPE_ASN1, tls_certificate
+ limiter = yield get_retry_limiter(
+ server_name,
+ self.clock,
+ self.store,
)
- if ("signatures" not in response
- or server_name not in response["signatures"]):
- raise ValueError("Key response not signed by remote server")
-
- if "tls_certificate" not in response:
- raise ValueError("Key response missing TLS certificate")
+ with limiter:
+ (response, tls_certificate) = yield fetch_server_key(
+ server_name, self.hs.tls_context_factory
+ )
- tls_certificate_b64 = response["tls_certificate"]
+ # Check the response.
- if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
- raise ValueError("TLS certificate doesn't match")
+ x509_certificate_bytes = crypto.dump_certificate(
+ crypto.FILETYPE_ASN1, tls_certificate
+ )
- verify_keys = {}
- for key_id, key_base64 in response["verify_keys"].items():
- if is_signing_algorithm_supported(key_id):
- key_bytes = decode_base64(key_base64)
- verify_key = decode_verify_key_bytes(key_id, key_bytes)
- verify_keys[key_id] = verify_key
+ if ("signatures" not in response
+ or server_name not in response["signatures"]):
+ raise ValueError("Key response not signed by remote server")
+
+ if "tls_certificate" not in response:
+ raise ValueError("Key response missing TLS certificate")
+
+ tls_certificate_b64 = response["tls_certificate"]
+
+ if encode_base64(x509_certificate_bytes) != tls_certificate_b64:
+ raise ValueError("TLS certificate doesn't match")
+
+ verify_keys = {}
+ for key_id, key_base64 in response["verify_keys"].items():
+ if is_signing_algorithm_supported(key_id):
+ key_bytes = decode_base64(key_base64)
+ verify_key = decode_verify_key_bytes(key_id, key_bytes)
+ verify_keys[key_id] = verify_key
+
+ for key_id in response["signatures"][server_name]:
+ if key_id not in response["verify_keys"]:
+ raise ValueError(
+ "Key response must include verification keys for all"
+ " signatures"
+ )
+ if key_id in verify_keys:
+ verify_signed_json(
+ response,
+ server_name,
+ verify_keys[key_id]
+ )
+
+ # Cache the result in the datastore.
+
+ time_now_ms = self.clock.time_msec()
+
+ yield self.store.store_server_certificate(
+ server_name,
+ server_name,
+ time_now_ms,
+ tls_certificate,
+ )
- for key_id in response["signatures"][server_name]:
- if key_id not in response["verify_keys"]:
- raise ValueError(
- "Key response must include verification keys for all"
- " signatures"
+ for key_id, key in verify_keys.items():
+ yield self.store.store_server_verify_key(
+ server_name, server_name, time_now_ms, key
)
- if key_id in verify_keys:
- verify_signed_json(
- response,
- server_name,
- verify_keys[key_id]
- )
-
- # Cache the result in the datastore.
-
- time_now_ms = self.clock.time_msec()
-
- yield self.store.store_server_certificate(
- server_name,
- server_name,
- time_now_ms,
- tls_certificate,
- )
-
- for key_id, key in verify_keys.items():
- yield self.store.store_server_verify_key(
- server_name, server_name, time_now_ms, key
- )
- for key_id in key_ids:
- if key_id in verify_keys:
- defer.returnValue(verify_keys[key_id])
- return
+ for key_id in key_ids:
+ if key_id in verify_keys:
+ defer.returnValue(verify_keys[key_id])
+ return
- raise ValueError("No verification key found for given key ids")
+ raise ValueError("No verification key found for given key ids")
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 70c9a6f46b..c5b0274c2f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -23,6 +23,8 @@ from synapse.api.errors import CodeMessageException
from synapse.util.logutils import log_function
from synapse.events import FrozenEvent
+from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+
import logging
@@ -163,24 +165,34 @@ class FederationClient(FederationBase):
pdu = None
for destination in destinations:
try:
- transaction_data = yield self.transport_layer.get_event(
- destination, event_id
+ limiter = yield get_retry_limiter(
+ destination,
+ self._clock,
+ self.store,
)
- logger.debug("transaction_data %r", transaction_data)
+ with limiter:
+ transaction_data = yield self.transport_layer.get_event(
+ destination, event_id
+ )
- pdu_list = [
- self.event_from_pdu_json(p, outlier=outlier)
- for p in transaction_data["pdus"]
- ]
+ logger.debug("transaction_data %r", transaction_data)
- if pdu_list:
- pdu = pdu_list[0]
+ pdu_list = [
+ self.event_from_pdu_json(p, outlier=outlier)
+ for p in transaction_data["pdus"]
+ ]
- # Check signatures are correct.
- pdu = yield self._check_sigs_and_hash(pdu)
+ if pdu_list:
+ pdu = pdu_list[0]
- break
+ # Check signatures are correct.
+ pdu = yield self._check_sigs_and_hash(pdu)
+
+ break
+ except NotRetryingDestination as e:
+ logger.info(e.message)
+ continue
except CodeMessageException:
raise
except Exception as e:
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index bb20f2ebab..7d02afe163 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -22,6 +22,9 @@ from .units import Transaction
from synapse.api.errors import HttpResponseException
from synapse.util.logutils import log_function
from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.retryutils import (
+ get_retry_limiter, NotRetryingDestination,
+)
import logging
@@ -138,25 +141,6 @@ class TransactionQueue(object):
@defer.inlineCallbacks
@log_function
def _attempt_new_transaction(self, destination):
-
- (retry_last_ts, retry_interval) = (0, 0)
- retry_timings = yield self.store.get_destination_retry_timings(
- destination
- )
- if retry_timings:
- (retry_last_ts, retry_interval) = (
- retry_timings.retry_last_ts, retry_timings.retry_interval
- )
- if retry_last_ts + retry_interval > int(self._clock.time_msec()):
- logger.info(
- "TX [%s] not ready for retry yet - "
- "dropping transaction for now",
- destination,
- )
- return
- else:
- logger.info("TX [%s] is ready for retry", destination)
-
if destination in self.pending_transactions:
# XXX: pending_transactions can get stuck on by a never-ending
# request at which point pending_pdus_by_dest just keeps growing.
@@ -204,77 +188,79 @@ class TransactionQueue(object):
]
try:
- self.pending_transactions[destination] = 1
-
- logger.debug("TX [%s] Persisting transaction...", destination)
-
- transaction = Transaction.create_new(
- origin_server_ts=int(self._clock.time_msec()),
- transaction_id=str(self._next_txn_id),
- origin=self.server_name,
- destination=destination,
- pdus=pdus,
- edus=edus,
- pdu_failures=failures,
+ limiter = yield get_retry_limiter(
+ destination,
+ self._clock,
+ self.store,
)
- self._next_txn_id += 1
+ with limiter:
+ self.pending_transactions[destination] = 1
- yield self.transaction_actions.prepare_to_send(transaction)
+ logger.debug("TX [%s] Persisting transaction...", destination)
- logger.debug("TX [%s] Persisted transaction", destination)
- logger.info(
- "TX [%s] Sending transaction [%s]",
- destination,
- transaction.transaction_id,
- )
+ transaction = Transaction.create_new(
+ origin_server_ts=int(self._clock.time_msec()),
+ transaction_id=str(self._next_txn_id),
+ origin=self.server_name,
+ destination=destination,
+ pdus=pdus,
+ edus=edus,
+ pdu_failures=failures,
+ )
+
+ self._next_txn_id += 1
- # Actually send the transaction
-
- # FIXME (erikj): This is a bit of a hack to make the Pdu age
- # keys work
- def json_data_cb():
- data = transaction.get_dict()
- now = int(self._clock.time_msec())
- if "pdus" in data:
- for p in data["pdus"]:
- if "age_ts" in p:
- unsigned = p.setdefault("unsigned", {})
- unsigned["age"] = now - int(p["age_ts"])
- del p["age_ts"]
- return data
-
- try:
- response = yield self.transport_layer.send_transaction(
- transaction, json_data_cb
+ yield self.transaction_actions.prepare_to_send(transaction)
+
+ logger.debug("TX [%s] Persisted transaction", destination)
+ logger.info(
+ "TX [%s] Sending transaction [%s]",
+ destination,
+ transaction.transaction_id,
)
- code = 200
- except HttpResponseException as e:
- code = e.code
- response = e.response
- logger.info("TX [%s] got %d response", destination, code)
+ # Actually send the transaction
+
+ # FIXME (erikj): This is a bit of a hack to make the Pdu age
+ # keys work
+ def json_data_cb():
+ data = transaction.get_dict()
+ now = int(self._clock.time_msec())
+ if "pdus" in data:
+ for p in data["pdus"]:
+ if "age_ts" in p:
+ unsigned = p.setdefault("unsigned", {})
+ unsigned["age"] = now - int(p["age_ts"])
+ del p["age_ts"]
+ return data
+
+ try:
+ response = yield self.transport_layer.send_transaction(
+ transaction, json_data_cb
+ )
+ code = 200
+ except HttpResponseException as e:
+ code = e.code
+ response = e.response
- logger.debug("TX [%s] Sent transaction", destination)
- logger.debug("TX [%s] Marking as delivered...", destination)
+ logger.info("TX [%s] got %d response", destination, code)
- yield self.transaction_actions.delivered(
- transaction, code, response
- )
+ logger.debug("TX [%s] Sent transaction", destination)
+ logger.debug("TX [%s] Marking as delivered...", destination)
+
+ yield self.transaction_actions.delivered(
+ transaction, code, response
+ )
+
+ logger.debug("TX [%s] Marked as delivered", destination)
- logger.debug("TX [%s] Marked as delivered", destination)
logger.debug("TX [%s] Yielding to callbacks...", destination)
for deferred in deferreds:
if code == 200:
- if retry_last_ts:
- # this host is alive! reset retry schedule
- yield self.store.set_destination_retry_timings(
- destination, 0, 0
- )
deferred.callback(None)
else:
- self.set_retrying(destination, retry_interval)
deferred.errback(RuntimeError("Got status %d" % code))
# Ensures we don't continue until all callbacks on that
@@ -285,6 +271,12 @@ class TransactionQueue(object):
pass
logger.debug("TX [%s] Yielded to callbacks", destination)
+ except NotRetryingDestination:
+ logger.info(
+ "TX [%s] not ready for retry yet - "
+ "dropping transaction for now",
+ destination,
+ )
except RuntimeError as e:
# We capture this here as there as nothing actually listens
# for this finishing functions deferred.
@@ -302,8 +294,6 @@ class TransactionQueue(object):
e,
)
- self.set_retrying(destination, retry_interval)
-
for deferred in deferreds:
if not deferred.called:
deferred.errback(e)
@@ -314,22 +304,3 @@ class TransactionQueue(object):
# Check to see if there is anything else to send.
self._attempt_new_transaction(destination)
-
- @defer.inlineCallbacks
- def set_retrying(self, destination, retry_interval):
- # track that this destination is having problems and we should
- # give it a chance to recover before trying it again
-
- if retry_interval:
- retry_interval *= 2
- # plateau at hourly retries for now
- if retry_interval >= 60 * 60 * 1000:
- retry_interval = 60 * 60 * 1000
- else:
- retry_interval = 2000 # try again at first after 2 seconds
-
- yield self.store.set_destination_retry_timings(
- destination,
- int(self._clock.time_msec()),
- retry_interval
- )
diff --git a/synapse/util/retryutils.py b/synapse/util/retryutils.py
new file mode 100644
index 0000000000..bba524ebd8
--- /dev/null
+++ b/synapse/util/retryutils.py
@@ -0,0 +1,108 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.internet import defer
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class NotRetryingDestination(Exception):
+ def __init__(self, retry_last_ts, retry_interval, destination):
+ msg = "Not retrying server %s." % (destination,)
+ super(NotRetryingDestination, self).__init__(msg)
+
+ self.retry_last_ts = retry_last_ts
+ self.retry_interval = retry_interval
+ self.destination = destination
+
+
+@defer.inlineCallbacks
+def get_retry_limiter(destination, clock, store, **kwargs):
+ retry_last_ts, retry_interval = (0, 0)
+
+ retry_timings = yield store.get_destination_retry_timings(
+ destination
+ )
+
+ if retry_timings:
+ retry_last_ts, retry_interval = (
+ retry_timings.retry_last_ts, retry_timings.retry_interval
+ )
+
+ now = int(clock.time_msec())
+
+ if retry_last_ts + retry_interval > now:
+ raise NotRetryingDestination(
+ retry_last_ts=retry_last_ts,
+ retry_interval=retry_interval,
+ destination=destination,
+ )
+
+ defer.returnValue(
+ RetryDestinationLimiter(
+ destination,
+ clock,
+ store,
+ retry_interval,
+ **kwargs
+ )
+ )
+
+
+class RetryDestinationLimiter(object):
+ def __init__(self, destination, clock, store, retry_interval,
+ min_retry_interval=20000, max_retry_interval=60 * 60 * 1000,
+ multiplier_retry_interval=2):
+ self.clock = clock
+ self.store = store
+ self.destination = destination
+
+ self.retry_interval = retry_interval
+ self.min_retry_interval = min_retry_interval
+ self.max_retry_interval = max_retry_interval
+ self.multiplier_retry_interval = multiplier_retry_interval
+
+ def __enter__(self):
+ pass
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ def err(self, failure):
+ logger.exception(
+ "Failed to store set_destination_retry_timings",
+ failure.value
+ )
+
+ if exc_type is None and exc_val is None and exc_tb is None:
+ # We connected successfully.
+ retry_last_ts = 0
+ self.retry_interval = 0
+ else:
+ # We couldn't connect.
+ if self.retry_interval:
+ self.retry_interval *= self.multiplier_retry_interval
+
+ if self.retry_interval >= self.max_retry_interval:
+ self.retry_interval = self.max_retry_interval
+ else:
+ self.retry_interval = self.min_retry_interval
+
+ retry_last_ts = int(self._clock.time_msec()),
+
+ self.store.set_destination_retry_timings(
+ self.destination, retry_last_ts, self.retry_interval
+ ).addErrback(err)
|