summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2019-12-11 16:36:06 +0000
committerGitHub <noreply@github.com>2019-12-11 16:36:06 +0000
commit894d2addacf4e7b9119784863781c893adbf34ff (patch)
tree6b4d9b35cf57a9a60c29daf8d09c7976544a7acf
parentMerge pull request #6504 from matrix-org/erikj/account_validity_async_await (diff)
parentchangelog (diff)
downloadsynapse-894d2addacf4e7b9119784863781c893adbf34ff.tar.xz
Merge pull request #6517 from matrix-org/rav/event_auth/13
Port some of FederationHandler to async/await
-rw-r--r--changelog.d/6517.misc1
-rw-r--r--synapse/handlers/federation.py159
-rw-r--r--synapse/handlers/pagination.py27
-rw-r--r--tests/test_federation.py14
4 files changed, 98 insertions, 103 deletions
diff --git a/changelog.d/6517.misc b/changelog.d/6517.misc
new file mode 100644
index 0000000000..c6ffed9952
--- /dev/null
+++ b/changelog.d/6517.misc
@@ -0,0 +1 @@
+Port some of FederationHandler to async/await.
\ No newline at end of file
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index cf9c46d027..bcd3b422aa 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -19,7 +19,7 @@
 
 import itertools
 import logging
-from typing import Dict, Iterable, Optional, Sequence, Tuple
+from typing import Dict, Iterable, List, Optional, Sequence, Tuple
 
 import six
 from six import iteritems, itervalues
@@ -165,8 +165,7 @@ class FederationHandler(BaseHandler):
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
-    @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+    async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
         """ Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
@@ -176,8 +175,6 @@ class FederationHandler(BaseHandler):
             pdu (FrozenEvent): received PDU
             sent_to_us_directly (bool): True if this event was pushed to us; False if
                 we pulled it as the result of a missing prev_event.
-
-        Returns (Deferred): completes with None
         """
 
         room_id = pdu.room_id
@@ -186,7 +183,7 @@ class FederationHandler(BaseHandler):
         logger.info("handling received PDU: %s", pdu)
 
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self.store.get_event(
+        existing = await self.store.get_event(
             event_id, allow_none=True, allow_rejected=True
         )
 
@@ -230,7 +227,7 @@ class FederationHandler(BaseHandler):
         #
         # Note that if we were never in the room then we would have already
         # dropped the event, since we wouldn't know the room version.
-        is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+        is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
         if not is_in_room:
             logger.info(
                 "[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -246,12 +243,12 @@ class FederationHandler(BaseHandler):
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
-            min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+            min_depth = await self.get_min_depth_for_context(pdu.room_id)
 
             logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
 
             prevs = set(pdu.prev_event_ids())
-            seen = yield self.store.have_seen_events(prevs)
+            seen = await self.store.have_seen_events(prevs)
 
             if min_depth and pdu.depth < min_depth:
                 # This is so that we don't notify the user about this
@@ -271,7 +268,7 @@ class FederationHandler(BaseHandler):
                         len(missing_prevs),
                         shortstr(missing_prevs),
                     )
-                    with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+                    with (await self._room_pdu_linearizer.queue(pdu.room_id)):
                         logger.info(
                             "[%s %s] Acquired room lock to fetch %d missing prev_events",
                             room_id,
@@ -280,7 +277,7 @@ class FederationHandler(BaseHandler):
                         )
 
                         try:
-                            yield self._get_missing_events_for_pdu(
+                            await self._get_missing_events_for_pdu(
                                 origin, pdu, prevs, min_depth
                             )
                         except Exception as e:
@@ -291,7 +288,7 @@ class FederationHandler(BaseHandler):
 
                         # Update the set of things we've seen after trying to
                         # fetch the missing stuff
-                        seen = yield self.store.have_seen_events(prevs)
+                        seen = await self.store.have_seen_events(prevs)
 
                         if not prevs - seen:
                             logger.info(
@@ -355,7 +352,7 @@ class FederationHandler(BaseHandler):
                 event_map = {event_id: pdu}
                 try:
                     # Get the state of the events we know about
-                    ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+                    ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
                     state_maps = list(
@@ -372,7 +369,7 @@ class FederationHandler(BaseHandler):
                             "Requesting state at missing prev_event %s", event_id,
                         )
 
-                        room_version = yield self.store.get_room_version(room_id)
+                        room_version = await self.store.get_room_version(room_id)
 
                         with nested_logging_context(p):
                             # note that if any of the missing prevs share missing state or
@@ -381,11 +378,11 @@ class FederationHandler(BaseHandler):
                             (
                                 remote_state,
                                 got_auth_chain,
-                            ) = yield self._get_state_for_room(origin, room_id, p)
+                            ) = await self._get_state_for_room(origin, room_id, p)
 
                             # we want the state *after* p; _get_state_for_room returns the
                             # state *before* p.
-                            remote_event = yield self.federation_client.get_pdu(
+                            remote_event = await self.federation_client.get_pdu(
                                 [origin], p, room_version, outlier=True
                             )
 
@@ -410,7 +407,7 @@ class FederationHandler(BaseHandler):
                             for x in remote_state:
                                 event_map[x.event_id] = x
 
-                    state_map = yield resolve_events_with_store(
+                    state_map = await resolve_events_with_store(
                         room_version,
                         state_maps,
                         event_map,
@@ -422,7 +419,7 @@ class FederationHandler(BaseHandler):
 
                     # First though we need to fetch all the events that are in
                     # state_map, so we can build up the state below.
-                    evs = yield self.store.get_events(
+                    evs = await self.store.get_events(
                         list(state_map.values()),
                         get_prev_content=False,
                         redact_behaviour=EventRedactBehaviour.AS_IS,
@@ -446,12 +443,11 @@ class FederationHandler(BaseHandler):
                         affected=event_id,
                     )
 
-        yield self._process_received_pdu(
+        await self._process_received_pdu(
             origin, pdu, state=state, auth_chain=auth_chain
         )
 
-    @defer.inlineCallbacks
-    def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+    async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
         """
         Args:
             origin (str): Origin of the pdu. Will be called to get the missing events
@@ -463,12 +459,12 @@ class FederationHandler(BaseHandler):
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        seen = yield self.store.have_seen_events(prevs)
+        seen = await self.store.have_seen_events(prevs)
 
         if not prevs - seen:
             return
 
-        latest = yield self.store.get_latest_event_ids_in_room(room_id)
+        latest = await self.store.get_latest_event_ids_in_room(room_id)
 
         # We add the prev events that we have seen to the latest
         # list to ensure the remote server doesn't give them to us
@@ -532,7 +528,7 @@ class FederationHandler(BaseHandler):
         # All that said: Let's try increasing the timout to 60s and see what happens.
 
         try:
-            missing_events = yield self.federation_client.get_missing_events(
+            missing_events = await self.federation_client.get_missing_events(
                 origin,
                 room_id,
                 earliest_events_ids=list(latest),
@@ -571,7 +567,7 @@ class FederationHandler(BaseHandler):
             )
             with nested_logging_context(ev.event_id):
                 try:
-                    yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+                    await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
                 except FederationError as e:
                     if e.code == 403:
                         logger.warning(
@@ -583,30 +579,30 @@ class FederationHandler(BaseHandler):
                     else:
                         raise
 
-    @defer.inlineCallbacks
     @log_function
-    def _get_state_for_room(self, destination, room_id, event_id):
+    async def _get_state_for_room(
+        self, destination: str, room_id: str, event_id: str
+    ) -> Tuple[List[EventBase], List[EventBase]]:
         """Requests all of the room state at a given event from a remote homeserver.
 
         Args:
-            destination (str): The remote homeserver to query for the state.
-            room_id (str): The id of the room we're interested in.
-            event_id (str): The id of the event we want the state at.
+            destination:: The remote homeserver to query for the state.
+            room_id: The id of the room we're interested in.
+            event_id: The id of the event we want the state at.
 
         Returns:
-            Deferred[Tuple[List[EventBase], List[EventBase]]]:
-                A list of events in the state, and a list of events in the auth chain
-                for the given event.
+            A list of events in the state, and a list of events in the auth chain
+            for the given event.
         """
         (
             state_event_ids,
             auth_event_ids,
-        ) = yield self.federation_client.get_room_state_ids(
+        ) = await self.federation_client.get_room_state_ids(
             destination, room_id, event_id=event_id
         )
 
         desired_events = set(state_event_ids + auth_event_ids)
-        event_map = yield self._get_events_from_store_or_dest(
+        event_map = await self._get_events_from_store_or_dest(
             destination, room_id, desired_events
         )
 
@@ -625,20 +621,20 @@ class FederationHandler(BaseHandler):
 
         return pdus, auth_chain
 
-    @defer.inlineCallbacks
-    def _get_events_from_store_or_dest(self, destination, room_id, event_ids):
+    async def _get_events_from_store_or_dest(
+        self, destination: str, room_id: str, event_ids: Iterable[str]
+    ) -> Dict[str, EventBase]:
         """Fetch events from a remote destination, checking if we already have them.
 
         Args:
-            destination (str)
-            room_id (str)
-            event_ids (Iterable[str])
+            destination
+            room_id
+            event_ids
 
         Returns:
-            Deferred[dict[str, EventBase]]: A deferred resolving to a map
-            from event_id to event
+            map from event_id to event
         """
-        fetched_events = yield self.store.get_events(event_ids, allow_rejected=True)
+        fetched_events = await self.store.get_events(event_ids, allow_rejected=True)
 
         missing_events = set(event_ids) - fetched_events.keys()
 
@@ -651,7 +647,7 @@ class FederationHandler(BaseHandler):
             event_ids,
         )
 
-        room_version = yield self.store.get_room_version(room_id)
+        room_version = await self.store.get_room_version(room_id)
 
         # XXX 20 requests at once? really?
         for batch in batch_iter(missing_events, 20):
@@ -665,7 +661,7 @@ class FederationHandler(BaseHandler):
                 for e_id in batch
             ]
 
-            res = yield make_deferred_yieldable(
+            res = await make_deferred_yieldable(
                 defer.DeferredList(deferreds, consumeErrors=True)
             )
             for success, result in res:
@@ -674,8 +670,7 @@ class FederationHandler(BaseHandler):
 
         return fetched_events
 
-    @defer.inlineCallbacks
-    def _process_received_pdu(self, origin, event, state, auth_chain):
+    async def _process_received_pdu(self, origin, event, state, auth_chain):
         """ Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
         """
@@ -690,7 +685,7 @@ class FederationHandler(BaseHandler):
         if auth_chain:
             event_ids |= {e.event_id for e in auth_chain}
 
-        seen_ids = yield self.store.have_seen_events(event_ids)
+        seen_ids = await self.store.have_seen_events(event_ids)
 
         if state and auth_chain is not None:
             # If we have any state or auth_chain given to us by the replication
@@ -717,18 +712,18 @@ class FederationHandler(BaseHandler):
                 event_id,
                 [e.event.event_id for e in event_infos],
             )
-            yield self._handle_new_events(origin, event_infos)
+            await self._handle_new_events(origin, event_infos)
 
         try:
-            context = yield self._handle_new_event(origin, event, state=state)
+            context = await self._handle_new_event(origin, event, state=state)
         except AuthError as e:
             raise FederationError("ERROR", e.code, e.msg, affected=event.event_id)
 
-        room = yield self.store.get_room(room_id)
+        room = await self.store.get_room(room_id)
 
         if not room:
             try:
-                yield self.store.store_room(
+                await self.store.store_room(
                     room_id=room_id, room_creator_user_id="", is_public=False
                 )
             except StoreError:
@@ -741,11 +736,11 @@ class FederationHandler(BaseHandler):
                 # changing their profile info.
                 newly_joined = True
 
-                prev_state_ids = yield context.get_prev_state_ids(self.store)
+                prev_state_ids = await context.get_prev_state_ids(self.store)
 
                 prev_state_id = prev_state_ids.get((event.type, event.state_key))
                 if prev_state_id:
-                    prev_state = yield self.store.get_event(
+                    prev_state = await self.store.get_event(
                         prev_state_id, allow_none=True
                     )
                     if prev_state and prev_state.membership == Membership.JOIN:
@@ -753,11 +748,10 @@ class FederationHandler(BaseHandler):
 
                 if newly_joined:
                     user = UserID.from_string(event.state_key)
-                    yield self.user_joined_room(user, room_id)
+                    await self.user_joined_room(user, room_id)
 
     @log_function
-    @defer.inlineCallbacks
-    def backfill(self, dest, room_id, limit, extremities):
+    async def backfill(self, dest, room_id, limit, extremities):
         """ Trigger a backfill request to `dest` for the given `room_id`
 
         This will attempt to get more events from the remote. If the other side
@@ -774,9 +768,9 @@ class FederationHandler(BaseHandler):
         if dest == self.server_name:
             raise SynapseError(400, "Can't backfill from self.")
 
-        room_version = yield self.store.get_room_version(room_id)
+        room_version = await self.store.get_room_version(room_id)
 
-        events = yield self.federation_client.backfill(
+        events = await self.federation_client.backfill(
             dest, room_id, limit=limit, extremities=extremities
         )
 
@@ -791,7 +785,7 @@ class FederationHandler(BaseHandler):
         #     self._sanity_check_event(ev)
 
         # Don't bother processing events we already have.
-        seen_events = yield self.store.have_events_in_timeline(
+        seen_events = await self.store.have_events_in_timeline(
             set(e.event_id for e in events)
         )
 
@@ -814,7 +808,7 @@ class FederationHandler(BaseHandler):
         state_events = {}
         events_to_state = {}
         for e_id in edges:
-            state, auth = yield self._get_state_for_room(
+            state, auth = await self._get_state_for_room(
                 destination=dest, room_id=room_id, event_id=e_id
             )
             auth_events.update({a.event_id: a for a in auth})
@@ -839,7 +833,7 @@ class FederationHandler(BaseHandler):
         # We repeatedly do this until we stop finding new auth events.
         while missing_auth - failed_to_fetch:
             logger.info("Missing auth for backfill: %r", missing_auth)
-            ret_events = yield self.store.get_events(missing_auth - failed_to_fetch)
+            ret_events = await self.store.get_events(missing_auth - failed_to_fetch)
             auth_events.update(ret_events)
 
             required_auth.update(
@@ -853,7 +847,7 @@ class FederationHandler(BaseHandler):
                     missing_auth - failed_to_fetch,
                 )
 
-                results = yield make_deferred_yieldable(
+                results = await make_deferred_yieldable(
                     defer.gatherResults(
                         [
                             run_in_background(
@@ -880,7 +874,7 @@ class FederationHandler(BaseHandler):
 
                 failed_to_fetch = missing_auth - set(auth_events)
 
-        seen_events = yield self.store.have_seen_events(
+        seen_events = await self.store.have_seen_events(
             set(auth_events.keys()) | set(state_events.keys())
         )
 
@@ -942,7 +936,7 @@ class FederationHandler(BaseHandler):
                 )
             )
 
-        yield self._handle_new_events(dest, ev_infos, backfilled=True)
+        await self._handle_new_events(dest, ev_infos, backfilled=True)
 
         # Step 2: Persist the rest of the events in the chunk one by one
         events.sort(key=lambda e: e.depth)
@@ -958,16 +952,15 @@ class FederationHandler(BaseHandler):
             # We store these one at a time since each event depends on the
             # previous to work out the state.
             # TODO: We can probably do something more clever here.
-            yield self._handle_new_event(dest, event, backfilled=True)
+            await self._handle_new_event(dest, event, backfilled=True)
 
         return events
 
-    @defer.inlineCallbacks
-    def maybe_backfill(self, room_id, current_depth):
+    async def maybe_backfill(self, room_id, current_depth):
         """Checks the database to see if we should backfill before paginating,
         and if so do.
         """
-        extremities = yield self.store.get_oldest_events_with_depth_in_room(room_id)
+        extremities = await self.store.get_oldest_events_with_depth_in_room(room_id)
 
         if not extremities:
             logger.debug("Not backfilling as no extremeties found.")
@@ -999,9 +992,9 @@ class FederationHandler(BaseHandler):
         #   state *before* the event, ignoring the special casing certain event
         #   types have.
 
-        forward_events = yield self.store.get_successor_events(list(extremities))
+        forward_events = await self.store.get_successor_events(list(extremities))
 
-        extremities_events = yield self.store.get_events(
+        extremities_events = await self.store.get_events(
             forward_events,
             redact_behaviour=EventRedactBehaviour.AS_IS,
             get_prev_content=False,
@@ -1009,7 +1002,7 @@ class FederationHandler(BaseHandler):
 
         # We set `check_history_visibility_only` as we might otherwise get false
         # positives from users having been erased.
-        filtered_extremities = yield filter_events_for_server(
+        filtered_extremities = await filter_events_for_server(
             self.storage,
             self.server_name,
             list(extremities_events.values()),
@@ -1039,7 +1032,7 @@ class FederationHandler(BaseHandler):
         # First we try hosts that are already in the room
         # TODO: HEURISTIC ALERT.
 
-        curr_state = yield self.state_handler.get_current_state(room_id)
+        curr_state = await self.state_handler.get_current_state(room_id)
 
         def get_domains_from_state(state):
             """Get joined domains from state
@@ -1078,12 +1071,11 @@ class FederationHandler(BaseHandler):
             domain for domain, depth in curr_domains if domain != self.server_name
         ]
 
-        @defer.inlineCallbacks
-        def try_backfill(domains):
+        async def try_backfill(domains):
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
-                    yield self.backfill(
+                    await self.backfill(
                         dom, room_id, limit=100, extremities=extremities
                     )
                     # If this succeeded then we probably already have the
@@ -1114,7 +1106,7 @@ class FederationHandler(BaseHandler):
 
             return False
 
-        success = yield try_backfill(likely_domains)
+        success = await try_backfill(likely_domains)
         if success:
             return True
 
@@ -1128,7 +1120,7 @@ class FederationHandler(BaseHandler):
 
         logger.debug("calling resolve_state_groups in _maybe_backfill")
         resolve = preserve_fn(self.state_handler.resolve_state_groups_for_events)
-        states = yield make_deferred_yieldable(
+        states = await make_deferred_yieldable(
             defer.gatherResults(
                 [resolve(room_id, [e]) for e in event_ids], consumeErrors=True
             )
@@ -1138,7 +1130,7 @@ class FederationHandler(BaseHandler):
         # event_ids.
         states = dict(zip(event_ids, [s.state for s in states]))
 
-        state_map = yield self.store.get_events(
+        state_map = await self.store.get_events(
             [e_id for ids in itervalues(states) for e_id in itervalues(ids)],
             get_prev_content=False,
         )
@@ -1154,7 +1146,7 @@ class FederationHandler(BaseHandler):
         for e_id, _ in sorted_extremeties_tuple:
             likely_domains = get_domains_from_state(states[e_id])
 
-            success = yield try_backfill(
+            success = await try_backfill(
                 [dom for dom, _ in likely_domains if dom not in tried_domains]
             )
             if success:
@@ -1331,8 +1323,7 @@ class FederationHandler(BaseHandler):
 
         return True
 
-    @defer.inlineCallbacks
-    def _handle_queued_pdus(self, room_queue):
+    async def _handle_queued_pdus(self, room_queue):
         """Process PDUs which got queued up while we were busy send_joining.
 
         Args:
@@ -1348,7 +1339,7 @@ class FederationHandler(BaseHandler):
                     p.room_id,
                 )
                 with nested_logging_context(p.event_id):
-                    yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+                    await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
             except Exception as e:
                 logger.warning(
                     "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
@@ -2907,7 +2898,7 @@ class FederationHandler(BaseHandler):
                 room_id=room_id, user_id=user.to_string(), change="joined"
             )
         else:
-            return user_joined_room(self.distributor, user, room_id)
+            return defer.succeed(user_joined_room(self.distributor, user, room_id))
 
     @defer.inlineCallbacks
     def get_room_complexity(self, remote_room_hosts, room_id):
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 8514ddc600..00a6afc963 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -280,8 +280,7 @@ class PaginationHandler(object):
 
             await self.storage.purge_events.purge_room(room_id)
 
-    @defer.inlineCallbacks
-    def get_messages(
+    async def get_messages(
         self,
         requester,
         room_id=None,
@@ -307,7 +306,7 @@ class PaginationHandler(object):
             room_token = pagin_config.from_token.room_key
         else:
             pagin_config.from_token = (
-                yield self.hs.get_event_sources().get_current_token_for_pagination()
+                await self.hs.get_event_sources().get_current_token_for_pagination()
             )
             room_token = pagin_config.from_token.room_key
 
@@ -319,11 +318,11 @@ class PaginationHandler(object):
 
         source_config = pagin_config.get_source_config("room")
 
-        with (yield self.pagination_lock.read(room_id)):
+        with (await self.pagination_lock.read(room_id)):
             (
                 membership,
                 member_event_id,
-            ) = yield self.auth.check_in_room_or_world_readable(room_id, user_id)
+            ) = await self.auth.check_in_room_or_world_readable(room_id, user_id)
 
             if source_config.direction == "b":
                 # if we're going backwards, we might need to backfill. This
@@ -331,7 +330,7 @@ class PaginationHandler(object):
                 if room_token.topological:
                     max_topo = room_token.topological
                 else:
-                    max_topo = yield self.store.get_max_topological_token(
+                    max_topo = await self.store.get_max_topological_token(
                         room_id, room_token.stream
                     )
 
@@ -339,18 +338,18 @@ class PaginationHandler(object):
                     # If they have left the room then clamp the token to be before
                     # they left the room, to save the effort of loading from the
                     # database.
-                    leave_token = yield self.store.get_topological_token_for_event(
+                    leave_token = await self.store.get_topological_token_for_event(
                         member_event_id
                     )
                     leave_token = RoomStreamToken.parse(leave_token)
                     if leave_token.topological < max_topo:
                         source_config.from_key = str(leave_token)
 
-                yield self.hs.get_handlers().federation_handler.maybe_backfill(
+                await self.hs.get_handlers().federation_handler.maybe_backfill(
                     room_id, max_topo
                 )
 
-            events, next_key = yield self.store.paginate_room_events(
+            events, next_key = await self.store.paginate_room_events(
                 room_id=room_id,
                 from_key=source_config.from_key,
                 to_key=source_config.to_key,
@@ -365,7 +364,7 @@ class PaginationHandler(object):
             if event_filter:
                 events = event_filter.filter(events)
 
-            events = yield filter_events_for_client(
+            events = await filter_events_for_client(
                 self.storage, user_id, events, is_peeking=(member_event_id is None)
             )
 
@@ -385,19 +384,19 @@ class PaginationHandler(object):
                 (EventTypes.Member, event.sender) for event in events
             )
 
-            state_ids = yield self.state_store.get_state_ids_for_event(
+            state_ids = await self.state_store.get_state_ids_for_event(
                 events[0].event_id, state_filter=state_filter
             )
 
             if state_ids:
-                state = yield self.store.get_events(list(state_ids.values()))
+                state = await self.store.get_events(list(state_ids.values()))
                 state = state.values()
 
         time_now = self.clock.time_msec()
 
         chunk = {
             "chunk": (
-                yield self._event_serializer.serialize_events(
+                await self._event_serializer.serialize_events(
                     events, time_now, as_client_event=as_client_event
                 )
             ),
@@ -406,7 +405,7 @@ class PaginationHandler(object):
         }
 
         if state:
-            chunk["state"] = yield self._event_serializer.serialize_events(
+            chunk["state"] = await self._event_serializer.serialize_events(
                 state, time_now, as_client_event=as_client_event
             )
 
diff --git a/tests/test_federation.py b/tests/test_federation.py
index ad165d7295..68684460c6 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,6 +1,6 @@
 from mock import Mock
 
-from twisted.internet.defer import maybeDeferred, succeed
+from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
 
 from synapse.events import FrozenEvent
 from synapse.logging.context import LoggingContext
@@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         # Send the join, it should return None (which is not an error)
-        d = self.handler.on_receive_pdu(
-            "test.serv", join_event, sent_to_us_directly=True
+        d = ensureDeferred(
+            self.handler.on_receive_pdu(
+                "test.serv", join_event, sent_to_us_directly=True
+            )
         )
         self.reactor.advance(1)
         self.assertEqual(self.successResultOf(d), None)
@@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         with LoggingContext(request="lying_event"):
-            d = self.handler.on_receive_pdu(
-                "test.serv", lying_event, sent_to_us_directly=True
+            d = ensureDeferred(
+                self.handler.on_receive_pdu(
+                    "test.serv", lying_event, sent_to_us_directly=True
+                )
             )
 
             # Step the reactor, so the database fetches come back