summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_base.py7
-rw-r--r--synapse/federation/federation_client.py64
-rw-r--r--synapse/federation/transaction_queue.py381
3 files changed, 216 insertions, 236 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index da2f5e8cfd..2339cc9034 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -23,6 +23,7 @@ from synapse.crypto.event_signing import check_event_content_hash
 from synapse.api.errors import SynapseError
 
 from synapse.util import unwrapFirstError
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 
 import logging
 
@@ -102,10 +103,10 @@ class FederationBase(object):
                 warn, pdu
             )
 
-        valid_pdus = yield defer.gatherResults(
+        valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
             deferreds,
             consumeErrors=True
-        ).addErrback(unwrapFirstError)
+        )).addErrback(unwrapFirstError)
 
         if include_none:
             defer.returnValue(valid_pdus)
@@ -129,7 +130,7 @@ class FederationBase(object):
             for pdu in pdus
         ]
 
-        deferreds = self.keyring.verify_json_objects_for_server([
+        deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
             (p.origin, p.get_pdu_json())
             for p in redacted_pdus
         ])
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index da95c2ad6d..f2b3aceb49 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -27,6 +27,7 @@ from synapse.util import unwrapFirstError
 from synapse.util.async import concurrently_execute
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
+from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
 from synapse.events import FrozenEvent
 import synapse.metrics
 
@@ -51,10 +52,34 @@ sent_edus_counter = metrics.register_counter("sent_edus")
 sent_queries_counter = metrics.register_counter("sent_queries", labels=["type"])
 
 
+PDU_RETRY_TIME_MS = 1 * 60 * 1000
+
+
 class FederationClient(FederationBase):
     def __init__(self, hs):
         super(FederationClient, self).__init__(hs)
 
+        self.pdu_destination_tried = {}
+        self._clock.looping_call(
+            self._clear_tried_cache, 60 * 1000,
+        )
+
+    def _clear_tried_cache(self):
+        """Clear pdu_destination_tried cache"""
+        now = self._clock.time_msec()
+
+        old_dict = self.pdu_destination_tried
+        self.pdu_destination_tried = {}
+
+        for event_id, destination_dict in old_dict.items():
+            destination_dict = {
+                dest: time
+                for dest, time in destination_dict.items()
+                if time + PDU_RETRY_TIME_MS > now
+            }
+            if destination_dict:
+                self.pdu_destination_tried[event_id] = destination_dict
+
     def start_get_pdu_cache(self):
         self._get_pdu_cache = ExpiringCache(
             cache_name="get_pdu_cache",
@@ -201,10 +226,10 @@ class FederationClient(FederationBase):
         ]
 
         # FIXME: We should handle signature failures more gracefully.
-        pdus[:] = yield defer.gatherResults(
+        pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
             self._check_sigs_and_hashes(pdus),
             consumeErrors=True,
-        ).addErrback(unwrapFirstError)
+        )).addErrback(unwrapFirstError)
 
         defer.returnValue(pdus)
 
@@ -240,8 +265,15 @@ class FederationClient(FederationBase):
             if ev:
                 defer.returnValue(ev)
 
+        pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+
         pdu = None
         for destination in destinations:
+            now = self._clock.time_msec()
+            last_attempt = pdu_attempts.get(destination, 0)
+            if last_attempt + PDU_RETRY_TIME_MS > now:
+                continue
+
             try:
                 limiter = yield get_retry_limiter(
                     destination,
@@ -269,25 +301,19 @@ class FederationClient(FederationBase):
 
                         break
 
-            except SynapseError as e:
-                logger.info(
-                    "Failed to get PDU %s from %s because %s",
-                    event_id, destination, e,
-                )
-                continue
-            except CodeMessageException as e:
-                if 400 <= e.code < 500:
-                    raise
+                pdu_attempts[destination] = now
 
+            except SynapseError as e:
                 logger.info(
                     "Failed to get PDU %s from %s because %s",
                     event_id, destination, e,
                 )
-                continue
             except NotRetryingDestination as e:
                 logger.info(e.message)
                 continue
             except Exception as e:
+                pdu_attempts[destination] = now
+
                 logger.info(
                     "Failed to get PDU %s from %s because %s",
                     event_id, destination, e,
@@ -406,7 +432,7 @@ class FederationClient(FederationBase):
             events and the second is a list of event ids that we failed to fetch.
         """
         if return_local:
-            seen_events = yield self.store.get_events(event_ids)
+            seen_events = yield self.store.get_events(event_ids, allow_rejected=True)
             signed_events = seen_events.values()
         else:
             seen_events = yield self.store.have_events(event_ids)
@@ -432,14 +458,16 @@ class FederationClient(FederationBase):
             batch = set(missing_events[i:i + batch_size])
 
             deferreds = [
-                self.get_pdu(
+                preserve_fn(self.get_pdu)(
                     destinations=random_server_list(),
                     event_id=e_id,
                 )
                 for e_id in batch
             ]
 
-            res = yield defer.DeferredList(deferreds, consumeErrors=True)
+            res = yield preserve_context_over_deferred(
+                defer.DeferredList(deferreds, consumeErrors=True)
+            )
             for success, result in res:
                 if success:
                     signed_events.append(result)
@@ -828,14 +856,16 @@ class FederationClient(FederationBase):
                 return srvs
 
             deferreds = [
-                self.get_pdu(
+                preserve_fn(self.get_pdu)(
                     destinations=random_server_list(),
                     event_id=e_id,
                 )
                 for e_id, depth in ordered_missing[:limit - len(signed_events)]
             ]
 
-            res = yield defer.DeferredList(deferreds, consumeErrors=True)
+            res = yield preserve_context_over_deferred(
+                defer.DeferredList(deferreds, consumeErrors=True)
+            )
             for (result, val), (e_id, _) in zip(res, ordered_missing):
                 if result and val:
                     signed_events.append(val)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 5787f854d4..cb2ef0210c 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -21,11 +21,11 @@ from .units import Transaction
 
 from synapse.api.errors import HttpResponseException
 from synapse.util.async import run_on_reactor
-from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import preserve_context_over_fn
 from synapse.util.retryutils import (
     get_retry_limiter, NotRetryingDestination,
 )
+from synapse.util.metrics import measure_func
 import synapse.metrics
 
 import logging
@@ -51,7 +51,7 @@ class TransactionQueue(object):
 
         self.transport_layer = transport_layer
 
-        self._clock = hs.get_clock()
+        self.clock = hs.get_clock()
 
         # Is a mapping from destinations -> deferreds. Used to keep track
         # of which destinations have transactions in flight and when they are
@@ -82,7 +82,7 @@ class TransactionQueue(object):
         self.pending_failures_by_dest = {}
 
         # HACK to get unique tx id
-        self._next_txn_id = int(self._clock.time_msec())
+        self._next_txn_id = int(self.clock.time_msec())
 
     def can_send_to(self, destination):
         """Can we send messages to the given server?
@@ -119,266 +119,215 @@ class TransactionQueue(object):
         if not destinations:
             return
 
-        deferreds = []
-
         for destination in destinations:
-            deferred = defer.Deferred()
             self.pending_pdus_by_dest.setdefault(destination, []).append(
-                (pdu, deferred, order)
+                (pdu, order)
             )
 
-            def chain(failure):
-                if not deferred.called:
-                    deferred.errback(failure)
-
-            def log_failure(f):
-                logger.warn("Failed to send pdu to %s: %s", destination, f.value)
-
-            deferred.addErrback(log_failure)
-
-            with PreserveLoggingContext():
-                self._attempt_new_transaction(destination).addErrback(chain)
-
-            deferreds.append(deferred)
+            preserve_context_over_fn(
+                self._attempt_new_transaction, destination
+            )
 
-    # NO inlineCallbacks
     def enqueue_edu(self, edu):
         destination = edu.destination
 
         if not self.can_send_to(destination):
             return
 
-        deferred = defer.Deferred()
-        self.pending_edus_by_dest.setdefault(destination, []).append(
-            (edu, deferred)
-        )
+        self.pending_edus_by_dest.setdefault(destination, []).append(edu)
 
-        def chain(failure):
-            if not deferred.called:
-                deferred.errback(failure)
-
-        def log_failure(f):
-            logger.warn("Failed to send edu to %s: %s", destination, f.value)
-
-        deferred.addErrback(log_failure)
-
-        with PreserveLoggingContext():
-            self._attempt_new_transaction(destination).addErrback(chain)
-
-        return deferred
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
 
-    @defer.inlineCallbacks
     def enqueue_failure(self, failure, destination):
         if destination == self.server_name or destination == "localhost":
             return
 
-        deferred = defer.Deferred()
-
         if not self.can_send_to(destination):
             return
 
         self.pending_failures_by_dest.setdefault(
             destination, []
-        ).append(
-            (failure, deferred)
-        )
-
-        def chain(f):
-            if not deferred.called:
-                deferred.errback(f)
-
-        def log_failure(f):
-            logger.warn("Failed to send failure to %s: %s", destination, f.value)
-
-        deferred.addErrback(log_failure)
-
-        with PreserveLoggingContext():
-            self._attempt_new_transaction(destination).addErrback(chain)
+        ).append(failure)
 
-        yield deferred
+        preserve_context_over_fn(
+            self._attempt_new_transaction, destination
+        )
 
     @defer.inlineCallbacks
-    @log_function
     def _attempt_new_transaction(self, destination):
         yield run_on_reactor()
+        while True:
+            # list of (pending_pdu, deferred, order)
+            if destination in self.pending_transactions:
+                # XXX: pending_transactions can get stuck on by a never-ending
+                # request at which point pending_pdus_by_dest just keeps growing.
+                # we need application-layer timeouts of some flavour of these
+                # requests
+                logger.debug(
+                    "TX [%s] Transaction already in progress",
+                    destination
+                )
+                return
 
-        # list of (pending_pdu, deferred, order)
-        if destination in self.pending_transactions:
-            # XXX: pending_transactions can get stuck on by a never-ending
-            # request at which point pending_pdus_by_dest just keeps growing.
-            # we need application-layer timeouts of some flavour of these
-            # requests
-            logger.debug(
-                "TX [%s] Transaction already in progress",
-                destination
-            )
-            return
-
-        pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
-        pending_edus = self.pending_edus_by_dest.pop(destination, [])
-        pending_failures = self.pending_failures_by_dest.pop(destination, [])
+            pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
+            pending_edus = self.pending_edus_by_dest.pop(destination, [])
+            pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
-        if pending_pdus:
-            logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                         destination, len(pending_pdus))
+            if pending_pdus:
+                logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                             destination, len(pending_pdus))
 
-        if not pending_pdus and not pending_edus and not pending_failures:
-            logger.debug("TX [%s] Nothing to send", destination)
-            return
+            if not pending_pdus and not pending_edus and not pending_failures:
+                logger.debug("TX [%s] Nothing to send", destination)
+                return
 
-        try:
-            self.pending_transactions[destination] = 1
+            yield self._send_new_transaction(
+                destination, pending_pdus, pending_edus, pending_failures
+            )
 
-            logger.debug("TX [%s] _attempt_new_transaction", destination)
+    @measure_func("_send_new_transaction")
+    @defer.inlineCallbacks
+    def _send_new_transaction(self, destination, pending_pdus, pending_edus,
+                              pending_failures):
 
             # Sort based on the order field
-            pending_pdus.sort(key=lambda t: t[2])
-
+            pending_pdus.sort(key=lambda t: t[1])
             pdus = [x[0] for x in pending_pdus]
-            edus = [x[0] for x in pending_edus]
-            failures = [x[0].get_dict() for x in pending_failures]
-            deferreds = [
-                x[1]
-                for x in pending_pdus + pending_edus + pending_failures
-            ]
-
-            txn_id = str(self._next_txn_id)
-
-            limiter = yield get_retry_limiter(
-                destination,
-                self._clock,
-                self.store,
-            )
+            edus = pending_edus
+            failures = [x.get_dict() for x in pending_failures]
 
-            logger.debug(
-                "TX [%s] {%s} Attempting new transaction"
-                " (pdus: %d, edus: %d, failures: %d)",
-                destination, txn_id,
-                len(pending_pdus),
-                len(pending_edus),
-                len(pending_failures)
-            )
+            try:
+                self.pending_transactions[destination] = 1
 
-            logger.debug("TX [%s] Persisting transaction...", destination)
+                logger.debug("TX [%s] _attempt_new_transaction", destination)
 
-            transaction = Transaction.create_new(
-                origin_server_ts=int(self._clock.time_msec()),
-                transaction_id=txn_id,
-                origin=self.server_name,
-                destination=destination,
-                pdus=pdus,
-                edus=edus,
-                pdu_failures=failures,
-            )
+                txn_id = str(self._next_txn_id)
 
-            self._next_txn_id += 1
+                limiter = yield get_retry_limiter(
+                    destination,
+                    self.clock,
+                    self.store,
+                )
 
-            yield self.transaction_actions.prepare_to_send(transaction)
+                logger.debug(
+                    "TX [%s] {%s} Attempting new transaction"
+                    " (pdus: %d, edus: %d, failures: %d)",
+                    destination, txn_id,
+                    len(pending_pdus),
+                    len(pending_edus),
+                    len(pending_failures)
+                )
 
-            logger.debug("TX [%s] Persisted transaction", destination)
-            logger.info(
-                "TX [%s] {%s} Sending transaction [%s],"
-                " (PDUs: %d, EDUs: %d, failures: %d)",
-                destination, txn_id,
-                transaction.transaction_id,
-                len(pending_pdus),
-                len(pending_edus),
-                len(pending_failures),
-            )
+                logger.debug("TX [%s] Persisting transaction...", destination)
 
-            with limiter:
-                # Actually send the transaction
-
-                # FIXME (erikj): This is a bit of a hack to make the Pdu age
-                # keys work
-                def json_data_cb():
-                    data = transaction.get_dict()
-                    now = int(self._clock.time_msec())
-                    if "pdus" in data:
-                        for p in data["pdus"]:
-                            if "age_ts" in p:
-                                unsigned = p.setdefault("unsigned", {})
-                                unsigned["age"] = now - int(p["age_ts"])
-                                del p["age_ts"]
-                    return data
-
-                try:
-                    response = yield self.transport_layer.send_transaction(
-                        transaction, json_data_cb
-                    )
-                    code = 200
-
-                    if response:
-                        for e_id, r in response.get("pdus", {}).items():
-                            if "error" in r:
-                                logger.warn(
-                                    "Transaction returned error for %s: %s",
-                                    e_id, r,
-                                )
-                except HttpResponseException as e:
-                    code = e.code
-                    response = e.response
+                transaction = Transaction.create_new(
+                    origin_server_ts=int(self.clock.time_msec()),
+                    transaction_id=txn_id,
+                    origin=self.server_name,
+                    destination=destination,
+                    pdus=pdus,
+                    edus=edus,
+                    pdu_failures=failures,
+                )
+
+                self._next_txn_id += 1
+
+                yield self.transaction_actions.prepare_to_send(transaction)
 
+                logger.debug("TX [%s] Persisted transaction", destination)
                 logger.info(
-                    "TX [%s] {%s} got %d response",
-                    destination, txn_id, code
+                    "TX [%s] {%s} Sending transaction [%s],"
+                    " (PDUs: %d, EDUs: %d, failures: %d)",
+                    destination, txn_id,
+                    transaction.transaction_id,
+                    len(pending_pdus),
+                    len(pending_edus),
+                    len(pending_failures),
                 )
 
-                logger.debug("TX [%s] Sent transaction", destination)
-                logger.debug("TX [%s] Marking as delivered...", destination)
+                with limiter:
+                    # Actually send the transaction
+
+                    # FIXME (erikj): This is a bit of a hack to make the Pdu age
+                    # keys work
+                    def json_data_cb():
+                        data = transaction.get_dict()
+                        now = int(self.clock.time_msec())
+                        if "pdus" in data:
+                            for p in data["pdus"]:
+                                if "age_ts" in p:
+                                    unsigned = p.setdefault("unsigned", {})
+                                    unsigned["age"] = now - int(p["age_ts"])
+                                    del p["age_ts"]
+                        return data
+
+                    try:
+                        response = yield self.transport_layer.send_transaction(
+                            transaction, json_data_cb
+                        )
+                        code = 200
+
+                        if response:
+                            for e_id, r in response.get("pdus", {}).items():
+                                if "error" in r:
+                                    logger.warn(
+                                        "Transaction returned error for %s: %s",
+                                        e_id, r,
+                                    )
+                    except HttpResponseException as e:
+                        code = e.code
+                        response = e.response
+
+                    logger.info(
+                        "TX [%s] {%s} got %d response",
+                        destination, txn_id, code
+                    )
 
-            yield self.transaction_actions.delivered(
-                transaction, code, response
-            )
+                    logger.debug("TX [%s] Sent transaction", destination)
+                    logger.debug("TX [%s] Marking as delivered...", destination)
 
-            logger.debug("TX [%s] Marked as delivered", destination)
-
-            logger.debug("TX [%s] Yielding to callbacks...", destination)
-
-            for deferred in deferreds:
-                if code == 200:
-                    deferred.callback(None)
-                else:
-                    deferred.errback(RuntimeError("Got status %d" % code))
-
-                # Ensures we don't continue until all callbacks on that
-                # deferred have fired
-                try:
-                    yield deferred
-                except:
-                    pass
-
-            logger.debug("TX [%s] Yielded to callbacks", destination)
-        except NotRetryingDestination:
-            logger.info(
-                "TX [%s] not ready for retry yet - "
-                "dropping transaction for now",
-                destination,
-            )
-        except RuntimeError as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
-        except Exception as e:
-            # We capture this here as there as nothing actually listens
-            # for this finishing functions deferred.
-            logger.warn(
-                "TX [%s] Problem in _attempt_transaction: %s",
-                destination,
-                e,
-            )
+                yield self.transaction_actions.delivered(
+                    transaction, code, response
+                )
 
-            for deferred in deferreds:
-                if not deferred.called:
-                    deferred.errback(e)
+                logger.debug("TX [%s] Marked as delivered", destination)
+
+                if code != 200:
+                    for p in pdus:
+                        logger.info(
+                            "Failed to send event %s to %s", p.event_id, destination
+                        )
+            except NotRetryingDestination:
+                logger.info(
+                    "TX [%s] not ready for retry yet - "
+                    "dropping transaction for now",
+                    destination,
+                )
+            except RuntimeError as e:
+                # We capture this here as there as nothing actually listens
+                # for this finishing functions deferred.
+                logger.warn(
+                    "TX [%s] Problem in _attempt_transaction: %s",
+                    destination,
+                    e,
+                )
+
+                for p in pdus:
+                    logger.info("Failed to send event %s to %s", p.event_id, destination)
+            except Exception as e:
+                # We capture this here as there as nothing actually listens
+                # for this finishing functions deferred.
+                logger.warn(
+                    "TX [%s] Problem in _attempt_transaction: %s",
+                    destination,
+                    e,
+                )
 
-        finally:
-            # We want to be *very* sure we delete this after we stop processing
-            self.pending_transactions.pop(destination, None)
+                for p in pdus:
+                    logger.info("Failed to send event %s to %s", p.event_id, destination)
 
-            # Check to see if there is anything else to send.
-            self._attempt_new_transaction(destination)
+            finally:
+                # We want to be *very* sure we delete this after we stop processing
+                self.pending_transactions.pop(destination, None)