summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2020-02-04 12:06:18 +0000
committerGitHub <noreply@github.com>2020-02-04 12:06:18 +0000
commitc7d6d5c69eb9036749d57b2d299ff1f1a50d0768 (patch)
treebd549f6df389c645263e4ab7b2ea6dc21f972916
parentAdd typing to SyncHandler (#6821) (diff)
parentmake FederationHandler.send_invite async (diff)
downloadsynapse-c7d6d5c69eb9036749d57b2d299ff1f1a50d0768.tar.xz
Merge pull request #6837 from matrix-org/rav/federation_async
Port much of `synapse.handlers.federation` to async/await.
-rw-r--r--changelog.d/6837.misc1
-rw-r--r--synapse/handlers/federation.py429
-rw-r--r--synapse/handlers/message.py5
-rw-r--r--synapse/handlers/room_member.py12
4 files changed, 212 insertions, 235 deletions
diff --git a/changelog.d/6837.misc b/changelog.d/6837.misc
new file mode 100644
index 0000000000..0496f12de8
--- /dev/null
+++ b/changelog.d/6837.misc
@@ -0,0 +1 @@
+Port much of `synapse.handlers.federation` to async/await.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c86d3177e9..5728ea2ee7 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -65,7 +65,7 @@ from synapse.replication.http.federation import (
 from synapse.replication.http.membership import ReplicationUserJoinedLeftRoomRestServlet
 from synapse.state import StateResolutionStore, resolve_events_with_store
 from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
-from synapse.types import StateMap, UserID, get_domain_from_id
+from synapse.types import JsonDict, StateMap, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer, concurrently_execute
 from synapse.util.distributor import user_joined_room
 from synapse.util.retryutils import NotRetryingDestination
@@ -1184,13 +1184,12 @@ class FederationHandler(BaseHandler):
             )
             raise SynapseError(http_client.BAD_REQUEST, "Too many auth_events")
 
-    @defer.inlineCallbacks
-    def send_invite(self, target_host, event):
+    async def send_invite(self, target_host, event):
         """ Sends the invite to the remote server for signing.
 
         Invites must be signed by the invitee's server before distribution.
         """
-        pdu = yield self.federation_client.send_invite(
+        pdu = await self.federation_client.send_invite(
             destination=target_host,
             room_id=event.room_id,
             event_id=event.event_id,
@@ -1199,17 +1198,16 @@ class FederationHandler(BaseHandler):
 
         return pdu
 
-    @defer.inlineCallbacks
-    def on_event_auth(self, event_id):
-        event = yield self.store.get_event(event_id)
-        auth = yield self.store.get_auth_chain(
+    async def on_event_auth(self, event_id: str) -> List[EventBase]:
+        event = await self.store.get_event(event_id)
+        auth = await self.store.get_auth_chain(
             [auth_id for auth_id in event.auth_event_ids()], include_given=True
         )
-        return [e for e in auth]
+        return list(auth)
 
-    @log_function
-    @defer.inlineCallbacks
-    def do_invite_join(self, target_hosts, room_id, joinee, content):
+    async def do_invite_join(
+        self, target_hosts: Iterable[str], room_id: str, joinee: str, content: JsonDict
+    ) -> None:
         """ Attempts to join the `joinee` to the room `room_id` via the
         servers contained in `target_hosts`.
 
@@ -1222,17 +1220,17 @@ class FederationHandler(BaseHandler):
         have finished processing the join.
 
         Args:
-            target_hosts (Iterable[str]): List of servers to attempt to join the room with.
+            target_hosts: List of servers to attempt to join the room with.
 
-            room_id (str): The ID of the room to join.
+            room_id: The ID of the room to join.
 
-            joinee (str): The User ID of the joining user.
+            joinee: The User ID of the joining user.
 
-            content (dict): The event content to use for the join event.
+            content: The event content to use for the join event.
         """
         logger.debug("Joining %s to %s", joinee, room_id)
 
-        origin, event, room_version_obj = yield self._make_and_verify_event(
+        origin, event, room_version_obj = await self._make_and_verify_event(
             target_hosts,
             room_id,
             joinee,
@@ -1248,7 +1246,7 @@ class FederationHandler(BaseHandler):
 
         self.room_queues[room_id] = []
 
-        yield self._clean_room_for_join(room_id)
+        await self._clean_room_for_join(room_id)
 
         handled_events = set()
 
@@ -1262,7 +1260,7 @@ class FederationHandler(BaseHandler):
                 pass
 
             event_format_version = room_version_obj.event_format
-            ret = yield self.federation_client.send_join(
+            ret = await self.federation_client.send_join(
                 target_hosts, event, event_format_version
             )
 
@@ -1281,7 +1279,7 @@ class FederationHandler(BaseHandler):
             logger.debug("do_invite_join event: %s", event)
 
             try:
-                yield self.store.store_room(
+                await self.store.store_room(
                     room_id=room_id,
                     room_creator_user_id="",
                     is_public=False,
@@ -1291,13 +1289,13 @@ class FederationHandler(BaseHandler):
                 # FIXME
                 pass
 
-            yield self._persist_auth_tree(
+            await self._persist_auth_tree(
                 origin, auth_chain, state, event, room_version_obj
             )
 
             # Check whether this room is the result of an upgrade of a room we already know
             # about. If so, migrate over user information
-            predecessor = yield self.store.get_room_predecessor(room_id)
+            predecessor = await self.store.get_room_predecessor(room_id)
             if not predecessor or not isinstance(predecessor.get("room_id"), str):
                 return
             old_room_id = predecessor["room_id"]
@@ -1307,7 +1305,7 @@ class FederationHandler(BaseHandler):
 
             # We retrieve the room member handler here as to not cause a cyclic dependency
             member_handler = self.hs.get_room_member_handler()
-            yield member_handler.transfer_room_state_on_room_upgrade(
+            await member_handler.transfer_room_state_on_room_upgrade(
                 old_room_id, room_id
             )
 
@@ -1324,8 +1322,6 @@ class FederationHandler(BaseHandler):
 
             run_in_background(self._handle_queued_pdus, room_queue)
 
-        return True
-
     async def _handle_queued_pdus(self, room_queue):
         """Process PDUs which got queued up while we were busy send_joining.
 
@@ -1348,20 +1344,17 @@ class FederationHandler(BaseHandler):
                     "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
                 )
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_make_join_request(self, origin, room_id, user_id):
+    async def on_make_join_request(
+        self, origin: str, room_id: str, user_id: str
+    ) -> EventBase:
         """ We've received a /make_join/ request, so we create a partial
         join event for the room and return that. We do *not* persist or
         process it until the other server has signed it and sent it back.
 
         Args:
-            origin (str): The (verified) server name of the requesting server.
-            room_id (str): Room to create join event in
-            user_id (str): The user to create the join for
-
-        Returns:
-            Deferred[FrozenEvent]
+            origin: The (verified) server name of the requesting server.
+            room_id: Room to create join event in
+            user_id: The user to create the join for
         """
         if get_domain_from_id(user_id) != origin:
             logger.info(
@@ -1373,7 +1366,7 @@ class FederationHandler(BaseHandler):
 
         event_content = {"membership": Membership.JOIN}
 
-        room_version = yield self.store.get_room_version_id(room_id)
+        room_version = await self.store.get_room_version_id(room_id)
 
         builder = self.event_builder_factory.new(
             room_version,
@@ -1387,14 +1380,14 @@ class FederationHandler(BaseHandler):
         )
 
         try:
-            event, context = yield self.event_creation_handler.create_new_client_event(
+            event, context = await self.event_creation_handler.create_new_client_event(
                 builder=builder
             )
         except AuthError as e:
             logger.warning("Failed to create join to %s because %s", room_id, e)
             raise e
 
-        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+        event_allowed = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -1405,15 +1398,13 @@ class FederationHandler(BaseHandler):
 
         # The remote hasn't signed it yet, obviously. We'll do the full checks
         # when we get the event back in `on_send_join_request`
-        yield self.auth.check_from_context(
+        await self.auth.check_from_context(
             room_version, event, context, do_sig_check=False
         )
 
         return event
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_send_join_request(self, origin, pdu):
+    async def on_send_join_request(self, origin, pdu):
         """ We have received a join event for a room. Fully process it and
         respond with the current state and auth chains.
         """
@@ -1450,9 +1441,9 @@ class FederationHandler(BaseHandler):
         # would introduce the danger of backwards-compatibility problems.
         event.internal_metadata.send_on_behalf_of = origin
 
-        context = yield self._handle_new_event(origin, event)
+        context = await self._handle_new_event(origin, event)
 
-        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+        event_allowed = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -1470,19 +1461,18 @@ class FederationHandler(BaseHandler):
         if event.type == EventTypes.Member:
             if event.content["membership"] == Membership.JOIN:
                 user = UserID.from_string(event.state_key)
-                yield self.user_joined_room(user, event.room_id)
+                await self.user_joined_room(user, event.room_id)
 
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = await context.get_prev_state_ids()
 
         state_ids = list(prev_state_ids.values())
-        auth_chain = yield self.store.get_auth_chain(state_ids)
+        auth_chain = await self.store.get_auth_chain(state_ids)
 
-        state = yield self.store.get_events(list(prev_state_ids.values()))
+        state = await self.store.get_events(list(prev_state_ids.values()))
 
         return {"state": list(state.values()), "auth_chain": auth_chain}
 
-    @defer.inlineCallbacks
-    def on_invite_request(
+    async def on_invite_request(
         self, origin: str, event: EventBase, room_version: RoomVersion
     ):
         """ We've got an invite event. Process and persist it. Sign it.
@@ -1492,7 +1482,7 @@ class FederationHandler(BaseHandler):
         if event.state_key is None:
             raise SynapseError(400, "The invite event did not have a state key")
 
-        is_blocked = yield self.store.is_room_blocked(event.room_id)
+        is_blocked = await self.store.is_room_blocked(event.room_id)
         if is_blocked:
             raise SynapseError(403, "This room has been blocked on this server")
 
@@ -1535,14 +1525,15 @@ class FederationHandler(BaseHandler):
             )
         )
 
-        context = yield self.state_handler.compute_event_context(event)
-        yield self.persist_events_and_notify([(event, context)])
+        context = await self.state_handler.compute_event_context(event)
+        await self.persist_events_and_notify([(event, context)])
 
         return event
 
-    @defer.inlineCallbacks
-    def do_remotely_reject_invite(self, target_hosts, room_id, user_id, content):
-        origin, event, room_version = yield self._make_and_verify_event(
+    async def do_remotely_reject_invite(
+        self, target_hosts: Iterable[str], room_id: str, user_id: str, content: JsonDict
+    ) -> EventBase:
+        origin, event, room_version = await self._make_and_verify_event(
             target_hosts, room_id, user_id, "leave", content=content
         )
         # Mark as outlier as we don't have any state for this event; we're not
@@ -1558,22 +1549,27 @@ class FederationHandler(BaseHandler):
         except ValueError:
             pass
 
-        yield self.federation_client.send_leave(target_hosts, event)
+        await self.federation_client.send_leave(target_hosts, event)
 
-        context = yield self.state_handler.compute_event_context(event)
-        yield self.persist_events_and_notify([(event, context)])
+        context = await self.state_handler.compute_event_context(event)
+        await self.persist_events_and_notify([(event, context)])
 
         return event
 
-    @defer.inlineCallbacks
-    def _make_and_verify_event(
-        self, target_hosts, room_id, user_id, membership, content={}, params=None
-    ):
+    async def _make_and_verify_event(
+        self,
+        target_hosts: Iterable[str],
+        room_id: str,
+        user_id: str,
+        membership: str,
+        content: JsonDict = {},
+        params: Optional[Dict[str, str]] = None,
+    ) -> Tuple[str, EventBase, RoomVersion]:
         (
             origin,
             event,
             room_version,
-        ) = yield self.federation_client.make_membership_event(
+        ) = await self.federation_client.make_membership_event(
             target_hosts, room_id, user_id, membership, content, params=params
         )
 
@@ -1587,20 +1583,17 @@ class FederationHandler(BaseHandler):
         assert event.room_id == room_id
         return origin, event, room_version
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_make_leave_request(self, origin, room_id, user_id):
+    async def on_make_leave_request(
+        self, origin: str, room_id: str, user_id: str
+    ) -> EventBase:
         """ We've received a /make_leave/ request, so we create a partial
         leave event for the room and return that. We do *not* persist or
         process it until the other server has signed it and sent it back.
 
         Args:
-            origin (str): The (verified) server name of the requesting server.
-            room_id (str): Room to create leave event in
-            user_id (str): The user to create the leave for
-
-        Returns:
-            Deferred[FrozenEvent]
+            origin: The (verified) server name of the requesting server.
+            room_id: Room to create leave event in
+            user_id: The user to create the leave for
         """
         if get_domain_from_id(user_id) != origin:
             logger.info(
@@ -1610,7 +1603,7 @@ class FederationHandler(BaseHandler):
             )
             raise SynapseError(403, "User not from origin", Codes.FORBIDDEN)
 
-        room_version = yield self.store.get_room_version_id(room_id)
+        room_version = await self.store.get_room_version_id(room_id)
         builder = self.event_builder_factory.new(
             room_version,
             {
@@ -1622,11 +1615,11 @@ class FederationHandler(BaseHandler):
             },
         )
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
+        event, context = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
 
-        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+        event_allowed = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -1638,7 +1631,7 @@ class FederationHandler(BaseHandler):
         try:
             # The remote hasn't signed it yet, obviously. We'll do the full checks
             # when we get the event back in `on_send_leave_request`
-            yield self.auth.check_from_context(
+            await self.auth.check_from_context(
                 room_version, event, context, do_sig_check=False
             )
         except AuthError as e:
@@ -1647,9 +1640,7 @@ class FederationHandler(BaseHandler):
 
         return event
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_send_leave_request(self, origin, pdu):
+    async def on_send_leave_request(self, origin, pdu):
         """ We have received a leave event for a room. Fully process it."""
         event = pdu
 
@@ -1669,9 +1660,9 @@ class FederationHandler(BaseHandler):
 
         event.internal_metadata.outlier = False
 
-        context = yield self._handle_new_event(origin, event)
+        context = await self._handle_new_event(origin, event)
 
-        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+        event_allowed = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -1793,11 +1784,10 @@ class FederationHandler(BaseHandler):
     def get_min_depth_for_context(self, context):
         return self.store.get_min_depth(context)
 
-    @defer.inlineCallbacks
-    def _handle_new_event(
+    async def _handle_new_event(
         self, origin, event, state=None, auth_events=None, backfilled=False
     ):
-        context = yield self._prep_event(
+        context = await self._prep_event(
             origin, event, state=state, auth_events=auth_events, backfilled=backfilled
         )
 
@@ -1810,11 +1800,11 @@ class FederationHandler(BaseHandler):
                 and not backfilled
                 and not context.rejected
             ):
-                yield self.action_generator.handle_push_actions_for_event(
+                await self.action_generator.handle_push_actions_for_event(
                     event, context
                 )
 
-            yield self.persist_events_and_notify(
+            await self.persist_events_and_notify(
                 [(event, context)], backfilled=backfilled
             )
             success = True
@@ -1826,13 +1816,12 @@ class FederationHandler(BaseHandler):
 
         return context
 
-    @defer.inlineCallbacks
-    def _handle_new_events(
+    async def _handle_new_events(
         self,
         origin: str,
         event_infos: Iterable[_NewEventInfo],
         backfilled: bool = False,
-    ):
+    ) -> None:
         """Creates the appropriate contexts and persists events. The events
         should not depend on one another, e.g. this should be used to persist
         a bunch of outliers, but not a chunk of individual events that depend
@@ -1841,11 +1830,10 @@ class FederationHandler(BaseHandler):
         Notifies about the events where appropriate.
         """
 
-        @defer.inlineCallbacks
-        def prep(ev_info: _NewEventInfo):
+        async def prep(ev_info: _NewEventInfo):
             event = ev_info.event
             with nested_logging_context(suffix=event.event_id):
-                res = yield self._prep_event(
+                res = await self._prep_event(
                     origin,
                     event,
                     state=ev_info.state,
@@ -1854,14 +1842,14 @@ class FederationHandler(BaseHandler):
                 )
             return res
 
-        contexts = yield make_deferred_yieldable(
+        contexts = await make_deferred_yieldable(
             defer.gatherResults(
                 [run_in_background(prep, ev_info) for ev_info in event_infos],
                 consumeErrors=True,
             )
         )
 
-        yield self.persist_events_and_notify(
+        await self.persist_events_and_notify(
             [
                 (ev_info.event, context)
                 for ev_info, context in zip(event_infos, contexts)
@@ -1869,15 +1857,14 @@ class FederationHandler(BaseHandler):
             backfilled=backfilled,
         )
 
-    @defer.inlineCallbacks
-    def _persist_auth_tree(
+    async def _persist_auth_tree(
         self,
         origin: str,
         auth_events: List[EventBase],
         state: List[EventBase],
         event: EventBase,
         room_version: RoomVersion,
-    ):
+    ) -> None:
         """Checks the auth chain is valid (and passes auth checks) for the
         state and event. Then persists the auth chain and state atomically.
         Persists the event separately. Notifies about the persisted events
@@ -1892,14 +1879,11 @@ class FederationHandler(BaseHandler):
             event
             room_version: The room version we expect this room to have, and
                 will raise if it doesn't match the version in the create event.
-
-        Returns:
-            Deferred
         """
         events_to_context = {}
         for e in itertools.chain(auth_events, state):
             e.internal_metadata.outlier = True
-            ctx = yield self.state_handler.compute_event_context(e)
+            ctx = await self.state_handler.compute_event_context(e)
             events_to_context[e.event_id] = ctx
 
         event_map = {
@@ -1931,7 +1915,7 @@ class FederationHandler(BaseHandler):
                     missing_auth_events.add(e_id)
 
         for e_id in missing_auth_events:
-            m_ev = yield self.federation_client.get_pdu(
+            m_ev = await self.federation_client.get_pdu(
                 [origin],
                 e_id,
                 room_version=room_version.identifier,
@@ -1967,91 +1951,74 @@ class FederationHandler(BaseHandler):
                     raise
                 events_to_context[e.event_id].rejected = RejectedReason.AUTH_ERROR
 
-        yield self.persist_events_and_notify(
+        await self.persist_events_and_notify(
             [
                 (e, events_to_context[e.event_id])
                 for e in itertools.chain(auth_events, state)
             ]
         )
 
-        new_event_context = yield self.state_handler.compute_event_context(
+        new_event_context = await self.state_handler.compute_event_context(
             event, old_state=state
         )
 
-        yield self.persist_events_and_notify([(event, new_event_context)])
+        await self.persist_events_and_notify([(event, new_event_context)])
 
-    @defer.inlineCallbacks
-    def _prep_event(
+    async def _prep_event(
         self,
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
         auth_events: Optional[StateMap[EventBase]],
         backfilled: bool,
-    ):
-        """
-
-        Args:
-            origin:
-            event:
-            state:
-            auth_events:
-            backfilled:
-
-        Returns:
-            Deferred, which resolves to synapse.events.snapshot.EventContext
-        """
-        context = yield self.state_handler.compute_event_context(event, old_state=state)
+    ) -> EventContext:
+        context = await self.state_handler.compute_event_context(event, old_state=state)
 
         if not auth_events:
-            prev_state_ids = yield context.get_prev_state_ids()
-            auth_events_ids = yield self.auth.compute_auth_events(
+            prev_state_ids = await context.get_prev_state_ids()
+            auth_events_ids = await self.auth.compute_auth_events(
                 event, prev_state_ids, for_verification=True
             )
-            auth_events = yield self.store.get_events(auth_events_ids)
+            auth_events = await self.store.get_events(auth_events_ids)
             auth_events = {(e.type, e.state_key): e for e in auth_events.values()}
 
         # This is a hack to fix some old rooms where the initial join event
         # didn't reference the create event in its auth events.
         if event.type == EventTypes.Member and not event.auth_event_ids():
             if len(event.prev_event_ids()) == 1 and event.depth < 5:
-                c = yield self.store.get_event(
+                c = await self.store.get_event(
                     event.prev_event_ids()[0], allow_none=True
                 )
                 if c and c.type == EventTypes.Create:
                     auth_events[(c.type, c.state_key)] = c
 
-        context = yield self.do_auth(origin, event, context, auth_events=auth_events)
+        context = await self.do_auth(origin, event, context, auth_events=auth_events)
 
         if not context.rejected:
-            yield self._check_for_soft_fail(event, state, backfilled)
+            await self._check_for_soft_fail(event, state, backfilled)
 
         if event.type == EventTypes.GuestAccess and not context.rejected:
-            yield self.maybe_kick_guest_users(event)
+            await self.maybe_kick_guest_users(event)
 
         return context
 
-    @defer.inlineCallbacks
-    def _check_for_soft_fail(
+    async def _check_for_soft_fail(
         self, event: EventBase, state: Optional[Iterable[EventBase]], backfilled: bool
-    ):
-        """Checks if we should soft fail the event, if so marks the event as
+    ) -> None:
+        """Checks if we should soft fail the event; if so, marks the event as
         such.
 
         Args:
             event
             state: The state at the event if we don't have all the event's prev events
             backfilled: Whether the event is from backfill
-
-        Returns:
-            Deferred
         """
         # For new (non-backfilled and non-outlier) events we check if the event
         # passes auth based on the current state. If it doesn't then we
         # "soft-fail" the event.
         do_soft_fail_check = not backfilled and not event.internal_metadata.is_outlier()
         if do_soft_fail_check:
-            extrem_ids = yield self.store.get_latest_event_ids_in_room(event.room_id)
+            extrem_ids = await self.store.get_latest_event_ids_in_room(event.room_id)
 
             extrem_ids = set(extrem_ids)
             prev_event_ids = set(event.prev_event_ids())
@@ -2062,7 +2029,7 @@ class FederationHandler(BaseHandler):
                 do_soft_fail_check = False
 
         if do_soft_fail_check:
-            room_version = yield self.store.get_room_version_id(event.room_id)
+            room_version = await self.store.get_room_version_id(event.room_id)
             room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
             # Calculate the "current state".
@@ -2079,19 +2046,19 @@ class FederationHandler(BaseHandler):
                 # given state at the event. This should correctly handle cases
                 # like bans, especially with state res v2.
 
-                state_sets = yield self.state_store.get_state_groups(
+                state_sets = await self.state_store.get_state_groups(
                     event.room_id, extrem_ids
                 )
                 state_sets = list(state_sets.values())
                 state_sets.append(state)
-                current_state_ids = yield self.state_handler.resolve_events(
+                current_state_ids = await self.state_handler.resolve_events(
                     room_version, state_sets, event
                 )
                 current_state_ids = {
                     k: e.event_id for k, e in iteritems(current_state_ids)
                 }
             else:
-                current_state_ids = yield self.state_handler.get_current_state_ids(
+                current_state_ids = await self.state_handler.get_current_state_ids(
                     event.room_id, latest_event_ids=extrem_ids
                 )
 
@@ -2107,7 +2074,7 @@ class FederationHandler(BaseHandler):
                 e for k, e in iteritems(current_state_ids) if k in auth_types
             ]
 
-            current_auth_events = yield self.store.get_events(current_state_ids)
+            current_auth_events = await self.store.get_events(current_state_ids)
             current_auth_events = {
                 (e.type, e.state_key): e for e in current_auth_events.values()
             }
@@ -2120,15 +2087,14 @@ class FederationHandler(BaseHandler):
                 logger.warning("Soft-failing %r because %s", event, e)
                 event.internal_metadata.soft_failed = True
 
-    @defer.inlineCallbacks
-    def on_query_auth(
+    async def on_query_auth(
         self, origin, event_id, room_id, remote_auth_chain, rejects, missing
     ):
-        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, allow_none=False, check_room_id=room_id
         )
 
@@ -2136,57 +2102,60 @@ class FederationHandler(BaseHandler):
         # don't want to fall into the trap of `missing` being wrong.
         for e in remote_auth_chain:
             try:
-                yield self._handle_new_event(origin, e)
+                await self._handle_new_event(origin, e)
             except AuthError:
                 pass
 
         # Now get the current auth_chain for the event.
-        local_auth_chain = yield self.store.get_auth_chain(
+        local_auth_chain = await self.store.get_auth_chain(
             [auth_id for auth_id in event.auth_event_ids()], include_given=True
         )
 
         # TODO: Check if we would now reject event_id. If so we need to tell
         # everyone.
 
-        ret = yield self.construct_auth_difference(local_auth_chain, remote_auth_chain)
+        ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
 
         logger.debug("on_query_auth returning: %s", ret)
 
         return ret
 
-    @defer.inlineCallbacks
-    def on_get_missing_events(
+    async def on_get_missing_events(
         self, origin, room_id, earliest_events, latest_events, limit
     ):
-        in_room = yield self.auth.check_host_in_room(room_id, origin)
+        in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
         limit = min(limit, 20)
 
-        missing_events = yield self.store.get_missing_events(
+        missing_events = await self.store.get_missing_events(
             room_id=room_id,
             earliest_events=earliest_events,
             latest_events=latest_events,
             limit=limit,
         )
 
-        missing_events = yield filter_events_for_server(
+        missing_events = await filter_events_for_server(
             self.storage, origin, missing_events
         )
 
         return missing_events
 
-    @defer.inlineCallbacks
-    @log_function
-    def do_auth(self, origin, event, context, auth_events):
+    async def do_auth(
+        self,
+        origin: str,
+        event: EventBase,
+        context: EventContext,
+        auth_events: StateMap[EventBase],
+    ) -> EventContext:
         """
 
         Args:
-            origin (str):
-            event (synapse.events.EventBase):
-            context (synapse.events.snapshot.EventContext):
-            auth_events (dict[(str, str)->synapse.events.EventBase]):
+            origin:
+            event:
+            context:
+            auth_events:
                 Map from (event_type, state_key) to event
 
                 Normally, our calculated auth_events based on the state of the room
@@ -2196,13 +2165,13 @@ class FederationHandler(BaseHandler):
 
                 Also NB that this function adds entries to it.
         Returns:
-            defer.Deferred[EventContext]: updated context object
+            updated context object
         """
-        room_version = yield self.store.get_room_version_id(event.room_id)
+        room_version = await self.store.get_room_version_id(event.room_id)
         room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
 
         try:
-            context = yield self._update_auth_events_and_context_for_auth(
+            context = await self._update_auth_events_and_context_for_auth(
                 origin, event, context, auth_events
             )
         except Exception:
@@ -2224,10 +2193,13 @@ class FederationHandler(BaseHandler):
 
         return context
 
-    @defer.inlineCallbacks
-    def _update_auth_events_and_context_for_auth(
-        self, origin, event, context, auth_events
-    ):
+    async def _update_auth_events_and_context_for_auth(
+        self,
+        origin: str,
+        event: EventBase,
+        context: EventContext,
+        auth_events: StateMap[EventBase],
+    ) -> EventContext:
         """Helper for do_auth. See there for docs.
 
         Checks whether a given event has the expected auth events. If it
@@ -2235,16 +2207,16 @@ class FederationHandler(BaseHandler):
         we can come to a consensus (e.g. if one server missed some valid
         state).
 
-        This attempts to resovle any potential divergence of state between
+        This attempts to resolve any potential divergence of state between
         servers, but is not essential and so failures should not block further
         processing of the event.
 
         Args:
-            origin (str):
-            event (synapse.events.EventBase):
-            context (synapse.events.snapshot.EventContext):
+            origin:
+            event:
+            context:
 
-            auth_events (dict[(str, str)->synapse.events.EventBase]):
+            auth_events:
                 Map from (event_type, state_key) to event
 
                 Normally, our calculated auth_events based on the state of the room
@@ -2255,7 +2227,7 @@ class FederationHandler(BaseHandler):
                 Also NB that this function adds entries to it.
 
         Returns:
-            defer.Deferred[EventContext]: updated context
+            updated context
         """
         event_auth_events = set(event.auth_event_ids())
 
@@ -2269,7 +2241,7 @@ class FederationHandler(BaseHandler):
         #
         # we start by checking if they are in the store, and then try calling /event_auth/.
         if missing_auth:
-            have_events = yield self.store.have_seen_events(missing_auth)
+            have_events = await self.store.have_seen_events(missing_auth)
             logger.debug("Events %s are in the store", have_events)
             missing_auth.difference_update(have_events)
 
@@ -2278,7 +2250,7 @@ class FederationHandler(BaseHandler):
             logger.info("auth_events contains unknown events: %s", missing_auth)
             try:
                 try:
-                    remote_auth_chain = yield self.federation_client.get_event_auth(
+                    remote_auth_chain = await self.federation_client.get_event_auth(
                         origin, event.room_id, event.event_id
                     )
                 except RequestSendFailed as e:
@@ -2287,7 +2259,7 @@ class FederationHandler(BaseHandler):
                     logger.info("Failed to get event auth from remote: %s", e)
                     return context
 
-                seen_remotes = yield self.store.have_seen_events(
+                seen_remotes = await self.store.have_seen_events(
                     [e.event_id for e in remote_auth_chain]
                 )
 
@@ -2310,7 +2282,7 @@ class FederationHandler(BaseHandler):
                         logger.debug(
                             "do_auth %s missing_auth: %s", event.event_id, e.event_id
                         )
-                        yield self._handle_new_event(origin, e, auth_events=auth)
+                        await self._handle_new_event(origin, e, auth_events=auth)
 
                         if e.event_id in event_auth_events:
                             auth_events[(e.type, e.state_key)] = e
@@ -2344,7 +2316,7 @@ class FederationHandler(BaseHandler):
 
         # XXX: currently this checks for redactions but I'm not convinced that is
         # necessary?
-        different_events = yield self.store.get_events_as_list(different_auth)
+        different_events = await self.store.get_events_as_list(different_auth)
 
         for d in different_events:
             if d.room_id != event.room_id:
@@ -2370,8 +2342,8 @@ class FederationHandler(BaseHandler):
         remote_auth_events.update({(d.type, d.state_key): d for d in different_events})
         remote_state = remote_auth_events.values()
 
-        room_version = yield self.store.get_room_version_id(event.room_id)
-        new_state = yield self.state_handler.resolve_events(
+        room_version = await self.store.get_room_version_id(event.room_id)
+        new_state = await self.state_handler.resolve_events(
             room_version, (local_state, remote_state), event
         )
 
@@ -2386,27 +2358,27 @@ class FederationHandler(BaseHandler):
 
         auth_events.update(new_state)
 
-        context = yield self._update_context_for_auth_events(
+        context = await self._update_context_for_auth_events(
             event, context, auth_events
         )
 
         return context
 
-    @defer.inlineCallbacks
-    def _update_context_for_auth_events(self, event, context, auth_events):
+    async def _update_context_for_auth_events(
+        self, event: EventBase, context: EventContext, auth_events: StateMap[EventBase]
+    ) -> EventContext:
         """Update the state_ids in an event context after auth event resolution,
         storing the changes as a new state group.
 
         Args:
-            event (Event): The event we're handling the context for
+            event: The event we're handling the context for
 
-            context (synapse.events.snapshot.EventContext): initial event context
+            context: initial event context
 
-            auth_events (dict[(str, str)->EventBase]): Events to update in the event
-                context.
+            auth_events: Events to update in the event context.
 
         Returns:
-            Deferred[EventContext]: new event context
+            new event context
         """
         # exclude the state key of the new event from the current_state in the context.
         if event.is_state():
@@ -2417,19 +2389,19 @@ class FederationHandler(BaseHandler):
             k: a.event_id for k, a in iteritems(auth_events) if k != event_key
         }
 
-        current_state_ids = yield context.get_current_state_ids()
+        current_state_ids = await context.get_current_state_ids()
         current_state_ids = dict(current_state_ids)
 
         current_state_ids.update(state_updates)
 
-        prev_state_ids = yield context.get_prev_state_ids()
+        prev_state_ids = await context.get_prev_state_ids()
         prev_state_ids = dict(prev_state_ids)
 
         prev_state_ids.update({k: a.event_id for k, a in iteritems(auth_events)})
 
         # create a new state group as a delta from the existing one.
         prev_group = context.state_group
-        state_group = yield self.state_store.store_state_group(
+        state_group = await self.state_store.store_state_group(
             event.event_id,
             event.room_id,
             prev_group=prev_group,
@@ -2446,8 +2418,9 @@ class FederationHandler(BaseHandler):
             delta_ids=state_updates,
         )
 
-    @defer.inlineCallbacks
-    def construct_auth_difference(self, local_auth, remote_auth):
+    async def construct_auth_difference(
+        self, local_auth: Iterable[EventBase], remote_auth: Iterable[EventBase]
+    ) -> Dict:
         """ Given a local and remote auth chain, find the differences. This
         assumes that we have already processed all events in remote_auth
 
@@ -2556,7 +2529,7 @@ class FederationHandler(BaseHandler):
         reason_map = {}
 
         for e in base_remote_rejected:
-            reason = yield self.store.get_rejection_reason(e.event_id)
+            reason = await self.store.get_rejection_reason(e.event_id)
             if reason is None:
                 # TODO: e is not in the current state, so we should
                 # construct some proof of that.
@@ -2641,33 +2614,31 @@ class FederationHandler(BaseHandler):
                 destinations, room_id, event_dict
             )
 
-    @defer.inlineCallbacks
-    @log_function
-    def on_exchange_third_party_invite_request(self, room_id, event_dict):
+    async def on_exchange_third_party_invite_request(
+        self, room_id: str, event_dict: JsonDict
+    ) -> None:
         """Handle an exchange_third_party_invite request from a remote server
 
         The remote server will call this when it wants to turn a 3pid invite
         into a normal m.room.member invite.
 
         Args:
-            room_id (str): The ID of the room.
+            room_id: The ID of the room.
 
             event_dict (dict[str, Any]): Dictionary containing the event body.
 
-        Returns:
-            Deferred: resolves (to None)
         """
-        room_version = yield self.store.get_room_version_id(room_id)
+        room_version = await self.store.get_room_version_id(room_id)
 
         # NB: event_dict has a particular specced format we might need to fudge
         # if we change event formats too much.
         builder = self.event_builder_factory.new(room_version, event_dict)
 
-        event, context = yield self.event_creation_handler.create_new_client_event(
+        event, context = await self.event_creation_handler.create_new_client_event(
             builder=builder
         )
 
-        event_allowed = yield self.third_party_event_rules.check_event_allowed(
+        event_allowed = await self.third_party_event_rules.check_event_allowed(
             event, context
         )
         if not event_allowed:
@@ -2678,16 +2649,16 @@ class FederationHandler(BaseHandler):
                 403, "This event is not allowed in this context", Codes.FORBIDDEN
             )
 
-        event, context = yield self.add_display_name_to_third_party_invite(
+        event, context = await self.add_display_name_to_third_party_invite(
             room_version, event_dict, event, context
         )
 
         try:
-            yield self.auth.check_from_context(room_version, event, context)
+            await self.auth.check_from_context(room_version, event, context)
         except AuthError as e:
             logger.warning("Denying third party invite %r because %s", event, e)
             raise e
-        yield self._check_signature(event, context)
+        await self._check_signature(event, context)
 
         # We need to tell the transaction queue to send this out, even
         # though the sender isn't a local user.
@@ -2695,7 +2666,7 @@ class FederationHandler(BaseHandler):
 
         # We retrieve the room member handler here as to not cause a cyclic dependency
         member_handler = self.hs.get_room_member_handler()
-        yield member_handler.send_membership_event(None, event, context)
+        await member_handler.send_membership_event(None, event, context)
 
     @defer.inlineCallbacks
     def add_display_name_to_third_party_invite(
@@ -2843,27 +2814,27 @@ class FederationHandler(BaseHandler):
         if "valid" not in response or not response["valid"]:
             raise AuthError(403, "Third party certificate was invalid")
 
-    @defer.inlineCallbacks
-    def persist_events_and_notify(self, event_and_contexts, backfilled=False):
+    async def persist_events_and_notify(
+        self,
+        event_and_contexts: Sequence[Tuple[EventBase, EventContext]],
+        backfilled: bool = False,
+    ) -> None:
         """Persists events and tells the notifier/pushers about them, if
         necessary.
 
         Args:
-            event_and_contexts(list[tuple[FrozenEvent, EventContext]])
-            backfilled (bool): Whether these events are a result of
+            event_and_contexts:
+            backfilled: Whether these events are a result of
                 backfilling or not
-
-        Returns:
-            Deferred
         """
         if self.config.worker_app:
-            yield self._send_events_to_master(
+            await self._send_events_to_master(
                 store=self.store,
                 event_and_contexts=event_and_contexts,
                 backfilled=backfilled,
             )
         else:
-            max_stream_id = yield self.storage.persistence.persist_events(
+            max_stream_id = await self.storage.persistence.persist_events(
                 event_and_contexts, backfilled=backfilled
             )
 
@@ -2874,15 +2845,17 @@ class FederationHandler(BaseHandler):
 
             if not backfilled:  # Never notify for backfilled events
                 for event, _ in event_and_contexts:
-                    yield self._notify_persisted_event(event, max_stream_id)
+                    await self._notify_persisted_event(event, max_stream_id)
 
-    def _notify_persisted_event(self, event, max_stream_id):
+    async def _notify_persisted_event(
+        self, event: EventBase, max_stream_id: int
+    ) -> None:
         """Checks to see if notifier/pushers should be notified about the
         event or not.
 
         Args:
-            event (FrozenEvent)
-            max_stream_id (int): The max_stream_id returned by persist_events
+            event:
+            max_stream_id: The max_stream_id returned by persist_events
         """
 
         extra_users = []
@@ -2906,29 +2879,29 @@ class FederationHandler(BaseHandler):
             event, event_stream_id, max_stream_id, extra_users=extra_users
         )
 
-        return self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
+        await self.pusher_pool.on_new_notifications(event_stream_id, max_stream_id)
 
-    def _clean_room_for_join(self, room_id):
+    async def _clean_room_for_join(self, room_id: str) -> None:
         """Called to clean up any data in DB for a given room, ready for the
         server to join the room.
 
         Args:
-            room_id (str)
+            room_id
         """
         if self.config.worker_app:
-            return self._clean_room_for_join_client(room_id)
+            await self._clean_room_for_join_client(room_id)
         else:
-            return self.store.clean_room_for_join(room_id)
+            await self.store.clean_room_for_join(room_id)
 
-    def user_joined_room(self, user, room_id):
+    async def user_joined_room(self, user: UserID, room_id: str) -> None:
         """Called when a new user has joined the room
         """
         if self.config.worker_app:
-            return self._notify_user_membership_change(
+            await self._notify_user_membership_change(
                 room_id=room_id, user_id=user.to_string(), change="joined"
             )
         else:
-            return defer.succeed(user_joined_room(self.distributor, user, room_id))
+            user_joined_room(self.distributor, user, room_id)
 
     @defer.inlineCallbacks
     def get_room_complexity(self, remote_room_hosts, room_id):
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index bdf16c84d3..be6ae18a92 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -932,10 +932,9 @@ class EventCreationHandler(object):
                     # way? If we have been invited by a remote server, we need
                     # to get them to sign the event.
 
-                    returned_invite = yield federation_handler.send_invite(
-                        invitee.domain, event
+                    returned_invite = yield defer.ensureDeferred(
+                        federation_handler.send_invite(invitee.domain, event)
                     )
-
                     event.unsigned.pop("room_state", None)
 
                     # TODO: Make sure the signatures actually are correct.
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 15e8aa5249..4260426369 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -944,8 +944,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         # join dance for now, since we're kinda implicitly checking
         # that we are allowed to join when we decide whether or not we
         # need to do the invite/join dance.
-        yield self.federation_handler.do_invite_join(
-            remote_room_hosts, room_id, user.to_string(), content
+        yield defer.ensureDeferred(
+            self.federation_handler.do_invite_join(
+                remote_room_hosts, room_id, user.to_string(), content
+            )
         )
         yield self._user_joined_room(user, room_id)
 
@@ -982,8 +984,10 @@ class RoomMemberMasterHandler(RoomMemberHandler):
         """
         fed_handler = self.federation_handler
         try:
-            ret = yield fed_handler.do_remotely_reject_invite(
-                remote_room_hosts, room_id, target.to_string(), content=content,
+            ret = yield defer.ensureDeferred(
+                fed_handler.do_remotely_reject_invite(
+                    remote_room_hosts, room_id, target.to_string(), content=content,
+                )
             )
             return ret
         except Exception as e: