summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-06-26 22:34:41 +0100
committerRichard van der Hoff <richard@matrix.org>2019-06-26 22:34:41 +0100
commita4daa899ec4cd195fc10936f68df5c78314b366c (patch)
tree35e88ff388b0f7652773a79930b732aa04f16bde /synapse/handlers/federation.py
parentchangelog (diff)
parentImprove docs on choosing server_name (#5558) (diff)
downloadsynapse-a4daa899ec4cd195fc10936f68df5c78314b366c.tar.xz
Merge branch 'develop' into rav/saml2_client
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py974
1 files changed, 457 insertions, 517 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index ac5ca79143..02d397c498 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -1,6 +1,7 @@
 # -*- coding: utf-8 -*-
 # Copyright 2014-2016 OpenMarket Ltd
-# Copyright 2018 New Vector Ltd
+# Copyright 2017-2018 New Vector Ltd
+# Copyright 2019 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -33,6 +34,7 @@ from synapse.api.constants import EventTypes, Membership, RejectedReason
 from synapse.api.errors import (
     AuthError,
     CodeMessageException,
+    Codes,
     FederationDeniedError,
     FederationError,
     RequestSendFailed,
@@ -80,7 +82,7 @@ def shortstr(iterable, maxitems=5):
     items = list(itertools.islice(iterable, maxitems + 1))
     if len(items) <= maxitems:
         return str(items)
-    return u"[" + u", ".join(repr(r) for r in items[:maxitems]) + u", ...]"
+    return "[" + ", ".join(repr(r) for r in items[:maxitems]) + ", ...]"
 
 
 class FederationHandler(BaseHandler):
@@ -113,24 +115,24 @@ class FederationHandler(BaseHandler):
         self.config = hs.config
         self.http_client = hs.get_simple_http_client()
 
-        self._send_events_to_master = (
-            ReplicationFederationSendEventsRestServlet.make_client(hs)
+        self._send_events_to_master = ReplicationFederationSendEventsRestServlet.make_client(
+            hs
         )
-        self._notify_user_membership_change = (
-            ReplicationUserJoinedLeftRoomRestServlet.make_client(hs)
+        self._notify_user_membership_change = ReplicationUserJoinedLeftRoomRestServlet.make_client(
+            hs
         )
-        self._clean_room_for_join_client = (
-            ReplicationCleanRoomRestServlet.make_client(hs)
+        self._clean_room_for_join_client = ReplicationCleanRoomRestServlet.make_client(
+            hs
         )
 
         # When joining a room we need to queue any events for that room up
         self.room_queues = {}
         self._room_pdu_linearizer = Linearizer("fed_room_pdu")
 
+        self.third_party_event_rules = hs.get_third_party_event_rules()
+
     @defer.inlineCallbacks
-    def on_receive_pdu(
-            self, origin, pdu, sent_to_us_directly=False,
-    ):
+    def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
         """ Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
@@ -147,26 +149,19 @@ class FederationHandler(BaseHandler):
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        logger.info(
-            "[%s %s] handling received PDU: %s",
-            room_id, event_id, pdu,
-        )
+        logger.info("[%s %s] handling received PDU: %s", room_id, event_id, pdu)
 
         # We reprocess pdus when we have seen them only as outliers
         existing = yield self.store.get_event(
-            event_id,
-            allow_none=True,
-            allow_rejected=True,
+            event_id, allow_none=True, allow_rejected=True
         )
 
         # FIXME: Currently we fetch an event again when we already have it
         # if it has been marked as an outlier.
 
-        already_seen = (
-            existing and (
-                not existing.internal_metadata.is_outlier()
-                or pdu.internal_metadata.is_outlier()
-            )
+        already_seen = existing and (
+            not existing.internal_metadata.is_outlier()
+            or pdu.internal_metadata.is_outlier()
         )
         if already_seen:
             logger.debug("[%s %s]: Already seen pdu", room_id, event_id)
@@ -178,20 +173,19 @@ class FederationHandler(BaseHandler):
         try:
             self._sanity_check_event(pdu)
         except SynapseError as err:
-            logger.warn("[%s %s] Received event failed sanity checks", room_id, event_id)
-            raise FederationError(
-                "ERROR",
-                err.code,
-                err.msg,
-                affected=pdu.event_id,
+            logger.warn(
+                "[%s %s] Received event failed sanity checks", room_id, event_id
             )
+            raise FederationError("ERROR", err.code, err.msg, affected=pdu.event_id)
 
         # If we are currently in the process of joining this room, then we
         # queue up events for later processing.
         if room_id in self.room_queues:
             logger.info(
                 "[%s %s] Queuing PDU from %s for now: join in progress",
-                room_id, event_id, origin,
+                room_id,
+                event_id,
+                origin,
             )
             self.room_queues[room_id].append((pdu, origin))
             return
@@ -202,14 +196,13 @@ class FederationHandler(BaseHandler):
         #
         # Note that if we were never in the room then we would have already
         # dropped the event, since we wouldn't know the room version.
-        is_in_room = yield self.auth.check_host_in_room(
-            room_id,
-            self.server_name
-        )
+        is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
         if not is_in_room:
             logger.info(
                 "[%s %s] Ignoring PDU from %s as we're not in the room",
-                room_id, event_id, origin,
+                room_id,
+                event_id,
+                origin,
             )
             defer.returnValue(None)
 
@@ -219,14 +212,9 @@ class FederationHandler(BaseHandler):
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
-            min_depth = yield self.get_min_depth_for_context(
-                pdu.room_id
-            )
+            min_depth = yield self.get_min_depth_for_context(pdu.room_id)
 
-            logger.debug(
-                "[%s %s] min_depth: %d",
-                room_id, event_id, min_depth,
-            )
+            logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
 
             prevs = set(pdu.prev_event_ids())
             seen = yield self.store.have_seen_events(prevs)
@@ -244,12 +232,17 @@ class FederationHandler(BaseHandler):
                     # at a time.
                     logger.info(
                         "[%s %s] Acquiring room lock to fetch %d missing prev_events: %s",
-                        room_id, event_id, len(missing_prevs), shortstr(missing_prevs),
+                        room_id,
+                        event_id,
+                        len(missing_prevs),
+                        shortstr(missing_prevs),
                     )
                     with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
                         logger.info(
                             "[%s %s] Acquired room lock to fetch %d missing prev_events",
-                            room_id, event_id, len(missing_prevs),
+                            room_id,
+                            event_id,
+                            len(missing_prevs),
                         )
 
                         yield self._get_missing_events_for_pdu(
@@ -263,12 +256,16 @@ class FederationHandler(BaseHandler):
                         if not prevs - seen:
                             logger.info(
                                 "[%s %s] Found all missing prev_events",
-                                room_id, event_id,
+                                room_id,
+                                event_id,
                             )
                 elif missing_prevs:
                     logger.info(
                         "[%s %s] Not recursively fetching %d missing prev_events: %s",
-                        room_id, event_id, len(missing_prevs), shortstr(missing_prevs),
+                        room_id,
+                        event_id,
+                        len(missing_prevs),
+                        shortstr(missing_prevs),
                     )
 
             if prevs - seen:
@@ -299,7 +296,10 @@ class FederationHandler(BaseHandler):
                 if sent_to_us_directly:
                     logger.warn(
                         "[%s %s] Rejecting: failed to fetch %d prev events: %s",
-                        room_id, event_id, len(prevs - seen), shortstr(prevs - seen)
+                        room_id,
+                        event_id,
+                        len(prevs - seen),
+                        shortstr(prevs - seen),
                     )
                     raise FederationError(
                         "ERROR",
@@ -314,9 +314,7 @@ class FederationHandler(BaseHandler):
                 # Calculate the state after each of the previous events, and
                 # resolve them to find the correct state at the current event.
                 auth_chains = set()
-                event_map = {
-                    event_id: pdu,
-                }
+                event_map = {event_id: pdu}
                 try:
                     # Get the state of the events we know about
                     ours = yield self.store.get_state_groups_ids(room_id, seen)
@@ -333,7 +331,9 @@ class FederationHandler(BaseHandler):
                     for p in prevs - seen:
                         logger.info(
                             "[%s %s] Requesting state at missing prev_event %s",
-                            room_id, event_id, p,
+                            room_id,
+                            event_id,
+                            p,
                         )
 
                         room_version = yield self.store.get_room_version(room_id)
@@ -344,19 +344,19 @@ class FederationHandler(BaseHandler):
                             # by the get_pdu_cache in federation_client.
                             remote_state, got_auth_chain = (
                                 yield self.federation_client.get_state_for_room(
-                                    origin, room_id, p,
+                                    origin, room_id, p
                                 )
                             )
 
                             # we want the state *after* p; get_state_for_room returns the
                             # state *before* p.
                             remote_event = yield self.federation_client.get_pdu(
-                                [origin], p, room_version, outlier=True,
+                                [origin], p, room_version, outlier=True
                             )
 
                             if remote_event is None:
                                 raise Exception(
-                                    "Unable to get missing prev_event %s" % (p, )
+                                    "Unable to get missing prev_event %s" % (p,)
                                 )
 
                             if remote_event.is_state():
@@ -376,7 +376,9 @@ class FederationHandler(BaseHandler):
                                 event_map[x.event_id] = x
 
                     state_map = yield resolve_events_with_store(
-                        room_version, state_maps, event_map,
+                        room_version,
+                        state_maps,
+                        event_map,
                         state_res_store=StateResolutionStore(self.store),
                     )
 
@@ -392,15 +394,15 @@ class FederationHandler(BaseHandler):
                     )
                     event_map.update(evs)
 
-                    state = [
-                        event_map[e] for e in six.itervalues(state_map)
-                    ]
+                    state = [event_map[e] for e in six.itervalues(state_map)]
                     auth_chain = list(auth_chains)
                 except Exception:
                     logger.warn(
                         "[%s %s] Error attempting to resolve state at missing "
                         "prev_events",
-                        room_id, event_id, exc_info=True,
+                        room_id,
+                        event_id,
+                        exc_info=True,
                     )
                     raise FederationError(
                         "ERROR",
@@ -410,10 +412,7 @@ class FederationHandler(BaseHandler):
                     )
 
         yield self._process_received_pdu(
-            origin,
-            pdu,
-            state=state,
-            auth_chain=auth_chain,
+            origin, pdu, state=state, auth_chain=auth_chain
         )
 
     @defer.inlineCallbacks
@@ -443,7 +442,10 @@ class FederationHandler(BaseHandler):
 
         logger.info(
             "[%s %s]: Requesting missing events between %s and %s",
-            room_id, event_id, shortstr(latest), event_id,
+            room_id,
+            event_id,
+            shortstr(latest),
+            event_id,
         )
 
         # XXX: we set timeout to 10s to help workaround
@@ -494,19 +496,29 @@ class FederationHandler(BaseHandler):
         #
         # All that said: Let's try increasing the timout to 60s and see what happens.
 
-        missing_events = yield self.federation_client.get_missing_events(
-            origin,
-            room_id,
-            earliest_events_ids=list(latest),
-            latest_events=[pdu],
-            limit=10,
-            min_depth=min_depth,
-            timeout=60000,
-        )
+        try:
+            missing_events = yield self.federation_client.get_missing_events(
+                origin,
+                room_id,
+                earliest_events_ids=list(latest),
+                latest_events=[pdu],
+                limit=10,
+                min_depth=min_depth,
+                timeout=60000,
+            )
+        except RequestSendFailed as e:
+            # We failed to get the missing events, but since we need to handle
+            # the case of `get_missing_events` not returning the necessary
+            # events anyway, it is safe to simply log the error and continue.
+            logger.warn("[%s %s]: Failed to get prev_events: %s", room_id, event_id, e)
+            return
 
         logger.info(
             "[%s %s]: Got %d prev_events: %s",
-            room_id, event_id, len(missing_events), shortstr(missing_events),
+            room_id,
+            event_id,
+            len(missing_events),
+            shortstr(missing_events),
         )
 
         # We want to sort these by depth so we process them and
@@ -516,20 +528,20 @@ class FederationHandler(BaseHandler):
         for ev in missing_events:
             logger.info(
                 "[%s %s] Handling received prev_event %s",
-                room_id, event_id, ev.event_id,
+                room_id,
+                event_id,
+                ev.event_id,
             )
             with logcontext.nested_logging_context(ev.event_id):
                 try:
-                    yield self.on_receive_pdu(
-                        origin,
-                        ev,
-                        sent_to_us_directly=False,
-                    )
+                    yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
                 except FederationError as e:
                     if e.code == 403:
                         logger.warn(
                             "[%s %s] Received prev_event %s failed history check.",
-                            room_id, event_id, ev.event_id,
+                            room_id,
+                            event_id,
+                            ev.event_id,
                         )
                     else:
                         raise
@@ -542,10 +554,7 @@ class FederationHandler(BaseHandler):
         room_id = event.room_id
         event_id = event.event_id
 
-        logger.debug(
-            "[%s %s] Processing event: %s",
-            room_id, event_id, event,
-        )
+        logger.debug("[%s %s] Processing event: %s", room_id, event_id, event)
 
         event_ids = set()
         if state:
@@ -567,43 +576,32 @@ class FederationHandler(BaseHandler):
                 e.internal_metadata.outlier = True
                 auth_ids = e.auth_event_ids()
                 auth = {
-                    (e.type, e.state_key): e for e in auth_chain
+                    (e.type, e.state_key): e
+                    for e in auth_chain
                     if e.event_id in auth_ids or e.type == EventTypes.Create
                 }
-                event_infos.append({
-                    "event": e,
-                    "auth_events": auth,
-                })
+                event_infos.append({"event": e, "auth_events": auth})
                 seen_ids.add(e.event_id)
 
             logger.info(
                 "[%s %s] persisting newly-received auth/state events %s",
-                room_id, event_id, [e["event"].event_id for e in event_infos]
+                room_id,
+                event_id,
+                [e["event"].event_id for e in event_infos],
             )
             yield self._handle_new_events(origin, event_infos)
 
         try:
-            context = yield self._handle_new_event(
-                origin,
-                event,
-                state=state,
-            )
+            context = yield self._handle_new_event(origin, event, state=state)
         except AuthError as e:
-            raise FederationError(
-                "ERROR",
-                e.code,
-                e.msg,
-                affected=event.event_id,
-            )
+            raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
         room = yield self.store.get_room(room_id)
 
         if not room:
             try:
                 yield self.store.store_room(
-                    room_id=room_id,
-                    room_creator_user_id="",
-                    is_public=False,
+                    room_id=room_id, room_creator_user_id="", is_public=False
                 )
             except StoreError:
                 logger.exception("Failed to store room.")
@@ -617,12 +615,10 @@ class FederationHandler(BaseHandler):
 
                 prev_state_ids = yield context.get_prev_state_ids(self.store)
 
-                prev_state_id = prev_state_ids.get(
-                    (event.type, event.state_key)
-                )
+                prev_state_id = prev_state_ids.get((event.type, event.state_key))
                 if prev_state_id:
                     prev_state = yield self.store.get_event(
-                        prev_state_id, allow_none=True,
+                        prev_state_id, allow_none=True
                     )
                     if prev_state and prev_state.membership == Membership.JOIN:
                         newly_joined = False
@@ -653,10 +649,7 @@ class FederationHandler(BaseHandler):
         room_version = yield self.store.get_room_version(room_id)
 
         events = yield self.federation_client.backfill(
-            dest,
-            room_id,
-            limit=limit,
-            extremities=extremities,
+            dest, room_id, limit=limit, extremities=extremities
         )
 
         # ideally we'd sanity check the events here for excess prev_events etc,
@@ -683,16 +676,9 @@ class FederationHandler(BaseHandler):
 
         event_ids = set(e.event_id for e in events)
 
-        edges = [
-            ev.event_id
-            for ev in events
-            if set(ev.prev_event_ids()) - event_ids
-        ]
+        edges = [ev.event_id for ev in events if set(ev.prev_event_ids()) - event_ids]
 
-        logger.info(
-            "backfill: Got %d events with %d edges",
-            len(events), len(edges),
-        )
+        logger.info("backfill: Got %d events with %d edges", len(events), len(edges))
 
         # For each edge get the current state.
 
@@ -701,9 +687,7 @@ class FederationHandler(BaseHandler):
         events_to_state = {}
         for e_id in edges:
             state, auth = yield self.federation_client.get_state_for_room(
-                destination=dest,
-                room_id=room_id,
-                event_id=e_id
+                destination=dest, room_id=room_id, event_id=e_id
             )
             auth_events.update({a.event_id: a for a in auth})
             auth_events.update({s.event_id: s for s in state})
@@ -712,12 +696,14 @@ class FederationHandler(BaseHandler):
 
         required_auth = set(
             a_id
-            for event in events + list(state_events.values()) + list(auth_events.values())
+            for event in events
+            + list(state_events.values())
+            + list(auth_events.values())
             for a_id in event.auth_event_ids()
         )
-        auth_events.update({
-            e_id: event_map[e_id] for e_id in required_auth if e_id in event_map
-        })
+        auth_events.update(
+            {e_id: event_map[e_id] for e_id in required_auth if e_id in event_map}
+        )
         missing_auth = required_auth - set(auth_events)
         failed_to_fetch = set()
 
@@ -736,27 +722,30 @@ class FederationHandler(BaseHandler):
             if missing_auth - failed_to_fetch:
                 logger.info(
                     "Fetching missing auth for backfill: %r",
-                    missing_auth - failed_to_fetch
+                    missing_auth - failed_to_fetch,
                 )
 
-                results = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-                    [
-                        logcontext.run_in_background(
-                            self.federation_client.get_pdu,
-                            [dest],
-                            event_id,
-                            room_version=room_version,
-                            outlier=True,
-                            timeout=10000,
-                        )
-                        for event_id in missing_auth - failed_to_fetch
-                    ],
-                    consumeErrors=True
-                )).addErrback(unwrapFirstError)
+                results = yield logcontext.make_deferred_yieldable(
+                    defer.gatherResults(
+                        [
+                            logcontext.run_in_background(
+                                self.federation_client.get_pdu,
+                                [dest],
+                                event_id,
+                                room_version=room_version,
+                                outlier=True,
+                                timeout=10000,
+                            )
+                            for event_id in missing_auth - failed_to_fetch
+                        ],
+                        consumeErrors=True,
+                    )
+                ).addErrback(unwrapFirstError)
                 auth_events.update({a.event_id: a for a in results if a})
                 required_auth.update(
                     a_id
-                    for event in results if event
+                    for event in results
+                    if event
                     for a_id in event.auth_event_ids()
                 )
                 missing_auth = required_auth - set(auth_events)
@@ -788,15 +777,19 @@ class FederationHandler(BaseHandler):
                 continue
 
             a.internal_metadata.outlier = True
-            ev_infos.append({
-                "event": a,
-                "auth_events": {
-                    (auth_events[a_id].type, auth_events[a_id].state_key):
-                    auth_events[a_id]
-                    for a_id in a.auth_event_ids()
-                    if a_id in auth_events
+            ev_infos.append(
+                {
+                    "event": a,
+                    "auth_events": {
+                        (
+                            auth_events[a_id].type,
+                            auth_events[a_id].state_key,
+                        ): auth_events[a_id]
+                        for a_id in a.auth_event_ids()
+                        if a_id in auth_events
+                    },
                 }
-            })
+            )
 
         # Step 1b: persist the events in the chunk we fetched state for (i.e.
         # the backwards extremities) as non-outliers.
@@ -804,23 +797,24 @@ class FederationHandler(BaseHandler):
             # For paranoia we ensure that these events are marked as
             # non-outliers
             ev = event_map[e_id]
-            assert(not ev.internal_metadata.is_outlier())
-
-            ev_infos.append({
-                "event": ev,
-                "state": events_to_state[e_id],
-                "auth_events": {
-                    (auth_events[a_id].type, auth_events[a_id].state_key):
-                    auth_events[a_id]
-                    for a_id in ev.auth_event_ids()
-                    if a_id in auth_events
+            assert not ev.internal_metadata.is_outlier()
+
+            ev_infos.append(
+                {
+                    "event": ev,
+                    "state": events_to_state[e_id],
+                    "auth_events": {
+                        (
+                            auth_events[a_id].type,
+                            auth_events[a_id].state_key,
+                        ): auth_events[a_id]
+                        for a_id in ev.auth_event_ids()
+                        if a_id in auth_events
+                    },
                 }
-            })
+            )
 
-        yield self._handle_new_events(
-            dest, ev_infos,
-            backfilled=True,
-        )
+        yield self._handle_new_events(dest, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -831,14 +825,12 @@ class FederationHandler(BaseHandler):
 
             # For paranoia we ensure that these events are marked as
             # non-outliers
-            assert(not event.internal_metadata.is_outlier())
+            assert not event.internal_metadata.is_outlier()
 
             # We store these one at a time since each event depends on the
             # previous to work out the state.
             # TODO: We can probably do something more clever here.
-            yield self._handle_new_event(
-                dest, event, backfilled=True,
-            )
+            yield self._handle_new_event(dest, event, backfilled=True)
 
         defer.returnValue(events)
 
@@ -847,9 +839,7 @@ class FederationHandler(BaseHandler):
         """Checks the database to see if we should backfill before paginating,
         and if so do.
         """
-        extremities = yield self.store.get_oldest_events_with_depth_in_room(
-            room_id
-        )
+        extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
 
         if not extremities:
             logger.debug("Not backfilling as no extremeties found.")
@@ -881,31 +871,27 @@ class FederationHandler(BaseHandler):
         #   state *before* the event, ignoring the special casing certain event
         #   types have.
 
-        forward_events = yield self.store.get_successor_events(
-            list(extremities),
-        )
+        forward_events = yield self.store.get_successor_events(list(extremities))
 
         extremities_events = yield self.store.get_events(
-            forward_events,
-            check_redacted=False,
-            get_prev_content=False,
+            forward_events, check_redacted=False, get_prev_content=False
         )
 
         # We set `check_history_visibility_only` as we might otherwise get false
         # positives from users having been erased.
         filtered_extremities = yield filter_events_for_server(
-            self.store, self.server_name, list(extremities_events.values()),
-            redact=False, check_history_visibility_only=True,
+            self.store,
+            self.server_name,
+            list(extremities_events.values()),
+            redact=False,
+            check_history_visibility_only=True,
         )
 
         if not filtered_extremities:
             defer.returnValue(False)
 
         # Check if we reached a point where we should start backfilling.
-        sorted_extremeties_tuple = sorted(
-            extremities.items(),
-            key=lambda e: -int(e[1])
-        )
+        sorted_extremeties_tuple = sorted(extremities.items(), key=lambda e: -int(e[1]))
         max_depth = sorted_extremeties_tuple[0][1]
 
         # We don't want to specify too many extremities as it causes the backfill
@@ -914,8 +900,7 @@ class FederationHandler(BaseHandler):
 
         if current_depth > max_depth:
             logger.debug(
-                "Not backfilling as we don't need to. %d < %d",
-                max_depth, current_depth,
+                "Not backfilling as we don't need to. %d < %d", max_depth, current_depth
             )
             return
 
@@ -940,8 +925,7 @@ class FederationHandler(BaseHandler):
             joined_users = [
                 (state_key, int(event.depth))
                 for (e_type, state_key), event in iteritems(state)
-                if e_type == EventTypes.Member
-                and event.membership == Membership.JOIN
+                if e_type == EventTypes.Member and event.membership == Membership.JOIN
             ]
 
             joined_domains = {}
@@ -961,8 +945,7 @@ class FederationHandler(BaseHandler):
         curr_domains = get_domains_from_state(curr_state)
 
         likely_domains = [
-            domain for domain, depth in curr_domains
-            if domain != self.server_name
+            domain for domain, depth in curr_domains if domain != self.server_name
         ]
 
         @defer.inlineCallbacks
@@ -971,28 +954,20 @@ class FederationHandler(BaseHandler):
             for dom in domains:
                 try:
                     yield self.backfill(
-                        dom, room_id,
-                        limit=100,
-                        extremities=extremities,
+                        dom, room_id, limit=100, extremities=extremities
                     )
                     # If this succeeded then we probably already have the
                     # appropriate stuff.
                     # TODO: We can probably do something more intelligent here.
                     defer.returnValue(True)
                 except SynapseError as e:
-                    logger.info(
-                        "Failed to backfill from %s because %s",
-                        dom, e,
-                    )
+                    logger.info("Failed to backfill from %s because %s", dom, e)
                     continue
                 except CodeMessageException as e:
                     if 400 <= e.code < 500:
                         raise
 
-                    logger.info(
-                        "Failed to backfill from %s because %s",
-                        dom, e,
-                    )
+                    logger.info("Failed to backfill from %s because %s", dom, e)
                     continue
                 except NotRetryingDestination as e:
                     logger.info(str(e))
@@ -1001,10 +976,7 @@ class FederationHandler(BaseHandler):
                     logger.info(e)
                     continue
                 except Exception as e:
-                    logger.exception(
-                        "Failed to backfill from %s because %s",
-                        dom, e,
-                    )
+                    logger.exception("Failed to backfill from %s because %s", dom, e)
                     continue
 
             defer.returnValue(False)
@@ -1025,10 +997,11 @@ class FederationHandler(BaseHandler):
         resolve = logcontext.preserve_fn(
             self.state_handler.resolve_state_groups_for_events
         )
-        states = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [resolve(room_id, [e]) for e in event_ids],
-            consumeErrors=True,
-        ))
+        states = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
+            )
+        )
 
         # dict[str, dict[tuple, str]], a map from event_id to state map of
         # event_ids.
@@ -1036,23 +1009,23 @@ class FederationHandler(BaseHandler):
 
         state_map = yield self.store.get_events(
             [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
-            get_prev_content=False
+            get_prev_content=False,
         )
         states = {
             key: {
                 k: state_map[e_id]
                 for k, e_id in iteritems(state_dict)
                 if e_id in state_map
-            } for key, state_dict in iteritems(states)
+            }
+            for key, state_dict in iteritems(states)
         }
 
         for e_id, _ in sorted_extremeties_tuple:
             likely_domains = get_domains_from_state(states[e_id])
 
-            success = yield try_backfill([
-                dom for dom, _ in likely_domains
-                if dom not in tried_domains
-            ])
+            success = yield try_backfill(
+                [dom for dom, _ in likely_domains if dom not in tried_domains]
+            )
             if success:
                 defer.returnValue(True)
 
@@ -1077,20 +1050,20 @@ class FederationHandler(BaseHandler):
             SynapseError if the event does not pass muster
         """
         if len(ev.prev_event_ids()) > 20:
-            logger.warn("Rejecting event %s which has %i prev_events",
-                        ev.event_id, len(ev.prev_event_ids()))
-            raise SynapseError(
-                http_client.BAD_REQUEST,
-                "Too many prev_events",
+            logger.warn(
+                "Rejecting event %s which has %i prev_events",
+                ev.event_id,
+                len(ev.prev_event_ids()),
             )
+            raise SynapseError(http_client.BAD_REQUEST, "Too many prev_events")
 
         if len(ev.auth_event_ids()) > 10:
-            logger.warn("Rejecting event %s which has %i auth_events",
-                        ev.event_id, len(ev.auth_event_ids()))
-            raise SynapseError(
-                http_client.BAD_REQUEST,
-                "Too many auth_events",
+            logger.warn(
+                "Rejecting event %s which has %i auth_events",
+                ev.event_id,
+                len(ev.auth_event_ids()),
             )
+            raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events")
 
     @defer.inlineCallbacks
     def send_invite(self, target_host, event):
@@ -1102,7 +1075,7 @@ class FederationHandler(BaseHandler):
             destination=target_host,
             room_id=event.room_id,
             event_id=event.event_id,
-            pdu=event
+            pdu=event,
         )
 
         defer.returnValue(pdu)
@@ -1111,8 +1084,7 @@ class FederationHandler(BaseHandler):
     def on_event_auth(self, event_id):
         event = yield self.store.get_event(event_id)
         auth = yield self.store.get_auth_chain(
-            [auth_id for auth_id in event.auth_event_ids()],
-            include_given=True
+            [auth_id for auth_id in event.auth_event_ids()], include_given=True
         )
         defer.returnValue([e for e in auth])
 
@@ -1138,15 +1110,13 @@ class FederationHandler(BaseHandler):
             joinee,
             "join",
             content,
-            params={
-                "ver": KNOWN_ROOM_VERSIONS,
-            },
+            params={"ver": KNOWN_ROOM_VERSIONS},
         )
 
         # This shouldn't happen, because the RoomMemberHandler has a
         # linearizer lock which only allows one operation per user per room
         # at a time - so this is just paranoia.
-        assert (room_id not in self.room_queues)
+        assert room_id not in self.room_queues
 
         self.room_queues[room_id] = []
 
@@ -1163,7 +1133,7 @@ class FederationHandler(BaseHandler):
             except ValueError:
                 pass
             ret = yield self.federation_client.send_join(
-                target_hosts, event, event_format_version,
+                target_hosts, event, event_format_version
             )
 
             origin = ret["origin"]
@@ -1182,17 +1152,13 @@ class FederationHandler(BaseHandler):
 
             try:
                 yield self.store.store_room(
-                    room_id=room_id,
-                    room_creator_user_id="",
-                    is_public=False
+                    room_id=room_id, room_creator_user_id="", is_public=False
                 )
             except Exception:
                 # FIXME
                 pass
 
-            yield self._persist_auth_tree(
-                origin, auth_chain, state, event
-            )
+            yield self._persist_auth_tree(origin, auth_chain, state, event)
 
             logger.debug("Finished joining %s to %s", joinee, room_id)
         finally:
@@ -1219,14 +1185,18 @@ class FederationHandler(BaseHandler):
         """
         for p, origin in room_queue:
             try:
-                logger.info("Processing queued PDU %s which was received "
-                            "while we were joining %s", p.event_id, p.room_id)
+                logger.info(
+                    "Processing queued PDU %s which was received "
+                    "while we were joining %s",
+                    p.event_id,
+                    p.room_id,
+                )
                 with logcontext.nested_logging_context(p.event_id):
                     yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
             except Exception as e:
                 logger.warn(
-                    "Error handling queued PDU %s from %s: %s",
-                    p.event_id, origin, e)
+                    "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
+                )
 
     @defer.inlineCallbacks
     @log_function
@@ -1247,21 +1217,30 @@ class FederationHandler(BaseHandler):
                 "room_id": room_id,
                 "sender": user_id,
                 "state_key": user_id,
-            }
+            },
         )
 
         try:
             event, context = yield self.event_creation_handler.create_new_client_event(
-                builder=builder,
+                builder=builder
             )
         except AuthError as e:
             logger.warn("Failed to create join %r because %s", event, e)
             raise e
 
+        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+            event, context
+        )
+        if not event_allowed:
+            logger.info("Creation of join %s forbidden by third-party rules", event)
+            raise SynapseError(
+                403, "This event is not allowed in this context", Codes.FORBIDDEN
+            )
+
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
         yield self.auth.check_from_context(
-            room_version, event, context, do_sig_check=False,
+            room_version, event, context, do_sig_check=False
         )
 
         defer.returnValue(event)
@@ -1296,9 +1275,16 @@ class FederationHandler(BaseHandler):
         # would introduce the danger of backwards-compatibility problems.
         event.internal_metadata.send_on_behalf_of = origin
 
-        context = yield self._handle_new_event(
-            origin, event
+        context = yield self._handle_new_event(origin, event)
+
+        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+            event, context
         )
+        if not event_allowed:
+            logger.info("Sending of join %s forbidden by third-party rules", event)
+            raise SynapseError(
+                403, "This event is not allowed in this context", Codes.FORBIDDEN
+            )
 
         logger.debug(
             "on_send_join_request: After _handle_new_event: %s, sigs: %s",
@@ -1318,10 +1304,7 @@ class FederationHandler(BaseHandler):
 
         state = yield self.store.get_events(list(prev_state_ids.values()))
 
-        defer.returnValue({
-            "state": list(state.values()),
-            "auth_chain": auth_chain,
-        })
+        defer.returnValue({"state": list(state.values()), "auth_chain": auth_chain})
 
     @defer.inlineCallbacks
     def on_invite_request(self, origin, pdu):
@@ -1342,7 +1325,7 @@ class FederationHandler(BaseHandler):
             raise SynapseError(403, "This server does not accept room invites")
 
         if not self.spam_checker.user_may_invite(
-            event.sender, event.state_key, event.room_id,
+            event.sender, event.state_key, event.room_id
         ):
             raise SynapseError(
                 403, "This user is not permitted to send invites to this server/user"
@@ -1354,26 +1337,23 @@ class FederationHandler(BaseHandler):
 
         sender_domain = get_domain_from_id(event.sender)
         if sender_domain != origin:
-            raise SynapseError(400, "The invite event was not from the server sending it")
+            raise SynapseError(
+                400, "The invite event was not from the server sending it"
+            )
 
         if not self.is_mine_id(event.state_key):
             raise SynapseError(400, "The invite event must be for this server")
 
         # block any attempts to invite the server notices mxid
         if event.state_key == self._server_notices_mxid:
-            raise SynapseError(
-                http_client.FORBIDDEN,
-                "Cannot invite this user",
-            )
+            raise SynapseError(http_client.FORBIDDEN, "Cannot invite this user")
 
         event.internal_metadata.outlier = True
         event.internal_metadata.out_of_band_membership = True
 
         event.signatures.update(
             compute_event_signature(
-                event.get_pdu_json(),
-                self.hs.hostname,
-                self.hs.config.signing_key[0]
+                event.get_pdu_json(), self.hs.hostname, self.hs.config.signing_key[0]
             )
         )
 
@@ -1385,10 +1365,7 @@ class FederationHandler(BaseHandler):
     @defer.inlineCallbacks
     def do_remotely_reject_invite(self, target_hosts, room_id, user_id):
         origin, event, event_format_version = yield self._make_and_verify_event(
-            target_hosts,
-            room_id,
-            user_id,
-            "leave"
+            target_hosts, room_id, user_id, "leave"
         )
         # Mark as outlier as we don't have any state for this event; we're not
         # even in the room.
@@ -1403,10 +1380,7 @@ class FederationHandler(BaseHandler):
         except ValueError:
             pass
 
-        yield self.federation_client.send_leave(
-            target_hosts,
-            event
-        )
+        yield self.federation_client.send_leave(target_hosts, event)
 
         context = yield self.state_handler.compute_event_context(event)
         yield self.persist_events_and_notify([(event, context)])
@@ -1414,25 +1388,21 @@ class FederationHandler(BaseHandler):
         defer.returnValue(event)
 
     @defer.inlineCallbacks
-    def _make_and_verify_event(self, target_hosts, room_id, user_id, membership,
-                               content={}, params=None):
+    def _make_and_verify_event(
+        self, target_hosts, room_id, user_id, membership, content={}, params=None
+    ):
         origin, event, format_ver = yield self.federation_client.make_membership_event(
-            target_hosts,
-            room_id,
-            user_id,
-            membership,
-            content,
-            params=params,
+            target_hosts, room_id, user_id, membership, content, params=params
         )
 
         logger.debug("Got response to make_%s: %s", membership, event)
 
         # We should assert some things.
         # FIXME: Do this in a nicer way
-        assert(event.type == EventTypes.Member)
-        assert(event.user_id == user_id)
-        assert(event.state_key == user_id)
-        assert(event.room_id == room_id)
+        assert event.type == EventTypes.Member
+        assert event.user_id == user_id
+        assert event.state_key == user_id
+        assert event.room_id == room_id
         defer.returnValue((origin, event, format_ver))
 
     @defer.inlineCallbacks
@@ -1451,18 +1421,27 @@ class FederationHandler(BaseHandler):
                 "room_id": room_id,
                 "sender": user_id,
                 "state_key": user_id,
-            }
+            },
         )
 
         event, context = yield self.event_creation_handler.create_new_client_event(
-            builder=builder,
+            builder=builder
         )
 
+        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+            event, context
+        )
+        if not event_allowed:
+            logger.warning("Creation of leave %s forbidden by third-party rules", event)
+            raise SynapseError(
+                403, "This event is not allowed in this context", Codes.FORBIDDEN
+            )
+
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
             yield self.auth.check_from_context(
-                room_version, event, context, do_sig_check=False,
+                room_version, event, context, do_sig_check=False
             )
         except AuthError as e:
             logger.warn("Failed to create new leave %r because %s", event, e)
@@ -1484,9 +1463,16 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        yield self._handle_new_event(
-            origin, event
+        context = yield self._handle_new_event(origin, event)
+
+        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+            event, context
         )
+        if not event_allowed:
+            logger.info("Sending of leave %s forbidden by third-party rules", event)
+            raise SynapseError(
+                403, "This event is not allowed in this context", Codes.FORBIDDEN
+            )
 
         logger.debug(
             "on_send_leave_request: After _handle_new_event: %s, sigs: %s",
@@ -1502,18 +1488,14 @@ class FederationHandler(BaseHandler):
         """
 
         event = yield self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id,
+            event_id, allow_none=False, check_room_id=room_id
         )
 
-        state_groups = yield self.store.get_state_groups(
-            room_id, [event_id]
-        )
+        state_groups = yield self.store.get_state_groups(room_id, [event_id])
 
         if state_groups:
             _, state = list(iteritems(state_groups)).pop()
-            results = {
-                (e.type, e.state_key): e for e in state
-            }
+            results = {(e.type, e.state_key): e for e in state}
 
             if event.is_state():
                 # Get previous state
@@ -1535,12 +1517,10 @@ class FederationHandler(BaseHandler):
         """Returns the state at the event. i.e. not including said event.
         """
         event = yield self.store.get_event(
-            event_id, allow_none=False, check_room_id=room_id,
+            event_id, allow_none=False, check_room_id=room_id
         )
 
-        state_groups = yield self.store.get_state_groups_ids(
-            room_id, [event_id]
-        )
+        state_groups = yield self.store.get_state_groups_ids(room_id, [event_id])
 
         if state_groups:
             _, state = list(state_groups.items()).pop()
@@ -1566,11 +1546,7 @@ class FederationHandler(BaseHandler):
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        events = yield self.store.get_backfill_events(
-            room_id,
-            pdu_list,
-            limit
-        )
+        events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
 
         events = yield filter_events_for_server(self.store, origin, events)
 
@@ -1594,22 +1570,15 @@ class FederationHandler(BaseHandler):
             AuthError if the server is not currently in the room
         """
         event = yield self.store.get_event(
-            event_id,
-            allow_none=True,
-            allow_rejected=True,
+            event_id, allow_none=True, allow_rejected=True
         )
 
         if event:
-            in_room = yield self.auth.check_host_in_room(
-                event.room_id,
-                origin
-            )
+            in_room = yield self.auth.check_host_in_room(event.room_id, origin)
             if not in_room:
                 raise AuthError(403, "Host not in room.")
 
-            events = yield filter_events_for_server(
-                self.store, origin, [event],
-            )
+            events = yield filter_events_for_server(self.store, origin, [event])
             event = events[0]
             defer.returnValue(event)
         else:
@@ -1619,13 +1588,11 @@ class FederationHandler(BaseHandler):
         return self.store.get_min_depth(context)
 
     @defer.inlineCallbacks
-    def _handle_new_event(self, origin, event, state=None, auth_events=None,
-                          backfilled=False):
+    def _handle_new_event(
+        self, origin, event, state=None, auth_events=None, backfilled=False
+    ):
         context = yield self._prep_event(
-            origin, event,
-            state=state,
-            auth_events=auth_events,
-            backfilled=backfilled,
+            origin, event, state=state, auth_events=auth_events, backfilled=backfilled
         )
 
         # reraise does not allow inlineCallbacks to preserve the stacktrace, so we
@@ -1638,15 +1605,13 @@ class FederationHandler(BaseHandler):
                 )
 
             yield self.persist_events_and_notify(
-                [(event, context)],
-                backfilled=backfilled,
+                [(event, context)], backfilled=backfilled
             )
             success = True
         finally:
             if not success:
                 logcontext.run_in_background(
-                    self.store.remove_push_actions_from_staging,
-                    event.event_id,
+                    self.store.remove_push_actions_from_staging, event.event_id
                 )
 
         defer.returnValue(context)
@@ -1674,12 +1639,15 @@ class FederationHandler(BaseHandler):
                 )
             defer.returnValue(res)
 
-        contexts = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                logcontext.run_in_background(prep, ev_info)
-                for ev_info in event_infos
-            ], consumeErrors=True,
-        ))
+        contexts = yield logcontext.make_deferred_yieldable(
+            defer.gatherResults(
+                [
+                    logcontext.run_in_background(prep, ev_info)
+                    for ev_info in event_infos
+                ],
+                consumeErrors=True,
+            )
+        )
 
         yield self.persist_events_and_notify(
             [
@@ -1714,8 +1682,7 @@ class FederationHandler(BaseHandler):
             events_to_context[e.event_id] = ctx
 
         event_map = {
-            e.event_id: e
-            for e in itertools.chain(auth_events, state, [event])
+            e.event_id: e for e in itertools.chain(auth_events, state, [event])
         }
 
         create_event = None
@@ -1730,7 +1697,7 @@ class FederationHandler(BaseHandler):
             raise SynapseError(400, "No create event in state")
 
         room_version = create_event.content.get(
-            "room_version", RoomVersions.V1.identifier,
+            "room_version", RoomVersions.V1.identifier
         )
 
         missing_auth_events = set()
@@ -1741,11 +1708,7 @@ class FederationHandler(BaseHandler):
 
         for e_id in missing_auth_events:
             m_ev = yield self.federation_client.get_pdu(
-                [origin],
-                e_id,
-                room_version=room_version,
-                outlier=True,
-                timeout=10000,
+                [origin], e_id, room_version=room_version, outlier=True, timeout=10000
             )
             if m_ev and m_ev.event_id == e_id:
                 event_map[e_id] = m_ev
@@ -1770,10 +1733,7 @@ class FederationHandler(BaseHandler):
                 # cause SynapseErrors in auth.check. We don't want to give up
                 # the attempt to federate altogether in such cases.
 
-                logger.warn(
-                    "Rejecting %s because %s",
-                    e.event_id, err.msg
-                )
+                logger.warn("Rejecting %s because %s", e.event_id, err.msg)
 
                 if e == event:
                     raise
@@ -1783,16 +1743,14 @@ class FederationHandler(BaseHandler):
             [
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
-            ],
+            ]
         )
 
         new_event_context = yield self.state_handler.compute_event_context(
             event, old_state=state
         )
 
-        yield self.persist_events_and_notify(
-            [(event, new_event_context)],
-        )
+        yield self.persist_events_and_notify([(event, new_event_context)])
 
     @defer.inlineCallbacks
     def _prep_event(self, origin, event, state, auth_events, backfilled):
@@ -1808,40 +1766,30 @@ class FederationHandler(BaseHandler):
         Returns:
             Deferred, which resolves to synapse.events.snapshot.EventContext
         """
-        context = yield self.state_handler.compute_event_context(
-            event, old_state=state,
-        )
+        context = yield self.state_handler.compute_event_context(event, old_state=state)
 
         if not auth_events:
             prev_state_ids = yield context.get_prev_state_ids(self.store)
             auth_events_ids = yield self.auth.compute_auth_events(
-                event, prev_state_ids, for_verification=True,
+                event, prev_state_ids, for_verification=True
             )
             auth_events = yield self.store.get_events(auth_events_ids)
-            auth_events = {
-                (e.type, e.state_key): e for e in auth_events.values()
-            }
+            auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
 
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
         if event.type == EventTypes.Member and not event.auth_event_ids():
             if len(event.prev_event_ids()) == 1 and event.depth < 5:
                 c = yield self.store.get_event(
-                    event.prev_event_ids()[0],
-                    allow_none=True,
+                    event.prev_event_ids()[0], allow_none=True
                 )
                 if c and c.type == EventTypes.Create:
                     auth_events[(c.type, c.state_key)] = c
 
         try:
-            yield self.do_auth(
-                origin, event, context, auth_events=auth_events
-            )
+            yield self.do_auth(origin, event, context, auth_events=auth_events)
         except AuthError as e:
-            logger.warn(
-                "[%s %s] Rejecting: %s",
-                event.room_id, event.event_id, e.msg
-            )
+            logger.warn("[%s %s] Rejecting: %s", event.room_id, event.event_id, e.msg)
 
             context.rejected = RejectedReason.AUTH_ERROR
 
@@ -1872,9 +1820,7 @@ class FederationHandler(BaseHandler):
         # "soft-fail" the event.
         do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
         if do_soft_fail_check:
-            extrem_ids = yield self.store.get_latest_event_ids_in_room(
-                event.room_id,
-            )
+            extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id)
 
             extrem_ids = set(extrem_ids)
             prev_event_ids = set(event.prev_event_ids())
@@ -1902,31 +1848,31 @@ class FederationHandler(BaseHandler):
                 # like bans, especially with state res v2.
 
                 state_sets = yield self.store.get_state_groups(
-                    event.room_id, extrem_ids,
+                    event.room_id, extrem_ids
                 )
                 state_sets = list(state_sets.values())
                 state_sets.append(state)
                 current_state_ids = yield self.state_handler.resolve_events(
-                    room_version, state_sets, event,
+                    room_version, state_sets, event
                 )
                 current_state_ids = {
                     k: e.event_id for k, e in iteritems(current_state_ids)
                 }
             else:
                 current_state_ids = yield self.state_handler.get_current_state_ids(
-                    event.room_id, latest_event_ids=extrem_ids,
+                    event.room_id, latest_event_ids=extrem_ids
                 )
 
             logger.debug(
                 "Doing soft-fail check for %s: state %s",
-                event.event_id, current_state_ids,
+                event.event_id,
+                current_state_ids,
             )
 
             # Now check if event pass auth against said current state
             auth_types = auth_types_for_event(event)
             current_state_ids = [
-                e for k, e in iteritems(current_state_ids)
-                if k in auth_types
+                e for k, e in iteritems(current_state_ids) if k in auth_types
             ]
 
             current_auth_events = yield self.store.get_events(current_state_ids)
@@ -1937,19 +1883,14 @@ class FederationHandler(BaseHandler):
             try:
                 self.auth.check(room_version, event, auth_events=current_auth_events)
             except AuthError as e:
-                logger.warn(
-                    "Soft-failing %r because %s",
-                    event, e,
-                )
+                logger.warn("Soft-failing %r because %s", event, e)
                 event.internal_metadata.soft_failed = True
 
     @defer.inlineCallbacks
-    def on_query_auth(self, origin, event_id, room_id, remote_auth_chain, rejects,
-                      missing):
-        in_room = yield self.auth.check_host_in_room(
-            room_id,
-            origin
-        )
+    def on_query_auth(
+        self, origin, event_id, room_id, remote_auth_chain, rejects, missing
+    ):
+        in_room = yield self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
@@ -1967,28 +1908,23 @@ class FederationHandler(BaseHandler):
 
         # Now get the current auth_chain for the event.
         local_auth_chain = yield self.store.get_auth_chain(
-            [auth_id for auth_id in event.auth_event_ids()],
-            include_given=True
+            [auth_id for auth_id in event.auth_event_ids()], include_given=True
         )
 
         # TODO: Check if we would now reject event_id. If so we need to tell
         # everyone.
 
-        ret = yield self.construct_auth_difference(
-            local_auth_chain, remote_auth_chain
-        )
+        ret = yield self.construct_auth_difference(local_auth_chain, remote_auth_chain)
 
         logger.debug("on_query_auth returning: %s", ret)
 
         defer.returnValue(ret)
 
     @defer.inlineCallbacks
-    def on_get_missing_events(self, origin, room_id, earliest_events,
-                              latest_events, limit):
-        in_room = yield self.auth.check_host_in_room(
-            room_id,
-            origin
-        )
+    def on_get_missing_events(
+        self, origin, room_id, earliest_events, latest_events, limit
+    ):
+        in_room = yield self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
@@ -2002,7 +1938,7 @@ class FederationHandler(BaseHandler):
         )
 
         missing_events = yield filter_events_for_server(
-            self.store, origin, missing_events,
+            self.store, origin, missing_events
         )
 
         defer.returnValue(missing_events)
@@ -2090,25 +2026,17 @@ class FederationHandler(BaseHandler):
 
         if missing_auth:
             # TODO: can we use store.have_seen_events here instead?
-            have_events = yield self.store.get_seen_events_with_rejections(
-                missing_auth
-            )
+            have_events = yield self.store.get_seen_events_with_rejections(missing_auth)
             logger.debug("Got events %s from store", have_events)
             missing_auth.difference_update(have_events.keys())
         else:
             have_events = {}
 
-        have_events.update({
-            e.event_id: ""
-            for e in auth_events.values()
-        })
+        have_events.update({e.event_id: "" for e in auth_events.values()})
 
         if missing_auth:
             # If we don't have all the auth events, we need to get them.
-            logger.info(
-                "auth_events contains unknown events: %s",
-                missing_auth,
-            )
+            logger.info("auth_events contains unknown events: %s", missing_auth)
             try:
                 try:
                     remote_auth_chain = yield self.federation_client.get_event_auth(
@@ -2134,18 +2062,16 @@ class FederationHandler(BaseHandler):
                     try:
                         auth_ids = e.auth_event_ids()
                         auth = {
-                            (e.type, e.state_key): e for e in remote_auth_chain
+                            (e.type, e.state_key): e
+                            for e in remote_auth_chain
                             if e.event_id in auth_ids or e.type == EventTypes.Create
                         }
                         e.internal_metadata.outlier = True
 
                         logger.debug(
-                            "do_auth %s missing_auth: %s",
-                            event.event_id, e.event_id
-                        )
-                        yield self._handle_new_event(
-                            origin, e, auth_events=auth
+                            "do_auth %s missing_auth: %s", event.event_id, e.event_id
                         )
+                        yield self._handle_new_event(origin, e, auth_events=auth)
 
                         if e.event_id in event_auth_events:
                             auth_events[(e.type, e.state_key)] = e
@@ -2181,35 +2107,36 @@ class FederationHandler(BaseHandler):
         room_version = yield self.store.get_room_version(event.room_id)
 
         different_events = yield logcontext.make_deferred_yieldable(
-            defer.gatherResults([
-                logcontext.run_in_background(
-                    self.store.get_event,
-                    d,
-                    allow_none=True,
-                    allow_rejected=False,
-                )
-                for d in different_auth
-                if d in have_events and not have_events[d]
-            ], consumeErrors=True)
+            defer.gatherResults(
+                [
+                    logcontext.run_in_background(
+                        self.store.get_event, d, allow_none=True, allow_rejected=False
+                    )
+                    for d in different_auth
+                    if d in have_events and not have_events[d]
+                ],
+                consumeErrors=True,
+            )
         ).addErrback(unwrapFirstError)
 
         if different_events:
             local_view = dict(auth_events)
             remote_view = dict(auth_events)
-            remote_view.update({
-                (d.type, d.state_key): d for d in different_events if d
-            })
+            remote_view.update(
+                {(d.type, d.state_key): d for d in different_events if d}
+            )
 
             new_state = yield self.state_handler.resolve_events(
                 room_version,
                 [list(local_view.values()), list(remote_view.values())],
-                event
+                event,
             )
 
             logger.info(
                 "After state res: updating auth_events with new state %s",
                 {
-                    (d.type, d.state_key): d.event_id for d in new_state.values()
+                    (d.type, d.state_key): d.event_id
+                    for d in new_state.values()
                     if auth_events.get((d.type, d.state_key)) != d
                 },
             )
@@ -2221,7 +2148,7 @@ class FederationHandler(BaseHandler):
             )
 
             yield self._update_context_for_auth_events(
-                event, context, auth_events, event_key,
+                event, context, auth_events, event_key
             )
 
         if not different_auth:
@@ -2255,21 +2182,14 @@ class FederationHandler(BaseHandler):
         prev_state_ids = yield context.get_prev_state_ids(self.store)
 
         # 1. Get what we think is the auth chain.
-        auth_ids = yield self.auth.compute_auth_events(
-            event, prev_state_ids
-        )
-        local_auth_chain = yield self.store.get_auth_chain(
-            auth_ids, include_given=True
-        )
+        auth_ids = yield self.auth.compute_auth_events(event, prev_state_ids)
+        local_auth_chain = yield self.store.get_auth_chain(auth_ids, include_given=True)
 
         try:
             # 2. Get remote difference.
             try:
                 result = yield self.federation_client.query_auth(
-                    origin,
-                    event.room_id,
-                    event.event_id,
-                    local_auth_chain,
+                    origin, event.room_id, event.event_id, local_auth_chain
                 )
             except RequestSendFailed as e:
                 # The other side isn't around or doesn't implement the
@@ -2294,19 +2214,15 @@ class FederationHandler(BaseHandler):
                     auth = {
                         (e.type, e.state_key): e
                         for e in result["auth_chain"]
-                        if e.event_id in auth_ids
-                        or event.type == EventTypes.Create
+                        if e.event_id in auth_ids or event.type == EventTypes.Create
                     }
                     ev.internal_metadata.outlier = True
 
                     logger.debug(
-                        "do_auth %s different_auth: %s",
-                        event.event_id, e.event_id
+                        "do_auth %s different_auth: %s", event.event_id, e.event_id
                     )
 
-                    yield self._handle_new_event(
-                        origin, ev, auth_events=auth
-                    )
+                    yield self._handle_new_event(origin, ev, auth_events=auth)
 
                     if ev.event_id in event_auth_events:
                         auth_events[(ev.type, ev.state_key)] = ev
@@ -2321,12 +2237,11 @@ class FederationHandler(BaseHandler):
         # TODO.
 
         yield self._update_context_for_auth_events(
-            event, context, auth_events, event_key,
+            event, context, auth_events, event_key
         )
 
     @defer.inlineCallbacks
-    def _update_context_for_auth_events(self, event, context, auth_events,
-                                        event_key):
+    def _update_context_for_auth_events(self, event, context, auth_events, event_key):
         """Update the state_ids in an event context after auth event resolution,
         storing the changes as a new state group.
 
@@ -2343,8 +2258,7 @@ class FederationHandler(BaseHandler):
                 this will not be included in the current_state in the context.
         """
         state_updates = {
-            k: a.event_id for k, a in iteritems(auth_events)
-            if k != event_key
+            k: a.event_id for k, a in iteritems(auth_events) if k != event_key
         }
         current_state_ids = yield context.get_current_state_ids(self.store)
         current_state_ids = dict(current_state_ids)
@@ -2354,9 +2268,7 @@ class FederationHandler(BaseHandler):
         prev_state_ids = yield context.get_prev_state_ids(self.store)
         prev_state_ids = dict(prev_state_ids)
 
-        prev_state_ids.update({
-            k: a.event_id for k, a in iteritems(auth_events)
-        })
+        prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
 
         # create a new state group as a delta from the existing one.
         prev_group = context.state_group
@@ -2505,30 +2417,23 @@ class FederationHandler(BaseHandler):
 
         logger.debug("construct_auth_difference returning")
 
-        defer.returnValue({
-            "auth_chain": local_auth,
-            "rejects": {
-                e.event_id: {
-                    "reason": reason_map[e.event_id],
-                    "proof": None,
-                }
-                for e in base_remote_rejected
-            },
-            "missing": [e.event_id for e in missing_locals],
-        })
+        defer.returnValue(
+            {
+                "auth_chain": local_auth,
+                "rejects": {
+                    e.event_id: {"reason": reason_map[e.event_id], "proof": None}
+                    for e in base_remote_rejected
+                },
+                "missing": [e.event_id for e in missing_locals],
+            }
+        )
 
     @defer.inlineCallbacks
     @log_function
     def exchange_third_party_invite(
-            self,
-            sender_user_id,
-            target_user_id,
-            room_id,
-            signed,
+        self, sender_user_id, target_user_id, room_id, signed
     ):
-        third_party_invite = {
-            "signed": signed,
-        }
+        third_party_invite = {"signed": signed}
 
         event_dict = {
             "type": EventTypes.Member,
@@ -2550,6 +2455,18 @@ class FederationHandler(BaseHandler):
                 builder=builder
             )
 
+            event_allowed = yield self.third_party_event_rules.check_event_allowed(
+                event, context
+            )
+            if not event_allowed:
+                logger.info(
+                    "Creation of threepid invite %s forbidden by third-party rules",
+                    event,
+                )
+                raise SynapseError(
+                    403, "This event is not allowed in this context", Codes.FORBIDDEN
+                )
+
             event, context = yield self.add_display_name_to_third_party_invite(
                 room_version, event_dict, event, context
             )
@@ -2572,9 +2489,7 @@ class FederationHandler(BaseHandler):
         else:
             destinations = set(x.split(":", 1)[-1] for x in (sender_user_id, room_id))
             yield self.federation_client.forward_third_party_invite(
-                destinations,
-                room_id,
-                event_dict,
+                destinations, room_id, event_dict
             )
 
     @defer.inlineCallbacks
@@ -2595,9 +2510,20 @@ class FederationHandler(BaseHandler):
         builder = self.event_builder_factory.new(room_version, event_dict)
 
         event, context = yield self.event_creation_handler.create_new_client_event(
-            builder=builder,
+            builder=builder
         )
 
+        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+            event, context
+        )
+        if not event_allowed:
+            logger.warning(
+                "Exchange of threepid invite %s forbidden by third-party rules", event
+            )
+            raise SynapseError(
+                403, "This event is not allowed in this context", Codes.FORBIDDEN
+            )
+
         event, context = yield self.add_display_name_to_third_party_invite(
             room_version, event_dict, event, context
         )
@@ -2613,21 +2539,16 @@ class FederationHandler(BaseHandler):
         # though the sender isn't a local user.
         event.internal_metadata.send_on_behalf_of = get_domain_from_id(event.sender)
 
-        # XXX we send the invite here, but send_membership_event also sends it,
-        # so we end up making two requests. I think this is redundant.
-        returned_invite = yield self.send_invite(origin, event)
-        # TODO: Make sure the signatures actually are correct.
-        event.signatures.update(returned_invite.signatures)
-
         member_handler = self.hs.get_room_member_handler()
         yield member_handler.send_membership_event(None, event, context)
 
     @defer.inlineCallbacks
-    def add_display_name_to_third_party_invite(self, room_version, event_dict,
-                                               event, context):
+    def add_display_name_to_third_party_invite(
+        self, room_version, event_dict, event, context
+    ):
         key = (
             EventTypes.ThirdPartyInvite,
-            event.content["third_party_invite"]["signed"]["token"]
+            event.content["third_party_invite"]["signed"]["token"],
         )
         original_invite = None
         prev_state_ids = yield context.get_prev_state_ids(self.store)
@@ -2641,8 +2562,7 @@ class FederationHandler(BaseHandler):
             event_dict["content"]["third_party_invite"]["display_name"] = display_name
         else:
             logger.info(
-                "Could not find invite event for third_party_invite: %r",
-                event_dict
+                "Could not find invite event for third_party_invite: %r", event_dict
             )
             # We don't discard here as this is not the appropriate place to do
             # auth checks. If we need the invite and don't have it then the
@@ -2651,7 +2571,7 @@ class FederationHandler(BaseHandler):
         builder = self.event_builder_factory.new(room_version, event_dict)
         EventValidator().validate_builder(builder)
         event, context = yield self.event_creation_handler.create_new_client_event(
-            builder=builder,
+            builder=builder
         )
         EventValidator().validate_new(event)
         defer.returnValue((event, context))
@@ -2675,9 +2595,7 @@ class FederationHandler(BaseHandler):
         token = signed["token"]
 
         prev_state_ids = yield context.get_prev_state_ids(self.store)
-        invite_event_id = prev_state_ids.get(
-            (EventTypes.ThirdPartyInvite, token,)
-        )
+        invite_event_id = prev_state_ids.get((EventTypes.ThirdPartyInvite, token))
 
         invite_event = None
         if invite_event_id:
@@ -2686,25 +2604,59 @@ class FederationHandler(BaseHandler):
         if not invite_event:
             raise AuthError(403, "Could not find invite")
 
+        logger.debug("Checking auth on event %r", event.content)
+
         last_exception = None
+        # for each public key in the 3pid invite event
         for public_key_object in self.hs.get_auth().get_public_keys(invite_event):
             try:
+                # for each sig on the third_party_invite block of the actual invite
                 for server, signature_block in signed["signatures"].items():
                     for key_name, encoded_signature in signature_block.items():
                         if not key_name.startswith("ed25519:"):
                             continue
 
-                        public_key = public_key_object["public_key"]
-                        verify_key = decode_verify_key_bytes(
+                        logger.debug(
+                            "Attempting to verify sig with key %s from %r "
+                            "against pubkey %r",
                             key_name,
-                            decode_base64(public_key)
+                            server,
+                            public_key_object,
                         )
-                        verify_signed_json(signed, server, verify_key)
-                        if "key_validity_url" in public_key_object:
-                            yield self._check_key_revocation(
-                                public_key,
-                                public_key_object["key_validity_url"]
+
+                        try:
+                            public_key = public_key_object["public_key"]
+                            verify_key = decode_verify_key_bytes(
+                                key_name, decode_base64(public_key)
+                            )
+                            verify_signed_json(signed, server, verify_key)
+                            logger.debug(
+                                "Successfully verified sig with key %s from %r "
+                                "against pubkey %r",
+                                key_name,
+                                server,
+                                public_key_object,
+                            )
+                        except Exception:
+                            logger.info(
+                                "Failed to verify sig with key %s from %r "
+                                "against pubkey %r",
+                                key_name,
+                                server,
+                                public_key_object,
+                            )
+                            raise
+                        try:
+                            if "key_validity_url" in public_key_object:
+                                yield self._check_key_revocation(
+                                    public_key, public_key_object["key_validity_url"]
+                                )
+                        except Exception:
+                            logger.info(
+                                "Failed to query key_validity_url %s",
+                                public_key_object["key_validity_url"],
                             )
+                            raise
                         return
             except Exception as e:
                 last_exception = e
@@ -2725,15 +2677,9 @@ class FederationHandler(BaseHandler):
                 for revocation.
         """
         try:
-            response = yield self.http_client.get_json(
-                url,
-                {"public_key": public_key}
-            )
+            response = yield self.http_client.get_json(url, {"public_key": public_key})
         except Exception:
-            raise SynapseError(
-                502,
-                "Third party certificate could not be checked"
-            )
+            raise SynapseError(502, "Third party certificate could not be checked")
         if "valid" not in response or not response["valid"]:
             raise AuthError(403, "Third party certificate was invalid")
 
@@ -2754,12 +2700,11 @@ class FederationHandler(BaseHandler):
             yield self._send_events_to_master(
                 store=self.store,
                 event_and_contexts=event_and_contexts,
-                backfilled=backfilled
+                backfilled=backfilled,
             )
         else:
             max_stream_id = yield self.store.persist_events(
-                event_and_contexts,
-                backfilled=backfilled,
+                event_and_contexts, backfilled=backfilled
             )
 
             if not backfilled:  # Never notify for backfilled events
@@ -2793,13 +2738,10 @@ class FederationHandler(BaseHandler):
 
         event_stream_id = event.internal_metadata.stream_ordering
         self.notifier.on_new_room_event(
-            event, event_stream_id, max_stream_id,
-            extra_users=extra_users
+            event, event_stream_id, max_stream_id, extra_users=extra_users
         )
 
-        return self.pusher_pool.on_new_notifications(
-            event_stream_id, max_stream_id,
-        )
+        return self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
 
     def _clean_room_for_join(self, room_id):
         """Called to clean up any data in DB for a given room, ready for the
@@ -2818,9 +2760,7 @@ class FederationHandler(BaseHandler):
         """
         if self.config.worker_app:
             return self._notify_user_membership_change(
-                room_id=room_id,
-                user_id=user.to_string(),
-                change="joined",
+                room_id=room_id, user_id=user.to_string(), change="joined"
             )
         else:
             return user_joined_room(self.distributor, user, room_id)