summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py38
-rw-r--r--synapse/api/constants.py6
-rw-r--r--synapse/federation/federation_client.py39
-rw-r--r--synapse/handlers/federation.py211
-rw-r--r--synapse/storage/rejections.py10
-rw-r--r--synapse/storage/schema/im.sql1
6 files changed, 253 insertions, 52 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index a342a0e0da..461faa8c78 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -353,9 +353,23 @@ class Auth(object):
     def add_auth_events(self, builder, context):
         yield run_on_reactor()
 
-        if builder.type == EventTypes.Create:
-            builder.auth_events = []
-            return
+        auth_ids = self.compute_auth_events(builder, context)
+
+        auth_events_entries = yield self.store.add_event_hashes(
+            auth_ids
+        )
+
+        builder.auth_events = auth_events_entries
+
+        context.auth_events = {
+            k: v
+            for k, v in context.current_state.items()
+            if v.event_id in auth_ids
+        }
+
+    def compute_auth_events(self, event, context):
+        if event.type == EventTypes.Create:
+            return []
 
         auth_ids = []
 
@@ -368,7 +382,7 @@ class Auth(object):
         key = (EventTypes.JoinRules, "", )
         join_rule_event = context.current_state.get(key)
 
-        key = (EventTypes.Member, builder.user_id, )
+        key = (EventTypes.Member, event.user_id, )
         member_event = context.current_state.get(key)
 
         key = (EventTypes.Create, "", )
@@ -382,8 +396,8 @@ class Auth(object):
         else:
             is_public = False
 
-        if builder.type == EventTypes.Member:
-            e_type = builder.content["membership"]
+        if event.type == EventTypes.Member:
+            e_type = event.content["membership"]
             if e_type in [Membership.JOIN, Membership.INVITE]:
                 if join_rule_event:
                     auth_ids.append(join_rule_event.event_id)
@@ -398,17 +412,7 @@ class Auth(object):
             if member_event.content["membership"] == Membership.JOIN:
                 auth_ids.append(member_event.event_id)
 
-        auth_events_entries = yield self.store.add_event_hashes(
-            auth_ids
-        )
-
-        builder.auth_events = auth_events_entries
-
-        context.auth_events = {
-            k: v
-            for k, v in context.current_state.items()
-            if v.event_id in auth_ids
-        }
+        return auth_ids
 
     @log_function
     def _can_send_event(self, event, auth_events):
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 7ee6dcc46e..0d3fc629af 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -74,3 +74,9 @@ class EventTypes(object):
     Message = "m.room.message"
     Topic = "m.room.topic"
     Name = "m.room.name"
+
+
+class RejectedReason(object):
+    AUTH_ERROR = "auth_error"
+    REPLACED = "replaced"
+    NOT_ANCESTOR = "not_ancestor"
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 91b44cd8b3..ebcd593506 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -331,6 +331,45 @@ class FederationClient(object):
 
         defer.returnValue(pdu)
 
+    @defer.inlineCallbacks
+    def query_auth(self, destination, room_id, event_id, local_auth):
+        """
+        Params:
+            destination (str)
+            event_it (str)
+            local_auth (list)
+        """
+        time_now = self._clock.time_msec()
+
+        send_content = {
+            "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
+        }
+
+        code, content = yield self.transport_layer.send_invite(
+            destination=destination,
+            room_id=room_id,
+            event_id=event_id,
+            content=send_content,
+        )
+
+        auth_chain = [
+            (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+            for e in content["auth_chain"]
+        ]
+
+        missing = [
+            (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+            for e in content.get("missing", [])
+        ]
+
+        ret = {
+            "auth_chain": auth_chain,
+            "rejects": content.get("rejects", []),
+            "missing": missing,
+        }
+
+        defer.returnValue(ret)
+
     def event_from_pdu_json(self, pdu_json, outlier=False):
         event = FrozenEvent(
             pdu_json
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bcdcc90a18..97e3c503b9 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,19 +17,16 @@
 
 from ._base import BaseHandler
 
-from synapse.events.utils import prune_event
 from synapse.api.errors import (
-    AuthError, FederationError, SynapseError, StoreError,
+    AuthError, FederationError, StoreError,
 )
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.crypto.event_signing import (
-    compute_event_signature, check_event_content_hash,
-    add_hashes_and_signatures,
+    compute_event_signature, add_hashes_and_signatures,
 )
 from synapse.types import UserID
-from syutil.jsonutil import encode_canonical_json
 
 from twisted.internet import defer
 
@@ -113,33 +110,6 @@ class FederationHandler(BaseHandler):
 
         logger.debug("Processing event: %s", event.event_id)
 
-        redacted_event = prune_event(event)
-
-        redacted_pdu_json = redacted_event.get_pdu_json()
-        try:
-            yield self.keyring.verify_json_for_server(
-                event.origin, redacted_pdu_json
-            )
-        except SynapseError as e:
-            logger.warn(
-                "Signature check failed for %s redacted to %s",
-                encode_canonical_json(pdu.get_pdu_json()),
-                encode_canonical_json(redacted_pdu_json),
-            )
-            raise FederationError(
-                "ERROR",
-                e.code,
-                e.msg,
-                affected=event.event_id,
-            )
-
-        if not check_event_content_hash(event):
-            logger.warn(
-                "Event content has been tampered, redacting %s, %s",
-                event.event_id, encode_canonical_json(event.get_dict())
-            )
-            event = redacted_event
-
         logger.debug("Event: %s", event)
 
         # FIXME (erikj): Awful hack to make the case where we are not currently
@@ -180,7 +150,6 @@ class FederationHandler(BaseHandler):
 
         if state:
             for e in state:
-                logging.info("A :) %r", e)
                 e.internal_metadata.outlier = True
                 try:
                     yield self._handle_new_event(e)
@@ -747,7 +716,20 @@ class FederationHandler(BaseHandler):
             event.event_id, event.signatures,
         )
 
-        self.auth.check(event, auth_events=context.auth_events)
+        try:
+            self.auth.check(event, auth_events=context.auth_events)
+        except AuthError:
+            # TODO: Store rejection.
+            context.rejected = RejectedReason.AUTH_ERROR
+
+            yield self.store.persist_event(
+                event,
+                context=context,
+                backfilled=backfilled,
+                is_new_state=False,
+                current_state=current_state,
+            )
+            raise
 
         logger.debug(
             "_handle_new_event: Before persist_event: %s, sigs: %s",
@@ -768,3 +750,162 @@ class FederationHandler(BaseHandler):
         )
 
         defer.returnValue(context)
+
+    @defer.inlineCallbacks
+    def do_auth(self, origin, event, context):
+        for e_id, _ in event.auth_events:
+            pass
+
+        auth_events = set(e_id for e_id, _ in event.auth_events)
+        current_state = set(e.event_id for e in context.auth_events.values())
+
+        missing_auth = auth_events - current_state
+
+        if missing_auth:
+            # Do auth conflict res.
+
+            # 1. Get what we think is the auth chain.
+            auth_ids = self.auth.compute_auth_events(event, context)
+            local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+
+            # 2. Get remote difference.
+            result = yield self.replication_layer.query_auth(
+                origin,
+                event.room_id,
+                event.event_id,
+                local_auth_chain,
+            )
+
+            # 3. Process any remote auth chain events we haven't seen.
+            for e in result.get("missing", []):
+                # TODO.
+                pass
+
+            # 4. Look at rejects and their proofs.
+            # TODO.
+
+        try:
+            self.auth.check(event, auth_events=context.auth_events)
+        except AuthError:
+            raise
+
+    @defer.inlineCallbacks
+    def construct_auth_difference(self, local_auth, remote_auth):
+        """ Given a local and remote auth chain, find the differences. This
+        assumes that we have already processed all events in remote_auth
+
+        Params:
+            local_auth (list)
+            remote_auth (list)
+
+        Returns:
+            dict
+        """
+
+        # TODO: Make sure we are OK with local_auth or remote_auth having more
+        # auth events in them than strictly necessary.
+
+        def sort_fun(ev):
+            return ev.depth, ev.event_id
+
+        # We find the differences by starting at the "bottom" of each list
+        # and iterating up on both lists. The lists are ordered by depth and
+        # then event_id, we iterate up both lists until we find the event ids
+        # don't match. Then we look at depth/event_id to see which side is
+        # missing that event, and iterate only up that list. Repeat.
+
+        remote_list = list(remote_auth)
+        remote_list.sort(key=sort_fun)
+
+        local_list = list(local_auth)
+        local_list.sort(key=sort_fun)
+
+        local_iter = iter(local_list)
+        remote_iter = iter(remote_list)
+
+        current_local = local_iter.next()
+        current_remote = remote_iter.next()
+
+        def get_next(it, opt=None):
+            return it.next() if it.has_next() else opt
+
+        missing_remotes = []
+        missing_locals = []
+        while current_local and current_remote:
+            if current_remote is None:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+                continue
+
+            if current_local is None:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+            if current_local.event_id == current_remote.event_id:
+                current_local = get_next(local_iter)
+                current_remote = get_next(remote_iter)
+                continue
+
+            if current_local.depth < current_remote.depth:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+                continue
+
+            if current_local.depth > current_remote.depth:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+            # They have the same depth, so we fall back to the event_id order
+            if current_local.event_id < current_remote.event_id:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+
+            if current_local.event_id > current_remote.event_id:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+        # missing locals should be sent to the server
+        # We should find why we are missing remotes, as they will have been
+        # rejected.
+
+        # Remove events from missing_remotes if they are referencing a missing
+        # remote. We only care about the "root" rejected ones.
+        missing_remote_ids = [e.event_id for e in missing_remotes]
+        base_remote_rejected = list(missing_remotes)
+        for e in missing_remotes:
+            for e_id, _ in e.auth_events:
+                if e_id in missing_remote_ids:
+                    base_remote_rejected.remove(e)
+
+        reason_map = {}
+
+        for e in base_remote_rejected:
+            reason = yield self.store.get_rejection_reason(e.event_id)
+            if reason is None:
+                # FIXME: ERRR?!
+                raise RuntimeError("")
+
+            reason_map[e.event_id] = reason
+
+            if reason == RejectedReason.AUTH_ERROR:
+                pass
+            elif reason == RejectedReason.REPLACED:
+                # TODO: Get proof
+                pass
+            elif reason == RejectedReason.NOT_ANCESTOR:
+                # TODO: Get proof.
+                pass
+
+        defer.returnValue({
+            "rejects": {
+                e.event_id: {
+                    "reason": reason_map[e.event_id],
+                    "proof": None,
+                }
+                for e in base_remote_rejected
+            },
+            "missing": missing_locals,
+        })
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
index 7d38b31f44..b7249700d7 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/rejections.py
@@ -31,3 +31,13 @@ class RejectionsStore(SQLBaseStore):
                 "last_failure": self._clock.time_msec(),
             }
         )
+
+    def get_rejection_reason(self, event_id):
+        self._simple_select_one_onecol(
+            table="rejections",
+            retcol="reason",
+            keyvalues={
+                "event_id": event_id,
+            },
+            allow_none=True,
+        )
diff --git a/synapse/storage/schema/im.sql b/synapse/storage/schema/im.sql
index bc7c6b6ed5..5866a387f6 100644
--- a/synapse/storage/schema/im.sql
+++ b/synapse/storage/schema/im.sql
@@ -128,5 +128,6 @@ CREATE TABLE IF NOT EXISTS rejections(
     event_id TEXT NOT NULL,
     reason TEXT NOT NULL,
     last_check TEXT NOT NULL,
+    root_rejected TEXT,
     CONSTRAINT ev_id UNIQUE (event_id) ON CONFLICT REPLACE
 );