summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2015-01-29 16:50:23 +0000
committerErik Johnston <erik@matrix.org>2015-01-29 16:52:33 +0000
commit78015948a7febb18e000651f72f8f58830a55b93 (patch)
tree42aac1eb30b723f21074bb814f33d8713008725a
parentMake post_json(...) actually send data. (diff)
downloadsynapse-78015948a7febb18e000651f72f8f58830a55b93.tar.xz
Initial implementation of auth conflict resolution
-rw-r--r--synapse/events/utils.py6
-rw-r--r--synapse/federation/federation_client.py2
-rw-r--r--synapse/federation/federation_server.py33
-rw-r--r--synapse/federation/transport/client.py16
-rw-r--r--synapse/federation/transport/server.py21
-rw-r--r--synapse/handlers/federation.py207
-rw-r--r--synapse/storage/rejections.py4
-rw-r--r--tests/handlers/test_federation.py2
8 files changed, 210 insertions, 81 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index bcb5457278..10a6b9f264 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -45,12 +45,14 @@ def prune_event(event):
         "membership",
     ]
 
+    event_dict = event.get_dict()
+
     new_content = {}
 
     def add_fields(*fields):
         for field in fields:
             if field in event.content:
-                new_content[field] = event.content[field]
+                new_content[field] = event_dict["content"][field]
 
     if event_type == EventTypes.Member:
         add_fields("membership")
@@ -75,7 +77,7 @@ def prune_event(event):
 
     allowed_fields = {
         k: v
-        for k, v in event.get_dict().items()
+        for k, v in event_dict.items()
         if k in allowed_keys
     }
 
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index ebcd593506..1173ca817b 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -345,7 +345,7 @@ class FederationClient(object):
             "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
         }
 
-        code, content = yield self.transport_layer.send_invite(
+        code, content = yield self.transport_layer.send_query_auth(
             destination=destination,
             room_id=room_id,
             event_id=event_id,
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index fc5342afaa..8cff4e6472 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -230,6 +230,39 @@ class FederationServer(object):
             "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
         }))
 
+    @defer.inlineCallbacks
+    def on_query_auth_request(self, origin, content, event_id):
+        auth_chain = [
+            (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+            for e in content["auth_chain"]
+        ]
+
+        missing = [
+            (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+            for e in content.get("missing", [])
+        ]
+
+        ret = yield self.handler.on_query_auth(
+            origin, event_id, auth_chain, content.get("rejects", []), missing
+        )
+
+        time_now = self._clock.time_msec()
+        send_content = {
+            "auth_chain": [
+                e.get_pdu_json(time_now)
+                for e in ret["auth_chain"]
+            ],
+            "rejects": content.get("rejects", []),
+            "missing": [
+                e.get_pdu_json(time_now)
+                for e in ret.get("missing", [])
+            ],
+        }
+
+        defer.returnValue(
+            (200, send_content)
+        )
+
     @log_function
     def _get_persisted_pdu(self, origin, event_id, do_auth=True):
         """ Get a PDU from the database with given origin and id.
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index e634a3a213..4cb1dea2de 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -213,3 +213,19 @@ class TransportLayerClient(object):
         )
 
         defer.returnValue(response)
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_query_auth(self, destination, room_id, event_id, content):
+        path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+
+        code, content = yield self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+        )
+
+        if not 200 <= code < 300:
+            raise RuntimeError("Got %d from send_invite", code)
+
+        defer.returnValue(json.loads(content))
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a380a6910b..9c9f8d525b 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -42,7 +42,7 @@ class TransportLayerServer(object):
         content = None
         origin = None
 
-        if request.method == "PUT":
+        if request.method in ["PUT", "POST"]:
             # TODO: Handle other method types? other content types?
             try:
                 content_bytes = request.content.read()
@@ -234,6 +234,16 @@ class TransportLayerServer(object):
                 )
             )
         )
+        self.server.register_path(
+            "POST",
+            re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, event_id:
+                self._on_query_auth_request(
+                    origin, content, event_id,
+                )
+            )
+        )
 
     @defer.inlineCallbacks
     @log_function
@@ -325,3 +335,12 @@ class TransportLayerServer(object):
         )
 
         defer.returnValue((200, content))
+
+    @defer.inlineCallbacks
+    @log_function
+    def _on_query_auth_request(self, origin, content, event_id):
+        new_content = yield self.request_handler.on_query_auth_request(
+            origin, content, event_id
+        )
+
+        defer.returnValue((200, new_content))
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 97e3c503b9..14c26d8cea 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -126,7 +126,7 @@ class FederationHandler(BaseHandler):
 
             if not state:
                 state, auth_chain = yield replication.get_state_for_room(
-                    origin, context=event.room_id, event_id=event.event_id,
+                    origin, room_id=event.room_id, event_id=event.event_id,
                 )
 
             if not auth_chain:
@@ -139,7 +139,7 @@ class FederationHandler(BaseHandler):
             for e in auth_chain:
                 e.internal_metadata.outlier = True
                 try:
-                    yield self._handle_new_event(e, fetch_auth_from=origin)
+                    yield self._handle_new_event(origin, e)
                 except:
                     logger.exception(
                         "Failed to handle auth event %s",
@@ -152,7 +152,7 @@ class FederationHandler(BaseHandler):
             for e in state:
                 e.internal_metadata.outlier = True
                 try:
-                    yield self._handle_new_event(e)
+                    yield self._handle_new_event(origin, e)
                 except:
                     logger.exception(
                         "Failed to handle state event %s",
@@ -161,6 +161,7 @@ class FederationHandler(BaseHandler):
 
         try:
             yield self._handle_new_event(
+                origin,
                 event,
                 state=state,
                 backfilled=backfilled,
@@ -363,7 +364,14 @@ class FederationHandler(BaseHandler):
             for e in auth_chain:
                 e.internal_metadata.outlier = True
                 try:
-                    yield self._handle_new_event(e)
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    yield self._handle_new_event(
+                        target_host, e, auth_events=auth
+                    )
                 except:
                     logger.exception(
                         "Failed to handle auth event %s",
@@ -374,8 +382,13 @@ class FederationHandler(BaseHandler):
                 # FIXME: Auth these.
                 e.internal_metadata.outlier = True
                 try:
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
                     yield self._handle_new_event(
-                        e, fetch_auth_from=target_host
+                        target_host, e, auth_events=auth
                     )
                 except:
                     logger.exception(
@@ -384,6 +397,7 @@ class FederationHandler(BaseHandler):
                     )
 
             yield self._handle_new_event(
+                target_host,
                 new_event,
                 state=state,
                 current_state=state,
@@ -450,7 +464,7 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context = yield self._handle_new_event(event)
+        context = yield self._handle_new_event(origin, event)
 
         logger.debug(
             "on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -651,11 +665,12 @@ class FederationHandler(BaseHandler):
             waiters.pop().callback(None)
 
     @defer.inlineCallbacks
-    def _handle_new_event(self, event, state=None, backfilled=False,
-                          current_state=None, fetch_auth_from=None):
+    @log_function
+    def _handle_new_event(self, origin, event, state=None, backfilled=False,
+                          current_state=None, auth_events=None):
 
         logger.debug(
-            "_handle_new_event: Before annotate: %s, sigs: %s",
+            "_handle_new_event: %s, sigs: %s",
             event.event_id, event.signatures,
         )
 
@@ -663,62 +678,34 @@ class FederationHandler(BaseHandler):
             event, old_state=state
         )
 
+        if not auth_events:
+            auth_events = context.auth_events
+
         logger.debug(
-            "_handle_new_event: Before auth fetch: %s, sigs: %s",
-            event.event_id, event.signatures,
+            "_handle_new_event: %s, auth_events: %s",
+            event.event_id, auth_events,
         )
 
         is_new_state = not event.internal_metadata.is_outlier()
 
-        known_ids = set(
-            [s.event_id for s in context.auth_events.values()]
-        )
-
-        for e_id, _ in event.auth_events:
-            if e_id not in known_ids:
-                e = yield self.store.get_event(e_id, allow_none=True)
-
-                if not e and fetch_auth_from is not None:
-                    # Grab the auth_chain over federation if we are missing
-                    # auth events.
-                    auth_chain = yield self.replication_layer.get_event_auth(
-                        fetch_auth_from, event.event_id, event.room_id
-                    )
-                    for auth_event in auth_chain:
-                        yield self._handle_new_event(auth_event)
-                    e = yield self.store.get_event(e_id, allow_none=True)
-
-                if not e:
-                    # TODO: Do some conflict res to make sure that we're
-                    # not the ones who are wrong.
-                    logger.info(
-                        "Rejecting %s as %s not in db or %s",
-                        event.event_id, e_id, known_ids,
-                    )
-                    # FIXME: How does raising AuthError work with federation?
-                    raise AuthError(403, "Cannot find auth event")
-
-                context.auth_events[(e.type, e.state_key)] = e
-
-        logger.debug(
-            "_handle_new_event: Before hack: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
-
+        # This is a hack to fix some old rooms where the initial join event
+        # didn't reference the create event in its auth events.
         if event.type == EventTypes.Member and not event.auth_events:
             if len(event.prev_events) == 1:
                 c = yield self.store.get_event(event.prev_events[0][0])
                 if c.type == EventTypes.Create:
-                    context.auth_events[(c.type, c.state_key)] = c
-
-        logger.debug(
-            "_handle_new_event: Before auth check: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
+                    auth_events[(c.type, c.state_key)] = c
 
         try:
-            self.auth.check(event, auth_events=context.auth_events)
-        except AuthError:
+            yield self.do_auth(
+                origin, event, context, auth_events=auth_events
+            )
+        except AuthError as e:
+            logger.warn(
+                "Rejecting %s because %s",
+                event.event_id, e.msg
+            )
+
             # TODO: Store rejection.
             context.rejected = RejectedReason.AUTH_ERROR
 
@@ -731,11 +718,6 @@ class FederationHandler(BaseHandler):
             )
             raise
 
-        logger.debug(
-            "_handle_new_event: Before persist_event: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
-
         yield self.store.persist_event(
             event,
             context=context,
@@ -744,25 +726,73 @@ class FederationHandler(BaseHandler):
             current_state=current_state,
         )
 
-        logger.debug(
-            "_handle_new_event: After persist_event: %s, sigs: %s",
-            event.event_id, event.signatures,
+        defer.returnValue(context)
+
+    @defer.inlineCallbacks
+    def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
+                      missing):
+        # Just go through and process each event in `remote_auth_chain`. We
+        # don't want to fall into the trap of `missing` being wrong.
+        for e in remote_auth_chain:
+            try:
+                yield self._handle_new_event(origin, e)
+            except AuthError:
+                pass
+
+        # Now get the current auth_chain for the event.
+        local_auth_chain = yield self.store.get_auth_chain([event_id])
+
+        # TODO: Check if we would now reject event_id. If so we need to tell
+        # everyone.
+
+        ret = yield self.construct_auth_difference(
+            local_auth_chain, remote_auth_chain
         )
 
-        defer.returnValue(context)
+        logger.debug("on_query_auth reutrning: %s", ret)
+
+        defer.returnValue(ret)
 
     @defer.inlineCallbacks
-    def do_auth(self, origin, event, context):
-        for e_id, _ in event.auth_events:
-            pass
+    @log_function
+    def do_auth(self, origin, event, context, auth_events):
+        # Check if we have all the auth events.
+        res = yield self.store.have_events(
+            [e_id for e_id, _ in event.auth_events]
+        )
 
-        auth_events = set(e_id for e_id, _ in event.auth_events)
-        current_state = set(e.event_id for e in context.auth_events.values())
+        event_auth_events = set(e_id for e_id, _ in event.auth_events)
+        seen_events = set(res.keys())
 
-        missing_auth = auth_events - current_state
+        missing_auth = event_auth_events - seen_events
 
         if missing_auth:
+            logger.debug("Missing auth: %s", missing_auth)
+            # If we don't have all the auth events, we need to get them.
+            remote_auth_chain = yield self.replication_layer.get_event_auth(
+                origin, event.room_id, event.event_id
+            )
+
+            for e in remote_auth_chain:
+                try:
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in remote_auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    yield self._handle_new_event(
+                        origin, e, auth_events=auth
+                    )
+                    auth_events[(e.type, e.state_key)] = e
+                except AuthError:
+                    pass
+
+        current_state = set(e.event_id for e in auth_events.values())
+        different_auth = event_auth_events - current_state
+
+        if different_auth and not event.internal_metadata.is_outlier():
             # Do auth conflict res.
+            logger.debug("Different auth: %s", different_auth)
 
             # 1. Get what we think is the auth chain.
             auth_ids = self.auth.compute_auth_events(event, context)
@@ -778,14 +808,24 @@ class FederationHandler(BaseHandler):
 
             # 3. Process any remote auth chain events we haven't seen.
             for e in result.get("missing", []):
-                # TODO.
-                pass
+                try:
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in result["auth_chain"]
+                        if e.event_id in auth_ids
+                    }
+                    yield self._handle_new_event(
+                        origin, e, auth_events=auth
+                    )
+                    auth_events[(e.type, e.state_key)] = e
+                except AuthError:
+                    pass
 
             # 4. Look at rejects and their proofs.
             # TODO.
 
         try:
-            self.auth.check(event, auth_events=context.auth_events)
+            self.auth.check(event, auth_events=auth_events)
         except AuthError:
             raise
 
@@ -802,12 +842,16 @@ class FederationHandler(BaseHandler):
             dict
         """
 
+        logger.debug("construct_auth_difference Start!")
+
         # TODO: Make sure we are OK with local_auth or remote_auth having more
         # auth events in them than strictly necessary.
 
         def sort_fun(ev):
             return ev.depth, ev.event_id
 
+        logger.debug("construct_auth_difference after sort_fun!")
+
         # We find the differences by starting at the "bottom" of each list
         # and iterating up on both lists. The lists are ordered by depth and
         # then event_id, we iterate up both lists until we find the event ids
@@ -823,11 +867,18 @@ class FederationHandler(BaseHandler):
         local_iter = iter(local_list)
         remote_iter = iter(remote_list)
 
-        current_local = local_iter.next()
-        current_remote = remote_iter.next()
+        logger.debug("construct_auth_difference before get_next!")
 
         def get_next(it, opt=None):
-            return it.next() if it.has_next() else opt
+            try:
+                return it.next()
+            except:
+                return opt
+
+        current_local = get_next(local_iter)
+        current_remote = get_next(remote_iter)
+
+        logger.debug("construct_auth_difference before while")
 
         missing_remotes = []
         missing_locals = []
@@ -867,6 +918,8 @@ class FederationHandler(BaseHandler):
                 current_remote = get_next(remote_iter)
                 continue
 
+        logger.debug("construct_auth_difference after while")
+
         # missing locals should be sent to the server
         # We should find why we are missing remotes, as they will have been
         # rejected.
@@ -886,6 +939,7 @@ class FederationHandler(BaseHandler):
             reason = yield self.store.get_rejection_reason(e.event_id)
             if reason is None:
                 # FIXME: ERRR?!
+                logger.warn("Could not find reason for %s", e.event_id)
                 raise RuntimeError("")
 
             reason_map[e.event_id] = reason
@@ -899,7 +953,10 @@ class FederationHandler(BaseHandler):
                 # TODO: Get proof.
                 pass
 
+        logger.debug("construct_auth_difference returning")
+
         defer.returnValue({
+            "auth_chain": local_auth,
             "rejects": {
                 e.event_id: {
                     "reason": reason_map[e.event_id],
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
index b7249700d7..4e1a9a2783 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/rejections.py
@@ -28,12 +28,12 @@ class RejectionsStore(SQLBaseStore):
             values={
                 "event_id": event_id,
                 "reason": reason,
-                "last_failure": self._clock.time_msec(),
+                "last_check": self._clock.time_msec(),
             }
         )
 
     def get_rejection_reason(self, event_id):
-        self._simple_select_one_onecol(
+        return self._simple_select_one_onecol(
             table="rejections",
             retcol="reason",
             keyvalues={
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index ed21defd13..44dbce6bea 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
                 "get_room",
                 "get_destination_retry_timings",
                 "set_destination_retry_timings",
+                "have_events",
             ]),
             resource_for_federation=NonCallableMock(),
             http_client=NonCallableMock(spec_set=[]),
@@ -90,6 +91,7 @@ class FederationTestCase(unittest.TestCase):
         self.datastore.persist_event.return_value = defer.succeed(None)
         self.datastore.get_room.return_value = defer.succeed(True)
         self.auth.check_host_in_room.return_value = defer.succeed(True)
+        self.datastore.have_events.return_value = defer.succeed({})
 
         def annotate(ev, old_state=None):
             context = Mock()