summary refs log tree commit diff
path: root/synapse/federation
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation')
-rw-r--r--synapse/federation/federation_base.py139
-rw-r--r--synapse/federation/federation_client.py8
-rw-r--r--synapse/federation/federation_server.py162
-rw-r--r--synapse/federation/transaction_queue.py48
-rw-r--r--synapse/federation/transport/server.py24
5 files changed, 231 insertions, 150 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 2339cc9034..a0f5d40eb3 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -12,28 +12,20 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
-from twisted.internet import defer
-
-from synapse.events.utils import prune_event
-
-from synapse.crypto.event_signing import check_event_content_hash
-
-from synapse.api.errors import SynapseError
-
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
-
 import logging
 
+from synapse.api.errors import SynapseError
+from synapse.crypto.event_signing import check_event_content_hash
+from synapse.events.utils import prune_event
+from synapse.util import unwrapFirstError, logcontext
+from twisted.internet import defer
 
 logger = logging.getLogger(__name__)
 
 
 class FederationBase(object):
     def __init__(self, hs):
-        pass
+        self.spam_checker = hs.get_spam_checker()
 
     @defer.inlineCallbacks
     def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
@@ -57,56 +49,52 @@ class FederationBase(object):
         """
         deferreds = self._check_sigs_and_hashes(pdus)
 
-        def callback(pdu):
-            return pdu
+        @defer.inlineCallbacks
+        def handle_check_result(pdu, deferred):
+            try:
+                res = yield logcontext.make_deferred_yieldable(deferred)
+            except SynapseError:
+                res = None
 
-        def errback(failure, pdu):
-            failure.trap(SynapseError)
-            return None
-
-        def try_local_db(res, pdu):
             if not res:
                 # Check local db.
-                return self.store.get_event(
+                res = yield self.store.get_event(
                     pdu.event_id,
                     allow_rejected=True,
                     allow_none=True,
                 )
-            return res
 
-        def try_remote(res, pdu):
             if not res and pdu.origin != origin:
-                return self.get_pdu(
-                    destinations=[pdu.origin],
-                    event_id=pdu.event_id,
-                    outlier=outlier,
-                    timeout=10000,
-                ).addErrback(lambda e: None)
-            return res
-
-        def warn(res, pdu):
+                try:
+                    res = yield self.get_pdu(
+                        destinations=[pdu.origin],
+                        event_id=pdu.event_id,
+                        outlier=outlier,
+                        timeout=10000,
+                    )
+                except SynapseError:
+                    pass
+
             if not res:
                 logger.warn(
                     "Failed to find copy of %s with valid signature",
                     pdu.event_id,
                 )
-            return res
 
-        for pdu, deferred in zip(pdus, deferreds):
-            deferred.addCallbacks(
-                callback, errback, errbackArgs=[pdu]
-            ).addCallback(
-                try_local_db, pdu
-            ).addCallback(
-                try_remote, pdu
-            ).addCallback(
-                warn, pdu
-            )
+            defer.returnValue(res)
 
-        valid_pdus = yield preserve_context_over_deferred(defer.gatherResults(
-            deferreds,
-            consumeErrors=True
-        )).addErrback(unwrapFirstError)
+        handle = logcontext.preserve_fn(handle_check_result)
+        deferreds2 = [
+            handle(pdu, deferred)
+            for pdu, deferred in zip(pdus, deferreds)
+        ]
+
+        valid_pdus = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                deferreds2,
+                consumeErrors=True,
+            )
+        ).addErrback(unwrapFirstError)
 
         if include_none:
             defer.returnValue(valid_pdus)
@@ -114,15 +102,24 @@ class FederationBase(object):
             defer.returnValue([p for p in valid_pdus if p])
 
     def _check_sigs_and_hash(self, pdu):
-        return self._check_sigs_and_hashes([pdu])[0]
+        return logcontext.make_deferred_yieldable(
+            self._check_sigs_and_hashes([pdu])[0],
+        )
 
     def _check_sigs_and_hashes(self, pdus):
-        """Throws a SynapseError if a PDU does not have the correct
-        signatures.
+        """Checks that each of the received events is correctly signed by the
+        sending server.
+
+        Args:
+            pdus (list[FrozenEvent]): the events to be checked
 
         Returns:
-            FrozenEvent: Either the given event or it redacted if it failed the
-            content hash check.
+            list[Deferred]: for each input event, a deferred which:
+              * returns the original event if the checks pass
+              * returns a redacted version of the event (if the signature
+                matched but the hash did not)
+              * throws a SynapseError if the signature check failed.
+            The deferreds run their callbacks in the sentinel logcontext.
         """
 
         redacted_pdus = [
@@ -130,26 +127,38 @@ class FederationBase(object):
             for pdu in pdus
         ]
 
-        deferreds = preserve_fn(self.keyring.verify_json_objects_for_server)([
+        deferreds = self.keyring.verify_json_objects_for_server([
             (p.origin, p.get_pdu_json())
             for p in redacted_pdus
         ])
 
+        ctx = logcontext.LoggingContext.current_context()
+
         def callback(_, pdu, redacted):
-            if not check_event_content_hash(pdu):
-                logger.warn(
-                    "Event content has been tampered, redacting %s: %s",
-                    pdu.event_id, pdu.get_pdu_json()
-                )
-                return redacted
-            return pdu
+            with logcontext.PreserveLoggingContext(ctx):
+                if not check_event_content_hash(pdu):
+                    logger.warn(
+                        "Event content has been tampered, redacting %s: %s",
+                        pdu.event_id, pdu.get_pdu_json()
+                    )
+                    return redacted
+
+                if self.spam_checker.check_event_for_spam(pdu):
+                    logger.warn(
+                        "Event contains spam, redacting %s: %s",
+                        pdu.event_id, pdu.get_pdu_json()
+                    )
+                    return redacted
+
+                return pdu
 
         def errback(failure, pdu):
             failure.trap(SynapseError)
-            logger.warn(
-                "Signature check failed for %s",
-                pdu.event_id,
-            )
+            with logcontext.PreserveLoggingContext(ctx):
+                logger.warn(
+                    "Signature check failed for %s",
+                    pdu.event_id,
+                )
             return failure
 
         for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 861441708b..7c5e5d957f 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -22,7 +22,7 @@ from synapse.api.constants import Membership
 from synapse.api.errors import (
     CodeMessageException, HttpResponseException, SynapseError,
 )
-from synapse.util import unwrapFirstError
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.caches.expiringcache import ExpiringCache
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
@@ -189,10 +189,10 @@ class FederationClient(FederationBase):
         ]
 
         # FIXME: We should handle signature failures more gracefully.
-        pdus[:] = yield preserve_context_over_deferred(defer.gatherResults(
+        pdus[:] = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             self._check_sigs_and_hashes(pdus),
             consumeErrors=True,
-        )).addErrback(unwrapFirstError)
+        ).addErrback(unwrapFirstError))
 
         defer.returnValue(pdus)
 
@@ -252,7 +252,7 @@ class FederationClient(FederationBase):
                     pdu = pdu_list[0]
 
                     # Check signatures are correct.
-                    signed_pdu = yield self._check_sigs_and_hashes([pdu])[0]
+                    signed_pdu = yield self._check_sigs_and_hash(pdu)
 
                     break
 
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 51e3fdea06..f00d59e701 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -12,14 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
 from twisted.internet import defer
 
 from .federation_base import FederationBase
 from .units import Transaction, Edu
 
-from synapse.util.async import Linearizer
+from synapse.util import async
 from synapse.util.logutils import log_function
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.events import FrozenEvent
@@ -33,6 +31,9 @@ from synapse.crypto.event_signing import compute_event_signature
 import simplejson as json
 import logging
 
+# when processing incoming transactions, we try to handle multiple rooms in
+# parallel, up to this limit.
+TRANSACTION_CONCURRENCY_LIMIT = 10
 
 logger = logging.getLogger(__name__)
 
@@ -52,7 +53,8 @@ class FederationServer(FederationBase):
 
         self.auth = hs.get_auth()
 
-        self._server_linearizer = Linearizer("fed_server")
+        self._server_linearizer = async.Linearizer("fed_server")
+        self._transaction_linearizer = async.Linearizer("fed_txn_handler")
 
         # We cache responses to state queries, as they take a while and often
         # come in waves.
@@ -109,25 +111,41 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     @log_function
     def on_incoming_transaction(self, transaction_data):
+        # keep this as early as possible to make the calculated origin ts as
+        # accurate as possible.
+        request_time = self._clock.time_msec()
+
         transaction = Transaction(**transaction_data)
 
-        received_pdus_counter.inc_by(len(transaction.pdus))
+        if not transaction.transaction_id:
+            raise Exception("Transaction missing transaction_id")
+        if not transaction.origin:
+            raise Exception("Transaction missing origin")
 
-        for p in transaction.pdus:
-            if "unsigned" in p:
-                unsigned = p["unsigned"]
-                if "age" in unsigned:
-                    p["age"] = unsigned["age"]
-            if "age" in p:
-                p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
-                del p["age"]
+        logger.debug("[%s] Got transaction", transaction.transaction_id)
 
-        pdu_list = [
-            self.event_from_pdu_json(p) for p in transaction.pdus
-        ]
+        # use a linearizer to ensure that we don't process the same transaction
+        # multiple times in parallel.
+        with (yield self._transaction_linearizer.queue(
+                (transaction.origin, transaction.transaction_id),
+        )):
+            result = yield self._handle_incoming_transaction(
+                transaction, request_time,
+            )
 
-        logger.debug("[%s] Got transaction", transaction.transaction_id)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def _handle_incoming_transaction(self, transaction, request_time):
+        """ Process an incoming transaction and return the HTTP response
+
+        Args:
+            transaction (Transaction): incoming transaction
+            request_time (int): timestamp that the HTTP request arrived at
 
+        Returns:
+            Deferred[(int, object)]: http response code and body
+        """
         response = yield self.transaction_actions.have_responded(transaction)
 
         if response:
@@ -140,42 +158,50 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
-        results = []
-
-        for pdu in pdu_list:
-            # check that it's actually being sent from a valid destination to
-            # workaround bug #1753 in 0.18.5 and 0.18.6
-            if transaction.origin != get_domain_from_id(pdu.event_id):
-                # We continue to accept join events from any server; this is
-                # necessary for the federation join dance to work correctly.
-                # (When we join over federation, the "helper" server is
-                # responsible for sending out the join event, rather than the
-                # origin. See bug #1893).
-                if not (
-                    pdu.type == 'm.room.member' and
-                    pdu.content and
-                    pdu.content.get("membership", None) == 'join'
-                ):
-                    logger.info(
-                        "Discarding PDU %s from invalid origin %s",
-                        pdu.event_id, transaction.origin
-                    )
-                    continue
-                else:
-                    logger.info(
-                        "Accepting join PDU %s from %s",
-                        pdu.event_id, transaction.origin
-                    )
+        received_pdus_counter.inc_by(len(transaction.pdus))
 
-            try:
-                yield self._handle_received_pdu(transaction.origin, pdu)
-                results.append({})
-            except FederationError as e:
-                self.send_failure(e, transaction.origin)
-                results.append({"error": str(e)})
-            except Exception as e:
-                results.append({"error": str(e)})
-                logger.exception("Failed to handle PDU")
+        pdus_by_room = {}
+
+        for p in transaction.pdus:
+            if "unsigned" in p:
+                unsigned = p["unsigned"]
+                if "age" in unsigned:
+                    p["age"] = unsigned["age"]
+            if "age" in p:
+                p["age_ts"] = request_time - int(p["age"])
+                del p["age"]
+
+            event = self.event_from_pdu_json(p)
+            room_id = event.room_id
+            pdus_by_room.setdefault(room_id, []).append(event)
+
+        pdu_results = {}
+
+        # we can process different rooms in parallel (which is useful if they
+        # require callouts to other servers to fetch missing events), but
+        # impose a limit to avoid going too crazy with ram/cpu.
+        @defer.inlineCallbacks
+        def process_pdus_for_room(room_id):
+            logger.debug("Processing PDUs for %s", room_id)
+            for pdu in pdus_by_room[room_id]:
+                event_id = pdu.event_id
+                try:
+                    yield self._handle_received_pdu(
+                        transaction.origin, pdu
+                    )
+                    pdu_results[event_id] = {}
+                except FederationError as e:
+                    logger.warn("Error handling PDU %s: %s", event_id, e)
+                    self.send_failure(e, transaction.origin)
+                    pdu_results[event_id] = {"error": str(e)}
+                except Exception as e:
+                    pdu_results[event_id] = {"error": str(e)}
+                    logger.exception("Failed to handle PDU %s", event_id)
+
+        yield async.concurrently_execute(
+            process_pdus_for_room, pdus_by_room.keys(),
+            TRANSACTION_CONCURRENCY_LIMIT,
+        )
 
         if hasattr(transaction, "edus"):
             for edu in (Edu(**x) for x in transaction.edus):
@@ -188,14 +214,12 @@ class FederationServer(FederationBase):
             for failure in getattr(transaction, "pdu_failures", []):
                 logger.info("Got failure %r", failure)
 
-        logger.debug("Returning: %s", str(results))
-
         response = {
-            "pdus": dict(zip(
-                (p.event_id for p in pdu_list), results
-            )),
+            "pdus": pdu_results,
         }
 
+        logger.debug("Returning: %s", str(response))
+
         yield self.transaction_actions.set_response(
             transaction,
             200, response
@@ -520,6 +544,30 @@ class FederationServer(FederationBase):
         Returns (Deferred): completes with None
         Raises: FederationError if the signatures / hash do not match
     """
+        # check that it's actually being sent from a valid destination to
+        # workaround bug #1753 in 0.18.5 and 0.18.6
+        if origin != get_domain_from_id(pdu.event_id):
+            # We continue to accept join events from any server; this is
+            # necessary for the federation join dance to work correctly.
+            # (When we join over federation, the "helper" server is
+            # responsible for sending out the join event, rather than the
+            # origin. See bug #1893).
+            if not (
+                pdu.type == 'm.room.member' and
+                pdu.content and
+                pdu.content.get("membership", None) == 'join'
+            ):
+                logger.info(
+                    "Discarding PDU %s from invalid origin %s",
+                    pdu.event_id, origin
+                )
+                return
+            else:
+                logger.info(
+                    "Accepting join PDU %s from %s",
+                    pdu.event_id, origin
+                )
+
         # Check signature.
         try:
             pdu = yield self._check_sigs_and_hash(pdu)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 003eaba893..7a3c9cbb70 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -20,8 +20,8 @@ from .persistence import TransactionActions
 from .units import Transaction, Edu
 
 from synapse.api.errors import HttpResponseException
+from synapse.util import logcontext
 from synapse.util.async import run_on_reactor
-from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 from synapse.util.metrics import measure_func
 from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
@@ -231,11 +231,9 @@ class TransactionQueue(object):
                 (pdu, order)
             )
 
-            preserve_context_over_fn(
-                self._attempt_new_transaction, destination
-            )
+            self._attempt_new_transaction(destination)
 
-    @preserve_fn  # the caller should not yield on this
+    @logcontext.preserve_fn  # the caller should not yield on this
     @defer.inlineCallbacks
     def send_presence(self, states):
         """Send the new presence states to the appropriate destinations.
@@ -299,7 +297,7 @@ class TransactionQueue(object):
                     state.user_id: state for state in states
                 })
 
-                preserve_fn(self._attempt_new_transaction)(destination)
+                self._attempt_new_transaction(destination)
 
     def send_edu(self, destination, edu_type, content, key=None):
         edu = Edu(
@@ -321,9 +319,7 @@ class TransactionQueue(object):
         else:
             self.pending_edus_by_dest.setdefault(destination, []).append(edu)
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def send_failure(self, failure, destination):
         if destination == self.server_name or destination == "localhost":
@@ -336,9 +332,7 @@ class TransactionQueue(object):
             destination, []
         ).append(failure)
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def send_device_messages(self, destination):
         if destination == self.server_name or destination == "localhost":
@@ -347,15 +341,24 @@ class TransactionQueue(object):
         if not self.can_send_to(destination):
             return
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def get_current_token(self):
         return 0
 
-    @defer.inlineCallbacks
     def _attempt_new_transaction(self, destination):
+        """Try to start a new transaction to this destination
+
+        If there is already a transaction in progress to this destination,
+        returns immediately. Otherwise kicks off the process of sending a
+        transaction in the background.
+
+        Args:
+            destination (str):
+
+        Returns:
+            None
+        """
         # list of (pending_pdu, deferred, order)
         if destination in self.pending_transactions:
             # XXX: pending_transactions can get stuck on by a never-ending
@@ -368,6 +371,19 @@ class TransactionQueue(object):
             )
             return
 
+        logger.debug("TX [%s] Starting transaction loop", destination)
+
+        # Drop the logcontext before starting the transaction. It doesn't
+        # really make sense to log all the outbound transactions against
+        # whatever path led us to this point: that's pretty arbitrary really.
+        #
+        # (this also means we can fire off _perform_transaction without
+        # yielding)
+        with logcontext.PreserveLoggingContext():
+            self._transaction_transmission_loop(destination)
+
+    @defer.inlineCallbacks
+    def _transaction_transmission_loop(self, destination):
         pending_pdus = []
         try:
             self.pending_transactions[destination] = 1
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 3d676e7d8b..a78f01e442 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -153,12 +153,10 @@ class Authenticator(object):
 class BaseFederationServlet(object):
     REQUIRE_AUTH = True
 
-    def __init__(self, handler, authenticator, ratelimiter, server_name,
-                 room_list_handler):
+    def __init__(self, handler, authenticator, ratelimiter, server_name):
         self.handler = handler
         self.authenticator = authenticator
         self.ratelimiter = ratelimiter
-        self.room_list_handler = room_list_handler
 
     def _wrap(self, func):
         authenticator = self.authenticator
@@ -590,7 +588,7 @@ class PublicRoomList(BaseFederationServlet):
         else:
             network_tuple = ThirdPartyInstanceID(None, None)
 
-        data = yield self.room_list_handler.get_local_public_room_list(
+        data = yield self.handler.get_local_public_room_list(
             limit, since_token,
             network_tuple=network_tuple
         )
@@ -611,7 +609,7 @@ class FederationVersionServlet(BaseFederationServlet):
         }))
 
 
-SERVLET_CLASSES = (
+FEDERATION_SERVLET_CLASSES = (
     FederationSendServlet,
     FederationPullServlet,
     FederationEventServlet,
@@ -634,17 +632,27 @@ SERVLET_CLASSES = (
     FederationThirdPartyInviteExchangeServlet,
     On3pidBindServlet,
     OpenIdUserInfo,
-    PublicRoomList,
     FederationVersionServlet,
 )
 
+ROOM_LIST_CLASSES = (
+    PublicRoomList,
+)
+
 
 def register_servlets(hs, resource, authenticator, ratelimiter):
-    for servletclass in SERVLET_CLASSES:
+    for servletclass in FEDERATION_SERVLET_CLASSES:
         servletclass(
             handler=hs.get_replication_layer(),
             authenticator=authenticator,
             ratelimiter=ratelimiter,
             server_name=hs.hostname,
-            room_list_handler=hs.get_room_list_handler(),
+        ).register(resource)
+
+    for servletclass in ROOM_LIST_CLASSES:
+        servletclass(
+            handler=hs.get_room_list_handler(),
+            authenticator=authenticator,
+            ratelimiter=ratelimiter,
+            server_name=hs.hostname,
         ).register(resource)