summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_base.py5
-rw-r--r--synapse/federation/federation_client.py42
-rw-r--r--synapse/federation/federation_server.py46
-rw-r--r--synapse/federation/transaction_queue.py31
-rw-r--r--synapse/federation/transport/client.py6
-rw-r--r--synapse/federation/transport/server.py8
6 files changed, 83 insertions, 55 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 21a763214b..f0430b2cb1 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -24,6 +24,8 @@ from synapse.crypto.event_signing import check_event_content_hash
 
 from synapse.api.errors import SynapseError
 
+from synapse.util import unwrapFirstError
+
 import logging
 
 
@@ -78,6 +80,7 @@ class FederationBase(object):
                             destinations=[pdu.origin],
                             event_id=pdu.event_id,
                             outlier=outlier,
+                            timeout=10000,
                         )
 
                         if new_pdu:
@@ -94,7 +97,7 @@ class FederationBase(object):
         yield defer.gatherResults(
             [do(pdu) for pdu in pdus],
             consumeErrors=True
-        )
+        ).addErrback(unwrapFirstError)
 
         defer.returnValue(signed_pdus)
 
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 904c7c0945..d3b46b24c1 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -22,6 +22,7 @@ from .units import Edu
 from synapse.api.errors import (
     CodeMessageException, HttpResponseException, SynapseError,
 )
+from synapse.util import unwrapFirstError
 from synapse.util.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.events import FrozenEvent
@@ -164,16 +165,17 @@ class FederationClient(FederationBase):
             for p in transaction_data["pdus"]
         ]
 
-        for i, pdu in enumerate(pdus):
-            pdus[i] = yield self._check_sigs_and_hash(pdu)
-
-            # FIXME: We should handle signature failures more gracefully.
+        # FIXME: We should handle signature failures more gracefully.
+        pdus[:] = yield defer.gatherResults(
+            [self._check_sigs_and_hash(pdu) for pdu in pdus],
+            consumeErrors=True,
+        ).addErrback(unwrapFirstError)
 
         defer.returnValue(pdus)
 
     @defer.inlineCallbacks
     @log_function
-    def get_pdu(self, destinations, event_id, outlier=False):
+    def get_pdu(self, destinations, event_id, outlier=False, timeout=None):
         """Requests the PDU with given origin and ID from the remote home
         servers.
 
@@ -189,6 +191,8 @@ class FederationClient(FederationBase):
             outlier (bool): Indicates whether the PDU is an `outlier`, i.e. if
                 it's from an arbitary point in the context as opposed to part
                 of the current block of PDUs. Defaults to `False`
+            timeout (int): How long to try (in ms) each destination for before
+                moving to the next destination. None indicates no timeout.
 
         Returns:
             Deferred: Results in the requested PDU.
@@ -212,7 +216,7 @@ class FederationClient(FederationBase):
 
                 with limiter:
                     transaction_data = yield self.transport_layer.get_event(
-                        destination, event_id
+                        destination, event_id, timeout=timeout,
                     )
 
                     logger.debug("transaction_data %r", transaction_data)
@@ -222,7 +226,7 @@ class FederationClient(FederationBase):
                         for p in transaction_data["pdus"]
                     ]
 
-                    if pdu_list:
+                    if pdu_list and pdu_list[0]:
                         pdu = pdu_list[0]
 
                         # Check signatures are correct.
@@ -255,7 +259,7 @@ class FederationClient(FederationBase):
                 )
                 continue
 
-        if self._get_pdu_cache is not None:
+        if self._get_pdu_cache is not None and pdu:
             self._get_pdu_cache[event_id] = pdu
 
         defer.returnValue(pdu)
@@ -370,13 +374,17 @@ class FederationClient(FederationBase):
                     for p in content.get("auth_chain", [])
                 ]
 
-                signed_state = yield self._check_sigs_and_hash_and_fetch(
-                    destination, state, outlier=True
-                )
-
-                signed_auth = yield self._check_sigs_and_hash_and_fetch(
-                    destination, auth_chain, outlier=True
-                )
+                signed_state, signed_auth = yield defer.gatherResults(
+                    [
+                        self._check_sigs_and_hash_and_fetch(
+                            destination, state, outlier=True
+                        ),
+                        self._check_sigs_and_hash_and_fetch(
+                            destination, auth_chain, outlier=True
+                        )
+                    ],
+                    consumeErrors=True
+                ).addErrback(unwrapFirstError)
 
                 auth_chain.sort(key=lambda e: e.depth)
 
@@ -518,7 +526,7 @@ class FederationClient(FederationBase):
             # Are we missing any?
 
             seen_events = set(earliest_events_ids)
-            seen_events.update(e.event_id for e in signed_events)
+            seen_events.update(e.event_id for e in signed_events if e)
 
             missing_events = {}
             for e in itertools.chain(latest_events, signed_events):
@@ -561,7 +569,7 @@ class FederationClient(FederationBase):
 
             res = yield defer.DeferredList(deferreds, consumeErrors=True)
             for (result, val), (e_id, _) in zip(res, ordered_missing):
-                if result:
+                if result and val:
                     signed_events.append(val)
                 else:
                     failed_to_fetch.add(e_id)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 2b46188c91..cd79e23f4b 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -20,7 +20,6 @@ from .federation_base import FederationBase
 from .units import Transaction, Edu
 
 from synapse.util.logutils import log_function
-from synapse.util.logcontext import PreserveLoggingContext
 from synapse.events import FrozenEvent
 import synapse.metrics
 
@@ -123,29 +122,28 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
-        with PreserveLoggingContext():
-            results = []
-
-            for pdu in pdu_list:
-                d = self._handle_new_pdu(transaction.origin, pdu)
-
-                try:
-                    yield d
-                    results.append({})
-                except FederationError as e:
-                    self.send_failure(e, transaction.origin)
-                    results.append({"error": str(e)})
-                except Exception as e:
-                    results.append({"error": str(e)})
-                    logger.exception("Failed to handle PDU")
-
-            if hasattr(transaction, "edus"):
-                for edu in [Edu(**x) for x in transaction.edus]:
-                    self.received_edu(
-                        transaction.origin,
-                        edu.edu_type,
-                        edu.content
-                    )
+        results = []
+
+        for pdu in pdu_list:
+            d = self._handle_new_pdu(transaction.origin, pdu)
+
+            try:
+                yield d
+                results.append({})
+            except FederationError as e:
+                self.send_failure(e, transaction.origin)
+                results.append({"error": str(e)})
+            except Exception as e:
+                results.append({"error": str(e)})
+                logger.exception("Failed to handle PDU")
+
+        if hasattr(transaction, "edus"):
+            for edu in [Edu(**x) for x in transaction.edus]:
+                self.received_edu(
+                    transaction.origin,
+                    edu.edu_type,
+                    edu.content
+                )
 
             for failure in getattr(transaction, "pdu_failures", []):
                 logger.info("Got failure %r", failure)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index ca04822fb3..32fa5e8c15 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -207,13 +207,13 @@ class TransactionQueue(object):
             # request at which point pending_pdus_by_dest just keeps growing.
             # we need application-layer timeouts of some flavour of these
             # requests
-            logger.info(
+            logger.debug(
                 "TX [%s] Transaction already in progress",
                 destination
             )
             return
 
-        logger.info("TX [%s] _attempt_new_transaction", destination)
+        logger.debug("TX [%s] _attempt_new_transaction", destination)
 
         # list of (pending_pdu, deferred, order)
         pending_pdus = self.pending_pdus_by_dest.pop(destination, [])
@@ -221,11 +221,11 @@ class TransactionQueue(object):
         pending_failures = self.pending_failures_by_dest.pop(destination, [])
 
         if pending_pdus:
-            logger.info("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
-                        destination, len(pending_pdus))
+            logger.debug("TX [%s] len(pending_pdus_by_dest[dest]) = %d",
+                         destination, len(pending_pdus))
 
         if not pending_pdus and not pending_edus and not pending_failures:
-            logger.info("TX [%s] Nothing to send", destination)
+            logger.debug("TX [%s] Nothing to send", destination)
             return
 
         # Sort based on the order field
@@ -242,6 +242,8 @@ class TransactionQueue(object):
         try:
             self.pending_transactions[destination] = 1
 
+            txn_id = str(self._next_txn_id)
+
             limiter = yield get_retry_limiter(
                 destination,
                 self._clock,
@@ -249,9 +251,9 @@ class TransactionQueue(object):
             )
 
             logger.debug(
-                "TX [%s] Attempting new transaction"
+                "TX [%s] {%s} Attempting new transaction"
                 " (pdus: %d, edus: %d, failures: %d)",
-                destination,
+                destination, txn_id,
                 len(pending_pdus),
                 len(pending_edus),
                 len(pending_failures)
@@ -261,7 +263,7 @@ class TransactionQueue(object):
 
             transaction = Transaction.create_new(
                 origin_server_ts=int(self._clock.time_msec()),
-                transaction_id=str(self._next_txn_id),
+                transaction_id=txn_id,
                 origin=self.server_name,
                 destination=destination,
                 pdus=pdus,
@@ -275,9 +277,13 @@ class TransactionQueue(object):
 
             logger.debug("TX [%s] Persisted transaction", destination)
             logger.info(
-                "TX [%s] Sending transaction [%s]",
-                destination,
+                "TX [%s] {%s} Sending transaction [%s],"
+                " (PDUs: %d, EDUs: %d, failures: %d)",
+                destination, txn_id,
                 transaction.transaction_id,
+                len(pending_pdus),
+                len(pending_edus),
+                len(pending_failures),
             )
 
             with limiter:
@@ -313,7 +319,10 @@ class TransactionQueue(object):
                     code = e.code
                     response = e.response
 
-                logger.info("TX [%s] got %d response", destination, code)
+                logger.info(
+                    "TX [%s] {%s} got %d response",
+                    destination, txn_id, code
+                )
 
                 logger.debug("TX [%s] Sent transaction", destination)
                 logger.debug("TX [%s] Marking as delivered...", destination)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 80d03012b7..610a4c3163 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -50,13 +50,15 @@ class TransportLayerClient(object):
         )
 
     @log_function
-    def get_event(self, destination, event_id):
+    def get_event(self, destination, event_id, timeout=None):
         """ Requests the pdu with give id and origin from the given server.
 
         Args:
             destination (str): The host name of the remote home server we want
                 to get the state from.
             event_id (str): The id of the event being requested.
+            timeout (int): How long to try (in ms) the destination for before
+                giving up. None indicates no timeout.
 
         Returns:
             Deferred: Results in a dict received from the remote homeserver.
@@ -65,7 +67,7 @@ class TransportLayerClient(object):
                      destination, event_id)
 
         path = PREFIX + "/event/%s/" % (event_id, )
-        return self.client.get_json(destination, path=path)
+        return self.client.get_json(destination, path=path, timeout=timeout)
 
     @log_function
     def backfill(self, destination, room_id, event_tuples, limit):
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 2bfe0f3c9b..af87805f34 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -196,6 +196,14 @@ class FederationSendServlet(BaseFederationServlet):
                 transaction_id, str(transaction_data)
             )
 
+            logger.info(
+                "Received txn %s from %s. (PDUs: %d, EDUs: %d, failures: %d)",
+                transaction_id, origin,
+                len(transaction_data.get("pdus", [])),
+                len(transaction_data.get("edus", [])),
+                len(transaction_data.get("failures", [])),
+            )
+
             # We should ideally be getting this from the security layer.
             # origin = body["origin"]