summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/events/spamcheck.py4
-rw-r--r--synapse/federation/federation_server.py162
-rw-r--r--synapse/federation/transaction_queue.py48
-rw-r--r--synapse/handlers/federation.py57
-rw-r--r--synapse/handlers/room_member.py2
-rw-r--r--synapse/state.py13
-rw-r--r--synapse/util/async.py4
7 files changed, 183 insertions, 107 deletions
diff --git a/synapse/events/spamcheck.py b/synapse/events/spamcheck.py
index 595b1760f8..dccc579eac 100644
--- a/synapse/events/spamcheck.py
+++ b/synapse/events/spamcheck.py
@@ -46,7 +46,7 @@ class SpamChecker(object):
 
         return self.spam_checker.check_event_for_spam(event)
 
-    def user_may_invite(self, userid, room_id):
+    def user_may_invite(self, inviter_userid, invitee_userid, room_id):
         """Checks if a given user may send an invite
 
         If this method returns false, the invite will be rejected.
@@ -60,7 +60,7 @@ class SpamChecker(object):
         if self.spam_checker is None:
             return True
 
-        return self.spam_checker.user_may_invite(userid, room_id)
+        return self.spam_checker.user_may_invite(inviter_userid, invitee_userid, room_id)
 
     def user_may_create_room(self, userid):
         """Checks if a given user may create a room
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 51e3fdea06..f00d59e701 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -12,14 +12,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-
 from twisted.internet import defer
 
 from .federation_base import FederationBase
 from .units import Transaction, Edu
 
-from synapse.util.async import Linearizer
+from synapse.util import async
 from synapse.util.logutils import log_function
 from synapse.util.caches.response_cache import ResponseCache
 from synapse.events import FrozenEvent
@@ -33,6 +31,9 @@ from synapse.crypto.event_signing import compute_event_signature
 import simplejson as json
 import logging
 
+# when processing incoming transactions, we try to handle multiple rooms in
+# parallel, up to this limit.
+TRANSACTION_CONCURRENCY_LIMIT = 10
 
 logger = logging.getLogger(__name__)
 
@@ -52,7 +53,8 @@ class FederationServer(FederationBase):
 
         self.auth = hs.get_auth()
 
-        self._server_linearizer = Linearizer("fed_server")
+        self._server_linearizer = async.Linearizer("fed_server")
+        self._transaction_linearizer = async.Linearizer("fed_txn_handler")
 
         # We cache responses to state queries, as they take a while and often
         # come in waves.
@@ -109,25 +111,41 @@ class FederationServer(FederationBase):
     @defer.inlineCallbacks
     @log_function
     def on_incoming_transaction(self, transaction_data):
+        # keep this as early as possible to make the calculated origin ts as
+        # accurate as possible.
+        request_time = self._clock.time_msec()
+
         transaction = Transaction(**transaction_data)
 
-        received_pdus_counter.inc_by(len(transaction.pdus))
+        if not transaction.transaction_id:
+            raise Exception("Transaction missing transaction_id")
+        if not transaction.origin:
+            raise Exception("Transaction missing origin")
 
-        for p in transaction.pdus:
-            if "unsigned" in p:
-                unsigned = p["unsigned"]
-                if "age" in unsigned:
-                    p["age"] = unsigned["age"]
-            if "age" in p:
-                p["age_ts"] = int(self._clock.time_msec()) - int(p["age"])
-                del p["age"]
+        logger.debug("[%s] Got transaction", transaction.transaction_id)
 
-        pdu_list = [
-            self.event_from_pdu_json(p) for p in transaction.pdus
-        ]
+        # use a linearizer to ensure that we don't process the same transaction
+        # multiple times in parallel.
+        with (yield self._transaction_linearizer.queue(
+                (transaction.origin, transaction.transaction_id),
+        )):
+            result = yield self._handle_incoming_transaction(
+                transaction, request_time,
+            )
 
-        logger.debug("[%s] Got transaction", transaction.transaction_id)
+        defer.returnValue(result)
+
+    @defer.inlineCallbacks
+    def _handle_incoming_transaction(self, transaction, request_time):
+        """ Process an incoming transaction and return the HTTP response
+
+        Args:
+            transaction (Transaction): incoming transaction
+            request_time (int): timestamp that the HTTP request arrived at
 
+        Returns:
+            Deferred[(int, object)]: http response code and body
+        """
         response = yield self.transaction_actions.have_responded(transaction)
 
         if response:
@@ -140,42 +158,50 @@ class FederationServer(FederationBase):
 
         logger.debug("[%s] Transaction is new", transaction.transaction_id)
 
-        results = []
-
-        for pdu in pdu_list:
-            # check that it's actually being sent from a valid destination to
-            # workaround bug #1753 in 0.18.5 and 0.18.6
-            if transaction.origin != get_domain_from_id(pdu.event_id):
-                # We continue to accept join events from any server; this is
-                # necessary for the federation join dance to work correctly.
-                # (When we join over federation, the "helper" server is
-                # responsible for sending out the join event, rather than the
-                # origin. See bug #1893).
-                if not (
-                    pdu.type == 'm.room.member' and
-                    pdu.content and
-                    pdu.content.get("membership", None) == 'join'
-                ):
-                    logger.info(
-                        "Discarding PDU %s from invalid origin %s",
-                        pdu.event_id, transaction.origin
-                    )
-                    continue
-                else:
-                    logger.info(
-                        "Accepting join PDU %s from %s",
-                        pdu.event_id, transaction.origin
-                    )
+        received_pdus_counter.inc_by(len(transaction.pdus))
 
-            try:
-                yield self._handle_received_pdu(transaction.origin, pdu)
-                results.append({})
-            except FederationError as e:
-                self.send_failure(e, transaction.origin)
-                results.append({"error": str(e)})
-            except Exception as e:
-                results.append({"error": str(e)})
-                logger.exception("Failed to handle PDU")
+        pdus_by_room = {}
+
+        for p in transaction.pdus:
+            if "unsigned" in p:
+                unsigned = p["unsigned"]
+                if "age" in unsigned:
+                    p["age"] = unsigned["age"]
+            if "age" in p:
+                p["age_ts"] = request_time - int(p["age"])
+                del p["age"]
+
+            event = self.event_from_pdu_json(p)
+            room_id = event.room_id
+            pdus_by_room.setdefault(room_id, []).append(event)
+
+        pdu_results = {}
+
+        # we can process different rooms in parallel (which is useful if they
+        # require callouts to other servers to fetch missing events), but
+        # impose a limit to avoid going too crazy with ram/cpu.
+        @defer.inlineCallbacks
+        def process_pdus_for_room(room_id):
+            logger.debug("Processing PDUs for %s", room_id)
+            for pdu in pdus_by_room[room_id]:
+                event_id = pdu.event_id
+                try:
+                    yield self._handle_received_pdu(
+                        transaction.origin, pdu
+                    )
+                    pdu_results[event_id] = {}
+                except FederationError as e:
+                    logger.warn("Error handling PDU %s: %s", event_id, e)
+                    self.send_failure(e, transaction.origin)
+                    pdu_results[event_id] = {"error": str(e)}
+                except Exception as e:
+                    pdu_results[event_id] = {"error": str(e)}
+                    logger.exception("Failed to handle PDU %s", event_id)
+
+        yield async.concurrently_execute(
+            process_pdus_for_room, pdus_by_room.keys(),
+            TRANSACTION_CONCURRENCY_LIMIT,
+        )
 
         if hasattr(transaction, "edus"):
             for edu in (Edu(**x) for x in transaction.edus):
@@ -188,14 +214,12 @@ class FederationServer(FederationBase):
             for failure in getattr(transaction, "pdu_failures", []):
                 logger.info("Got failure %r", failure)
 
-        logger.debug("Returning: %s", str(results))
-
         response = {
-            "pdus": dict(zip(
-                (p.event_id for p in pdu_list), results
-            )),
+            "pdus": pdu_results,
         }
 
+        logger.debug("Returning: %s", str(response))
+
         yield self.transaction_actions.set_response(
             transaction,
             200, response
@@ -520,6 +544,30 @@ class FederationServer(FederationBase):
         Returns (Deferred): completes with None
         Raises: FederationError if the signatures / hash do not match
     """
+        # check that it's actually being sent from a valid destination to
+        # workaround bug #1753 in 0.18.5 and 0.18.6
+        if origin != get_domain_from_id(pdu.event_id):
+            # We continue to accept join events from any server; this is
+            # necessary for the federation join dance to work correctly.
+            # (When we join over federation, the "helper" server is
+            # responsible for sending out the join event, rather than the
+            # origin. See bug #1893).
+            if not (
+                pdu.type == 'm.room.member' and
+                pdu.content and
+                pdu.content.get("membership", None) == 'join'
+            ):
+                logger.info(
+                    "Discarding PDU %s from invalid origin %s",
+                    pdu.event_id, origin
+                )
+                return
+            else:
+                logger.info(
+                    "Accepting join PDU %s from %s",
+                    pdu.event_id, origin
+                )
+
         # Check signature.
         try:
             pdu = yield self._check_sigs_and_hash(pdu)
diff --git a/synapse/federation/transaction_queue.py b/synapse/federation/transaction_queue.py
index 003eaba893..7a3c9cbb70 100644
--- a/synapse/federation/transaction_queue.py
+++ b/synapse/federation/transaction_queue.py
@@ -20,8 +20,8 @@ from .persistence import TransactionActions
 from .units import Transaction, Edu
 
 from synapse.api.errors import HttpResponseException
+from synapse.util import logcontext
 from synapse.util.async import run_on_reactor
-from synapse.util.logcontext import preserve_context_over_fn, preserve_fn
 from synapse.util.retryutils import NotRetryingDestination, get_retry_limiter
 from synapse.util.metrics import measure_func
 from synapse.handlers.presence import format_user_presence_state, get_interested_remotes
@@ -231,11 +231,9 @@ class TransactionQueue(object):
                 (pdu, order)
             )
 
-            preserve_context_over_fn(
-                self._attempt_new_transaction, destination
-            )
+            self._attempt_new_transaction(destination)
 
-    @preserve_fn  # the caller should not yield on this
+    @logcontext.preserve_fn  # the caller should not yield on this
     @defer.inlineCallbacks
     def send_presence(self, states):
         """Send the new presence states to the appropriate destinations.
@@ -299,7 +297,7 @@ class TransactionQueue(object):
                     state.user_id: state for state in states
                 })
 
-                preserve_fn(self._attempt_new_transaction)(destination)
+                self._attempt_new_transaction(destination)
 
     def send_edu(self, destination, edu_type, content, key=None):
         edu = Edu(
@@ -321,9 +319,7 @@ class TransactionQueue(object):
         else:
             self.pending_edus_by_dest.setdefault(destination, []).append(edu)
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def send_failure(self, failure, destination):
         if destination == self.server_name or destination == "localhost":
@@ -336,9 +332,7 @@ class TransactionQueue(object):
             destination, []
         ).append(failure)
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def send_device_messages(self, destination):
         if destination == self.server_name or destination == "localhost":
@@ -347,15 +341,24 @@ class TransactionQueue(object):
         if not self.can_send_to(destination):
             return
 
-        preserve_context_over_fn(
-            self._attempt_new_transaction, destination
-        )
+        self._attempt_new_transaction(destination)
 
     def get_current_token(self):
         return 0
 
-    @defer.inlineCallbacks
     def _attempt_new_transaction(self, destination):
+        """Try to start a new transaction to this destination
+
+        If there is already a transaction in progress to this destination,
+        returns immediately. Otherwise kicks off the process of sending a
+        transaction in the background.
+
+        Args:
+            destination (str):
+
+        Returns:
+            None
+        """
         # list of (pending_pdu, deferred, order)
         if destination in self.pending_transactions:
             # XXX: pending_transactions can get stuck on by a never-ending
@@ -368,6 +371,19 @@ class TransactionQueue(object):
             )
             return
 
+        logger.debug("TX [%s] Starting transaction loop", destination)
+
+        # Drop the logcontext before starting the transaction. It doesn't
+        # really make sense to log all the outbound transactions against
+        # whatever path led us to this point: that's pretty arbitrary really.
+        #
+        # (this also means we can fire off _perform_transaction without
+        # yielding)
+        with logcontext.PreserveLoggingContext():
+            self._transaction_transmission_loop(destination)
+
+    @defer.inlineCallbacks
+    def _transaction_transmission_loop(self, destination):
         pending_pdus = []
         try:
             self.pending_transactions[destination] = 1
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 737fe518ef..63c56a4a32 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 
 """Contains handlers for federation events."""
-import synapse.util.logcontext
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -26,10 +25,7 @@ from synapse.api.errors import (
 )
 from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.events.validator import EventValidator
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import (
-    preserve_fn, preserve_context_over_deferred
-)
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.metrics import measure_func
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor, Linearizer
@@ -592,9 +588,9 @@ class FederationHandler(BaseHandler):
                     missing_auth - failed_to_fetch
                 )
 
-                results = yield preserve_context_over_deferred(defer.gatherResults(
+                results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
                     [
-                        preserve_fn(self.replication_layer.get_pdu)(
+                        logcontext.preserve_fn(self.replication_layer.get_pdu)(
                             [dest],
                             event_id,
                             outlier=True,
@@ -786,10 +782,14 @@ class FederationHandler(BaseHandler):
         event_ids = list(extremities.keys())
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
-        states = yield preserve_context_over_deferred(defer.gatherResults([
-            preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
-            for e in event_ids
-        ]))
+        states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
+            [
+                logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
+                    room_id, [e]
+                )
+                for e in event_ids
+            ], consumeErrors=True,
+        ))
         states = dict(zip(event_ids, [s.state for s in states]))
 
         state_map = yield self.store.get_events(
@@ -942,9 +942,7 @@ class FederationHandler(BaseHandler):
             # lots of requests for missing prev_events which we do actually
             # have. Hence we fire off the deferred, but don't wait for it.
 
-            synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
-                room_queue
-            )
+            logcontext.preserve_fn(self._handle_queued_pdus)(room_queue)
 
         defer.returnValue(True)
 
@@ -1071,6 +1069,9 @@ class FederationHandler(BaseHandler):
         """
         event = pdu
 
+        if event.state_key is None:
+            raise SynapseError(400, "The invite event did not have a state key")
+
         is_blocked = yield self.store.is_room_blocked(event.room_id)
         if is_blocked:
             raise SynapseError(403, "This room has been blocked on this server")
@@ -1078,9 +1079,11 @@ class FederationHandler(BaseHandler):
         if self.hs.config.block_non_admin_invites:
             raise SynapseError(403, "This server does not accept room invites")
 
-        if not self.spam_checker.user_may_invite(event.sender, event.room_id):
+        if not self.spam_checker.user_may_invite(
+            event.sender, event.state_key, event.room_id,
+        ):
             raise SynapseError(
-                403, "This user is not permitted to send invites to this server"
+                403, "This user is not permitted to send invites to this server/user"
             )
 
         membership = event.content.get("membership")
@@ -1091,9 +1094,6 @@ class FederationHandler(BaseHandler):
         if sender_domain != origin:
             raise SynapseError(400, "The invite event was not from the server sending it")
 
-        if event.state_key is None:
-            raise SynapseError(400, "The invite event did not have a state key")
-
         if not self.is_mine_id(event.state_key):
             raise SynapseError(400, "The invite event must be for this server")
 
@@ -1436,7 +1436,7 @@ class FederationHandler(BaseHandler):
         if not backfilled:
             # this intentionally does not yield: we don't care about the result
             # and don't need to wait for it.
-            preserve_fn(self.pusher_pool.on_new_notifications)(
+            logcontext.preserve_fn(self.pusher_pool.on_new_notifications)(
                 event_stream_id, max_stream_id
             )
 
@@ -1449,16 +1449,16 @@ class FederationHandler(BaseHandler):
         a bunch of outliers, but not a chunk of individual events that depend
         on each other for state calculations.
         """
-        contexts = yield preserve_context_over_deferred(defer.gatherResults(
+        contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self._prep_event)(
+                logcontext.preserve_fn(self._prep_event)(
                     origin,
                     ev_info["event"],
                     state=ev_info.get("state"),
                     auth_events=ev_info.get("auth_events"),
                 )
                 for ev_info in event_infos
-            ]
+            ], consumeErrors=True,
         ))
 
         yield self.store.persist_events(
@@ -1766,18 +1766,17 @@ class FederationHandler(BaseHandler):
             # Do auth conflict res.
             logger.info("Different auth: %s", different_auth)
 
-            different_events = yield preserve_context_over_deferred(defer.gatherResults(
-                [
-                    preserve_fn(self.store.get_event)(
+            different_events = yield logcontext.make_deferred_yieldable(
+                defer.gatherResults([
+                    logcontext.preserve_fn(self.store.get_event)(
                         d,
                         allow_none=True,
                         allow_rejected=False,
                     )
                     for d in different_auth
                     if d in have_events and not have_events[d]
-                ],
-                consumeErrors=True
-            )).addErrback(unwrapFirstError)
+                ], consumeErrors=True)
+            ).addErrback(unwrapFirstError)
 
             if different_events:
                 local_view = dict(auth_events)
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 37985fa1f9..36a8ef8ce0 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -225,7 +225,7 @@ class RoomMemberHandler(BaseHandler):
                     block_invite = True
 
                 if not self.spam_checker.user_may_invite(
-                    requester.user.to_string(), room_id,
+                    requester.user.to_string(), target.to_string(), room_id,
                 ):
                     logger.info("Blocking invite due to spam checker")
                     block_invite = True
diff --git a/synapse/state.py b/synapse/state.py
index 390799fbd5..dcdcdef65e 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -288,6 +288,9 @@ class StateHandler(object):
         """
         logger.debug("resolve_state_groups event_ids %s", event_ids)
 
+        # map from state group id to the state in that state group (where
+        # 'state' is a map from state key to event id)
+        # dict[int, dict[(str, str), str]]
         state_groups_ids = yield self.store.get_state_groups_ids(
             room_id, event_ids
         )
@@ -320,11 +323,15 @@ class StateHandler(object):
                 "Resolving state for %s with %d groups", room_id, len(state_groups_ids)
             )
 
+            # build a map from state key to the event_ids which set that state.
+            # dict[(str, str), set[str])
             state = {}
             for st in state_groups_ids.values():
                 for key, e_id in st.items():
                     state.setdefault(key, set()).add(e_id)
 
+            # build a map from state key to the event_ids which set that state,
+            # including only those where there are state keys in conflict.
             conflicted_state = {
                 k: list(v)
                 for k, v in state.items()
@@ -494,8 +501,14 @@ def _resolve_with_state_fac(unconflicted_state, conflicted_state,
 
     logger.info("Asking for %d conflicted events", len(needed_events))
 
+    # dict[str, FrozenEvent]: a map from state event id to event. Only includes
+    # the state events which are in conflict.
     state_map = yield state_map_factory(needed_events)
 
+    # get the ids of the auth events which allow us to authenticate the
+    # conflicted state, picking only from the unconflicting state.
+    #
+    # dict[(str, str), str]: a map from state key to event id
     auth_events = _create_auth_events_from_maps(
         unconflicted_state, conflicted_state, state_map
     )
diff --git a/synapse/util/async.py b/synapse/util/async.py
index 1453faf0ef..bb252f75d7 100644
--- a/synapse/util/async.py
+++ b/synapse/util/async.py
@@ -19,7 +19,7 @@ from twisted.internet import defer, reactor
 from .logcontext import (
     PreserveLoggingContext, preserve_fn, preserve_context_over_deferred,
 )
-from synapse.util import unwrapFirstError
+from synapse.util import logcontext, unwrapFirstError
 
 from contextlib import contextmanager
 
@@ -155,7 +155,7 @@ def concurrently_execute(func, args, limit):
         except StopIteration:
             pass
 
-    return preserve_context_over_deferred(defer.gatherResults([
+    return logcontext.make_deferred_yieldable(defer.gatherResults([
         preserve_fn(_concurrently_execute_inner)()
         for _ in xrange(limit)
     ], consumeErrors=True)).addErrback(unwrapFirstError)