summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py81
1 files changed, 54 insertions, 27 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 18f87cad67..7711cded01 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 
 """Contains handlers for federation events."""
-import synapse.util.logcontext
 from signedjson.key import decode_verify_key_bytes
 from signedjson.sign import verify_signed_json
 from unpaddedbase64 import decode_base64
@@ -26,10 +25,7 @@ from synapse.api.errors import (
 )
 from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.events.validator import EventValidator
-from synapse.util import unwrapFirstError
-from synapse.util.logcontext import (
-    preserve_fn, preserve_context_over_deferred
-)
+from synapse.util import unwrapFirstError, logcontext
 from synapse.util.metrics import measure_func
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor, Linearizer
@@ -77,6 +73,7 @@ class FederationHandler(BaseHandler):
         self.action_generator = hs.get_action_generator()
         self.is_mine_id = hs.is_mine_id
         self.pusher_pool = hs.get_pusherpool()
+        self.spam_checker = hs.get_spam_checker()
 
         self.replication_layer.set_handler(self)
 
@@ -125,6 +122,28 @@ class FederationHandler(BaseHandler):
             self.room_queues[pdu.room_id].append((pdu, origin))
             return
 
+        # If we're no longer in the room just ditch the event entirely. This
+        # is probably an old server that has come back and thinks we're still
+        # in the room (or we've been rejoined to the room by a state reset).
+        #
+        # If we were never in the room then maybe our database got vaped and
+        # we should check if we *are* in fact in the room. If we are then we
+        # can magically rejoin the room.
+        is_in_room = yield self.auth.check_host_in_room(
+            pdu.room_id,
+            self.server_name
+        )
+        if not is_in_room:
+            was_in_room = yield self.store.was_host_joined(
+                pdu.room_id, self.server_name,
+            )
+            if was_in_room:
+                logger.info(
+                    "Ignoring PDU %s for room %s from %s as we've left the room!",
+                    pdu.event_id, pdu.room_id, origin,
+                )
+                return
+
         state = None
 
         auth_chain = []
@@ -591,9 +610,9 @@ class FederationHandler(BaseHandler):
                     missing_auth - failed_to_fetch
                 )
 
-                results = yield preserve_context_over_deferred(defer.gatherResults(
+                results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
                     [
-                        preserve_fn(self.replication_layer.get_pdu)(
+                        logcontext.preserve_fn(self.replication_layer.get_pdu)(
                             [dest],
                             event_id,
                             outlier=True,
@@ -785,10 +804,14 @@ class FederationHandler(BaseHandler):
         event_ids = list(extremities.keys())
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
-        states = yield preserve_context_over_deferred(defer.gatherResults([
-            preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
-            for e in event_ids
-        ]))
+        states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
+            [
+                logcontext.preserve_fn(self.state_handler.resolve_state_groups)(
+                    room_id, [e]
+                )
+                for e in event_ids
+            ], consumeErrors=True,
+        ))
         states = dict(zip(event_ids, [s.state for s in states]))
 
         state_map = yield self.store.get_events(
@@ -941,9 +964,7 @@ class FederationHandler(BaseHandler):
             # lots of requests for missing prev_events which we do actually
             # have. Hence we fire off the deferred, but don't wait for it.
 
-            synapse.util.logcontext.preserve_fn(self._handle_queued_pdus)(
-                room_queue
-            )
+            logcontext.preserve_fn(self._handle_queued_pdus)(room_queue)
 
         defer.returnValue(True)
 
@@ -1070,6 +1091,9 @@ class FederationHandler(BaseHandler):
         """
         event = pdu
 
+        if event.state_key is None:
+            raise SynapseError(400, "The invite event did not have a state key")
+
         is_blocked = yield self.store.is_room_blocked(event.room_id)
         if is_blocked:
             raise SynapseError(403, "This room has been blocked on this server")
@@ -1077,6 +1101,13 @@ class FederationHandler(BaseHandler):
         if self.hs.config.block_non_admin_invites:
             raise SynapseError(403, "This server does not accept room invites")
 
+        if not self.spam_checker.user_may_invite(
+            event.sender, event.state_key, event.room_id,
+        ):
+            raise SynapseError(
+                403, "This user is not permitted to send invites to this server/user"
+            )
+
         membership = event.content.get("membership")
         if event.type != EventTypes.Member or membership != Membership.INVITE:
             raise SynapseError(400, "The event was not an m.room.member invite event")
@@ -1085,9 +1116,6 @@ class FederationHandler(BaseHandler):
         if sender_domain != origin:
             raise SynapseError(400, "The invite event was not from the server sending it")
 
-        if event.state_key is None:
-            raise SynapseError(400, "The invite event did not have a state key")
-
         if not self.is_mine_id(event.state_key):
             raise SynapseError(400, "The invite event must be for this server")
 
@@ -1430,7 +1458,7 @@ class FederationHandler(BaseHandler):
         if not backfilled:
             # this intentionally does not yield: we don't care about the result
             # and don't need to wait for it.
-            preserve_fn(self.pusher_pool.on_new_notifications)(
+            logcontext.preserve_fn(self.pusher_pool.on_new_notifications)(
                 event_stream_id, max_stream_id
             )
 
@@ -1443,16 +1471,16 @@ class FederationHandler(BaseHandler):
         a bunch of outliers, but not a chunk of individual events that depend
         on each other for state calculations.
         """
-        contexts = yield preserve_context_over_deferred(defer.gatherResults(
+        contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
             [
-                preserve_fn(self._prep_event)(
+                logcontext.preserve_fn(self._prep_event)(
                     origin,
                     ev_info["event"],
                     state=ev_info.get("state"),
                     auth_events=ev_info.get("auth_events"),
                 )
                 for ev_info in event_infos
-            ]
+            ], consumeErrors=True,
         ))
 
         yield self.store.persist_events(
@@ -1760,18 +1788,17 @@ class FederationHandler(BaseHandler):
             # Do auth conflict res.
             logger.info("Different auth: %s", different_auth)
 
-            different_events = yield preserve_context_over_deferred(defer.gatherResults(
-                [
-                    preserve_fn(self.store.get_event)(
+            different_events = yield logcontext.make_deferred_yieldable(
+                defer.gatherResults([
+                    logcontext.preserve_fn(self.store.get_event)(
                         d,
                         allow_none=True,
                         allow_rejected=False,
                     )
                     for d in different_auth
                     if d in have_events and not have_events[d]
-                ],
-                consumeErrors=True
-            )).addErrback(unwrapFirstError)
+                ], consumeErrors=True)
+            ).addErrback(unwrapFirstError)
 
             if different_events:
                 local_view = dict(auth_events)