summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/api/auth.py14
-rw-r--r--synapse/events/__init__.py2
-rw-r--r--synapse/federation/federation_base.py118
-rw-r--r--synapse/federation/federation_client.py99
-rw-r--r--synapse/federation/federation_server.py57
-rw-r--r--synapse/handlers/federation.py173
-rw-r--r--synapse/state.py22
-rw-r--r--synapse/storage/__init__.py211
-rw-r--r--tests/handlers/test_federation.py5
9 files changed, 429 insertions, 272 deletions
diff --git a/synapse/api/auth.py b/synapse/api/auth.py
index 37e31d2b6f..7105ee21dc 100644
--- a/synapse/api/auth.py
+++ b/synapse/api/auth.py
@@ -102,8 +102,6 @@ class Auth(object):
     def check_host_in_room(self, room_id, host):
         curr_state = yield self.state.get_current_state(room_id)
 
-        logger.debug("Got curr_state %s", curr_state)
-
         for event in curr_state:
             if event.type == EventTypes.Member:
                 try:
@@ -360,7 +358,7 @@ class Auth(object):
     def add_auth_events(self, builder, context):
         yield run_on_reactor()
 
-        auth_ids = self.compute_auth_events(builder, context)
+        auth_ids = self.compute_auth_events(builder, context.current_state)
 
         auth_events_entries = yield self.store.add_event_hashes(
             auth_ids
@@ -374,26 +372,26 @@ class Auth(object):
             if v.event_id in auth_ids
         }
 
-    def compute_auth_events(self, event, context):
+    def compute_auth_events(self, event, current_state):
         if event.type == EventTypes.Create:
             return []
 
         auth_ids = []
 
         key = (EventTypes.PowerLevels, "", )
-        power_level_event = context.current_state.get(key)
+        power_level_event = current_state.get(key)
 
         if power_level_event:
             auth_ids.append(power_level_event.event_id)
 
         key = (EventTypes.JoinRules, "", )
-        join_rule_event = context.current_state.get(key)
+        join_rule_event = current_state.get(key)
 
         key = (EventTypes.Member, event.user_id, )
-        member_event = context.current_state.get(key)
+        member_event = current_state.get(key)
 
         key = (EventTypes.Create, "", )
-        create_event = context.current_state.get(key)
+        create_event = current_state.get(key)
         if create_event:
             auth_ids.append(create_event.event_id)
 
diff --git a/synapse/events/__init__.py b/synapse/events/__init__.py
index bf07951027..8f0c6e959f 100644
--- a/synapse/events/__init__.py
+++ b/synapse/events/__init__.py
@@ -77,7 +77,7 @@ class EventBase(object):
         return self.content["membership"]
 
     def is_state(self):
-        return hasattr(self, "state_key")
+        return hasattr(self, "state_key") and self.state_key is not None
 
     def get_dict(self):
         d = dict(self._event_dict)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
new file mode 100644
index 0000000000..a990aec4fd
--- /dev/null
+++ b/synapse/federation/federation_base.py
@@ -0,0 +1,118 @@
+# -*- coding: utf-8 -*-
+# Copyright 2015 OpenMarket Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from twisted.internet import defer
+
+from synapse.events.utils import prune_event
+
+from syutil.jsonutil import encode_canonical_json
+
+from synapse.crypto.event_signing import check_event_content_hash
+
+from synapse.api.errors import SynapseError
+
+import logging
+
+
+logger = logging.getLogger(__name__)
+
+
+class FederationBase(object):
+    @defer.inlineCallbacks
+    def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
+        """Takes a list of PDUs and checks the signatures and hashs of each
+        one. If a PDU fails its signature check then we check if we have it in
+        the database and if not then request if from the originating server of
+        that PDU.
+
+        If a PDU fails its content hash check then it is redacted.
+
+        The given list of PDUs are not modified, instead the function returns
+        a new list.
+
+        Args:
+            pdu (list)
+            outlier (bool)
+
+        Returns:
+            Deferred : A list of PDUs that have valid signatures and hashes.
+        """
+        signed_pdus = []
+        for pdu in pdus:
+            try:
+                new_pdu = yield self._check_sigs_and_hash(pdu)
+                signed_pdus.append(new_pdu)
+            except SynapseError:
+                # FIXME: We should handle signature failures more gracefully.
+
+                # Check local db.
+                new_pdu = yield self.store.get_event(
+                    pdu.event_id,
+                    allow_rejected=True
+                )
+                if new_pdu:
+                    signed_pdus.append(new_pdu)
+                    continue
+
+                # Check pdu.origin
+                if pdu.origin != origin:
+                    new_pdu = yield self.get_pdu(
+                        destinations=[pdu.origin],
+                        event_id=pdu.event_id,
+                        outlier=outlier,
+                    )
+
+                    if new_pdu:
+                        signed_pdus.append(new_pdu)
+                        continue
+
+                logger.warn("Failed to find copy of %s with valid signature")
+
+        defer.returnValue(signed_pdus)
+
+    @defer.inlineCallbacks
+    def _check_sigs_and_hash(self, pdu):
+        """Throws a SynapseError if the PDU does not have the correct
+        signatures.
+
+        Returns:
+            FrozenEvent: Either the given event or it redacted if it failed the
+            content hash check.
+        """
+        # Check signatures are correct.
+        redacted_event = prune_event(pdu)
+        redacted_pdu_json = redacted_event.get_pdu_json()
+
+        try:
+            yield self.keyring.verify_json_for_server(
+                pdu.origin, redacted_pdu_json
+            )
+        except SynapseError:
+            logger.warn(
+                "Signature check failed for %s redacted to %s",
+                encode_canonical_json(pdu.get_pdu_json()),
+                encode_canonical_json(redacted_pdu_json),
+            )
+            raise
+
+        if not check_event_content_hash(pdu):
+            logger.warn(
+                "Event content has been tampered, redacting %s, %s",
+                pdu.event_id, encode_canonical_json(pdu.get_dict())
+            )
+            defer.returnValue(redacted_event)
+
+        defer.returnValue(pdu)
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index e1539bd0e0..5fac629709 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -16,17 +16,11 @@
 
 from twisted.internet import defer
 
+from .federation_base import FederationBase
 from .units import Edu
 
 from synapse.util.logutils import log_function
 from synapse.events import FrozenEvent
-from synapse.events.utils import prune_event
-
-from syutil.jsonutil import encode_canonical_json
-
-from synapse.crypto.event_signing import check_event_content_hash
-
-from synapse.api.errors import SynapseError
 
 import logging
 
@@ -34,7 +28,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class FederationClient(object):
+class FederationClient(FederationBase):
     @log_function
     def send_pdu(self, pdu, destinations):
         """Informs the replication layer about a new PDU generated within the
@@ -224,17 +218,17 @@ class FederationClient(object):
             for p in result.get("auth_chain", [])
         ]
 
-        for i, pdu in enumerate(pdus):
-            pdus[i] = yield self._check_sigs_and_hash(pdu)
-
-            # FIXME: We should handle signature failures more gracefully.
+        signed_pdus = yield self._check_sigs_and_hash_and_fetch(
+            destination, pdus, outlier=True
+        )
 
-        for i, pdu in enumerate(auth_chain):
-            auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+        signed_auth = yield self._check_sigs_and_hash_and_fetch(
+            destination, auth_chain, outlier=True
+        )
 
-            # FIXME: We should handle signature failures more gracefully.
+        signed_auth.sort(key=lambda e: e.depth)
 
-        defer.returnValue((pdus, auth_chain))
+        defer.returnValue((signed_pdus, signed_auth))
 
     @defer.inlineCallbacks
     @log_function
@@ -248,14 +242,13 @@ class FederationClient(object):
             for p in res["auth_chain"]
         ]
 
-        for i, pdu in enumerate(auth_chain):
-            auth_chain[i] = yield self._check_sigs_and_hash(pdu)
-
-            # FIXME: We should handle signature failures more gracefully.
+        signed_auth = yield self._check_sigs_and_hash_and_fetch(
+            destination, auth_chain, outlier=True
+        )
 
-        auth_chain.sort(key=lambda e: e.depth)
+        signed_auth.sort(key=lambda e: e.depth)
 
-        defer.returnValue(auth_chain)
+        defer.returnValue(signed_auth)
 
     @defer.inlineCallbacks
     def make_join(self, destination, room_id, user_id):
@@ -291,21 +284,19 @@ class FederationClient(object):
             for p in content.get("auth_chain", [])
         ]
 
-        for i, pdu in enumerate(state):
-            state[i] = yield self._check_sigs_and_hash(pdu)
-
-            # FIXME: We should handle signature failures more gracefully.
-
-        for i, pdu in enumerate(auth_chain):
-            auth_chain[i] = yield self._check_sigs_and_hash(pdu)
+        signed_state = yield self._check_sigs_and_hash_and_fetch(
+            destination, state, outlier=True
+        )
 
-            # FIXME: We should handle signature failures more gracefully.
+        signed_auth = yield self._check_sigs_and_hash_and_fetch(
+            destination, auth_chain, outlier=True
+        )
 
         auth_chain.sort(key=lambda e: e.depth)
 
         defer.returnValue({
-            "state": state,
-            "auth_chain": auth_chain,
+            "state": signed_state,
+            "auth_chain": signed_auth,
         })
 
     @defer.inlineCallbacks
@@ -353,12 +344,18 @@ class FederationClient(object):
         )
 
         auth_chain = [
-            (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+            self.event_from_pdu_json(e)
             for e in content["auth_chain"]
         ]
 
+        signed_auth = yield self._check_sigs_and_hash_and_fetch(
+            destination, auth_chain, outlier=True
+        )
+
+        signed_auth.sort(key=lambda e: e.depth)
+
         ret = {
-            "auth_chain": auth_chain,
+            "auth_chain": signed_auth,
             "rejects": content.get("rejects", []),
             "missing": content.get("missing", []),
         }
@@ -373,37 +370,3 @@ class FederationClient(object):
         event.internal_metadata.outlier = outlier
 
         return event
-
-    @defer.inlineCallbacks
-    def _check_sigs_and_hash(self, pdu):
-        """Throws a SynapseError if the PDU does not have the correct
-        signatures.
-
-        Returns:
-            FrozenEvent: Either the given event or it redacted if it failed the
-            content hash check.
-        """
-        # Check signatures are correct.
-        redacted_event = prune_event(pdu)
-        redacted_pdu_json = redacted_event.get_pdu_json()
-
-        try:
-            yield self.keyring.verify_json_for_server(
-                pdu.origin, redacted_pdu_json
-            )
-        except SynapseError:
-            logger.warn(
-                "Signature check failed for %s redacted to %s",
-                encode_canonical_json(pdu.get_pdu_json()),
-                encode_canonical_json(redacted_pdu_json),
-            )
-            raise
-
-        if not check_event_content_hash(pdu):
-            logger.warn(
-                "Event content has been tampered, redacting %s, %s",
-                pdu.event_id, encode_canonical_json(pdu.get_dict())
-            )
-            defer.returnValue(redacted_event)
-
-        defer.returnValue(pdu)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 5fbd8b19de..4742ca9390 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -16,16 +16,12 @@
 
 from twisted.internet import defer
 
+from .federation_base import FederationBase
 from .units import Transaction, Edu
 
 from synapse.util.logutils import log_function
 from synapse.util.logcontext import PreserveLoggingContext
 from synapse.events import FrozenEvent
-from synapse.events.utils import prune_event
-
-from syutil.jsonutil import encode_canonical_json
-
-from synapse.crypto.event_signing import check_event_content_hash
 
 from synapse.api.errors import FederationError, SynapseError
 
@@ -35,7 +31,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class FederationServer(object):
+class FederationServer(FederationBase):
     def set_handler(self, handler):
         """Sets the handler that the replication layer will use to communicate
         receipt of new PDUs from other home servers. The required methods are
@@ -251,17 +247,20 @@ class FederationServer(object):
             Deferred: Results in `dict` with the same format as `content`
         """
         auth_chain = [
-            (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
+            self.event_from_pdu_json(e)
             for e in content["auth_chain"]
         ]
 
-        missing = [
-            (yield self._check_sigs_and_hash(self.event_from_pdu_json(e)))
-            for e in content.get("missing", [])
-        ]
+        signed_auth = yield self._check_sigs_and_hash_and_fetch(
+            origin, auth_chain, outlier=True
+        )
 
         ret = yield self.handler.on_query_auth(
-            origin, event_id, auth_chain, content.get("rejects", []), missing
+            origin,
+            event_id,
+            signed_auth,
+            content.get("rejects", []),
+            content.get("missing", []),
         )
 
         time_now = self._clock.time_msec()
@@ -426,37 +425,3 @@ class FederationServer(object):
         event.internal_metadata.outlier = outlier
 
         return event
-
-    @defer.inlineCallbacks
-    def _check_sigs_and_hash(self, pdu):
-        """Throws a SynapseError if the PDU does not have the correct
-        signatures.
-
-        Returns:
-            FrozenEvent: Either the given event or it redacted if it failed the
-            content hash check.
-        """
-        # Check signatures are correct.
-        redacted_event = prune_event(pdu)
-        redacted_pdu_json = redacted_event.get_pdu_json()
-
-        try:
-            yield self.keyring.verify_json_for_server(
-                pdu.origin, redacted_pdu_json
-            )
-        except SynapseError:
-            logger.warn(
-                "Signature check failed for %s redacted to %s",
-                encode_canonical_json(pdu.get_pdu_json()),
-                encode_canonical_json(redacted_pdu_json),
-            )
-            raise
-
-        if not check_event_content_hash(pdu):
-            logger.warn(
-                "Event content has been tampered, redacting %s, %s",
-                pdu.event_id, encode_canonical_json(pdu.get_dict())
-            )
-            defer.returnValue(redacted_event)
-
-        defer.returnValue(pdu)
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 8bf5a4cc11..2e2c23ef63 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -30,6 +30,7 @@ from synapse.types import UserID
 
 from twisted.internet import defer
 
+import itertools
 import logging
 
 
@@ -123,8 +124,21 @@ class FederationHandler(BaseHandler):
             logger.debug("Got event for room we're not in.")
             current_state = state
 
+        event_ids = set()
+        if state:
+            event_ids |= {e.event_id for e in state}
+        if auth_chain:
+            event_ids |= {e.event_id for e in auth_chain}
+
+        seen_ids = (yield self.store.have_events(event_ids)).keys()
+
         if state and auth_chain is not None:
-            for e in state:
+            # If we have any state or auth_chain given to us by the replication
+            # layer, then we should handle them (if we haven't before.)
+            for e in itertools.chain(auth_chain, state):
+                if e.event_id in seen_ids:
+                    continue
+
                 e.internal_metadata.outlier = True
                 try:
                     auth_ids = [e_id for e_id, _ in e.auth_events]
@@ -132,7 +146,10 @@ class FederationHandler(BaseHandler):
                         (e.type, e.state_key): e for e in auth_chain
                         if e.event_id in auth_ids
                     }
-                    yield self._handle_new_event(origin, e, auth_events=auth)
+                    yield self._handle_new_event(
+                        origin, e, auth_events=auth
+                    )
+                    seen_ids.add(e.event_id)
                 except:
                     logger.exception(
                         "Failed to handle state event %s",
@@ -498,6 +515,8 @@ class FederationHandler(BaseHandler):
                     "Failed to get destination from event %s", s.event_id
                 )
 
+        destinations.remove(origin)
+
         logger.debug(
             "on_send_join_request: Sending event: %s, signatures: %s",
             event.event_id,
@@ -618,6 +637,7 @@ class FederationHandler(BaseHandler):
         event = yield self.store.get_event(
             event_id,
             allow_none=True,
+            allow_rejected=True,
         )
 
         if event:
@@ -701,6 +721,8 @@ class FederationHandler(BaseHandler):
 
             context.rejected = RejectedReason.AUTH_ERROR
 
+            # FIXME: Don't store as rejected with AUTH_ERROR if we haven't
+            # seen all the auth events.
             yield self.store.persist_event(
                 event,
                 context=context,
@@ -750,7 +772,7 @@ class FederationHandler(BaseHandler):
                 )
             )
 
-        logger.debug("on_query_auth reutrning: %s", ret)
+        logger.debug("on_query_auth returning: %s", ret)
 
         defer.returnValue(ret)
 
@@ -770,41 +792,45 @@ class FederationHandler(BaseHandler):
         if missing_auth:
             logger.debug("Missing auth: %s", missing_auth)
             # If we don't have all the auth events, we need to get them.
-            remote_auth_chain = yield self.replication_layer.get_event_auth(
-                origin, event.room_id, event.event_id
-            )
+            try:
+                remote_auth_chain = yield self.replication_layer.get_event_auth(
+                    origin, event.room_id, event.event_id
+                )
 
-            seen_remotes = yield self.store.have_events(
-                [e.event_id for e in remote_auth_chain]
-            )
+                seen_remotes = yield self.store.have_events(
+                    [e.event_id for e in remote_auth_chain]
+                )
 
-            for e in remote_auth_chain:
-                if e.event_id in seen_remotes.keys():
-                    continue
+                for e in remote_auth_chain:
+                    if e.event_id in seen_remotes.keys():
+                        continue
 
-                if e.event_id == event.event_id:
-                    continue
+                    if e.event_id == event.event_id:
+                        continue
 
-                try:
-                    auth_ids = [e_id for e_id, _ in e.auth_events]
-                    auth = {
-                        (e.type, e.state_key): e for e in remote_auth_chain
-                        if e.event_id in auth_ids
-                    }
-                    e.internal_metadata.outlier = True
+                    try:
+                        auth_ids = [e_id for e_id, _ in e.auth_events]
+                        auth = {
+                            (e.type, e.state_key): e for e in remote_auth_chain
+                            if e.event_id in auth_ids
+                        }
+                        e.internal_metadata.outlier = True
 
-                    logger.debug(
-                        "do_auth %s missing_auth: %s",
-                        event.event_id, e.event_id
-                    )
-                    yield self._handle_new_event(
-                        origin, e, auth_events=auth
-                    )
+                        logger.debug(
+                            "do_auth %s missing_auth: %s",
+                            event.event_id, e.event_id
+                        )
+                        yield self._handle_new_event(
+                            origin, e, auth_events=auth
+                        )
 
-                    if e.event_id in event_auth_events:
-                        auth_events[(e.type, e.state_key)] = e
-                except AuthError:
-                    pass
+                        if e.event_id in event_auth_events:
+                            auth_events[(e.type, e.state_key)] = e
+                    except AuthError:
+                        pass
+            except:
+                # FIXME:
+                logger.exception("Failed to get auth chain")
 
         # FIXME: Assumes we have and stored all the state for all the
         # prev_events
@@ -816,50 +842,57 @@ class FederationHandler(BaseHandler):
             logger.debug("Different auth: %s", different_auth)
 
             # 1. Get what we think is the auth chain.
-            auth_ids = self.auth.compute_auth_events(event, context)
-            local_auth_chain = yield self.store.get_auth_chain(auth_ids)
-
-            # 2. Get remote difference.
-            result = yield self.replication_layer.query_auth(
-                origin,
-                event.room_id,
-                event.event_id,
-                local_auth_chain,
-            )
-
-            seen_remotes = yield self.store.have_events(
-                [e.event_id for e in result["auth_chain"]]
+            auth_ids = self.auth.compute_auth_events(
+                event, context.current_state
             )
+            local_auth_chain = yield self.store.get_auth_chain(auth_ids)
 
-            # 3. Process any remote auth chain events we haven't seen.
-            for ev in result["auth_chain"]:
-                if ev.event_id in seen_remotes.keys():
-                    continue
+            try:
+                # 2. Get remote difference.
+                result = yield self.replication_layer.query_auth(
+                    origin,
+                    event.room_id,
+                    event.event_id,
+                    local_auth_chain,
+                )
 
-                if ev.event_id == event.event_id:
-                    continue
+                seen_remotes = yield self.store.have_events(
+                    [e.event_id for e in result["auth_chain"]]
+                )
 
-                try:
-                    auth_ids = [e_id for e_id, _ in ev.auth_events]
-                    auth = {
-                        (e.type, e.state_key): e for e in result["auth_chain"]
-                        if e.event_id in auth_ids
-                    }
-                    ev.internal_metadata.outlier = True
+                # 3. Process any remote auth chain events we haven't seen.
+                for ev in result["auth_chain"]:
+                    if ev.event_id in seen_remotes.keys():
+                        continue
+
+                    if ev.event_id == event.event_id:
+                        continue
+
+                    try:
+                        auth_ids = [e_id for e_id, _ in ev.auth_events]
+                        auth = {
+                            (e.type, e.state_key): e for e in result["auth_chain"]
+                            if e.event_id in auth_ids
+                        }
+                        ev.internal_metadata.outlier = True
+
+                        logger.debug(
+                            "do_auth %s different_auth: %s",
+                            event.event_id, e.event_id
+                        )
 
-                    logger.debug(
-                        "do_auth %s different_auth: %s",
-                        event.event_id, e.event_id
-                    )
+                        yield self._handle_new_event(
+                            origin, ev, auth_events=auth
+                        )
 
-                    yield self._handle_new_event(
-                        origin, ev, auth_events=auth
-                    )
+                        if ev.event_id in event_auth_events:
+                            auth_events[(ev.type, ev.state_key)] = ev
+                    except AuthError:
+                        pass
 
-                    if ev.event_id in event_auth_events:
-                        auth_events[(ev.type, ev.state_key)] = ev
-                except AuthError:
-                    pass
+            except:
+                # FIXME:
+                logger.exception("Failed to query auth chain")
 
             # 4. Look at rejects and their proofs.
             # TODO.
@@ -983,7 +1016,7 @@ class FederationHandler(BaseHandler):
             if reason is None:
                 # FIXME: ERRR?!
                 logger.warn("Could not find reason for %s", e.event_id)
-                raise RuntimeError("")
+                raise RuntimeError("Could not find reason for %s" % e.event_id)
 
             reason_map[e.event_id] = reason
 
diff --git a/synapse/state.py b/synapse/state.py
index 8a056ee955..695a5e7ac4 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -37,7 +37,10 @@ def _get_state_key_from_event(event):
 KeyStateTuple = namedtuple("KeyStateTuple", ("context", "type", "state_key"))
 
 
-AuthEventTypes = (EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,)
+AuthEventTypes = (
+    EventTypes.Create, EventTypes.Member, EventTypes.PowerLevels,
+    EventTypes.JoinRules,
+)
 
 
 class StateHandler(object):
@@ -100,7 +103,9 @@ class StateHandler(object):
             context.state_group = None
 
             if hasattr(event, "auth_events") and event.auth_events:
-                auth_ids = zip(*event.auth_events)[0]
+                auth_ids = self.hs.get_auth().compute_auth_events(
+                    event, context.current_state
+                )
                 context.auth_events = {
                     k: v
                     for k, v in context.current_state.items()
@@ -146,7 +151,9 @@ class StateHandler(object):
                 event.unsigned["replaces_state"] = replaces.event_id
 
         if hasattr(event, "auth_events") and event.auth_events:
-            auth_ids = zip(*event.auth_events)[0]
+            auth_ids = self.hs.get_auth().compute_auth_events(
+                event, context.current_state
+            )
             context.auth_events = {
                 k: v
                 for k, v in context.current_state.items()
@@ -259,6 +266,15 @@ class StateHandler(object):
         auth_events.update(resolved_state)
 
         for key, events in conflicted_state.items():
+            if key[0] == EventTypes.JoinRules:
+                resolved_state[key] = self._resolve_auth_events(
+                    events,
+                    auth_events
+                )
+
+        auth_events.update(resolved_state)
+
+        for key, events in conflicted_state.items():
             if key[0] == EventTypes.Member:
                 resolved_state[key] = self._resolve_auth_events(
                     events,
diff --git a/synapse/storage/__init__.py b/synapse/storage/__init__.py
index 7c54b1b9d3..a63c59a8a2 100644
--- a/synapse/storage/__init__.py
+++ b/synapse/storage/__init__.py
@@ -128,21 +128,144 @@ class DataStore(RoomMemberStore, RoomStore,
             pass
 
     @defer.inlineCallbacks
-    def get_event(self, event_id, allow_none=False):
-        events = yield self._get_events([event_id])
+    def get_event(self, event_id, check_redacted=True,
+                  get_prev_content=False, allow_rejected=False,
+                  allow_none=False):
+        """Get an event from the database by event_id.
+
+        Args:
+            event_id (str): The event_id of the event to fetch
+            check_redacted (bool): If True, check if event has been redacted
+                and redact it.
+            get_prev_content (bool): If True and event is a state event,
+                include the previous states content in the unsigned field.
+            allow_rejected (bool): If True return rejected events.
+            allow_none (bool): If True, return None if no event found, if
+                False throw an exception.
 
-        if not events:
-            if allow_none:
-                defer.returnValue(None)
-            else:
-                raise RuntimeError("Could not find event %s" % (event_id,))
+        Returns:
+            Deferred : A FrozenEvent.
+        """
+        event = yield self.runInteraction(
+            "get_event", self._get_event_txn,
+            event_id,
+            check_redacted=check_redacted,
+            get_prev_content=get_prev_content,
+            allow_rejected=allow_rejected,
+        )
 
-        defer.returnValue(events[0])
+        if not event and not allow_none:
+            raise RuntimeError("Could not find event %s" % (event_id,))
+
+        defer.returnValue(event)
 
     @log_function
     def _persist_event_txn(self, txn, event, context, backfilled,
                            stream_ordering=None, is_new_state=True,
                            current_state=None):
+
+        # We purposefully do this first since if we include a `current_state`
+        # key, we *want* to update the `current_state_events` table
+        if current_state:
+            txn.execute(
+                "DELETE FROM current_state_events WHERE room_id = ?",
+                (event.room_id,)
+            )
+
+            for s in current_state:
+                self._simple_insert_txn(
+                    txn,
+                    "current_state_events",
+                    {
+                        "event_id": s.event_id,
+                        "room_id": s.room_id,
+                        "type": s.type,
+                        "state_key": s.state_key,
+                    },
+                    or_replace=True,
+                )
+
+        if event.is_state() and is_new_state:
+            if not backfilled and not context.rejected:
+                self._simple_insert_txn(
+                    txn,
+                    table="state_forward_extremities",
+                    values={
+                        "event_id": event.event_id,
+                        "room_id": event.room_id,
+                        "type": event.type,
+                        "state_key": event.state_key,
+                    },
+                    or_replace=True,
+                )
+
+                for prev_state_id, _ in event.prev_state:
+                    self._simple_delete_txn(
+                        txn,
+                        table="state_forward_extremities",
+                        keyvalues={
+                            "event_id": prev_state_id,
+                        }
+                    )
+
+        outlier = event.internal_metadata.is_outlier()
+
+        if not outlier:
+            self._store_state_groups_txn(txn, event, context)
+
+            self._update_min_depth_for_room_txn(
+                txn,
+                event.room_id,
+                event.depth
+            )
+
+        self._handle_prev_events(
+            txn,
+            outlier=outlier,
+            event_id=event.event_id,
+            prev_events=event.prev_events,
+            room_id=event.room_id,
+        )
+
+        have_persisted = self._simple_select_one_onecol_txn(
+            txn,
+            table="event_json",
+            keyvalues={"event_id": event.event_id},
+            retcol="event_id",
+            allow_none=True,
+        )
+
+        metadata_json = encode_canonical_json(
+            event.internal_metadata.get_dict()
+        )
+
+        # If we have already persisted this event, we don't need to do any
+        # more processing.
+        # The processing above must be done on every call to persist event,
+        # since they might not have happened on previous calls. For example,
+        # if we are persisting an event that we had persisted as an outlier,
+        # but is no longer one.
+        if have_persisted:
+            if not outlier:
+                sql = (
+                    "UPDATE event_json SET internal_metadata = ?"
+                    " WHERE event_id = ?"
+                )
+                txn.execute(
+                    sql,
+                    (metadata_json.decode("UTF-8"), event.event_id,)
+                )
+
+                sql = (
+                    "UPDATE events SET outlier = 0"
+                    " WHERE event_id = ?"
+                )
+                txn.execute(
+                    sql,
+                    (event.event_id,)
+                )
+            return
+
         if event.type == EventTypes.Member:
             self._store_room_member_txn(txn, event)
         elif event.type == EventTypes.Feedback:
@@ -154,8 +277,6 @@ class DataStore(RoomMemberStore, RoomStore,
         elif event.type == EventTypes.Redaction:
             self._store_redaction(txn, event)
 
-        outlier = event.internal_metadata.is_outlier()
-
         event_dict = {
             k: v
             for k, v in event.get_dict().items()
@@ -165,10 +286,6 @@ class DataStore(RoomMemberStore, RoomStore,
             ]
         }
 
-        metadata_json = encode_canonical_json(
-            event.internal_metadata.get_dict()
-        )
-
         self._simple_insert_txn(
             txn,
             table="event_json",
@@ -224,41 +341,10 @@ class DataStore(RoomMemberStore, RoomStore,
             )
             raise _RollbackButIsFineException("_persist_event")
 
-        self._handle_prev_events(
-            txn,
-            outlier=outlier,
-            event_id=event.event_id,
-            prev_events=event.prev_events,
-            room_id=event.room_id,
-        )
-
-        if not outlier:
-            self._store_state_groups_txn(txn, event, context)
-
         if context.rejected:
             self._store_rejections_txn(txn, event.event_id, context.rejected)
 
-        if current_state:
-            txn.execute(
-                "DELETE FROM current_state_events WHERE room_id = ?",
-                (event.room_id,)
-            )
-
-            for s in current_state:
-                self._simple_insert_txn(
-                    txn,
-                    "current_state_events",
-                    {
-                        "event_id": s.event_id,
-                        "room_id": s.room_id,
-                        "type": s.type,
-                        "state_key": s.state_key,
-                    },
-                    or_replace=True,
-                )
-
-        is_state = hasattr(event, "state_key") and event.state_key is not None
-        if is_state:
+        if event.is_state():
             vals = {
                 "event_id": event.event_id,
                 "room_id": event.room_id,
@@ -266,6 +352,7 @@ class DataStore(RoomMemberStore, RoomStore,
                 "state_key": event.state_key,
             }
 
+            # TODO: How does this work with backfilling?
             if hasattr(event, "replaces_state"):
                 vals["prev_state"] = event.replaces_state
 
@@ -302,28 +389,6 @@ class DataStore(RoomMemberStore, RoomStore,
                     or_ignore=True,
                 )
 
-            if not backfilled and not context.rejected:
-                self._simple_insert_txn(
-                    txn,
-                    table="state_forward_extremities",
-                    values={
-                        "event_id": event.event_id,
-                        "room_id": event.room_id,
-                        "type": event.type,
-                        "state_key": event.state_key,
-                    },
-                    or_replace=True,
-                )
-
-                for prev_state_id, _ in event.prev_state:
-                    self._simple_delete_txn(
-                        txn,
-                        table="state_forward_extremities",
-                        keyvalues={
-                            "event_id": prev_state_id,
-                        }
-                    )
-
         for hash_alg, hash_base64 in event.hashes.items():
             hash_bytes = decode_base64(hash_base64)
             self._store_event_content_hash_txn(
@@ -354,13 +419,6 @@ class DataStore(RoomMemberStore, RoomStore,
             txn, event.event_id, ref_alg, ref_hash_bytes
         )
 
-        if not outlier:
-            self._update_min_depth_for_room_txn(
-                txn,
-                event.room_id,
-                event.depth
-            )
-
     def _store_redaction(self, txn, event):
         txn.execute(
             "INSERT OR IGNORE INTO redactions "
@@ -477,6 +535,9 @@ class DataStore(RoomMemberStore, RoomStore,
             the rejected reason string if we rejected the event, else maps to
             None.
         """
+        if not event_ids:
+            return defer.succeed({})
+
         def f(txn):
             sql = (
                 "SELECT e.event_id, reason FROM events as e "
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index 44dbce6bea..4270481139 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -91,7 +91,10 @@ class FederationTestCase(unittest.TestCase):
         self.datastore.persist_event.return_value = defer.succeed(None)
         self.datastore.get_room.return_value = defer.succeed(True)
         self.auth.check_host_in_room.return_value = defer.succeed(True)
-        self.datastore.have_events.return_value = defer.succeed({})
+
+        def have_events(event_ids):
+            return defer.succeed({})
+        self.datastore.have_events.side_effect = have_events
 
         def annotate(ev, old_state=None):
             context = Mock()