summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py714
1 files changed, 442 insertions, 272 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index f599e817aa..771ab3bc43 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -26,20 +26,24 @@ from synapse.api.errors import (
 from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.events.validator import EventValidator
 from synapse.util import unwrapFirstError
-from synapse.util.logcontext import PreserveLoggingContext
+from synapse.util.logcontext import (
+    PreserveLoggingContext, preserve_fn, preserve_context_over_deferred
+)
+from synapse.util.metrics import measure_func
 from synapse.util.logutils import log_function
 from synapse.util.async import run_on_reactor
 from synapse.util.frozenutils import unfreeze
 from synapse.crypto.event_signing import (
     compute_event_signature, add_hashes_and_signatures,
 )
-from synapse.types import UserID
+from synapse.types import UserID, get_domain_from_id
 
 from synapse.events.utils import prune_event
 
 from synapse.util.retryutils import NotRetryingDestination
 
 from synapse.push.action_generator import ActionGenerator
+from synapse.util.distributor import user_joined_room
 
 from twisted.internet import defer
 
@@ -49,10 +53,6 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-def user_joined_room(distributor, user, room_id):
-    return distributor.fire("user_joined_room", user, room_id)
-
-
 class FederationHandler(BaseHandler):
     """Handles events that originated from federation.
         Responsible for:
@@ -69,10 +69,6 @@ class FederationHandler(BaseHandler):
 
         self.hs = hs
 
-        self.distributor.observe("user_joined_room", self.user_joined_room)
-
-        self.waiting_for_join_list = {}
-
         self.store = hs.get_datastore()
         self.replication_layer = hs.get_replication_layer()
         self.state_handler = hs.get_state_handler()
@@ -84,28 +80,14 @@ class FederationHandler(BaseHandler):
         # When joining a room we need to queue any events for that room up
         self.room_queues = {}
 
-    def handle_new_event(self, event, destinations):
-        """ Takes in an event from the client to server side, that has already
-        been authed and handled by the state module, and sends it to any
-        remote home servers that may be interested.
-
-        Args:
-            event: The event to send
-            destinations: A list of destinations to send it to
-
-        Returns:
-            Deferred: Resolved when it has successfully been queued for
-            processing.
-        """
-
-        return self.replication_layer.send_pdu(event, destinations)
-
     @log_function
     @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, backfilled, state=None,
-                       auth_chain=None):
+    def on_receive_pdu(self, origin, pdu, state=None, auth_chain=None):
         """ Called by the ReplicationLayer when we have a new pdu. We need to
         do auth checks and put it through the StateHandler.
+
+        auth_chain and state are None if we already have the necessary state
+        and prev_events in the db
         """
         event = pdu
 
@@ -123,17 +105,25 @@ class FederationHandler(BaseHandler):
 
         # FIXME (erikj): Awful hack to make the case where we are not currently
         # in the room work
-        current_state = None
-        is_in_room = yield self.auth.check_host_in_room(
-            event.room_id,
-            self.server_name
-        )
-        if not is_in_room and not event.internal_metadata.is_outlier():
-            logger.debug("Got event for room we're not in.")
+        # If state and auth_chain are None, then we don't need to do this check
+        # as we already know we have enough state in the DB to handle this
+        # event.
+        if state and auth_chain and not event.internal_metadata.is_outlier():
+            is_in_room = yield self.auth.check_host_in_room(
+                event.room_id,
+                self.server_name
+            )
+        else:
+            is_in_room = True
+        if not is_in_room:
+            logger.info(
+                "Got event for room we're not in: %r %r",
+                event.room_id, event.event_id
+            )
 
             try:
                 event_stream_id, max_stream_id = yield self._persist_auth_tree(
-                    auth_chain, state, event
+                    origin, auth_chain, state, event
                 )
             except AuthError as e:
                 raise FederationError(
@@ -175,19 +165,13 @@ class FederationHandler(BaseHandler):
                     })
                     seen_ids.add(e.event_id)
 
-                yield self._handle_new_events(
-                    origin,
-                    event_infos,
-                    outliers=True
-                )
+                yield self._handle_new_events(origin, event_infos)
 
             try:
                 context, event_stream_id, max_stream_id = yield self._handle_new_event(
                     origin,
                     event,
                     state=state,
-                    backfilled=backfilled,
-                    current_state=current_state,
                 )
             except AuthError as e:
                 raise FederationError(
@@ -216,32 +200,42 @@ class FederationHandler(BaseHandler):
             except StoreError:
                 logger.exception("Failed to store room.")
 
-        if not backfilled:
-            extra_users = []
-            if event.type == EventTypes.Member:
-                target_user_id = event.state_key
-                target_user = UserID.from_string(target_user_id)
-                extra_users.append(target_user)
+        extra_users = []
+        if event.type == EventTypes.Member:
+            target_user_id = event.state_key
+            target_user = UserID.from_string(target_user_id)
+            extra_users.append(target_user)
 
-            with PreserveLoggingContext():
-                self.notifier.on_new_room_event(
-                    event, event_stream_id, max_stream_id,
-                    extra_users=extra_users
-                )
+        with PreserveLoggingContext():
+            self.notifier.on_new_room_event(
+                event, event_stream_id, max_stream_id,
+                extra_users=extra_users
+            )
 
         if event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
-                prev_state = context.current_state.get((event.type, event.state_key))
-                if not prev_state or prev_state.membership != Membership.JOIN:
-                    # Only fire user_joined_room if the user has acutally
-                    # joined the room. Don't bother if the user is just
-                    # changing their profile info.
+                # Only fire user_joined_room if the user has acutally
+                # joined the room. Don't bother if the user is just
+                # changing their profile info.
+                newly_joined = True
+                prev_state_id = context.prev_state_ids.get(
+                    (event.type, event.state_key)
+                )
+                if prev_state_id:
+                    prev_state = yield self.store.get_event(
+                        prev_state_id, allow_none=True,
+                    )
+                    if prev_state and prev_state.membership == Membership.JOIN:
+                        newly_joined = False
+
+                if newly_joined:
                     user = UserID.from_string(event.state_key)
                     yield user_joined_room(self.distributor, user, event.room_id)
 
+    @measure_func("_filter_events_for_server")
     @defer.inlineCallbacks
     def _filter_events_for_server(self, server_name, room_id, events):
-        event_to_state = yield self.store.get_state_for_events(
+        event_to_state_ids = yield self.store.get_state_ids_for_events(
             frozenset(e.event_id for e in events),
             types=(
                 (EventTypes.RoomHistoryVisibility, ""),
@@ -249,6 +243,30 @@ class FederationHandler(BaseHandler):
             )
         )
 
+        # We only want to pull out member events that correspond to the
+        # server's domain.
+
+        def check_match(id):
+            try:
+                return server_name == get_domain_from_id(id)
+            except:
+                return False
+
+        event_map = yield self.store.get_events([
+            e_id for key_to_eid in event_to_state_ids.values()
+            for key, e_id in key_to_eid
+            if key[0] != EventTypes.Member or check_match(key[1])
+        ])
+
+        event_to_state = {
+            e_id: {
+                key: event_map[inner_e_id]
+                for key, inner_e_id in key_to_eid.items()
+                if inner_e_id in event_map
+            }
+            for e_id, key_to_eid in event_to_state_ids.items()
+        }
+
         def redact_disallowed(event, state):
             if not state:
                 return event
@@ -265,7 +283,7 @@ class FederationHandler(BaseHandler):
                         if ev.type != EventTypes.Member:
                             continue
                         try:
-                            domain = UserID.from_string(ev.state_key).domain
+                            domain = get_domain_from_id(ev.state_key)
                         except:
                             continue
 
@@ -290,11 +308,15 @@ class FederationHandler(BaseHandler):
 
     @log_function
     @defer.inlineCallbacks
-    def backfill(self, dest, room_id, limit, extremities=[]):
+    def backfill(self, dest, room_id, limit, extremities):
         """ Trigger a backfill request to `dest` for the given `room_id`
+
+        This will attempt to get more events from the remote. This may return
+        be successfull and still return no events if the other side has no new
+        events to offer.
         """
-        if not extremities:
-            extremities = yield self.store.get_oldest_events_in_room(room_id)
+        if dest == self.server_name:
+            raise SynapseError(400, "Can't backfill from self.")
 
         events = yield self.replication_layer.backfill(
             dest,
@@ -303,6 +325,16 @@ class FederationHandler(BaseHandler):
             extremities=extremities,
         )
 
+        # Don't bother processing events we already have.
+        seen_events = yield self.store.have_events_in_timeline(
+            set(e.event_id for e in events)
+        )
+
+        events = [e for e in events if e.event_id not in seen_events]
+
+        if not events:
+            defer.returnValue([])
+
         event_map = {e.event_id: e for e in events}
 
         event_ids = set(e.event_id for e in events)
@@ -334,40 +366,73 @@ class FederationHandler(BaseHandler):
             state_events.update({s.event_id: s for s in state})
             events_to_state[e_id] = state
 
-        seen_events = yield self.store.have_events(
-            set(auth_events.keys()) | set(state_events.keys())
-        )
-
-        all_events = events + state_events.values() + auth_events.values()
         required_auth = set(
-            a_id for event in all_events for a_id, _ in event.auth_events
+            a_id
+            for event in events + state_events.values() + auth_events.values()
+            for a_id, _ in event.auth_events
         )
-
+        auth_events.update({
+            e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
+        })
         missing_auth = required_auth - set(auth_events)
-        results = yield defer.gatherResults(
-            [
-                self.replication_layer.get_pdu(
-                    [dest],
-                    event_id,
-                    outlier=True,
-                    timeout=10000,
+        failed_to_fetch = set()
+
+        # Try and fetch any missing auth events from both DB and remote servers.
+        # We repeatedly do this until we stop finding new auth events.
+        while missing_auth - failed_to_fetch:
+            logger.info("Missing auth for backfill: %r", missing_auth)
+            ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
+            auth_events.update(ret_events)
+
+            required_auth.update(
+                a_id for event in ret_events.values() for a_id, _ in event.auth_events
+            )
+            missing_auth = required_auth - set(auth_events)
+
+            if missing_auth - failed_to_fetch:
+                logger.info(
+                    "Fetching missing auth for backfill: %r",
+                    missing_auth - failed_to_fetch
                 )
-                for event_id in missing_auth
-            ],
-            consumeErrors=True
-        ).addErrback(unwrapFirstError)
-        auth_events.update({a.event_id: a for a in results})
+
+                results = yield preserve_context_over_deferred(defer.gatherResults(
+                    [
+                        preserve_fn(self.replication_layer.get_pdu)(
+                            [dest],
+                            event_id,
+                            outlier=True,
+                            timeout=10000,
+                        )
+                        for event_id in missing_auth - failed_to_fetch
+                    ],
+                    consumeErrors=True
+                )).addErrback(unwrapFirstError)
+                auth_events.update({a.event_id: a for a in results if a})
+                required_auth.update(
+                    a_id
+                    for event in results if event
+                    for a_id, _ in event.auth_events
+                )
+                missing_auth = required_auth - set(auth_events)
+
+                failed_to_fetch = missing_auth - set(auth_events)
+
+        seen_events = yield self.store.have_events(
+            set(auth_events.keys()) | set(state_events.keys())
+        )
 
         ev_infos = []
         for a in auth_events.values():
             if a.event_id in seen_events:
                 continue
+            a.internal_metadata.outlier = True
             ev_infos.append({
                 "event": a,
                 "auth_events": {
                     (auth_events[a_id].type, auth_events[a_id].state_key):
                     auth_events[a_id]
                     for a_id, _ in a.auth_events
+                    if a_id in auth_events
                 }
             })
 
@@ -379,23 +444,27 @@ class FederationHandler(BaseHandler):
                     (auth_events[a_id].type, auth_events[a_id].state_key):
                     auth_events[a_id]
                     for a_id, _ in event_map[e_id].auth_events
+                    if a_id in auth_events
                 }
             })
 
+        yield self._handle_new_events(
+            dest, ev_infos,
+            backfilled=True,
+        )
+
         events.sort(key=lambda e: e.depth)
 
         for event in events:
             if event in events_to_state:
                 continue
 
-            ev_infos.append({
-                "event": event,
-            })
-
-        yield self._handle_new_events(
-            dest, ev_infos,
-            backfilled=True,
-        )
+            # We store these one at a time since each event depends on the
+            # previous to work out the state.
+            # TODO: We can probably do something more clever here.
+            yield self._handle_new_event(
+                dest, event, backfilled=True,
+            )
 
         defer.returnValue(events)
 
@@ -419,6 +488,10 @@ class FederationHandler(BaseHandler):
         )
         max_depth = sorted_extremeties_tuple[0][1]
 
+        # We don't want to specify too many extremities as it causes the backfill
+        # request URI to be too long.
+        extremities = dict(sorted_extremeties_tuple[:5])
+
         if current_depth > max_depth:
             logger.debug(
                 "Not backfilling as we don't need to. %d < %d",
@@ -444,7 +517,7 @@ class FederationHandler(BaseHandler):
             joined_domains = {}
             for u, d in joined_users:
                 try:
-                    dom = UserID.from_string(u).domain
+                    dom = get_domain_from_id(u)
                     old_d = joined_domains.get(dom)
                     if old_d:
                         joined_domains[dom] = min(d, old_d)
@@ -459,7 +532,7 @@ class FederationHandler(BaseHandler):
 
         likely_domains = [
             domain for domain, depth in curr_domains
-            if domain is not self.server_name
+            if domain != self.server_name
         ]
 
         @defer.inlineCallbacks
@@ -467,11 +540,15 @@ class FederationHandler(BaseHandler):
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
-                    events = yield self.backfill(
+                    yield self.backfill(
                         dom, room_id,
                         limit=100,
                         extremities=[e for e in extremities.keys()]
                     )
+                    # If this succeeded then we probably already have the
+                    # appropriate stuff.
+                    # TODO: We can probably do something more intelligent here.
+                    defer.returnValue(True)
                 except SynapseError as e:
                     logger.info(
                         "Failed to backfill from %s because %s",
@@ -497,8 +574,6 @@ class FederationHandler(BaseHandler):
                     )
                     continue
 
-                if events:
-                    defer.returnValue(True)
             defer.returnValue(False)
 
         success = yield try_backfill(likely_domains)
@@ -513,12 +588,24 @@ class FederationHandler(BaseHandler):
 
         event_ids = list(extremities.keys())
 
-        states = yield defer.gatherResults([
-            self.state_handler.resolve_state_groups(room_id, [e])
+        states = yield preserve_context_over_deferred(defer.gatherResults([
+            preserve_fn(self.state_handler.resolve_state_groups)(room_id, [e])
             for e in event_ids
-        ])
+        ]))
         states = dict(zip(event_ids, [s[1] for s in states]))
 
+        state_map = yield self.store.get_events(
+            [e_id for ids in states.values() for e_id in ids],
+            get_prev_content=False
+        )
+        states = {
+            key: {
+                k: state_map[e_id]
+                for k, e_id in state_dict.items()
+                if e_id in state_map
+            } for key, state_dict in states.items()
+        }
+
         for e_id, _ in sorted_extremeties_tuple:
             likely_domains = get_domains_from_state(states[e_id])
 
@@ -628,7 +715,7 @@ class FederationHandler(BaseHandler):
                 pass
 
             event_stream_id, max_stream_id = yield self._persist_auth_tree(
-                auth_chain, state, event
+                origin, auth_chain, state, event
             )
 
             with PreserveLoggingContext():
@@ -647,7 +734,7 @@ class FederationHandler(BaseHandler):
                     continue
 
                 try:
-                    self.on_receive_pdu(origin, p, backfilled=False)
+                    self.on_receive_pdu(origin, p)
                 except:
                     logger.exception("Couldn't handle pdu")
 
@@ -670,11 +757,18 @@ class FederationHandler(BaseHandler):
             "state_key": user_id,
         })
 
-        event, context = yield self._create_new_client_event(
-            builder=builder,
-        )
+        try:
+            message_handler = self.hs.get_handlers().message_handler
+            event, context = yield message_handler._create_new_client_event(
+                builder=builder,
+            )
+        except AuthError as e:
+            logger.warn("Failed to create join %r because %s", event, e)
+            raise e
 
-        self.auth.check(event, auth_events=context.current_state)
+        # The remote hasn't signed it yet, obviously. We'll do the full checks
+        # when we get the event back in `on_send_join_request`
+        yield self.auth.check_from_context(event, context, do_sig_check=False)
 
         defer.returnValue(event)
 
@@ -720,39 +814,15 @@ class FederationHandler(BaseHandler):
                 user = UserID.from_string(event.state_key)
                 yield user_joined_room(self.distributor, user, event.room_id)
 
-        new_pdu = event
-
-        destinations = set()
-
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.JOIN:
-                        destinations.add(
-                            UserID.from_string(s.state_key).domain
-                        )
-            except:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
-
-        destinations.discard(origin)
-
-        logger.debug(
-            "on_send_join_request: Sending event: %s, signatures: %s",
-            event.event_id,
-            event.signatures,
-        )
-
-        self.replication_layer.send_pdu(new_pdu, destinations)
-
-        state_ids = [e.event_id for e in context.current_state.values()]
+        state_ids = context.prev_state_ids.values()
         auth_chain = yield self.store.get_auth_chain(set(
             [event.event_id] + state_ids
         ))
 
+        state = yield self.store.get_events(context.prev_state_ids.values())
+
         defer.returnValue({
-            "state": context.current_state.values(),
+            "state": state.values(),
             "auth_chain": auth_chain,
         })
 
@@ -765,6 +835,7 @@ class FederationHandler(BaseHandler):
         event = pdu
 
         event.internal_metadata.outlier = True
+        event.internal_metadata.invite_from_remote = True
 
         event.signatures.update(
             compute_event_signature(
@@ -779,7 +850,6 @@ class FederationHandler(BaseHandler):
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event,
             context=context,
-            backfilled=False,
         )
 
         target_user = UserID.from_string(event.state_key)
@@ -793,13 +863,19 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
-        origin, event = yield self._make_and_verify_event(
-            target_hosts,
-            room_id,
-            user_id,
-            "leave"
-        )
-        signed_event = self._sign_event(event)
+        try:
+            origin, event = yield self._make_and_verify_event(
+                target_hosts,
+                room_id,
+                user_id,
+                "leave"
+            )
+            signed_event = self._sign_event(event)
+        except SynapseError:
+            raise
+        except CodeMessageException as e:
+            logger.warn("Failed to reject invite: %s", e)
+            raise SynapseError(500, "Failed to reject invite")
 
         # Try the host we successfully got a response to /make_join/
         # request first.
@@ -809,17 +885,22 @@ class FederationHandler(BaseHandler):
         except ValueError:
             pass
 
-        yield self.replication_layer.send_leave(
-            target_hosts,
-            signed_event
-        )
+        try:
+            yield self.replication_layer.send_leave(
+                target_hosts,
+                signed_event
+            )
+        except SynapseError:
+            raise
+        except CodeMessageException as e:
+            logger.warn("Failed to reject invite: %s", e)
+            raise SynapseError(500, "Failed to reject invite")
 
         context = yield self.state_handler.compute_event_context(event)
 
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event,
             context=context,
-            backfilled=False,
         )
 
         target_user = UserID.from_string(event.state_key)
@@ -889,11 +970,18 @@ class FederationHandler(BaseHandler):
             "state_key": user_id,
         })
 
-        event, context = yield self._create_new_client_event(
+        message_handler = self.hs.get_handlers().message_handler
+        event, context = yield message_handler._create_new_client_event(
             builder=builder,
         )
 
-        self.auth.check(event, auth_events=context.current_state)
+        try:
+            # The remote hasn't signed it yet, obviously. We'll do the full checks
+            # when we get the event back in `on_send_leave_request`
+            yield self.auth.check_from_context(event, context, do_sig_check=False)
+        except AuthError as e:
+            logger.warn("Failed to create new leave %r because %s", event, e)
+            raise e
 
         defer.returnValue(event)
 
@@ -932,43 +1020,14 @@ class FederationHandler(BaseHandler):
                 event, event_stream_id, max_stream_id, extra_users=extra_users
             )
 
-        new_pdu = event
-
-        destinations = set()
-
-        for k, s in context.current_state.items():
-            try:
-                if k[0] == EventTypes.Member:
-                    if s.content["membership"] == Membership.LEAVE:
-                        destinations.add(
-                            UserID.from_string(s.state_key).domain
-                        )
-            except:
-                logger.warn(
-                    "Failed to get destination from event %s", s.event_id
-                )
-
-        destinations.discard(origin)
-
-        logger.debug(
-            "on_send_leave_request: Sending event: %s, signatures: %s",
-            event.event_id,
-            event.signatures,
-        )
-
-        self.replication_layer.send_pdu(new_pdu, destinations)
-
         defer.returnValue(None)
 
     @defer.inlineCallbacks
-    def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True):
+    def get_state_for_pdu(self, room_id, event_id):
+        """Returns the state at the event. i.e. not including said event.
+        """
         yield run_on_reactor()
 
-        if do_auth:
-            in_room = yield self.auth.check_host_in_room(room_id, origin)
-            if not in_room:
-                raise AuthError(403, "Host not in room.")
-
         state_groups = yield self.store.get_state_groups(
             room_id, [event_id]
         )
@@ -992,19 +1051,50 @@ class FederationHandler(BaseHandler):
 
             res = results.values()
             for event in res:
-                event.signatures.update(
-                    compute_event_signature(
-                        event,
-                        self.hs.hostname,
-                        self.hs.config.signing_key[0]
+                # We sign these again because there was a bug where we
+                # incorrectly signed things the first time round
+                if self.hs.is_mine_id(event.event_id):
+                    event.signatures.update(
+                        compute_event_signature(
+                            event,
+                            self.hs.hostname,
+                            self.hs.config.signing_key[0]
+                        )
                     )
-                )
 
             defer.returnValue(res)
         else:
             defer.returnValue([])
 
     @defer.inlineCallbacks
+    def get_state_ids_for_pdu(self, room_id, event_id):
+        """Returns the state at the event. i.e. not including said event.
+        """
+        yield run_on_reactor()
+
+        state_groups = yield self.store.get_state_groups_ids(
+            room_id, [event_id]
+        )
+
+        if state_groups:
+            _, state = state_groups.items().pop()
+            results = state
+
+            event = yield self.store.get_event(event_id)
+            if event and event.is_state():
+                # Get previous state
+                if "replaces_state" in event.unsigned:
+                    prev_id = event.unsigned["replaces_state"]
+                    if prev_id != event.event_id:
+                        results[(event.type, event.state_key)] = prev_id
+                else:
+                    del results[(event.type, event.state_key)]
+
+            defer.returnValue(results.values())
+        else:
+            defer.returnValue([])
+
+    @defer.inlineCallbacks
     @log_function
     def on_backfill_request(self, origin, room_id, pdu_list, limit):
         in_room = yield self.auth.check_host_in_room(room_id, origin)
@@ -1036,16 +1126,17 @@ class FederationHandler(BaseHandler):
         )
 
         if event:
-            # FIXME: This is a temporary work around where we occasionally
-            # return events slightly differently than when they were
-            # originally signed
-            event.signatures.update(
-                compute_event_signature(
-                    event,
-                    self.hs.hostname,
-                    self.hs.config.signing_key[0]
+            if self.hs.is_mine_id(event.event_id):
+                # FIXME: This is a temporary work around where we occasionally
+                # return events slightly differently than when they were
+                # originally signed
+                event.signatures.update(
+                    compute_event_signature(
+                        event,
+                        self.hs.hostname,
+                        self.hs.config.signing_key[0]
+                    )
                 )
-            )
 
             if do_auth:
                 in_room = yield self.auth.check_host_in_room(
@@ -1055,6 +1146,12 @@ class FederationHandler(BaseHandler):
                 if not in_room:
                     raise AuthError(403, "Host not in room.")
 
+                events = yield self._filter_events_for_server(
+                    origin, event.room_id, [event]
+                )
+
+                event = events[0]
+
             defer.returnValue(event)
         else:
             defer.returnValue(None)
@@ -1063,50 +1160,47 @@ class FederationHandler(BaseHandler):
     def get_min_depth_for_context(self, context):
         return self.store.get_min_depth(context)
 
-    @log_function
-    def user_joined_room(self, user, room_id):
-        waiters = self.waiting_for_join_list.get(
-            (user.to_string(), room_id),
-            []
-        )
-        while waiters:
-            waiters.pop().callback(None)
-
     @defer.inlineCallbacks
     @log_function
-    def _handle_new_event(self, origin, event, state=None, backfilled=False,
-                          current_state=None, auth_events=None):
-
-        outlier = event.internal_metadata.is_outlier()
-
+    def _handle_new_event(self, origin, event, state=None, auth_events=None,
+                          backfilled=False):
         context = yield self._prep_event(
             origin, event,
             state=state,
             auth_events=auth_events,
         )
 
-        if not backfilled and not event.internal_metadata.is_outlier():
+        if not event.internal_metadata.is_outlier():
             action_generator = ActionGenerator(self.hs)
             yield action_generator.handle_push_actions_for_event(
-                event, context, self
+                event, context
             )
 
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event,
             context=context,
             backfilled=backfilled,
-            is_new_state=(not outlier and not backfilled),
-            current_state=current_state,
         )
 
+        if not backfilled:
+            # this intentionally does not yield: we don't care about the result
+            # and don't need to wait for it.
+            preserve_fn(self.hs.get_pusherpool().on_new_notifications)(
+                event_stream_id, max_stream_id
+            )
+
         defer.returnValue((context, event_stream_id, max_stream_id))
 
     @defer.inlineCallbacks
-    def _handle_new_events(self, origin, event_infos, backfilled=False,
-                           outliers=False):
-        contexts = yield defer.gatherResults(
+    def _handle_new_events(self, origin, event_infos, backfilled=False):
+        """Creates the appropriate contexts and persists events. The events
+        should not depend on one another, e.g. this should be used to persist
+        a bunch of outliers, but not a chunk of individual events that depend
+        on each other for state calculations.
+        """
+        contexts = yield preserve_context_over_deferred(defer.gatherResults(
             [
-                self._prep_event(
+                preserve_fn(self._prep_event)(
                     origin,
                     ev_info["event"],
                     state=ev_info.get("state"),
@@ -1114,7 +1208,7 @@ class FederationHandler(BaseHandler):
                 )
                 for ev_info in event_infos
             ]
-        )
+        ))
 
         yield self.store.persist_events(
             [
@@ -1122,30 +1216,35 @@ class FederationHandler(BaseHandler):
                 for ev_info, context in itertools.izip(event_infos, contexts)
             ],
             backfilled=backfilled,
-            is_new_state=(not outliers and not backfilled),
         )
 
     @defer.inlineCallbacks
-    def _persist_auth_tree(self, auth_events, state, event):
+    def _persist_auth_tree(self, origin, auth_events, state, event):
         """Checks the auth chain is valid (and passes auth checks) for the
         state and event. Then persists the auth chain and state atomically.
         Persists the event seperately.
 
+        Will attempt to fetch missing auth events.
+
+        Args:
+            origin (str): Where the events came from
+            auth_events (list)
+            state (list)
+            event (Event)
+
         Returns:
             2-tuple of (event_stream_id, max_stream_id) from the persist_event
             call for `event`
         """
         events_to_context = {}
         for e in itertools.chain(auth_events, state):
-            ctx = yield self.state_handler.compute_event_context(
-                e, outlier=True,
-            )
-            events_to_context[e.event_id] = ctx
             e.internal_metadata.outlier = True
+            ctx = yield self.state_handler.compute_event_context(e)
+            events_to_context[e.event_id] = ctx
 
         event_map = {
             e.event_id: e
-            for e in auth_events
+            for e in itertools.chain(auth_events, state, [event])
         }
 
         create_event = None
@@ -1154,10 +1253,29 @@ class FederationHandler(BaseHandler):
                 create_event = e
                 break
 
+        missing_auth_events = set()
+        for e in itertools.chain(auth_events, state, [event]):
+            for e_id, _ in e.auth_events:
+                if e_id not in event_map:
+                    missing_auth_events.add(e_id)
+
+        for e_id in missing_auth_events:
+            m_ev = yield self.replication_layer.get_pdu(
+                [origin],
+                e_id,
+                outlier=True,
+                timeout=10000,
+            )
+            if m_ev and m_ev.event_id == e_id:
+                event_map[e_id] = m_ev
+            else:
+                logger.info("Failed to find auth event %r", e_id)
+
         for e in itertools.chain(auth_events, state, [event]):
             auth_for_e = {
                 (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id]
                 for e_id, _ in e.auth_events
+                if e_id in event_map
             }
             if create_event:
                 auth_for_e[(EventTypes.Create, "")] = create_event
@@ -1185,17 +1303,14 @@ class FederationHandler(BaseHandler):
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
             ],
-            is_new_state=False,
         )
 
         new_event_context = yield self.state_handler.compute_event_context(
-            event, old_state=state, outlier=False,
+            event, old_state=state
         )
 
         event_stream_id, max_stream_id = yield self.store.persist_event(
             event, new_event_context,
-            backfilled=False,
-            is_new_state=True,
             current_state=state,
         )
 
@@ -1203,14 +1318,19 @@ class FederationHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def _prep_event(self, origin, event, state=None, auth_events=None):
-        outlier = event.internal_metadata.is_outlier()
 
         context = yield self.state_handler.compute_event_context(
-            event, old_state=state, outlier=outlier,
+            event, old_state=state,
         )
 
         if not auth_events:
-            auth_events = context.current_state
+            auth_events_ids = yield self.auth.compute_auth_events(
+                event, context.prev_state_ids, for_verification=True,
+            )
+            auth_events = yield self.store.get_events(auth_events_ids)
+            auth_events = {
+                (e.type, e.state_key): e for e in auth_events.values()
+            }
 
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
@@ -1236,8 +1356,7 @@ class FederationHandler(BaseHandler):
             context.rejected = RejectedReason.AUTH_ERROR
 
         if event.type == EventTypes.GuestAccess:
-            full_context = yield self.store.get_current_state(room_id=event.room_id)
-            yield self.maybe_kick_guest_users(event, full_context)
+            yield self.maybe_kick_guest_users(event)
 
         defer.returnValue(context)
 
@@ -1305,6 +1424,11 @@ class FederationHandler(BaseHandler):
         current_state = set(e.event_id for e in auth_events.values())
         event_auth_events = set(e_id for e_id, _ in event.auth_events)
 
+        if event.is_state():
+            event_key = (event.type, event.state_key)
+        else:
+            event_key = None
+
         if event_auth_events - current_state:
             have_events = yield self.store.have_events(
                 event_auth_events - current_state
@@ -1378,9 +1502,9 @@ class FederationHandler(BaseHandler):
             # Do auth conflict res.
             logger.info("Different auth: %s", different_auth)
 
-            different_events = yield defer.gatherResults(
+            different_events = yield preserve_context_over_deferred(defer.gatherResults(
                 [
-                    self.store.get_event(
+                    preserve_fn(self.store.get_event)(
                         d,
                         allow_none=True,
                         allow_rejected=False,
@@ -1389,13 +1513,13 @@ class FederationHandler(BaseHandler):
                     if d in have_events and not have_events[d]
                 ],
                 consumeErrors=True
-            ).addErrback(unwrapFirstError)
+            )).addErrback(unwrapFirstError)
 
             if different_events:
                 local_view = dict(auth_events)
                 remote_view = dict(auth_events)
                 remote_view.update({
-                    (d.type, d.state_key): d for d in different_events
+                    (d.type, d.state_key): d for d in different_events if d
                 })
 
                 new_state, prev_state = self.state_handler.resolve_events(
@@ -1408,8 +1532,16 @@ class FederationHandler(BaseHandler):
                 current_state = set(e.event_id for e in auth_events.values())
                 different_auth = event_auth_events - current_state
 
-                context.current_state.update(auth_events)
-                context.state_group = None
+                context.current_state_ids = dict(context.current_state_ids)
+                context.current_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                    if k != event_key
+                })
+                context.prev_state_ids = dict(context.prev_state_ids)
+                context.prev_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                })
+                context.state_group = self.store.get_next_state_group()
 
         if different_auth and not event.internal_metadata.is_outlier():
             logger.info("Different auth after resolution: %s", different_auth)
@@ -1430,8 +1562,8 @@ class FederationHandler(BaseHandler):
 
             if do_resolution:
                 # 1. Get what we think is the auth chain.
-                auth_ids = self.auth.compute_auth_events(
-                    event, context.current_state
+                auth_ids = yield self.auth.compute_auth_events(
+                    event, context.prev_state_ids
                 )
                 local_auth_chain = yield self.store.get_auth_chain(auth_ids)
 
@@ -1487,13 +1619,22 @@ class FederationHandler(BaseHandler):
                 # 4. Look at rejects and their proofs.
                 # TODO.
 
-                context.current_state.update(auth_events)
-                context.state_group = None
+                context.current_state_ids = dict(context.current_state_ids)
+                context.current_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                    if k != event_key
+                })
+                context.prev_state_ids = dict(context.prev_state_ids)
+                context.prev_state_ids.update({
+                    k: a.event_id for k, a in auth_events.items()
+                })
+                context.state_group = self.store.get_next_state_group()
 
         try:
             self.auth.check(event, auth_events=auth_events)
-        except AuthError:
-            raise
+        except AuthError as e:
+            logger.warn("Failed auth resolution for %r because %s", event, e)
+            raise e
 
     @defer.inlineCallbacks
     def construct_auth_difference(self, local_auth, remote_auth):
@@ -1663,14 +1804,22 @@ class FederationHandler(BaseHandler):
         if (yield self.auth.check_host_in_room(room_id, self.hs.hostname)):
             builder = self.event_builder_factory.new(event_dict)
             EventValidator().validate_new(builder)
-            event, context = yield self._create_new_client_event(builder=builder)
+            message_handler = self.hs.get_handlers().message_handler
+            event, context = yield message_handler._create_new_client_event(
+                builder=builder
+            )
 
             event, context = yield self.add_display_name_to_third_party_invite(
                 event_dict, event, context
             )
 
-            self.auth.check(event, context.current_state)
-            yield self._check_signature(event, auth_events=context.current_state)
+            try:
+                yield self.auth.check_from_context(event, context)
+            except AuthError as e:
+                logger.warn("Denying new third party invite %r because %s", event, e)
+                raise e
+
+            yield self._check_signature(event, context)
             member_handler = self.hs.get_handlers().room_member_handler
             yield member_handler.send_membership_event(None, event, context)
         else:
@@ -1686,7 +1835,8 @@ class FederationHandler(BaseHandler):
     def on_exchange_third_party_invite_request(self, origin, room_id, event_dict):
         builder = self.event_builder_factory.new(event_dict)
 
-        event, context = yield self._create_new_client_event(
+        message_handler = self.hs.get_handlers().message_handler
+        event, context = yield message_handler._create_new_client_event(
             builder=builder,
         )
 
@@ -1694,8 +1844,12 @@ class FederationHandler(BaseHandler):
             event_dict, event, context
         )
 
-        self.auth.check(event, auth_events=context.current_state)
-        yield self._check_signature(event, auth_events=context.current_state)
+        try:
+            self.auth.check_from_context(event, context)
+        except AuthError as e:
+            logger.warn("Denying third party invite %r because %s", event, e)
+            raise e
+        yield self._check_signature(event, context)
 
         returned_invite = yield self.send_invite(origin, event)
         # TODO: Make sure the signatures actually are correct.
@@ -1709,41 +1863,56 @@ class FederationHandler(BaseHandler):
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"]
         )
-        original_invite = context.current_state.get(key)
-        if not original_invite:
+        original_invite = None
+        original_invite_id = context.prev_state_ids.get(key)
+        if original_invite_id:
+            original_invite = yield self.store.get_event(
+                original_invite_id, allow_none=True
+            )
+        if original_invite:
+            display_name = original_invite.content["display_name"]
+            event_dict["content"]["third_party_invite"]["display_name"] = display_name
+        else:
             logger.info(
-                "Could not find invite event for third_party_invite - "
-                "discarding: %s" % (event_dict,)
+                "Could not find invite event for third_party_invite: %r",
+                event_dict
             )
-            return
+            # We don't discard here as this is not the appropriate place to do
+            # auth checks. If we need the invite and don't have it then the
+            # auth check code will explode appropriately.
 
-        display_name = original_invite.content["display_name"]
-        event_dict["content"]["third_party_invite"]["display_name"] = display_name
         builder = self.event_builder_factory.new(event_dict)
         EventValidator().validate_new(builder)
-        event, context = yield self._create_new_client_event(builder=builder)
+        message_handler = self.hs.get_handlers().message_handler
+        event, context = yield message_handler._create_new_client_event(builder=builder)
         defer.returnValue((event, context))
 
     @defer.inlineCallbacks
-    def _check_signature(self, event, auth_events):
+    def _check_signature(self, event, context):
         """
         Checks that the signature in the event is consistent with its invite.
-        :param event (Event): The m.room.member event to check
-        :param auth_events (dict<(event type, state_key), event>)
 
-        :raises
-            AuthError if signature didn't match any keys, or key has been
+        Args:
+            event (Event): The m.room.member event to check
+            context (EventContext):
+
+        Raises:
+            AuthError: if signature didn't match any keys, or key has been
                 revoked,
-            SynapseError if a transient error meant a key couldn't be checked
+            SynapseError: if a transient error meant a key couldn't be checked
                 for revocation.
         """
         signed = event.content["third_party_invite"]["signed"]
         token = signed["token"]
 
-        invite_event = auth_events.get(
+        invite_event_id = context.prev_state_ids.get(
             (EventTypes.ThirdPartyInvite, token,)
         )
 
+        invite_event = None
+        if invite_event_id:
+            invite_event = yield self.store.get_event(invite_event_id, allow_none=True)
+
         if not invite_event:
             raise AuthError(403, "Could not find invite")
 
@@ -1776,12 +1945,13 @@ class FederationHandler(BaseHandler):
         """
         Checks whether public_key has been revoked.
 
-        :param public_key (str): base-64 encoded public key.
-        :param url (str): Key revocation URL.
+        Args:
+            public_key (str): base-64 encoded public key.
+            url (str): Key revocation URL.
 
-        :raises
-            AuthError if they key has been revoked.
-            SynapseError if a transient error meant a key couldn't be checked
+        Raises:
+            AuthError: if they key has been revoked.
+            SynapseError: if a transient error meant a key couldn't be checked
                 for revocation.
         """
         try: