summary refs log tree commit diff
path: root/synapse/storage/databases/main/event_federation.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/event_federation.py')
-rw-r--r--synapse/storage/databases/main/event_federation.py71
1 files changed, 37 insertions, 34 deletions
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 6e5761c7b7..0b69aa6a94 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -59,7 +59,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             include_given: include the given events in result
 
         Returns:
-            list of event_ids
+            An awaitable which resolve to a list of event_ids
         """
         return await self.db_pool.runInteraction(
             "get_auth_chain_ids",
@@ -95,7 +95,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return list(results)
 
-    def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+    async def get_auth_chain_difference(self, state_sets: List[Set[str]]) -> Set[str]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
 
@@ -104,10 +104,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         chain.
 
         Returns:
-            Deferred[Set[str]]
+            The set of the difference in auth chains.
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_auth_chain_difference",
             self._get_auth_chain_difference_txn,
             state_sets,
@@ -252,8 +252,8 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         # Return all events where not all sets can reach them.
         return {eid for eid, n in event_to_missing_sets.items() if n}
 
-    def get_oldest_events_with_depth_in_room(self, room_id):
-        return self.db_pool.runInteraction(
+    async def get_oldest_events_with_depth_in_room(self, room_id):
+        return await self.db_pool.runInteraction(
             "get_oldest_events_with_depth_in_room",
             self.get_oldest_events_with_depth_in_room_txn,
             room_id,
@@ -293,7 +293,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         else:
             return max(row["depth"] for row in rows)
 
-    def get_prev_events_for_room(self, room_id: str):
+    async def get_prev_events_for_room(self, room_id: str) -> List[str]:
         """
         Gets a subset of the current forward extremities in the given room.
 
@@ -301,14 +301,14 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         events which refer to hundreds of prev_events.
 
         Args:
-            room_id (str): room_id
+            room_id: room_id
 
         Returns:
-            Deferred[List[str]]: the event ids of the forward extremites
+            The event ids of the forward extremities.
 
         """
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_prev_events_for_room", self._get_prev_events_for_room_txn, room_id
         )
 
@@ -328,17 +328,19 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return [row[0] for row in txn]
 
-    def get_rooms_with_many_extremities(self, min_count, limit, room_id_filter):
+    async def get_rooms_with_many_extremities(
+        self, min_count: int, limit: int, room_id_filter: Iterable[str]
+    ) -> List[str]:
         """Get the top rooms with at least N extremities.
 
         Args:
-            min_count (int): The minimum number of extremities
-            limit (int): The maximum number of rooms to return.
-            room_id_filter (iterable[str]): room_ids to exclude from the results
+            min_count: The minimum number of extremities
+            limit: The maximum number of rooms to return.
+            room_id_filter: room_ids to exclude from the results
 
         Returns:
-            Deferred[list]: At most `limit` room IDs that have at least
-            `min_count` extremities, sorted by extremity count.
+            At most `limit` room IDs that have at least `min_count` extremities,
+            sorted by extremity count.
         """
 
         def _get_rooms_with_many_extremities_txn(txn):
@@ -363,7 +365,7 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             txn.execute(sql, query_args)
             return [room_id for room_id, in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_rooms_with_many_extremities", _get_rooms_with_many_extremities_txn
         )
 
@@ -376,10 +378,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             desc="get_latest_event_ids_in_room",
         )
 
-    def get_min_depth(self, room_id):
-        """ For hte given room, get the minimum depth we have seen for it.
+    async def get_min_depth(self, room_id: str) -> int:
+        """For the given room, get the minimum depth we have seen for it.
         """
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_min_depth", self._get_min_depth_interaction, room_id
         )
 
@@ -394,7 +396,9 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
 
         return int(min_depth) if min_depth is not None else None
 
-    def get_forward_extremeties_for_room(self, room_id, stream_ordering):
+    async def get_forward_extremeties_for_room(
+        self, room_id: str, stream_ordering: int
+    ) -> List[str]:
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -402,11 +406,11 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         stream_orderings from that point.
 
         Args:
-            room_id (str):
-            stream_ordering (int):
+            room_id:
+            stream_ordering:
 
         Returns:
-            deferred, which resolves to a list of event_ids
+            A list of event_ids
         """
         # We want to make the cache more effective, so we clamp to the last
         # change before the given ordering.
@@ -422,10 +426,10 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
         if last_change > self.stream_ordering_month_ago:
             stream_ordering = min(last_change, stream_ordering)
 
-        return self._get_forward_extremeties_for_room(room_id, stream_ordering)
+        return await self._get_forward_extremeties_for_room(room_id, stream_ordering)
 
     @cached(max_entries=5000, num_args=2)
-    def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
+    async def _get_forward_extremeties_for_room(self, room_id, stream_ordering):
         """For a given room_id and stream_ordering, return the forward
         extremeties of the room at that point in "time".
 
@@ -450,19 +454,18 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas
             txn.execute(sql, (stream_ordering, room_id))
             return [event_id for event_id, in txn]
 
-        return self.db_pool.runInteraction(
+        return await self.db_pool.runInteraction(
             "get_forward_extremeties_for_room", get_forward_extremeties_for_room_txn
         )
 
-    async def get_backfill_events(self, room_id, event_list, limit):
+    async def get_backfill_events(self, room_id: str, event_list: list, limit: int):
         """Get a list of Events for a given topic that occurred before (and
         including) the events in event_list. Return a list of max size `limit`
 
         Args:
-            txn
-            room_id (str)
-            event_list (list)
-            limit (int)
+            room_id
+            event_list
+            limit
         """
         event_ids = await self.db_pool.runInteraction(
             "get_backfill_events",
@@ -631,8 +634,8 @@ class EventFederationStore(EventFederationWorkerStore):
             _delete_old_forward_extrem_cache_txn,
         )
 
-    def clean_room_for_join(self, room_id):
-        return self.db_pool.runInteraction(
+    async def clean_room_for_join(self, room_id):
+        return await self.db_pool.runInteraction(
             "clean_room_for_join", self._clean_room_for_join_txn, room_id
         )