summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py40
-rw-r--r--synapse/api/constants.py6
-rw-r--r--synapse/crypto/keyclient.py7
-rw-r--r--synapse/events/__init__.py2
-rw-r--r--synapse/events/builder.py5
-rw-r--r--synapse/events/utils.py11
-rw-r--r--synapse/federation/federation_client.py142
-rw-r--r--synapse/federation/federation_server.py123
-rw-r--r--synapse/federation/replication.py2
-rw-r--r--synapse/federation/transport/client.py16
-rw-r--r--synapse/federation/transport/server.py21
-rw-r--r--synapse/handlers/federation.py427
-rw-r--r--synapse/http/matrixfederationclient.py37
-rw-r--r--synapse/state.py16
-rw-r--r--synapse/storage/rejections.py12
-rw-r--r--tests/handlers/test_federation.py2
16 files changed, 697 insertions, 172 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 9c03024512..37e31d2b6f 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -102,6 +102,8 @@ class Auth(object):
     def check_host_in_room(self, room_id, host):
         curr_state = yield self.state.get_current_state(room_id)
 
+        logger.debug("Got curr_state %s", curr_state)
+
         for event in curr_state:
             if event.type == EventTypes.Member:
                 try:
@@ -358,9 +360,23 @@ class Auth(object):
     def add_auth_events(self, builder, context):
         yield run_on_reactor()
 
-        if builder.type == EventTypes.Create:
-            builder.auth_events = []
-            return
+        auth_ids = self.compute_auth_events(builder, context)
+
+        auth_events_entries = yield self.store.add_event_hashes(
+            auth_ids
+        )
+
+        builder.auth_events = auth_events_entries
+
+        context.auth_events = {
+            k: v
+            for k, v in context.current_state.items()
+            if v.event_id in auth_ids
+        }
+
+    def compute_auth_events(self, event, context):
+        if event.type == EventTypes.Create:
+            return []
 
         auth_ids = []
 
@@ -373,7 +389,7 @@ class Auth(object):
         key = (EventTypes.JoinRules, "", )
         join_rule_event = context.current_state.get(key)
 
-        key = (EventTypes.Member, builder.user_id, )
+        key = (EventTypes.Member, event.user_id, )
         member_event = context.current_state.get(key)
 
         key = (EventTypes.Create, "", )
@@ -387,8 +403,8 @@ class Auth(object):
         else:
             is_public = False
 
-        if builder.type == EventTypes.Member:
-            e_type = builder.content["membership"]
+        if event.type == EventTypes.Member:
+            e_type = event.content["membership"]
             if e_type in [Membership.JOIN, Membership.INVITE]:
                 if join_rule_event:
                     auth_ids.append(join_rule_event.event_id)
@@ -403,17 +419,7 @@ class Auth(object):
             if member_event.content["membership"] == Membership.JOIN:
                 auth_ids.append(member_event.event_id)
 
-        auth_events_entries = yield self.store.add_event_hashes(
-            auth_ids
-        )
-
-        builder.auth_events = auth_events_entries
-
-        context.auth_events = {
-            k: v
-            for k, v in context.current_state.items()
-            if v.event_id in auth_ids
-        }
+        return auth_ids
 
     @log_function
     def _can_send_event(self, event, auth_events):
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index 7ee6dcc46e..0d3fc629af 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -74,3 +74,9 @@ class EventTypes(object):
     Message = "m.room.message"
     Topic = "m.room.topic"
     Name = "m.room.name"
+
+
+class RejectedReason(object):
+    AUTH_ERROR = "auth_error"
+    REPLACED = "replaced"
+    NOT_ANCESTOR = "not_ancestor"
diff --git a/synapse/crypto/keyclient.py b/synapse/crypto/keyclient.py
index 9c910fa3fc..cdb6279764 100644
--- a/synapse/crypto/keyclient.py
+++ b/synapse/crypto/keyclient.py
@@ -61,9 +61,11 @@ class SynapseKeyClientProtocol(HTTPClient):
 
     def __init__(self):
         self.remote_key = defer.Deferred()
+        self.host = None
 
     def connectionMade(self):
-        logger.debug("Connected to %s", self.transport.getHost())
+        self.host = self.transport.getHost()
+        logger.debug("Connected to %s", self.host)
         self.sendCommand(b"GET", b"/_matrix/key/v1/")
         self.endHeaders()
         self.timer = reactor.callLater(
@@ -92,8 +94,7 @@ class SynapseKeyClientProtocol(HTTPClient):
         self.timer.cancel()
 
     def on_timeout(self):
-        logger.debug("Timeout waiting for response from %s",
-                     self.transport.getHost())
+        logger.debug("Timeout waiting for response from %s", self.host)
         self.remote_key.errback(IOError("Timeout waiting for response"))
         self.transport.abortConnection()
 
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index 4252e5ab5c..bf07951027 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -18,7 +18,7 @@ from synapse.util.frozenutils import freeze, unfreeze
 
 class _EventInternalMetadata(object):
     def __init__(self, internal_metadata_dict):
-        self.__dict__ = internal_metadata_dict
+        self.__dict__ = dict(internal_metadata_dict)
 
     def get_dict(self):
         return dict(self.__dict__)
diff --git a/synapse/events/builder.py b/synapse/events/builder.py
index a9b1b99a10..9d45bdb892 100644
--- a/synapse/events/builder.py
+++ b/synapse/events/builder.py
@@ -23,14 +23,15 @@ import copy
 
 
 class EventBuilder(EventBase):
-    def __init__(self, key_values={}):
+    def __init__(self, key_values={}, internal_metadata_dict={}):
         signatures = copy.deepcopy(key_values.pop("signatures", {}))
         unsigned = copy.deepcopy(key_values.pop("unsigned", {}))
 
         super(EventBuilder, self).__init__(
             key_values,
             signatures=signatures,
-            unsigned=unsigned
+            unsigned=unsigned,
+            internal_metadata_dict=internal_metadata_dict,
         )
 
     def build(self):
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index e391aca4cc..7ae5d42b96 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -45,12 +45,14 @@ def prune_event(event):
         "membership",
     ]
 
+    event_dict = event.get_dict()
+
     new_content = {}
 
     def add_fields(*fields):
         for field in fields:
             if field in event.content:
-                new_content[field] = event.content[field]
+                new_content[field] = event_dict["content"][field]
 
     if event_type == EventTypes.Member:
         add_fields("membership")
@@ -75,7 +77,7 @@ def prune_event(event):
 
     allowed_fields = {
         k: v
-        for k, v in event.get_dict().items()
+        for k, v in event_dict.items()
         if k in allowed_keys
     }
 
@@ -86,7 +88,10 @@ def prune_event(event):
     if "age_ts" in event.unsigned:
         allowed_fields["unsigned"]["age_ts"] = event.unsigned["age_ts"]
 
-    return type(event)(allowed_fields)
+    return type(event)(
+        allowed_fields,
+        internal_metadata_dict=event.internal_metadata.get_dict()
+    )
 
 
 def serialize_event(e, time_now_ms, client_event=True):
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c80f4c61bc..e1539bd0e0 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -20,6 +20,13 @@ from .units import Edu
 
 from synapse.util.logutils import log_function
 from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from syutil.jsonutil import encode_canonical_json
+
+from synapse.crypto.event_signing import check_event_content_hash
+
+from synapse.api.errors import SynapseError
 
 import logging
 
@@ -126,6 +133,11 @@ class FederationClient(object):
             for p in transaction_data["pdus"]
         ]
 
+        for i, pdu in enumerate(pdus):
+            pdus[i] = yield self._check_sigs_and_hash(pdu)
+
+            # FIXME: We should handle signature failures more gracefully.
+
         defer.returnValue(pdus)
 
     @defer.inlineCallbacks
@@ -159,6 +171,22 @@ class FederationClient(object):
                 transaction_data = yield self.transport_layer.get_event(
                     destination, event_id
                 )
+
+                logger.debug("transaction_data %r", transaction_data)
+
+                pdu_list = [
+                    self.event_from_pdu_json(p, outlier=outlier)
+                    for p in transaction_data["pdus"]
+                ]
+
+                if pdu_list:
+                    pdu = pdu_list[0]
+
+                    # Check signatures are correct.
+                    pdu = yield self._check_sigs_and_hash(pdu)
+
+                    break
+
             except Exception as e:
                 logger.info(
                     "Failed to get PDU %s from %s because %s",
@@ -166,18 +194,6 @@ class FederationClient(object):
                 )
                 continue
 
-            logger.debug("transaction_data %r", transaction_data)
-
-            pdu_list = [
-                self.event_from_pdu_json(p, outlier=outlier)
-                for p in transaction_data["pdus"]
-            ]
-
-            if pdu_list:
-                pdu = pdu_list[0]
-                # TODO: We need to check signatures here
-                break
-
         defer.returnValue(pdu)
 
     @defer.inlineCallbacks
@@ -208,6 +224,16 @@ class FederationClient(object):
             for p in result.get("auth_chain", [])
         ]
 
+        for i, pdu in enumerate(pdus):
+            pdus[i] = yield self._check_sigs_and_hash(pdu)
+
+            # FIXME: We should handle signature failures more gracefully.
+
+        for i, pdu in enumerate(auth_chain):
+            auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+
+            # FIXME: We should handle signature failures more gracefully.
+
         defer.returnValue((pdus, auth_chain))
 
     @defer.inlineCallbacks
@@ -222,6 +248,11 @@ class FederationClient(object):
             for p in res["auth_chain"]
         ]
 
+        for i, pdu in enumerate(auth_chain):
+            auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+
+            # FIXME: We should handle signature failures more gracefully.
+
         auth_chain.sort(key=lambda e: e.depth)
 
         defer.returnValue(auth_chain)
@@ -260,6 +291,16 @@ class FederationClient(object):
             for p in content.get("auth_chain", [])
         ]
 
+        for i, pdu in enumerate(state):
+            state[i] = yield self._check_sigs_and_hash(pdu)
+
+            # FIXME: We should handle signature failures more gracefully.
+
+        for i, pdu in enumerate(auth_chain):
+            auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+
+            # FIXME: We should handle signature failures more gracefully.
+
         auth_chain.sort(key=lambda e: e.depth)
 
         defer.returnValue({
@@ -281,7 +322,48 @@ class FederationClient(object):
 
         logger.debug("Got response to send_invite: %s", pdu_dict)
 
-        defer.returnValue(self.event_from_pdu_json(pdu_dict))
+        pdu = self.event_from_pdu_json(pdu_dict)
+
+        # Check signatures are correct.
+        pdu = yield self._check_sigs_and_hash(pdu)
+
+        # FIXME: We should handle signature failures more gracefully.
+
+        defer.returnValue(pdu)
+
+    @defer.inlineCallbacks
+    def query_auth(self, destination, room_id, event_id, local_auth):
+        """
+        Params:
+            destination (str)
+            event_it (str)
+            local_auth (list)
+        """
+        time_now = self._clock.time_msec()
+
+        send_content = {
+            "auth_chain": [e.get_pdu_json(time_now) for e in local_auth],
+        }
+
+        code, content = yield self.transport_layer.send_query_auth(
+            destination=destination,
+            room_id=room_id,
+            event_id=event_id,
+            content=send_content,
+        )
+
+        auth_chain = [
+            (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+            for e in content["auth_chain"]
+        ]
+
+        ret = {
+            "auth_chain": auth_chain,
+            "rejects": content.get("rejects", []),
+            "missing": content.get("missing", []),
+        }
+
+        defer.returnValue(ret)
 
     def event_from_pdu_json(self, pdu_json, outlier=False):
         event = FrozenEvent(
@@ -291,3 +373,37 @@ class FederationClient(object):
         event.internal_metadata.outlier = outlier
 
         return event
+
+    @defer.inlineCallbacks
+    def _check_sigs_and_hash(self, pdu):
+        """Throws a SynapseError if the PDU does not have the correct
+        signatures.
+
+        Returns:
+            FrozenEvent: Either the given event or it redacted if it failed the
+            content hash check.
+        """
+        # Check signatures are correct.
+        redacted_event = prune_event(pdu)
+        redacted_pdu_json = redacted_event.get_pdu_json()
+
+        try:
+            yield self.keyring.verify_json_for_server(
+                pdu.origin, redacted_pdu_json
+            )
+        except SynapseError:
+            logger.warn(
+                "Signature check failed for %s redacted to %s",
+                encode_canonical_json(pdu.get_pdu_json()),
+                encode_canonical_json(redacted_pdu_json),
+            )
+            raise
+
+        if not check_event_content_hash(pdu):
+            logger.warn(
+                "Event content has been tampered, redacting %s, %s",
+                pdu.event_id, encode_canonical_json(pdu.get_dict())
+            )
+            defer.returnValue(redacted_event)
+
+        defer.returnValue(pdu)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 0597725ce7..84ed0a0ba0 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -21,6 +21,13 @@ from .units import Transaction, Edu
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import PreserveLoggingContext
 from synapse.events import FrozenEvent
+from synapse.events.utils import prune_event
+
+from syutil.jsonutil import encode_canonical_json
+
+from synapse.crypto.event_signing import check_event_content_hash
+
+from synapse.api.errors import FederationError, SynapseError
 
 import logging
 
@@ -97,8 +104,10 @@ class FederationServer(object):
         response = yield self.transaction_actions.have_responded(transaction)
 
         if response:
-            logger.debug("[%s] We've already responed to this request",
-                         transaction.transaction_id)
+            logger.debug(
+                "[%s] We've already responed to this request",
+                transaction.transaction_id
+            )
             defer.returnValue(response)
             return
 
@@ -221,6 +230,36 @@ class FederationServer(object):
             "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus],
         }))
 
+    @defer.inlineCallbacks
+    def on_query_auth_request(self, origin, content, event_id):
+        auth_chain = [
+            (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+            for e in content["auth_chain"]
+        ]
+
+        missing = [
+            (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+            for e in content.get("missing", [])
+        ]
+
+        ret = yield self.handler.on_query_auth(
+            origin, event_id, auth_chain, content.get("rejects", []), missing
+        )
+
+        time_now = self._clock.time_msec()
+        send_content = {
+            "auth_chain": [
+                e.get_pdu_json(time_now)
+                for e in ret["auth_chain"]
+            ],
+            "rejects": ret.get("rejects", []),
+            "missing": ret.get("missing", []),
+        }
+
+        defer.returnValue(
+            (200, send_content)
+        )
+
     @log_function
     def _get_persisted_pdu(self, origin, event_id, do_auth=True):
         """ Get a PDU from the database with given origin and id.
@@ -253,6 +292,9 @@ class FederationServer(object):
             origin, pdu.event_id, do_auth=False
         )
 
+        # FIXME: Currently we fetch an event again when we already have it
+        # if it has been marked as an outlier.
+
         already_seen = (
             existing and (
                 not existing.internal_metadata.is_outlier()
@@ -264,14 +306,27 @@ class FederationServer(object):
             defer.returnValue({})
             return
 
+        # Check signature.
+        try:
+            pdu = yield self._check_sigs_and_hash(pdu)
+        except SynapseError as e:
+            raise FederationError(
+                "ERROR",
+                e.code,
+                e.msg,
+                affected=pdu.event_id,
+            )
+
         state = None
 
         auth_chain = []
 
         have_seen = yield self.store.have_events(
-            [e for e, _ in pdu.prev_events]
+            [ev for ev, _ in pdu.prev_events]
         )
 
+        fetch_state = False
+
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
@@ -308,19 +363,29 @@ class FederationServer(object):
                                 logger.debug("Processed pdu %s", event_id)
                             else:
                                 logger.warn("Failed to get PDU %s", event_id)
+                                fetch_state = True
                         except:
                             # TODO(erikj): Do some more intelligent retries.
                             logger.exception("Failed to get PDU")
+                            fetch_state = True
             else:
-                # We need to get the state at this event, since we have reached
-                # a backward extremity edge.
-                logger.debug(
-                    "_handle_new_pdu getting state for %s",
-                    pdu.room_id
-                )
-                state, auth_chain = yield self.get_state_for_room(
-                    origin, pdu.room_id, pdu.event_id,
-                )
+                prevs = {e_id for e_id, _ in pdu.prev_events}
+                seen = set(have_seen.keys())
+                if prevs - seen:
+                    fetch_state = True
+        else:
+            fetch_state = True
+
+        if fetch_state:
+            # We need to get the state at this event, since we haven't
+            # processed all the prev events.
+            logger.debug(
+                "_handle_new_pdu getting state for %s",
+                pdu.room_id
+            )
+            state, auth_chain = yield self.get_state_for_room(
+                origin, pdu.room_id, pdu.event_id,
+            )
 
         ret = yield self.handler.on_receive_pdu(
             origin,
@@ -343,3 +408,37 @@ class FederationServer(object):
         event.internal_metadata.outlier = outlier
 
         return event
+
+    @defer.inlineCallbacks
+    def _check_sigs_and_hash(self, pdu):
+        """Throws a SynapseError if the PDU does not have the correct
+        signatures.
+
+        Returns:
+            FrozenEvent: Either the given event or it redacted if it failed the
+            content hash check.
+        """
+        # Check signatures are correct.
+        redacted_event = prune_event(pdu)
+        redacted_pdu_json = redacted_event.get_pdu_json()
+
+        try:
+            yield self.keyring.verify_json_for_server(
+                pdu.origin, redacted_pdu_json
+            )
+        except SynapseError:
+            logger.warn(
+                "Signature check failed for %s redacted to %s",
+                encode_canonical_json(pdu.get_pdu_json()),
+                encode_canonical_json(redacted_pdu_json),
+            )
+            raise
+
+        if not check_event_content_hash(pdu):
+            logger.warn(
+                "Event content has been tampered, redacting %s, %s",
+                pdu.event_id, encode_canonical_json(pdu.get_dict())
+            )
+            defer.returnValue(redacted_event)
+
+        defer.returnValue(pdu)
diff --git a/synapse/federation/replication.py b/synapse/federation/replication.py
index 9ef4834927..e442c6c5d5 100644
--- a/synapse/federation/replication.py
+++ b/synapse/federation/replication.py
@@ -51,6 +51,8 @@ class ReplicationLayer(FederationClient, FederationServer):
     def __init__(self, hs, transport_layer):
         self.server_name = hs.hostname
 
+        self.keyring = hs.get_keyring()
+
         self.transport_layer = transport_layer
         self.transport_layer.register_received_handler(self)
         self.transport_layer.register_request_handler(self)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index e634a3a213..4cb1dea2de 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -213,3 +213,19 @@ class TransportLayerClient(object):
         )
 
         defer.returnValue(response)
+
+    @defer.inlineCallbacks
+    @log_function
+    def send_query_auth(self, destination, room_id, event_id, content):
+        path = PREFIX + "/query_auth/%s/%s" % (room_id, event_id)
+
+        code, content = yield self.client.post_json(
+            destination=destination,
+            path=path,
+            data=content,
+        )
+
+        if not 200 <= code < 300:
+            raise RuntimeError("Got %d from send_invite", code)
+
+        defer.returnValue(json.loads(content))
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index a380a6910b..9c9f8d525b 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -42,7 +42,7 @@ class TransportLayerServer(object):
         content = None
         origin = None
 
-        if request.method == "PUT":
+        if request.method in ["PUT", "POST"]:
             # TODO: Handle other method types? other content types?
             try:
                 content_bytes = request.content.read()
@@ -234,6 +234,16 @@ class TransportLayerServer(object):
                 )
             )
         )
+        self.server.register_path(
+            "POST",
+            re.compile("^" + PREFIX + "/query_auth/([^/]*)/([^/]*)$"),
+            self._with_authentication(
+                lambda origin, content, query, context, event_id:
+                self._on_query_auth_request(
+                    origin, content, event_id,
+                )
+            )
+        )
 
     @defer.inlineCallbacks
     @log_function
@@ -325,3 +335,12 @@ class TransportLayerServer(object):
         )
 
         defer.returnValue((200, content))
+
+    @defer.inlineCallbacks
+    @log_function
+    def _on_query_auth_request(self, origin, content, event_id):
+        new_content = yield self.request_handler.on_query_auth_request(
+            origin, content, event_id
+        )
+
+        defer.returnValue((200, new_content))
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index bcdcc90a18..35cad4182a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -17,19 +17,16 @@
 
 from ._base import BaseHandler
 
-from synapse.events.utils import prune_event
 from synapse.api.errors import (
-    AuthError, FederationError, SynapseError, StoreError,
+    AuthError, FederationError, StoreError,
 )
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.crypto.event_signing import (
-    compute_event_signature, check_event_content_hash,
-    add_hashes_and_signatures,
+    compute_event_signature, add_hashes_and_signatures,
 )
 from synapse.types import UserID
-from syutil.jsonutil import encode_canonical_json
 
 from twisted.internet import defer
 
@@ -113,33 +110,6 @@ class FederationHandler(BaseHandler):
 
         logger.debug("Processing event: %s", event.event_id)
 
-        redacted_event = prune_event(event)
-
-        redacted_pdu_json = redacted_event.get_pdu_json()
-        try:
-            yield self.keyring.verify_json_for_server(
-                event.origin, redacted_pdu_json
-            )
-        except SynapseError as e:
-            logger.warn(
-                "Signature check failed for %s redacted to %s",
-                encode_canonical_json(pdu.get_pdu_json()),
-                encode_canonical_json(redacted_pdu_json),
-            )
-            raise FederationError(
-                "ERROR",
-                e.code,
-                e.msg,
-                affected=event.event_id,
-            )
-
-        if not check_event_content_hash(event):
-            logger.warn(
-                "Event content has been tampered, redacting %s, %s",
-                event.event_id, encode_canonical_json(event.get_dict())
-            )
-            event = redacted_event
-
         logger.debug("Event: %s", event)
 
         # FIXME (erikj): Awful hack to make the case where we are not currently
@@ -149,41 +119,20 @@ class FederationHandler(BaseHandler):
             event.room_id,
             self.server_name
         )
-        if not is_in_room and not event.internal_metadata.outlier:
+        if not is_in_room and not event.internal_metadata.is_outlier():
             logger.debug("Got event for room we're not in.")
-
-            replication = self.replication_layer
-
-            if not state:
-                state, auth_chain = yield replication.get_state_for_room(
-                    origin, context=event.room_id, event_id=event.event_id,
-                )
-
-            if not auth_chain:
-                auth_chain = yield replication.get_event_auth(
-                    origin,
-                    context=event.room_id,
-                    event_id=event.event_id,
-                )
-
-            for e in auth_chain:
-                e.internal_metadata.outlier = True
-                try:
-                    yield self._handle_new_event(e, fetch_auth_from=origin)
-                except:
-                    logger.exception(
-                        "Failed to handle auth event %s",
-                        e.event_id,
-                    )
-
             current_state = state
 
-        if state:
+        if state and auth_chain is not None:
             for e in state:
-                logging.info("A :) %r", e)
                 e.internal_metadata.outlier = True
                 try:
-                    yield self._handle_new_event(e)
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    yield self._handle_new_event(origin, e, auth_events=auth)
                 except:
                     logger.exception(
                         "Failed to handle state event %s",
@@ -192,6 +141,7 @@ class FederationHandler(BaseHandler):
 
         try:
             yield self._handle_new_event(
+                origin,
                 event,
                 state=state,
                 backfilled=backfilled,
@@ -394,7 +344,14 @@ class FederationHandler(BaseHandler):
             for e in auth_chain:
                 e.internal_metadata.outlier = True
                 try:
-                    yield self._handle_new_event(e)
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    yield self._handle_new_event(
+                        target_host, e, auth_events=auth
+                    )
                 except:
                     logger.exception(
                         "Failed to handle auth event %s",
@@ -405,8 +362,13 @@ class FederationHandler(BaseHandler):
                 # FIXME: Auth these.
                 e.internal_metadata.outlier = True
                 try:
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in auth_chain
+                        if e.event_id in auth_ids
+                    }
                     yield self._handle_new_event(
-                        e, fetch_auth_from=target_host
+                        target_host, e, auth_events=auth
                     )
                 except:
                     logger.exception(
@@ -415,6 +377,7 @@ class FederationHandler(BaseHandler):
                     )
 
             yield self._handle_new_event(
+                target_host,
                 new_event,
                 state=state,
                 current_state=state,
@@ -481,7 +444,7 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context = yield self._handle_new_event(event)
+        context = yield self._handle_new_event(origin, event)
 
         logger.debug(
             "on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -682,11 +645,12 @@ class FederationHandler(BaseHandler):
             waiters.pop().callback(None)
 
     @defer.inlineCallbacks
-    def _handle_new_event(self, event, state=None, backfilled=False,
-                          current_state=None, fetch_auth_from=None):
+    @log_function
+    def _handle_new_event(self, origin, event, state=None, backfilled=False,
+                          current_state=None, auth_events=None):
 
         logger.debug(
-            "_handle_new_event: Before annotate: %s, sigs: %s",
+            "_handle_new_event: %s, sigs: %s",
             event.event_id, event.signatures,
         )
 
@@ -694,65 +658,44 @@ class FederationHandler(BaseHandler):
             event, old_state=state
         )
 
+        if not auth_events:
+            auth_events = context.auth_events
+
         logger.debug(
-            "_handle_new_event: Before auth fetch: %s, sigs: %s",
-            event.event_id, event.signatures,
+            "_handle_new_event: %s, auth_events: %s",
+            event.event_id, auth_events,
         )
 
         is_new_state = not event.internal_metadata.is_outlier()
 
-        known_ids = set(
-            [s.event_id for s in context.auth_events.values()]
-        )
-
-        for e_id, _ in event.auth_events:
-            if e_id not in known_ids:
-                e = yield self.store.get_event(e_id, allow_none=True)
-
-                if not e and fetch_auth_from is not None:
-                    # Grab the auth_chain over federation if we are missing
-                    # auth events.
-                    auth_chain = yield self.replication_layer.get_event_auth(
-                        fetch_auth_from, event.event_id, event.room_id
-                    )
-                    for auth_event in auth_chain:
-                        yield self._handle_new_event(auth_event)
-                    e = yield self.store.get_event(e_id, allow_none=True)
-
-                if not e:
-                    # TODO: Do some conflict res to make sure that we're
-                    # not the ones who are wrong.
-                    logger.info(
-                        "Rejecting %s as %s not in db or %s",
-                        event.event_id, e_id, known_ids,
-                    )
-                    # FIXME: How does raising AuthError work with federation?
-                    raise AuthError(403, "Cannot find auth event")
-
-                context.auth_events[(e.type, e.state_key)] = e
-
-        logger.debug(
-            "_handle_new_event: Before hack: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
-
+        # This is a hack to fix some old rooms where the initial join event
+        # didn't reference the create event in its auth events.
         if event.type == EventTypes.Member and not event.auth_events:
             if len(event.prev_events) == 1:
                 c = yield self.store.get_event(event.prev_events[0][0])
                 if c.type == EventTypes.Create:
-                    context.auth_events[(c.type, c.state_key)] = c
+                    auth_events[(c.type, c.state_key)] = c
 
-        logger.debug(
-            "_handle_new_event: Before auth check: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
+        try:
+            yield self.do_auth(
+                origin, event, context, auth_events=auth_events
+            )
+        except AuthError as e:
+            logger.warn(
+                "Rejecting %s because %s",
+                event.event_id, e.msg
+            )
 
-        self.auth.check(event, auth_events=context.auth_events)
+            context.rejected = RejectedReason.AUTH_ERROR
 
-        logger.debug(
-            "_handle_new_event: Before persist_event: %s, sigs: %s",
-            event.event_id, event.signatures,
-        )
+            yield self.store.persist_event(
+                event,
+                context=context,
+                backfilled=backfilled,
+                is_new_state=False,
+                current_state=current_state,
+            )
+            raise
 
         yield self.store.persist_event(
             event,
@@ -762,9 +705,255 @@ class FederationHandler(BaseHandler):
             current_state=current_state,
         )
 
-        logger.debug(
-            "_handle_new_event: After persist_event: %s, sigs: %s",
-            event.event_id, event.signatures,
+        defer.returnValue(context)
+
+    @defer.inlineCallbacks
+    def on_query_auth(self, origin, event_id, remote_auth_chain, rejects,
+                      missing):
+        # Just go through and process each event in `remote_auth_chain`. We
+        # don't want to fall into the trap of `missing` being wrong.
+        for e in remote_auth_chain:
+            try:
+                yield self._handle_new_event(origin, e)
+            except AuthError:
+                pass
+
+        # Now get the current auth_chain for the event.
+        local_auth_chain = yield self.store.get_auth_chain([event_id])
+
+        # TODO: Check if we would now reject event_id. If so we need to tell
+        # everyone.
+
+        ret = yield self.construct_auth_difference(
+            local_auth_chain, remote_auth_chain
         )
 
-        defer.returnValue(context)
+        logger.debug("on_query_auth reutrning: %s", ret)
+
+        defer.returnValue(ret)
+
+    @defer.inlineCallbacks
+    @log_function
+    def do_auth(self, origin, event, context, auth_events):
+        # Check if we have all the auth events.
+        res = yield self.store.have_events(
+            [e_id for e_id, _ in event.auth_events]
+        )
+
+        event_auth_events = set(e_id for e_id, _ in event.auth_events)
+        seen_events = set(res.keys())
+
+        missing_auth = event_auth_events - seen_events
+
+        if missing_auth:
+            logger.debug("Missing auth: %s", missing_auth)
+            # If we don't have all the auth events, we need to get them.
+            remote_auth_chain = yield self.replication_layer.get_event_auth(
+                origin, event.room_id, event.event_id
+            )
+
+            for e in remote_auth_chain:
+                try:
+                    auth_ids = [e_id for e_id, _ in e.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in remote_auth_chain
+                        if e.event_id in auth_ids
+                    }
+                    e.internal_metadata.outlier = True
+                    yield self._handle_new_event(
+                        origin, e, auth_events=auth
+                    )
+                    auth_events[(e.type, e.state_key)] = e
+                except AuthError:
+                    pass
+
+        # FIXME: Assumes we have and stored all the state for all the
+        # prev_events
+        current_state = set(e.event_id for e in auth_events.values())
+        different_auth = event_auth_events - current_state
+
+        if different_auth and not event.internal_metadata.is_outlier():
+            # Do auth conflict res.
+            logger.debug("Different auth: %s", different_auth)
+
+            # 1. Get what we think is the auth chain.
+            auth_ids = self.auth.compute_auth_events(event, context)
+            local_auth_chain = yield self.store.get_auth_chain(auth_ids)
+
+            # 2. Get remote difference.
+            result = yield self.replication_layer.query_auth(
+                origin,
+                event.room_id,
+                event.event_id,
+                local_auth_chain,
+            )
+
+            # 3. Process any remote auth chain events we haven't seen.
+            for missing_id in result.get("missing", []):
+                try:
+                    for e in result["auth_chain"]:
+                        if e.event_id == missing_id:
+                            ev = e
+                            break
+
+                    auth_ids = [e_id for e_id, _ in ev.auth_events]
+                    auth = {
+                        (e.type, e.state_key): e for e in result["auth_chain"]
+                        if e.event_id in auth_ids
+                    }
+                    ev.internal_metadata.outlier = True
+                    yield self._handle_new_event(
+                        origin, ev, auth_events=auth
+                    )
+                    auth_events[(ev.type, ev.state_key)] = ev
+                except AuthError:
+                    pass
+
+            # 4. Look at rejects and their proofs.
+            # TODO.
+
+            context.current_state.update(auth_events)
+            context.state_group = None
+
+        try:
+            self.auth.check(event, auth_events=auth_events)
+        except AuthError:
+            raise
+
+    @defer.inlineCallbacks
+    def construct_auth_difference(self, local_auth, remote_auth):
+        """ Given a local and remote auth chain, find the differences. This
+        assumes that we have already processed all events in remote_auth
+
+        Params:
+            local_auth (list)
+            remote_auth (list)
+
+        Returns:
+            dict
+        """
+
+        logger.debug("construct_auth_difference Start!")
+
+        # TODO: Make sure we are OK with local_auth or remote_auth having more
+        # auth events in them than strictly necessary.
+
+        def sort_fun(ev):
+            return ev.depth, ev.event_id
+
+        logger.debug("construct_auth_difference after sort_fun!")
+
+        # We find the differences by starting at the "bottom" of each list
+        # and iterating up on both lists. The lists are ordered by depth and
+        # then event_id, we iterate up both lists until we find the event ids
+        # don't match. Then we look at depth/event_id to see which side is
+        # missing that event, and iterate only up that list. Repeat.
+
+        remote_list = list(remote_auth)
+        remote_list.sort(key=sort_fun)
+
+        local_list = list(local_auth)
+        local_list.sort(key=sort_fun)
+
+        local_iter = iter(local_list)
+        remote_iter = iter(remote_list)
+
+        logger.debug("construct_auth_difference before get_next!")
+
+        def get_next(it, opt=None):
+            try:
+                return it.next()
+            except:
+                return opt
+
+        current_local = get_next(local_iter)
+        current_remote = get_next(remote_iter)
+
+        logger.debug("construct_auth_difference before while")
+
+        missing_remotes = []
+        missing_locals = []
+        while current_local or current_remote:
+            if current_remote is None:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+                continue
+
+            if current_local is None:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+            if current_local.event_id == current_remote.event_id:
+                current_local = get_next(local_iter)
+                current_remote = get_next(remote_iter)
+                continue
+
+            if current_local.depth < current_remote.depth:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+                continue
+
+            if current_local.depth > current_remote.depth:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+            # They have the same depth, so we fall back to the event_id order
+            if current_local.event_id < current_remote.event_id:
+                missing_locals.append(current_local)
+                current_local = get_next(local_iter)
+
+            if current_local.event_id > current_remote.event_id:
+                missing_remotes.append(current_remote)
+                current_remote = get_next(remote_iter)
+                continue
+
+        logger.debug("construct_auth_difference after while")
+
+        # missing locals should be sent to the server
+        # We should find why we are missing remotes, as they will have been
+        # rejected.
+
+        # Remove events from missing_remotes if they are referencing a missing
+        # remote. We only care about the "root" rejected ones.
+        missing_remote_ids = [e.event_id for e in missing_remotes]
+        base_remote_rejected = list(missing_remotes)
+        for e in missing_remotes:
+            for e_id, _ in e.auth_events:
+                if e_id in missing_remote_ids:
+                    base_remote_rejected.remove(e)
+
+        reason_map = {}
+
+        for e in base_remote_rejected:
+            reason = yield self.store.get_rejection_reason(e.event_id)
+            if reason is None:
+                # FIXME: ERRR?!
+                logger.warn("Could not find reason for %s", e.event_id)
+                raise RuntimeError("")
+
+            reason_map[e.event_id] = reason
+
+            if reason == RejectedReason.AUTH_ERROR:
+                pass
+            elif reason == RejectedReason.REPLACED:
+                # TODO: Get proof
+                pass
+            elif reason == RejectedReason.NOT_ANCESTOR:
+                # TODO: Get proof.
+                pass
+
+        logger.debug("construct_auth_difference returning")
+
+        defer.returnValue({
+            "auth_chain": local_auth,
+            "rejects": {
+                e.event_id: {
+                    "reason": reason_map[e.event_id],
+                    "proof": None,
+                }
+                for e in base_remote_rejected
+            },
+            "missing": [e.event_id for e in missing_locals],
+        })
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 1dda3ba2c7..c7bf1b47b8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -245,6 +245,43 @@ class MatrixFederationHttpClient(object):
         defer.returnValue((response.code, body))
 
     @defer.inlineCallbacks
+    def post_json(self, destination, path, data={}):
+        """ Sends the specifed json data using POST
+
+        Args:
+            destination (str): The remote server to send the HTTP request
+                to.
+            path (str): The HTTP path.
+            data (dict): A dict containing the data that will be used as
+                the request body. This will be encoded as JSON.
+
+        Returns:
+            Deferred: Succeeds when we get a 2xx HTTP response. The result
+            will be the decoded JSON body. On a 4xx or 5xx error response a
+            CodeMessageException is raised.
+        """
+
+        def body_callback(method, url_bytes, headers_dict):
+            self.sign_request(
+                destination, method, url_bytes, headers_dict, data
+            )
+            return _JsonProducer(data)
+
+        response = yield self._create_request(
+            destination.encode("ascii"),
+            "POST",
+            path.encode("ascii"),
+            body_callback=body_callback,
+            headers_dict={"Content-Type": ["application/json"]},
+        )
+
+        logger.debug("Getting resp body")
+        body = yield readBody(response)
+        logger.debug("Got resp body")
+
+        defer.returnValue((response.code, body))
+
+    @defer.inlineCallbacks
     def get_json(self, destination, path, args={}, retry_on_dns_fail=True):
         """ GETs some json from the given host homeserver and path
 
diff --git a/synapse/state.py b/synapse/state.py
index 081bc31bb5..038e5eba11 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -168,10 +168,17 @@ class StateHandler(object):
         first is the name of a state group if one and only one is involved,
         otherwise `None`.
         """
+        logger.debug("resolve_state_groups event_ids %s", event_ids)
+
         state_groups = yield self.store.get_state_groups(
             event_ids
         )
 
+        logger.debug(
+            "resolve_state_groups state_groups %s",
+            state_groups.keys()
+        )
+
         group_names = set(state_groups.keys())
         if len(group_names) == 1:
             name, state_list = state_groups.items().pop()
@@ -207,6 +214,15 @@ class StateHandler(object):
             if len(v.values()) > 1
         }
 
+        logger.debug(
+            "resolve_state_groups Unconflicted state: %s",
+            unconflicted_state.values(),
+        )
+        logger.debug(
+            "resolve_state_groups Conflicted state: %s",
+            conflicted_state.values(),
+        )
+
         if event_type:
             prev_states_events = conflicted_state.get(
                 (event_type, state_key), []
diff --git a/synapse/storage/rejections.py b/synapse/storage/rejections.py
index 7d38b31f44..4e1a9a2783 100644
--- a/synapse/storage/rejections.py
+++ b/synapse/storage/rejections.py
@@ -28,6 +28,16 @@ class RejectionsStore(SQLBaseStore):
             values={
                 "event_id": event_id,
                 "reason": reason,
-                "last_failure": self._clock.time_msec(),
+                "last_check": self._clock.time_msec(),
             }
         )
+
+    def get_rejection_reason(self, event_id):
+        return self._simple_select_one_onecol(
+            table="rejections",
+            retcol="reason",
+            keyvalues={
+                "event_id": event_id,
+            },
+            allow_none=True,
+        )
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index ed21defd13..44dbce6bea 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -52,6 +52,7 @@ class FederationTestCase(unittest.TestCase):
                 "get_room",
                 "get_destination_retry_timings",
                 "set_destination_retry_timings",
+                "have_events",
             ]),
             resource_for_federation=NonCallableMock(),
             http_client=NonCallableMock(spec_set=[]),
@@ -90,6 +91,7 @@ class FederationTestCase(unittest.TestCase):
         self.datastore.persist_event.return_value = defer.succeed(None)
         self.datastore.get_room.return_value = defer.succeed(True)
         self.auth.check_host_in_room.return_value = defer.succeed(True)
+        self.datastore.have_events.return_value = defer.succeed({})
 
         def annotate(ev, old_state=None):
             context = Mock()