summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/federation.py49
-rw-r--r--tests/test_federation.py14
2 files changed, 31 insertions, 32 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index e54d509b62..01d9c5120e 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -165,8 +165,7 @@ class FederationHandler(BaseHandler):
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
-    @defer.inlineCallbacks
-    def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False):
+    async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
         """ Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
@@ -176,8 +175,6 @@ class FederationHandler(BaseHandler):
             pdu (FrozenEvent): received PDU
             sent_to_us_directly (bool): True if this event was pushed to us; False if
                 we pulled it as the result of a missing prev_event.
-
-        Returns (Deferred): completes with None
         """
 
         room_id = pdu.room_id
@@ -186,7 +183,7 @@ class FederationHandler(BaseHandler):
         logger.info("handling received PDU: %s", pdu)
 
         # We reprocess pdus when we have seen them only as outliers
-        existing = yield self.store.get_event(
+        existing = await self.store.get_event(
             event_id, allow_none=True, allow_rejected=True
         )
 
@@ -230,7 +227,7 @@ class FederationHandler(BaseHandler):
         #
         # Note that if we were never in the room then we would have already
         # dropped the event, since we wouldn't know the room version.
-        is_in_room = yield self.auth.check_host_in_room(room_id, self.server_name)
+        is_in_room = await self.auth.check_host_in_room(room_id, self.server_name)
         if not is_in_room:
             logger.info(
                 "[%s %s] Ignoring PDU from %s as we're not in the room",
@@ -246,12 +243,12 @@ class FederationHandler(BaseHandler):
         # Get missing pdus if necessary.
         if not pdu.internal_metadata.is_outlier():
             # We only backfill backwards to the min depth.
-            min_depth = yield self.get_min_depth_for_context(pdu.room_id)
+            min_depth = await self.get_min_depth_for_context(pdu.room_id)
 
             logger.debug("[%s %s] min_depth: %d", room_id, event_id, min_depth)
 
             prevs = set(pdu.prev_event_ids())
-            seen = yield self.store.have_seen_events(prevs)
+            seen = await self.store.have_seen_events(prevs)
 
             if min_depth and pdu.depth < min_depth:
                 # This is so that we don't notify the user about this
@@ -271,7 +268,7 @@ class FederationHandler(BaseHandler):
                         len(missing_prevs),
                         shortstr(missing_prevs),
                     )
-                    with (yield self._room_pdu_linearizer.queue(pdu.room_id)):
+                    with (await self._room_pdu_linearizer.queue(pdu.room_id)):
                         logger.info(
                             "[%s %s] Acquired room lock to fetch %d missing prev_events",
                             room_id,
@@ -280,7 +277,7 @@ class FederationHandler(BaseHandler):
                         )
 
                         try:
-                            yield self._get_missing_events_for_pdu(
+                            await self._get_missing_events_for_pdu(
                                 origin, pdu, prevs, min_depth
                             )
                         except Exception as e:
@@ -291,7 +288,7 @@ class FederationHandler(BaseHandler):
 
                         # Update the set of things we've seen after trying to
                         # fetch the missing stuff
-                        seen = yield self.store.have_seen_events(prevs)
+                        seen = await self.store.have_seen_events(prevs)
 
                         if not prevs - seen:
                             logger.info(
@@ -355,7 +352,7 @@ class FederationHandler(BaseHandler):
                 event_map = {event_id: pdu}
                 try:
                     # Get the state of the events we know about
-                    ours = yield self.state_store.get_state_groups_ids(room_id, seen)
+                    ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
                     state_maps = list(
@@ -372,7 +369,7 @@ class FederationHandler(BaseHandler):
                             "Requesting state at missing prev_event %s", event_id,
                         )
 
-                        room_version = yield self.store.get_room_version(room_id)
+                        room_version = await self.store.get_room_version(room_id)
 
                         with nested_logging_context(p):
                             # note that if any of the missing prevs share missing state or
@@ -381,11 +378,11 @@ class FederationHandler(BaseHandler):
                             (
                                 remote_state,
                                 got_auth_chain,
-                            ) = yield self._get_state_for_room(origin, room_id, p)
+                            ) = await self._get_state_for_room(origin, room_id, p)
 
                             # we want the state *after* p; _get_state_for_room returns the
                             # state *before* p.
-                            remote_event = yield self.federation_client.get_pdu(
+                            remote_event = await self.federation_client.get_pdu(
                                 [origin], p, room_version, outlier=True
                             )
 
@@ -410,7 +407,7 @@ class FederationHandler(BaseHandler):
                             for x in remote_state:
                                 event_map[x.event_id] = x
 
-                    state_map = yield resolve_events_with_store(
+                    state_map = await resolve_events_with_store(
                         room_version,
                         state_maps,
                         event_map,
@@ -422,7 +419,7 @@ class FederationHandler(BaseHandler):
 
                     # First though we need to fetch all the events that are in
                     # state_map, so we can build up the state below.
-                    evs = yield self.store.get_events(
+                    evs = await self.store.get_events(
                         list(state_map.values()),
                         get_prev_content=False,
                         redact_behaviour=EventRedactBehaviour.AS_IS,
@@ -446,12 +443,11 @@ class FederationHandler(BaseHandler):
                         affected=event_id,
                     )
 
-        yield self._process_received_pdu(
+        await self._process_received_pdu(
             origin, pdu, state=state, auth_chain=auth_chain
         )
 
-    @defer.inlineCallbacks
-    def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+    async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
         """
         Args:
             origin (str): Origin of the pdu. Will be called to get the missing events
@@ -463,12 +459,12 @@ class FederationHandler(BaseHandler):
         room_id = pdu.room_id
         event_id = pdu.event_id
 
-        seen = yield self.store.have_seen_events(prevs)
+        seen = await self.store.have_seen_events(prevs)
 
         if not prevs - seen:
             return
 
-        latest = yield self.store.get_latest_event_ids_in_room(room_id)
+        latest = await self.store.get_latest_event_ids_in_room(room_id)
 
         # We add the prev events that we have seen to the latest
         # list to ensure the remote server doesn't give them to us
@@ -532,7 +528,7 @@ class FederationHandler(BaseHandler):
         # All that said: Let's try increasing the timout to 60s and see what happens.
 
         try:
-            missing_events = yield self.federation_client.get_missing_events(
+            missing_events = await self.federation_client.get_missing_events(
                 origin,
                 room_id,
                 earliest_events_ids=list(latest),
@@ -571,7 +567,7 @@ class FederationHandler(BaseHandler):
             )
             with nested_logging_context(ev.event_id):
                 try:
-                    yield self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
+                    await self.on_receive_pdu(origin, ev, sent_to_us_directly=False)
                 except FederationError as e:
                     if e.code == 403:
                         logger.warning(
@@ -1328,8 +1324,7 @@ class FederationHandler(BaseHandler):
 
         return True
 
-    @defer.inlineCallbacks
-    def _handle_queued_pdus(self, room_queue):
+    async def _handle_queued_pdus(self, room_queue):
         """Process PDUs which got queued up while we were busy send_joining.
 
         Args:
@@ -1345,7 +1340,7 @@ class FederationHandler(BaseHandler):
                     p.room_id,
                 )
                 with nested_logging_context(p.event_id):
-                    yield self.on_receive_pdu(origin, p, sent_to_us_directly=True)
+                    await self.on_receive_pdu(origin, p, sent_to_us_directly=True)
             except Exception as e:
                 logger.warning(
                     "Error handling queued PDU %s from %s: %s", p.event_id, origin, e
diff --git a/tests/test_federation.py b/tests/test_federation.py
index ad165d7295..68684460c6 100644
--- a/tests/test_federation.py
+++ b/tests/test_federation.py
@@ -1,6 +1,6 @@
 from mock import Mock
 
-from twisted.internet.defer import maybeDeferred, succeed
+from twisted.internet.defer import ensureDeferred, maybeDeferred, succeed
 
 from synapse.events import FrozenEvent
 from synapse.logging.context import LoggingContext
@@ -70,8 +70,10 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         # Send the join, it should return None (which is not an error)
-        d = self.handler.on_receive_pdu(
-            "test.serv", join_event, sent_to_us_directly=True
+        d = ensureDeferred(
+            self.handler.on_receive_pdu(
+                "test.serv", join_event, sent_to_us_directly=True
+            )
         )
         self.reactor.advance(1)
         self.assertEqual(self.successResultOf(d), None)
@@ -119,8 +121,10 @@ class MessageAcceptTests(unittest.TestCase):
         )
 
         with LoggingContext(request="lying_event"):
-            d = self.handler.on_receive_pdu(
-                "test.serv", lying_event, sent_to_us_directly=True
+            d = ensureDeferred(
+                self.handler.on_receive_pdu(
+                    "test.serv", lying_event, sent_to_us_directly=True
+                )
             )
 
             # Step the reactor, so the database fetches come back