summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py233
1 files changed, 137 insertions, 96 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 267fedf114..ff83c608e7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -26,20 +26,21 @@ from synapse.api.errors import (
 from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.events.validator import EventValidator
 from synapse.util import unwrapFirstError
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import PreserveLoggingContext, preserve_fn
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.util.frozenutils import unfreeze
 from synapse.crypto.event_signing import (
     compute_event_signature, add_hashes_and_signatures,
 )
-from synapse.types import UserID
+from synapse.types import UserID, get_domain_from_id
 
 from synapse.events.utils import prune_event
 
 from synapse.util.retryutils import NotRetryingDestination
 
 from synapse.push.action_generator import ActionGenerator
+from synapse.util.distributor import user_joined_room
 
 from twisted.internet import defer
 
@@ -49,10 +50,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-def user_joined_room(distributor, user, room_id):
-    return distributor.fire("user_joined_room", user, room_id)
-
-
 class FederationHandler(BaseHandler):
     """Handles events that originated from federation.
         Responsible for:
@@ -69,10 +66,6 @@ class FederationHandler(BaseHandler):
 
         self.hs = hs
 
-        self.distributor.observe("user_joined_room", self.user_joined_room)
-
-        self.waiting_for_join_list = {}
-
         self.store = hs.get_datastore()
         self.replication_layer = hs.get_replication_layer()
         self.state_handler = hs.get_state_handler()
@@ -102,8 +95,7 @@ class FederationHandler(BaseHandler):
 
     @log_function
     @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, state=None,
-                       auth_chain=None):
+    def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
         """ Called by the ReplicationLayer when we have a new pdu. We need to
         do auth checks and put it through the StateHandler.
         """
@@ -174,11 +166,7 @@ class FederationHandler(BaseHandler):
                     })
                     seen_ids.add(e.event_id)
 
-                yield self._handle_new_events(
-                    origin,
-                    event_infos,
-                    outliers=True
-                )
+                yield self._handle_new_events(origin, event_infos)
 
             try:
                 context, event_stream_id, max_stream_id = yield self._handle_new_event(
@@ -288,7 +276,14 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     def backfill(self, dest, room_id, limit, extremities=[]):
         """ Trigger a backfill request to `dest` for the given `room_id`
+
+        This will attempt to get more events from the remote. This may return
+        be successfull and still return no events if the other side has no new
+        events to offer.
         """
+        if dest == self.server_name:
+            raise SynapseError(400, "Can't backfill from self.")
+
         if not extremities:
             extremities = yield self.store.get_oldest_events_in_room(room_id)
 
@@ -299,6 +294,16 @@ class FederationHandler(BaseHandler):
             extremities=extremities,
         )
 
+        # Don't bother processing events we already have.
+        seen_events = yield self.store.have_events_in_timeline(
+            set(e.event_id for e in events)
+        )
+
+        events = [e for e in events if e.event_id not in seen_events]
+
+        if not events:
+            defer.returnValue([])
+
         event_map = {e.event_id: e for e in events}
 
         event_ids = set(e.event_id for e in events)
@@ -358,6 +363,7 @@ class FederationHandler(BaseHandler):
         for a in auth_events.values():
             if a.event_id in seen_events:
                 continue
+            a.internal_metadata.outlier = True
             ev_infos.append({
                 "event": a,
                 "auth_events": {
@@ -378,20 +384,23 @@ class FederationHandler(BaseHandler):
                 }
             })
 
+        yield self._handle_new_events(
+            dest, ev_infos,
+            backfilled=True,
+        )
+
         events.sort(key=lambda e: e.depth)
 
         for event in events:
             if event in events_to_state:
                 continue
 
-            ev_infos.append({
-                "event": event,
-            })
-
-        yield self._handle_new_events(
-            dest, ev_infos,
-            backfilled=True,
-        )
+            # We store these one at a time since each event depends on the
+            # previous to work out the state.
+            # TODO: We can probably do something more clever here.
+            yield self._handle_new_event(
+                dest, event
+            )
 
         defer.returnValue(events)
 
@@ -440,7 +449,7 @@ class FederationHandler(BaseHandler):
             joined_domains = {}
             for u, d in joined_users:
                 try:
-                    dom = UserID.from_string(u).domain
+                    dom = get_domain_from_id(u)
                     old_d = joined_domains.get(dom)
                     if old_d:
                         joined_domains[dom] = min(d, old_d)
@@ -455,7 +464,7 @@ class FederationHandler(BaseHandler):
 
         likely_domains = [
             domain for domain, depth in curr_domains
-            if domain is not self.server_name
+            if domain != self.server_name
         ]
 
         @defer.inlineCallbacks
@@ -463,11 +472,15 @@ class FederationHandler(BaseHandler):
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
-                    events = yield self.backfill(
+                    yield self.backfill(
                         dom, room_id,
                         limit=100,
                         extremities=[e for e in extremities.keys()]
                     )
+                    # If this succeeded then we probably already have the
+                    # appropriate stuff.
+                    # TODO: We can probably do something more intelligent here.
+                    defer.returnValue(True)
                 except SynapseError as e:
                     logger.info(
                         "Failed to backfill from %s because %s",
@@ -493,8 +506,6 @@ class FederationHandler(BaseHandler):
                     )
                     continue
 
-                if events:
-                    defer.returnValue(True)
             defer.returnValue(False)
 
         success = yield try_backfill(likely_domains)
@@ -666,9 +677,14 @@ class FederationHandler(BaseHandler):
             "state_key": user_id,
         })
 
-        event, context = yield self._create_new_client_event(
-            builder=builder,
-        )
+        try:
+            message_handler = self.hs.get_handlers().message_handler
+            event, context = yield message_handler._create_new_client_event(
+                builder=builder,
+            )
+        except AuthError as e:
+            logger.warn("Failed to create join %r because %s", event, e)
+            raise e
 
         self.auth.check(event, auth_events=context.current_state)
 
@@ -724,9 +740,7 @@ class FederationHandler(BaseHandler):
             try:
                 if k[0] == EventTypes.Member:
                     if s.content["membership"] == Membership.JOIN:
-                        destinations.add(
-                            UserID.from_string(s.state_key).domain
-                        )
+                        destinations.add(get_domain_from_id(s.state_key))
             except:
                 logger.warn(
                     "Failed to get destination from event %s", s.event_id
@@ -761,6 +775,7 @@ class FederationHandler(BaseHandler):
         event = pdu
 
         event.internal_metadata.outlier = True
+        event.internal_metadata.invite_from_remote = True
 
         event.signatures.update(
             compute_event_signature(
@@ -788,13 +803,19 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
-        origin, event = yield self._make_and_verify_event(
-            target_hosts,
-            room_id,
-            user_id,
-            "leave"
-        )
-        signed_event = self._sign_event(event)
+        try:
+            origin, event = yield self._make_and_verify_event(
+                target_hosts,
+                room_id,
+                user_id,
+                "leave"
+            )
+            signed_event = self._sign_event(event)
+        except SynapseError:
+            raise
+        except CodeMessageException as e:
+            logger.warn("Failed to reject invite: %s", e)
+            raise SynapseError(500, "Failed to reject invite")
 
         # Try the host we successfully got a response to /make_join/
         # request first.
@@ -804,10 +825,16 @@ class FederationHandler(BaseHandler):
         except ValueError:
             pass
 
-        yield self.replication_layer.send_leave(
-            target_hosts,
-            signed_event
-        )
+        try:
+            yield self.replication_layer.send_leave(
+                target_hosts,
+                signed_event
+            )
+        except SynapseError:
+            raise
+        except CodeMessageException as e:
+            logger.warn("Failed to reject invite: %s", e)
+            raise SynapseError(500, "Failed to reject invite")
 
         context = yield self.state_handler.compute_event_context(event)
 
@@ -883,11 +910,16 @@ class FederationHandler(BaseHandler):
             "state_key": user_id,
         })
 
-        event, context = yield self._create_new_client_event(
+        message_handler = self.hs.get_handlers().message_handler
+        event, context = yield message_handler._create_new_client_event(
             builder=builder,
         )
 
-        self.auth.check(event, auth_events=context.current_state)
+        try:
+            self.auth.check(event, auth_events=context.current_state)
+        except AuthError as e:
+            logger.warn("Failed to create new leave %r because %s", event, e)
+            raise e
 
         defer.returnValue(event)
 
@@ -934,9 +966,7 @@ class FederationHandler(BaseHandler):
             try:
                 if k[0] == EventTypes.Member:
                     if s.content["membership"] == Membership.LEAVE:
-                        destinations.add(
-                            UserID.from_string(s.state_key).domain
-                        )
+                        destinations.add(get_domain_from_id(s.state_key))
             except:
                 logger.warn(
                     "Failed to get destination from event %s", s.event_id
@@ -1057,21 +1087,10 @@ class FederationHandler(BaseHandler):
     def get_min_depth_for_context(self, context):
         return self.store.get_min_depth(context)
 
-    @log_function
-    def user_joined_room(self, user, room_id):
-        waiters = self.waiting_for_join_list.get(
-            (user.to_string(), room_id),
-            []
-        )
-        while waiters:
-            waiters.pop().callback(None)
-
     @defer.inlineCallbacks
     @log_function
-    def _handle_new_event(self, origin, event, state=None, auth_events=None):
-
-        outlier = event.internal_metadata.is_outlier()
-
+    def _handle_new_event(self, origin, event, state=None, auth_events=None,
+                          backfilled=False):
         context = yield self._prep_event(
             origin, event,
             state=state,
@@ -1081,20 +1100,30 @@ class FederationHandler(BaseHandler):
         if not event.internal_metadata.is_outlier():
             action_generator = ActionGenerator(self.hs)
             yield action_generator.handle_push_actions_for_event(
-                event, context, self
+                event, context
             )
 
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event,
             context=context,
-            is_new_state=not outlier,
+            backfilled=backfilled,
+        )
+
+        # this intentionally does not yield: we don't care about the result
+        # and don't need to wait for it.
+        preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
+            event_stream_id, max_stream_id
         )
 
         defer.returnValue((context, event_stream_id, max_stream_id))
 
     @defer.inlineCallbacks
-    def _handle_new_events(self, origin, event_infos, backfilled=False,
-                           outliers=False):
+    def _handle_new_events(self, origin, event_infos, backfilled=False):
+        """Creates the appropriate contexts and persists events. The events
+        should not depend on one another, e.g. this should be used to persist
+        a bunch of outliers, but not a chunk of individual events that depend
+        on each other for state calculations.
+        """
         contexts = yield defer.gatherResults(
             [
                 self._prep_event(
@@ -1113,7 +1142,6 @@ class FederationHandler(BaseHandler):
                 for ev_info, context in itertools.izip(event_infos, contexts)
             ],
             backfilled=backfilled,
-            is_new_state=(not outliers and not backfilled),
         )
 
     @defer.inlineCallbacks
@@ -1128,11 +1156,9 @@ class FederationHandler(BaseHandler):
         """
         events_to_context = {}
         for e in itertools.chain(auth_events, state):
-            ctx = yield self.state_handler.compute_event_context(
-                e, outlier=True,
-            )
-            events_to_context[e.event_id] = ctx
             e.internal_metadata.outlier = True
+            ctx = yield self.state_handler.compute_event_context(e)
+            events_to_context[e.event_id] = ctx
 
         event_map = {
             e.event_id: e
@@ -1176,16 +1202,14 @@ class FederationHandler(BaseHandler):
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
             ],
-            is_new_state=False,
         )
 
         new_event_context = yield self.state_handler.compute_event_context(
-            event, old_state=state, outlier=False,
+            event, old_state=state
         )
 
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event, new_event_context,
-            is_new_state=True,
             current_state=state,
         )
 
@@ -1193,10 +1217,9 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _prep_event(self, origin, event, state=None, auth_events=None):
-        outlier = event.internal_metadata.is_outlier()
 
         context = yield self.state_handler.compute_event_context(
-            event, old_state=state, outlier=outlier,
+            event, old_state=state,
         )
 
         if not auth_events:
@@ -1482,8 +1505,9 @@ class FederationHandler(BaseHandler):
 
         try:
             self.auth.check(event, auth_events=auth_events)
-        except AuthError:
-            raise
+        except AuthError as e:
+            logger.warn("Failed auth resolution for %r because %s", event, e)
+            raise e
 
     @defer.inlineCallbacks
     def construct_auth_difference(self, local_auth, remote_auth):
@@ -1653,13 +1677,21 @@ class FederationHandler(BaseHandler):
         if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
             builder = self.event_builder_factory.new(event_dict)
             EventValidator().validate_new(builder)
-            event, context = yield self._create_new_client_event(builder=builder)
+            message_handler = self.hs.get_handlers().message_handler
+            event, context = yield message_handler._create_new_client_event(
+                builder=builder
+            )
 
             event, context = yield self.add_display_name_to_third_party_invite(
                 event_dict, event, context
             )
 
-            self.auth.check(event, context.current_state)
+            try:
+                self.auth.check(event, context.current_state)
+            except AuthError as e:
+                logger.warn("Denying new third party invite %r because %s", event, e)
+                raise e
+
             yield self._check_signature(event, auth_events=context.current_state)
             member_handler = self.hs.get_handlers().room_member_handler
             yield member_handler.send_membership_event(None, event, context)
@@ -1676,7 +1708,8 @@ class FederationHandler(BaseHandler):
     def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
         builder = self.event_builder_factory.new(event_dict)
 
-        event, context = yield self._create_new_client_event(
+        message_handler = self.hs.get_handlers().message_handler
+        event, context = yield message_handler._create_new_client_event(
             builder=builder,
         )
 
@@ -1684,7 +1717,11 @@ class FederationHandler(BaseHandler):
             event_dict, event, context
         )
 
-        self.auth.check(event, auth_events=context.current_state)
+        try:
+            self.auth.check(event, auth_events=context.current_state)
+        except AuthError as e:
+            logger.warn("Denying third party invite %r because %s", event, e)
+            raise e
         yield self._check_signature(event, auth_events=context.current_state)
 
         returned_invite = yield self.send_invite(origin, event)
@@ -1711,20 +1748,23 @@ class FederationHandler(BaseHandler):
         event_dict["content"]["third_party_invite"]["display_name"] = display_name
         builder = self.event_builder_factory.new(event_dict)
         EventValidator().validate_new(builder)
-        event, context = yield self._create_new_client_event(builder=builder)
+        message_handler = self.hs.get_handlers().message_handler
+        event, context = yield message_handler._create_new_client_event(builder=builder)
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
     def _check_signature(self, event, auth_events):
         """
         Checks that the signature in the event is consistent with its invite.
-        :param event (Event): The m.room.member event to check
-        :param auth_events (dict<(event type, state_key), event>)
 
-        :raises
-            AuthError if signature didn't match any keys, or key has been
+        Args:
+            event (Event): The m.room.member event to check
+            auth_events (dict<(event type, state_key), event>):
+
+        Raises:
+            AuthError: if signature didn't match any keys, or key has been
                 revoked,
-            SynapseError if a transient error meant a key couldn't be checked
+            SynapseError: if a transient error meant a key couldn't be checked
                 for revocation.
         """
         signed = event.content["third_party_invite"]["signed"]
@@ -1766,12 +1806,13 @@ class FederationHandler(BaseHandler):
         """
         Checks whether public_key has been revoked.
 
-        :param public_key (str): base-64 encoded public key.
-        :param url (str): Key revocation URL.
+        Args:
+            public_key (str): base-64 encoded public key.
+            url (str): Key revocation URL.
 
-        :raises
-            AuthError if they key has been revoked.
-            SynapseError if a transient error meant a key couldn't be checked
+        Raises:
+            AuthError: if they key has been revoked.
+            SynapseError: if a transient error meant a key couldn't be checked
                 for revocation.
         """
         try: