diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index da95c2ad6d..f2b3aceb49 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError
from synapse.util.async import concurrently_execute
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
from synapse.events import FrozenEvent
import synapse.metrics
@@ -51,10 +52,34 @@ sent_edus_counter = metrics.register_counter("sent_edus")
sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
+PDU_RETRY_TIME_MS = 1 * 60 * 1000
+
+
class FederationClient(FederationBase):
def __init__(self, hs):
super(FederationClient, self).__init__(hs)
+ self.pdu_destination_tried = {}
+ self._clock.looping_call(
+ self._clear_tried_cache, 60 * 1000,
+ )
+
+ def _clear_tried_cache(self):
+ """Clear pdu_destination_tried cache"""
+ now = self._clock.time_msec()
+
+ old_dict = self.pdu_destination_tried
+ self.pdu_destination_tried = {}
+
+ for event_id, destination_dict in old_dict.items():
+ destination_dict = {
+ dest: time
+ for dest, time in destination_dict.items()
+ if time + PDU_RETRY_TIME_MS > now
+ }
+ if destination_dict:
+ self.pdu_destination_tried[event_id] = destination_dict
+
def start_get_pdu_cache(self):
self._get_pdu_cache = ExpiringCache(
cache_name="get_pdu_cache",
@@ -201,10 +226,10 @@ class FederationClient(FederationBase):
]
# FIXME: We should handle signature failures more gracefully.
- pdus[:] = yield defer.gatherResults(
+ pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
self._check_sigs_and_hashes(pdus),
consumeErrors=True,
- ).addErrback(unwrapFirstError)
+ )).addErrback(unwrapFirstError)
defer.returnValue(pdus)
@@ -240,8 +265,15 @@ class FederationClient(FederationBase):
if ev:
defer.returnValue(ev)
+ pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+
pdu = None
for destination in destinations:
+ now = self._clock.time_msec()
+ last_attempt = pdu_attempts.get(destination, 0)
+ if last_attempt + PDU_RETRY_TIME_MS > now:
+ continue
+
try:
limiter = yield get_retry_limiter(
destination,
@@ -269,25 +301,19 @@ class FederationClient(FederationBase):
break
- except SynapseError as e:
- logger.info(
- "Failed to get PDU %s from %s because %s",
- event_id, destination, e,
- )
- continue
- except CodeMessageException as e:
- if 400 <= e.code < 500:
- raise
+ pdu_attempts[destination] = now
+ except SynapseError as e:
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
)
- continue
except NotRetryingDestination as e:
logger.info(e.message)
continue
except Exception as e:
+ pdu_attempts[destination] = now
+
logger.info(
"Failed to get PDU %s from %s because %s",
event_id, destination, e,
@@ -406,7 +432,7 @@ class FederationClient(FederationBase):
events and the second is a list of event ids that we failed to fetch.
"""
if return_local:
- seen_events = yield self.store.get_events(event_ids)
+ seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
signed_events = seen_events.values()
else:
seen_events = yield self.store.have_events(event_ids)
@@ -432,14 +458,16 @@ class FederationClient(FederationBase):
batch = set(missing_events[i:i + batch_size])
deferreds = [
- self.get_pdu(
+ preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id in batch
]
- res = yield defer.DeferredList(deferreds, consumeErrors=True)
+ res = yield preserve_context_over_deferred(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
for success, result in res:
if success:
signed_events.append(result)
@@ -828,14 +856,16 @@ class FederationClient(FederationBase):
return srvs
deferreds = [
- self.get_pdu(
+ preserve_fn(self.get_pdu)(
destinations=random_server_list(),
event_id=e_id,
)
for e_id, depth in ordered_missing[:limit - len(signed_events)]
]
- res = yield defer.DeferredList(deferreds, consumeErrors=True)
+ res = yield preserve_context_over_deferred(
+ defer.DeferredList(deferreds, consumeErrors=True)
+ )
for (result, val), (e_id, _) in zip(res, ordered_missing):
if result and val:
signed_events.append(val)
|