summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py508
1 files changed, 375 insertions, 133 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index d26975a88a..8bf5a4cc11 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,18 +17,16 @@
 
 from ._base import BaseHandler
 
-from synapse.events.utils import prune_event
 from synapse.api.errors import (
-    AuthError, FederationError, SynapseError, StoreError,
+    AuthError, FederationError, StoreError,
 )
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.crypto.event_signing import (
-    compute_event_signature, check_event_content_hash,
-    add_hashes_and_signatures,
+    compute_event_signature, add_hashes_and_signatures,
 )
-from syutil.jsonutil import encode_canonical_json
+from synapse.types import UserID
 
 from twisted.internet import defer
 
@@ -75,14 +73,14 @@ class FederationHandler(BaseHandler):
 
     @log_function
     @defer.inlineCallbacks
-    def handle_new_event(self, event, snapshot, destinations):
+    def handle_new_event(self, event, destinations):
         """ Takes in an event from the client to server side, that has already
         been authed and handled by the state module, and sends it to any
         remote home servers that may be interested.
 
         Args:
-            event
-            snapshot (.storage.Snapshot): THe snapshot the event happened after
+            event: The event to send
+            destinations: A list of destinations to send it to
 
         Returns:
             Deferred: Resolved when it has successfully been queued for
@@ -112,33 +110,6 @@ class FederationHandler(BaseHandler):
 
         logger.debug("Processing event: %s", event.event_id)
 
-        redacted_event = prune_event(event)
-
-        redacted_pdu_json = redacted_event.get_pdu_json()
-        try:
-            yield self.keyring.verify_json_for_server(
-                event.origin, redacted_pdu_json
-            )
-        except SynapseError as e:
-            logger.warn(
-                "Signature check failed for %s redacted to %s",
-                encode_canonical_json(pdu.get_pdu_json()),
-                encode_canonical_json(redacted_pdu_json),
-            )
-            raise FederationError(
-                "ERROR",
-                e.code,
-                e.msg,
-                affected=event.event_id,
-            )
-
-        if not check_event_content_hash(event):
-            logger.warn(
-                "Event content has been tampered, redacting %s, %s",
-                event.event_id, encode_canonical_json(event.get_dict())
-            )
-            event = redacted_event
-
         logger.debug("Event: %s", event)
 
         # FIXME (erikj): Awful hack to make the case where we are not currently
@@ -148,41 +119,20 @@ class FederationHandler(BaseHandler):
             event.room_id,
             self.server_name
         )
-        if not is_in_room and not event.internal_metadata.outlier:
+        if not is_in_room and not event.internal_metadata.is_outlier():
             logger.debug("Got event for room we're not in.")
-
-            replication = self.replication_layer
-
-            if not state:
-                state, auth_chain = yield replication.get_state_for_context(
-                    origin, context=event.room_id, event_id=event.event_id,
-                )
-
-            if not auth_chain:
-                auth_chain = yield replication.get_event_auth(
-                    origin,
-                    context=event.room_id,
-                    event_id=event.event_id,
-                )
-
-            for e in auth_chain:
-                e.internal_metadata.outlier = True
-                try:
-                    yield self._handle_new_event(e, fetch_auth_from=origin)
-                except:
-                    logger.exception(
-                        "Failed to handle auth event %s",
-                        e.event_id,
-                    )
-
             current_state = state
 
-        if state:
+        if state and auth_chain is not None:
             for e in state:
-                logging.info("A :) %r", e)
                 e.internal_metadata.outlier = True
                 try:
-                    yield self._handle_new_event(e)
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    yield self._handle_new_event(origin, e, auth_events=auth)
                 except:
                     logger.exception(
                         "Failed to handle state event %s",
@@ -191,6 +141,7 @@ class FederationHandler(BaseHandler):
 
         try:
             yield self._handle_new_event(
+                origin,
                 event,
                 state=state,
                 backfilled=backfilled,
@@ -227,7 +178,7 @@ class FederationHandler(BaseHandler):
             extra_users = []
             if event.type == EventTypes.Member:
                 target_user_id = event.state_key
-                target_user = self.hs.parse_userid(target_user_id)
+                target_user = UserID.from_string(target_user_id)
                 extra_users.append(target_user)
 
             yield self.notifier.on_new_room_event(
@@ -236,7 +187,7 @@ class FederationHandler(BaseHandler):
 
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
-                user = self.hs.parse_userid(event.state_key)
+                user = UserID.from_string(event.state_key)
                 yield self.distributor.fire(
                     "user_joined_room", user=user, room_id=event.room_id
                 )
@@ -281,7 +232,7 @@ class FederationHandler(BaseHandler):
         """
         pdu = yield self.replication_layer.send_invite(
             destination=target_host,
-            context=event.room_id,
+            room_id=event.room_id,
             event_id=event.event_id,
             pdu=event
         )
@@ -392,8 +343,19 @@ class FederationHandler(BaseHandler):
 
             for e in auth_chain:
                 e.internal_metadata.outlier = True
+
+                if e.event_id == event.event_id:
+                    continue
+
                 try:
-                    yield self._handle_new_event(e)
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    yield self._handle_new_event(
+                        target_host, e, auth_events=auth
+                    )
                 except:
                     logger.exception(
                         "Failed to handle auth event %s",
@@ -401,11 +363,18 @@ class FederationHandler(BaseHandler):
                     )
 
             for e in state:
-                # FIXME: Auth these.
+                if e.event_id == event.event_id:
+                    continue
+
                 e.internal_metadata.outlier = True
                 try:
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
                     yield self._handle_new_event(
-                        e, fetch_auth_from=target_host
+                        target_host, e, auth_events=auth
                     )
                 except:
                     logger.exception(
@@ -413,10 +382,18 @@ class FederationHandler(BaseHandler):
                         e.event_id,
                     )
 
+            auth_ids = [e_id for e_id, _ in event.auth_events]
+            auth_events = {
+                (e.type, e.state_key): e for e in auth_chain
+                if e.event_id in auth_ids
+            }
+
             yield self._handle_new_event(
+                target_host,
                 new_event,
                 state=state,
                 current_state=state,
+                auth_events=auth_events,
             )
 
             yield self.notifier.on_new_room_event(
@@ -480,7 +457,7 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context = yield self._handle_new_event(event)
+        context = yield self._handle_new_event(origin, event)
 
         logger.debug(
             "on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -491,7 +468,7 @@ class FederationHandler(BaseHandler):
         extra_users = []
         if event.type == EventTypes.Member:
             target_user_id = event.state_key
-            target_user = self.hs.parse_userid(target_user_id)
+            target_user = UserID.from_string(target_user_id)
             extra_users.append(target_user)
 
         yield self.notifier.on_new_room_event(
@@ -500,7 +477,7 @@ class FederationHandler(BaseHandler):
 
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.JOIN:
-                user = self.hs.parse_userid(event.state_key)
+                user = UserID.from_string(event.state_key)
                 yield self.distributor.fire(
                     "user_joined_room", user=user, room_id=event.room_id
                 )
@@ -514,7 +491,7 @@ class FederationHandler(BaseHandler):
                 if k[0] == EventTypes.Member:
                     if s.content["membership"] == Membership.JOIN:
                         destinations.add(
-                            self.hs.parse_userid(s.state_key).domain
+                            UserID.from_string(s.state_key).domain
                         )
             except:
                 logger.warn(
@@ -565,7 +542,7 @@ class FederationHandler(BaseHandler):
             backfilled=False,
         )
 
-        target_user = self.hs.parse_userid(event.state_key)
+        target_user = UserID.from_string(event.state_key)
         yield self.notifier.on_new_room_event(
             event, extra_users=[target_user],
         )
@@ -617,13 +594,13 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     @log_function
-    def on_backfill_request(self, origin, context, pdu_list, limit):
-        in_room = yield self.auth.check_host_in_room(context, origin)
+    def on_backfill_request(self, origin, room_id, pdu_list, limit):
+        in_room = yield self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
         events = yield self.store.get_backfill_events(
-            context,
+            room_id,
             pdu_list,
             limit
         )
@@ -681,11 +658,12 @@ class FederationHandler(BaseHandler):
             waiters.pop().callback(None)
 
     @defer.inlineCallbacks
-    def _handle_new_event(self, event, state=None, backfilled=False,
-                          current_state=None, fetch_auth_from=None):
+    @log_function
+    def _handle_new_event(self, origin, event, state=None, backfilled=False,
+                          current_state=None, auth_events=None):
 
         logger.debug(
-            "_handle_new_event: Before annotate: %s, sigs: %s",
+            "_handle_new_event: %s, sigs: %s",
             event.event_id, event.signatures,
         )
 
@@ -693,65 +671,44 @@ class FederationHandler(BaseHandler):
             event, old_state=state
         )
 
+        if not auth_events:
+            auth_events = context.auth_events
+
         logger.debug(
-            "_handle_new_event: Before auth fetch: %s, sigs: %s",
-            event.event_id, event.signatures,
+            "_handle_new_event: %s, auth_events: %s",
+            event.event_id, auth_events,
         )
 
         is_new_state = not event.internal_metadata.is_outlier()
 
-        known_ids = set(
-            [s.event_id for s in context.auth_events.values()]
-        )
-
-        for e_id, _ in event.auth_events:
-            if e_id not in known_ids:
-                e = yield self.store.get_event(e_id, allow_none=True)
-
-                if not e and fetch_auth_from is not None:
-                    # Grab the auth_chain over federation if we are missing
-                    # auth events.
-                    auth_chain = yield self.replication_layer.get_event_auth(
-                        fetch_auth_from, event.event_id, event.room_id
-                    )
-                    for auth_event in auth_chain:
-                        yield self._handle_new_event(auth_event)
-                    e = yield self.store.get_event(e_id, allow_none=True)
-
-                if not e:
-                    # TODO: Do some conflict res to make sure that we're
-                    # not the ones who are wrong.
-                    logger.info(
-                        "Rejecting %s as %s not in db or %s",
-                        event.event_id, e_id, known_ids,
-                    )
-                    # FIXME: How does raising AuthError work with federation?
-                    raise AuthError(403, "Cannot find auth event")
-
-                context.auth_events[(e.type, e.state_key)] = e
-
-        logger.debug(
-            "_handle_new_event: Before hack: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
-
+        # This is a hack to fix some old rooms where the initial join event
+        # didn't reference the create event in its auth events.
         if event.type == EventTypes.Member and not event.auth_events:
             if len(event.prev_events) == 1:
                 c = yield self.store.get_event(event.prev_events[0][0])
                 if c.type == EventTypes.Create:
-                    context.auth_events[(c.type, c.state_key)] = c
+                    auth_events[(c.type, c.state_key)] = c
 
-        logger.debug(
-            "_handle_new_event: Before auth check: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
+        try:
+            yield self.do_auth(
+                origin, event, context, auth_events=auth_events
+            )
+        except AuthError as e:
+            logger.warn(
+                "Rejecting %s because %s",
+                event.event_id, e.msg
+            )
 
-        self.auth.check(event, auth_events=context.auth_events)
+            context.rejected = RejectedReason.AUTH_ERROR
 
-        logger.debug(
-            "_handle_new_event: Before persist_event: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
+            yield self.store.persist_event(
+                event,
+                context=context,
+                backfilled=backfilled,
+                is_new_state=False,
+                current_state=current_state,
+            )
+            raise
 
         yield self.store.persist_event(
             event,
@@ -761,9 +718,294 @@ class FederationHandler(BaseHandler):
             current_state=current_state,
         )
 
-        logger.debug(
-            "_handle_new_event: After persist_event: %s, sigs: %s",
-            event.event_id, event.signatures,
+        defer.returnValue(context)
+
+    @defer.inlineCallbacks
+    def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
+                      missing):
+        # Just go through and process each event in `remote_auth_chain`. We
+        # don't want to fall into the trap of `missing` being wrong.
+        for e in remote_auth_chain:
+            try:
+                yield self._handle_new_event(origin, e)
+            except AuthError:
+                pass
+
+        # Now get the current auth_chain for the event.
+        local_auth_chain = yield self.store.get_auth_chain([event_id])
+
+        # TODO: Check if we would now reject event_id. If so we need to tell
+        # everyone.
+
+        ret = yield self.construct_auth_difference(
+            local_auth_chain, remote_auth_chain
         )
 
-        defer.returnValue(context)
+        for event in ret["auth_chain"]:
+            event.signatures.update(
+                compute_event_signature(
+                    event,
+                    self.hs.hostname,
+                    self.hs.config.signing_key[0]
+                )
+            )
+
+        logger.debug("on_query_auth reutrning: %s", ret)
+
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    @log_function
+    def do_auth(self, origin, event, context, auth_events):
+        # Check if we have all the auth events.
+        res = yield self.store.have_events(
+            [e_id for e_id, _ in event.auth_events]
+        )
+
+        event_auth_events = set(e_id for e_id, _ in event.auth_events)
+        seen_events = set(res.keys())
+
+        missing_auth = event_auth_events - seen_events
+
+        if missing_auth:
+            logger.debug("Missing auth: %s", missing_auth)
+            # If we don't have all the auth events, we need to get them.
+            remote_auth_chain = yield self.replication_layer.get_event_auth(
+                origin, event.room_id, event.event_id
+            )
+
+            seen_remotes = yield self.store.have_events(
+                [e.event_id for e in remote_auth_chain]
+            )
+
+            for e in remote_auth_chain:
+                if e.event_id in seen_remotes.keys():
+                    continue
+
+                if e.event_id == event.event_id:
+                    continue
+
+                try:
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in remote_auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    e.internal_metadata.outlier = True
+
+                    logger.debug(
+                        "do_auth %s missing_auth: %s",
+                        event.event_id, e.event_id
+                    )
+                    yield self._handle_new_event(
+                        origin, e, auth_events=auth
+                    )
+
+                    if e.event_id in event_auth_events:
+                        auth_events[(e.type, e.state_key)] = e
+                except AuthError:
+                    pass
+
+        # FIXME: Assumes we have and stored all the state for all the
+        # prev_events
+        current_state = set(e.event_id for e in auth_events.values())
+        different_auth = event_auth_events - current_state
+
+        if different_auth and not event.internal_metadata.is_outlier():
+            # Do auth conflict res.
+            logger.debug("Different auth: %s", different_auth)
+
+            # 1. Get what we think is the auth chain.
+            auth_ids = self.auth.compute_auth_events(event, context)
+            local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+
+            # 2. Get remote difference.
+            result = yield self.replication_layer.query_auth(
+                origin,
+                event.room_id,
+                event.event_id,
+                local_auth_chain,
+            )
+
+            seen_remotes = yield self.store.have_events(
+                [e.event_id for e in result["auth_chain"]]
+            )
+
+            # 3. Process any remote auth chain events we haven't seen.
+            for ev in result["auth_chain"]:
+                if ev.event_id in seen_remotes.keys():
+                    continue
+
+                if ev.event_id == event.event_id:
+                    continue
+
+                try:
+                    auth_ids = [e_id for e_id, _ in ev.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in result["auth_chain"]
+                        if e.event_id in auth_ids
+                    }
+                    ev.internal_metadata.outlier = True
+
+                    logger.debug(
+                        "do_auth %s different_auth: %s",
+                        event.event_id, e.event_id
+                    )
+
+                    yield self._handle_new_event(
+                        origin, ev, auth_events=auth
+                    )
+
+                    if ev.event_id in event_auth_events:
+                        auth_events[(ev.type, ev.state_key)] = ev
+                except AuthError:
+                    pass
+
+            # 4. Look at rejects and their proofs.
+            # TODO.
+
+            context.current_state.update(auth_events)
+            context.state_group = None
+
+        try:
+            self.auth.check(event, auth_events=auth_events)
+        except AuthError:
+            raise
+
+    @defer.inlineCallbacks
+    def construct_auth_difference(self, local_auth, remote_auth):
+        """ Given a local and remote auth chain, find the differences. This
+        assumes that we have already processed all events in remote_auth
+
+        Params:
+            local_auth (list)
+            remote_auth (list)
+
+        Returns:
+            dict
+        """
+
+        logger.debug("construct_auth_difference Start!")
+
+        # TODO: Make sure we are OK with local_auth or remote_auth having more
+        # auth events in them than strictly necessary.
+
+        def sort_fun(ev):
+            return ev.depth, ev.event_id
+
+        logger.debug("construct_auth_difference after sort_fun!")
+
+        # We find the differences by starting at the "bottom" of each list
+        # and iterating up on both lists. The lists are ordered by depth and
+        # then event_id, we iterate up both lists until we find the event ids
+        # don't match. Then we look at depth/event_id to see which side is
+        # missing that event, and iterate only up that list. Repeat.
+
+        remote_list = list(remote_auth)
+        remote_list.sort(key=sort_fun)
+
+        local_list = list(local_auth)
+        local_list.sort(key=sort_fun)
+
+        local_iter = iter(local_list)
+        remote_iter = iter(remote_list)
+
+        logger.debug("construct_auth_difference before get_next!")
+
+        def get_next(it, opt=None):
+            try:
+                return it.next()
+            except:
+                return opt
+
+        current_local = get_next(local_iter)
+        current_remote = get_next(remote_iter)
+
+        logger.debug("construct_auth_difference before while")
+
+        missing_remotes = []
+        missing_locals = []
+        while current_local or current_remote:
+            if current_remote is None:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+                continue
+
+            if current_local is None:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+            if current_local.event_id == current_remote.event_id:
+                current_local = get_next(local_iter)
+                current_remote = get_next(remote_iter)
+                continue
+
+            if current_local.depth < current_remote.depth:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+                continue
+
+            if current_local.depth > current_remote.depth:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+            # They have the same depth, so we fall back to the event_id order
+            if current_local.event_id < current_remote.event_id:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+
+            if current_local.event_id > current_remote.event_id:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+        logger.debug("construct_auth_difference after while")
+
+        # missing locals should be sent to the server
+        # We should find why we are missing remotes, as they will have been
+        # rejected.
+
+        # Remove events from missing_remotes if they are referencing a missing
+        # remote. We only care about the "root" rejected ones.
+        missing_remote_ids = [e.event_id for e in missing_remotes]
+        base_remote_rejected = list(missing_remotes)
+        for e in missing_remotes:
+            for e_id, _ in e.auth_events:
+                if e_id in missing_remote_ids:
+                    base_remote_rejected.remove(e)
+
+        reason_map = {}
+
+        for e in base_remote_rejected:
+            reason = yield self.store.get_rejection_reason(e.event_id)
+            if reason is None:
+                # FIXME: ERRR?!
+                logger.warn("Could not find reason for %s", e.event_id)
+                raise RuntimeError("")
+
+            reason_map[e.event_id] = reason
+
+            if reason == RejectedReason.AUTH_ERROR:
+                pass
+            elif reason == RejectedReason.REPLACED:
+                # TODO: Get proof
+                pass
+            elif reason == RejectedReason.NOT_ANCESTOR:
+                # TODO: Get proof.
+                pass
+
+        logger.debug("construct_auth_difference returning")
+
+        defer.returnValue({
+            "auth_chain": local_auth,
+            "rejects": {
+                e.event_id: {
+                    "reason": reason_map[e.event_id],
+                    "proof": None,
+                }
+                for e in base_remote_rejected
+            },
+            "missing": [e.event_id for e in missing_locals],
+        })