summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py429
1 files changed, 310 insertions, 119 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bcdcc90a18..40836e0c51 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,19 +17,16 @@
 
 from ._base import BaseHandler
 
-from synapse.events.utils import prune_event
 from synapse.api.errors import (
-    AuthError, FederationError, SynapseError, StoreError,
+    AuthError, FederationError, StoreError,
 )
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.crypto.event_signing import (
-    compute_event_signature, check_event_content_hash,
-    add_hashes_and_signatures,
+    compute_event_signature, add_hashes_and_signatures,
 )
 from synapse.types import UserID
-from syutil.jsonutil import encode_canonical_json
 
 from twisted.internet import defer
 
@@ -113,33 +110,6 @@ class FederationHandler(BaseHandler):
 
         logger.debug("Processing event: %s", event.event_id)
 
-        redacted_event = prune_event(event)
-
-        redacted_pdu_json = redacted_event.get_pdu_json()
-        try:
-            yield self.keyring.verify_json_for_server(
-                event.origin, redacted_pdu_json
-            )
-        except SynapseError as e:
-            logger.warn(
-                "Signature check failed for %s redacted to %s",
-                encode_canonical_json(pdu.get_pdu_json()),
-                encode_canonical_json(redacted_pdu_json),
-            )
-            raise FederationError(
-                "ERROR",
-                e.code,
-                e.msg,
-                affected=event.event_id,
-            )
-
-        if not check_event_content_hash(event):
-            logger.warn(
-                "Event content has been tampered, redacting %s, %s",
-                event.event_id, encode_canonical_json(event.get_dict())
-            )
-            event = redacted_event
-
         logger.debug("Event: %s", event)
 
         # FIXME (erikj): Awful hack to make the case where we are not currently
@@ -149,41 +119,20 @@ class FederationHandler(BaseHandler):
             event.room_id,
             self.server_name
         )
-        if not is_in_room and not event.internal_metadata.outlier:
+        if not is_in_room and not event.internal_metadata.is_outlier():
             logger.debug("Got event for room we're not in.")
-
-            replication = self.replication_layer
-
-            if not state:
-                state, auth_chain = yield replication.get_state_for_room(
-                    origin, context=event.room_id, event_id=event.event_id,
-                )
-
-            if not auth_chain:
-                auth_chain = yield replication.get_event_auth(
-                    origin,
-                    context=event.room_id,
-                    event_id=event.event_id,
-                )
-
-            for e in auth_chain:
-                e.internal_metadata.outlier = True
-                try:
-                    yield self._handle_new_event(e, fetch_auth_from=origin)
-                except:
-                    logger.exception(
-                        "Failed to handle auth event %s",
-                        e.event_id,
-                    )
-
             current_state = state
 
-        if state:
+        if state and auth_chain is not None:
             for e in state:
-                logging.info("A :) %r", e)
                 e.internal_metadata.outlier = True
                 try:
-                    yield self._handle_new_event(e)
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    yield self._handle_new_event(origin, e, auth_events=auth)
                 except:
                     logger.exception(
                         "Failed to handle state event %s",
@@ -192,6 +141,7 @@ class FederationHandler(BaseHandler):
 
         try:
             yield self._handle_new_event(
+                origin,
                 event,
                 state=state,
                 backfilled=backfilled,
@@ -394,7 +344,14 @@ class FederationHandler(BaseHandler):
             for e in auth_chain:
                 e.internal_metadata.outlier = True
                 try:
-                    yield self._handle_new_event(e)
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    yield self._handle_new_event(
+                        target_host, e, auth_events=auth
+                    )
                 except:
                     logger.exception(
                         "Failed to handle auth event %s",
@@ -405,8 +362,13 @@ class FederationHandler(BaseHandler):
                 # FIXME: Auth these.
                 e.internal_metadata.outlier = True
                 try:
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
                     yield self._handle_new_event(
-                        e, fetch_auth_from=target_host
+                        target_host, e, auth_events=auth
                     )
                 except:
                     logger.exception(
@@ -415,6 +377,7 @@ class FederationHandler(BaseHandler):
                     )
 
             yield self._handle_new_event(
+                target_host,
                 new_event,
                 state=state,
                 current_state=state,
@@ -481,7 +444,7 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context = yield self._handle_new_event(event)
+        context = yield self._handle_new_event(origin, event)
 
         logger.debug(
             "on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -682,11 +645,12 @@ class FederationHandler(BaseHandler):
             waiters.pop().callback(None)
 
     @defer.inlineCallbacks
-    def _handle_new_event(self, event, state=None, backfilled=False,
-                          current_state=None, fetch_auth_from=None):
+    @log_function
+    def _handle_new_event(self, origin, event, state=None, backfilled=False,
+                          current_state=None, auth_events=None):
 
         logger.debug(
-            "_handle_new_event: Before annotate: %s, sigs: %s",
+            "_handle_new_event: %s, sigs: %s",
             event.event_id, event.signatures,
         )
 
@@ -694,65 +658,44 @@ class FederationHandler(BaseHandler):
             event, old_state=state
         )
 
+        if not auth_events:
+            auth_events = context.auth_events
+
         logger.debug(
-            "_handle_new_event: Before auth fetch: %s, sigs: %s",
-            event.event_id, event.signatures,
+            "_handle_new_event: %s, auth_events: %s",
+            event.event_id, auth_events,
         )
 
         is_new_state = not event.internal_metadata.is_outlier()
 
-        known_ids = set(
-            [s.event_id for s in context.auth_events.values()]
-        )
-
-        for e_id, _ in event.auth_events:
-            if e_id not in known_ids:
-                e = yield self.store.get_event(e_id, allow_none=True)
-
-                if not e and fetch_auth_from is not None:
-                    # Grab the auth_chain over federation if we are missing
-                    # auth events.
-                    auth_chain = yield self.replication_layer.get_event_auth(
-                        fetch_auth_from, event.event_id, event.room_id
-                    )
-                    for auth_event in auth_chain:
-                        yield self._handle_new_event(auth_event)
-                    e = yield self.store.get_event(e_id, allow_none=True)
-
-                if not e:
-                    # TODO: Do some conflict res to make sure that we're
-                    # not the ones who are wrong.
-                    logger.info(
-                        "Rejecting %s as %s not in db or %s",
-                        event.event_id, e_id, known_ids,
-                    )
-                    # FIXME: How does raising AuthError work with federation?
-                    raise AuthError(403, "Cannot find auth event")
-
-                context.auth_events[(e.type, e.state_key)] = e
-
-        logger.debug(
-            "_handle_new_event: Before hack: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
-
+        # This is a hack to fix some old rooms where the initial join event
+        # didn't reference the create event in its auth events.
         if event.type == EventTypes.Member and not event.auth_events:
             if len(event.prev_events) == 1:
                 c = yield self.store.get_event(event.prev_events[0][0])
                 if c.type == EventTypes.Create:
-                    context.auth_events[(c.type, c.state_key)] = c
+                    auth_events[(c.type, c.state_key)] = c
 
-        logger.debug(
-            "_handle_new_event: Before auth check: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
+        try:
+            yield self.do_auth(
+                origin, event, context, auth_events=auth_events
+            )
+        except AuthError as e:
+            logger.warn(
+                "Rejecting %s because %s",
+                event.event_id, e.msg
+            )
 
-        self.auth.check(event, auth_events=context.auth_events)
+            context.rejected = RejectedReason.AUTH_ERROR
 
-        logger.debug(
-            "_handle_new_event: Before persist_event: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
+            yield self.store.persist_event(
+                event,
+                context=context,
+                backfilled=backfilled,
+                is_new_state=False,
+                current_state=current_state,
+            )
+            raise
 
         yield self.store.persist_event(
             event,
@@ -762,9 +705,257 @@ class FederationHandler(BaseHandler):
             current_state=current_state,
         )
 
-        logger.debug(
-            "_handle_new_event: After persist_event: %s, sigs: %s",
-            event.event_id, event.signatures,
+        defer.returnValue(context)
+
+    @defer.inlineCallbacks
+    def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
+                      missing):
+        # Just go through and process each event in `remote_auth_chain`. We
+        # don't want to fall into the trap of `missing` being wrong.
+        for e in remote_auth_chain:
+            try:
+                yield self._handle_new_event(origin, e)
+            except AuthError:
+                pass
+
+        # Now get the current auth_chain for the event.
+        local_auth_chain = yield self.store.get_auth_chain([event_id])
+
+        # TODO: Check if we would now reject event_id. If so we need to tell
+        # everyone.
+
+        ret = yield self.construct_auth_difference(
+            local_auth_chain, remote_auth_chain
         )
 
-        defer.returnValue(context)
+        logger.debug("on_query_auth reutrning: %s", ret)
+
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    @log_function
+    def do_auth(self, origin, event, context, auth_events):
+        # Check if we have all the auth events.
+        res = yield self.store.have_events(
+            [e_id for e_id, _ in event.auth_events]
+        )
+
+        event_auth_events = set(e_id for e_id, _ in event.auth_events)
+        seen_events = set(res.keys())
+
+        missing_auth = event_auth_events - seen_events
+
+        if missing_auth:
+            logger.debug("Missing auth: %s", missing_auth)
+            # If we don't have all the auth events, we need to get them.
+            remote_auth_chain = yield self.replication_layer.get_event_auth(
+                origin, event.room_id, event.event_id
+            )
+
+            for e in remote_auth_chain:
+                try:
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in remote_auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    e.internal_metadata.outlier = True
+                    yield self._handle_new_event(
+                        origin, e, auth_events=auth
+                    )
+                    auth_events[(e.type, e.state_key)] = e
+                except AuthError:
+                    pass
+
+        # FIXME: Assumes we have and stored all the state for all the
+        # prev_events
+        current_state = set(e.event_id for e in auth_events.values())
+        different_auth = event_auth_events - current_state
+
+        if different_auth and not event.internal_metadata.is_outlier():
+            # Do auth conflict res.
+            logger.debug("Different auth: %s", different_auth)
+
+            # 1. Get what we think is the auth chain.
+            auth_ids = self.auth.compute_auth_events(event, context)
+            local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+
+            # 2. Get remote difference.
+            result = yield self.replication_layer.query_auth(
+                origin,
+                event.room_id,
+                event.event_id,
+                local_auth_chain,
+            )
+
+            # 3. Process any remote auth chain events we haven't seen.
+            for missing_id in result.get("missing", []):
+                try:
+                    for e in result["auth_chain"]:
+                        if e.event_id == missing_id:
+                            ev = e
+                            break
+
+                    auth_ids = [e_id for e_id, _ in ev.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in result["auth_chain"]
+                        if e.event_id in auth_ids
+                    }
+                    ev.internal_metadata.outlier = True
+                    yield self._handle_new_event(
+                        origin, ev, auth_events=auth
+                    )
+
+                    if ev.event_id in event_auth_events:
+                        auth_events[(ev.type, ev.state_key)] = ev
+                except AuthError:
+                    pass
+
+            # 4. Look at rejects and their proofs.
+            # TODO.
+
+            context.current_state.update(auth_events)
+            context.state_group = None
+
+        try:
+            self.auth.check(event, auth_events=auth_events)
+        except AuthError:
+            raise
+
+    @defer.inlineCallbacks
+    def construct_auth_difference(self, local_auth, remote_auth):
+        """ Given a local and remote auth chain, find the differences. This
+        assumes that we have already processed all events in remote_auth
+
+        Params:
+            local_auth (list)
+            remote_auth (list)
+
+        Returns:
+            dict
+        """
+
+        logger.debug("construct_auth_difference Start!")
+
+        # TODO: Make sure we are OK with local_auth or remote_auth having more
+        # auth events in them than strictly necessary.
+
+        def sort_fun(ev):
+            return ev.depth, ev.event_id
+
+        logger.debug("construct_auth_difference after sort_fun!")
+
+        # We find the differences by starting at the "bottom" of each list
+        # and iterating up on both lists. The lists are ordered by depth and
+        # then event_id, we iterate up both lists until we find the event ids
+        # don't match. Then we look at depth/event_id to see which side is
+        # missing that event, and iterate only up that list. Repeat.
+
+        remote_list = list(remote_auth)
+        remote_list.sort(key=sort_fun)
+
+        local_list = list(local_auth)
+        local_list.sort(key=sort_fun)
+
+        local_iter = iter(local_list)
+        remote_iter = iter(remote_list)
+
+        logger.debug("construct_auth_difference before get_next!")
+
+        def get_next(it, opt=None):
+            try:
+                return it.next()
+            except:
+                return opt
+
+        current_local = get_next(local_iter)
+        current_remote = get_next(remote_iter)
+
+        logger.debug("construct_auth_difference before while")
+
+        missing_remotes = []
+        missing_locals = []
+        while current_local or current_remote:
+            if current_remote is None:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+                continue
+
+            if current_local is None:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+            if current_local.event_id == current_remote.event_id:
+                current_local = get_next(local_iter)
+                current_remote = get_next(remote_iter)
+                continue
+
+            if current_local.depth < current_remote.depth:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+                continue
+
+            if current_local.depth > current_remote.depth:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+            # They have the same depth, so we fall back to the event_id order
+            if current_local.event_id < current_remote.event_id:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+
+            if current_local.event_id > current_remote.event_id:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+        logger.debug("construct_auth_difference after while")
+
+        # missing locals should be sent to the server
+        # We should find why we are missing remotes, as they will have been
+        # rejected.
+
+        # Remove events from missing_remotes if they are referencing a missing
+        # remote. We only care about the "root" rejected ones.
+        missing_remote_ids = [e.event_id for e in missing_remotes]
+        base_remote_rejected = list(missing_remotes)
+        for e in missing_remotes:
+            for e_id, _ in e.auth_events:
+                if e_id in missing_remote_ids:
+                    base_remote_rejected.remove(e)
+
+        reason_map = {}
+
+        for e in base_remote_rejected:
+            reason = yield self.store.get_rejection_reason(e.event_id)
+            if reason is None:
+                # FIXME: ERRR?!
+                logger.warn("Could not find reason for %s", e.event_id)
+                raise RuntimeError("")
+
+            reason_map[e.event_id] = reason
+
+            if reason == RejectedReason.AUTH_ERROR:
+                pass
+            elif reason == RejectedReason.REPLACED:
+                # TODO: Get proof
+                pass
+            elif reason == RejectedReason.NOT_ANCESTOR:
+                # TODO: Get proof.
+                pass
+
+        logger.debug("construct_auth_difference returning")
+
+        defer.returnValue({
+            "auth_chain": local_auth,
+            "rejects": {
+                e.event_id: {
+                    "reason": reason_map[e.event_id],
+                    "proof": None,
+                }
+                for e in base_remote_rejected
+            },
+            "missing": [e.event_id for e in missing_locals],
+        })