summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2020-04-24 14:36:38 -0400
committerGitHub <noreply@github.com>2020-04-24 14:36:38 -0400
commit33bceb7f705012b7543bd13dbf4c3941833ae356 (patch)
treeaca08be2571d59e01fac6228c9e47dd238621b17
parentMerge pull request #7337 from matrix-org/rav/fix_update_limit_assertion (diff)
downloadsynapse-33bceb7f705012b7543bd13dbf4c3941833ae356.tar.xz
Convert some of the federation handler methods to async/await. (#7338)
-rw-r--r--changelog.d/7338.misc1
-rw-r--r--synapse/handlers/federation.py49
2 files changed, 25 insertions, 25 deletions
diff --git a/changelog.d/7338.misc b/changelog.d/7338.misc
new file mode 100644
index 0000000000..7cafd074ca
--- /dev/null
+++ b/changelog.d/7338.misc
@@ -0,0 +1 @@
+Convert some federation handler code to async/await.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index c7aa7acf3b..41b96c0a73 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -343,7 +343,7 @@ class FederationHandler(BaseHandler):
                     ours = await self.state_store.get_state_groups_ids(room_id, seen)
 
                     # state_maps is a list of mappings from (type, state_key) to event_id
-                    state_maps = list(ours.values())  # type: list[StateMap[str]]
+                    state_maps = list(ours.values())  # type: List[StateMap[str]]
 
                     # we don't need this any more, let's delete it.
                     del ours
@@ -1694,16 +1694,15 @@ class FederationHandler(BaseHandler):
 
         return None
 
-    @defer.inlineCallbacks
-    def get_state_for_pdu(self, room_id, event_id):
+    async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
         """Returns the state at the event. i.e. not including said event.
         """
 
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, allow_none=False, check_room_id=room_id
         )
 
-        state_groups = yield self.state_store.get_state_groups(room_id, [event_id])
+        state_groups = await self.state_store.get_state_groups(room_id, [event_id])
 
         if state_groups:
             _, state = list(iteritems(state_groups)).pop()
@@ -1714,7 +1713,7 @@ class FederationHandler(BaseHandler):
                 if "replaces_state" in event.unsigned:
                     prev_id = event.unsigned["replaces_state"]
                     if prev_id != event.event_id:
-                        prev_event = yield self.store.get_event(prev_id)
+                        prev_event = await self.store.get_event(prev_id)
                         results[(event.type, event.state_key)] = prev_event
                 else:
                     del results[(event.type, event.state_key)]
@@ -1724,15 +1723,14 @@ class FederationHandler(BaseHandler):
         else:
             return []
 
-    @defer.inlineCallbacks
-    def get_state_ids_for_pdu(self, room_id, event_id):
+    async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event.
         """
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, allow_none=False, check_room_id=room_id
         )
 
-        state_groups = yield self.state_store.get_state_groups_ids(room_id, [event_id])
+        state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
 
         if state_groups:
             _, state = list(state_groups.items()).pop()
@@ -1751,49 +1749,50 @@ class FederationHandler(BaseHandler):
         else:
             return []
 
-    @defer.inlineCallbacks
     @log_function
-    def on_backfill_request(self, origin, room_id, pdu_list, limit):
-        in_room = yield self.auth.check_host_in_room(room_id, origin)
+    async def on_backfill_request(
+        self, origin: str, room_id: str, pdu_list: List[str], limit: int
+    ) -> List[EventBase]:
+        in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
 
         # Synapse asks for 100 events per backfill request. Do not allow more.
         limit = min(limit, 100)
 
-        events = yield self.store.get_backfill_events(room_id, pdu_list, limit)
+        events = await self.store.get_backfill_events(room_id, pdu_list, limit)
 
-        events = yield filter_events_for_server(self.storage, origin, events)
+        events = await filter_events_for_server(self.storage, origin, events)
 
         return events
 
-    @defer.inlineCallbacks
     @log_function
-    def get_persisted_pdu(self, origin, event_id):
+    async def get_persisted_pdu(
+        self, origin: str, event_id: str
+    ) -> Optional[EventBase]:
         """Get an event from the database for the given server.
 
         Args:
-            origin [str]: hostname of server which is requesting the event; we
+            origin: hostname of server which is requesting the event; we
                will check that the server is allowed to see it.
-            event_id [str]: id of the event being requested
+            event_id: id of the event being requested
 
         Returns:
-            Deferred[EventBase|None]: None if we know nothing about the event;
-                otherwise the (possibly-redacted) event.
+            None if we know nothing about the event; otherwise the (possibly-redacted) event.
 
         Raises:
             AuthError if the server is not currently in the room
         """
-        event = yield self.store.get_event(
+        event = await self.store.get_event(
             event_id, allow_none=True, allow_rejected=True
         )
 
         if event:
-            in_room = yield self.auth.check_host_in_room(event.room_id, origin)
+            in_room = await self.auth.check_host_in_room(event.room_id, origin)
             if not in_room:
                 raise AuthError(403, "Host not in room.")
 
-            events = yield filter_events_for_server(self.storage, origin, [event])
+            events = await filter_events_for_server(self.storage, origin, [event])
             event = events[0]
             return event
         else:
@@ -2397,7 +2396,7 @@ class FederationHandler(BaseHandler):
         """
         # exclude the state key of the new event from the current_state in the context.
         if event.is_state():
-            event_key = (event.type, event.state_key)
+            event_key = (event.type, event.state_key)  # type: Optional[Tuple[str, str]]
         else:
             event_key = None
         state_updates = {