summary refs log tree commit diff
path: root/synapse/handlers/federation.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2021-04-06 07:21:57 -0400
committerGitHub <noreply@github.com>2021-04-06 07:21:57 -0400
commitd959d28730ec6a0765ab72b10bcc96b1507233ac (patch)
treec3623edbf657acabff377549416f9d7d53a2e836 /synapse/handlers/federation.py
parentConvert storage test cases to HomeserverTestCase. (#9736) (diff)
downloadsynapse-d959d28730ec6a0765ab72b10bcc96b1507233ac.tar.xz
Add type hints to the federation handler and server. (#9743)
Diffstat (limited to 'synapse/handlers/federation.py')
-rw-r--r--synapse/handlers/federation.py161
1 files changed, 81 insertions, 80 deletions
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3ebee38ebe..5ea8a7b603 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,7 +21,17 @@ import itertools
 import logging
 from collections.abc import Container
 from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import (
+    TYPE_CHECKING,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    Union,
+)
 
 import attr
 from signedjson.key import decode_verify_key_bytes
@@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
 
         self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
 
-    async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
+    async def on_receive_pdu(
+        self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
+    ) -> None:
         """Process a PDU received via a federation /send/ transaction, or
         via backfill of missing prev_events
 
         Args:
-            origin (str): server which initiated the /send/ transaction. Will
+            origin: server which initiated the /send/ transaction. Will
                 be used to fetch missing events or state.
-            pdu (FrozenEvent): received PDU
-            sent_to_us_directly (bool): True if this event was pushed to us; False if
+            pdu: received PDU
+            sent_to_us_directly: True if this event was pushed to us; False if
                 we pulled it as the result of a missing prev_event.
         """
 
@@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
 
         await self._process_received_pdu(origin, pdu, state=state)
 
-    async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+    async def _get_missing_events_for_pdu(
+        self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
+    ) -> None:
         """
         Args:
-            origin (str): Origin of the pdu. Will be called to get the missing events
+            origin: Origin of the pdu. Will be called to get the missing events
             pdu: received pdu
-            prevs (set(str)): List of event ids which we are missing
-            min_depth (int): Minimum depth of events to return.
+            prevs: List of event ids which we are missing
+            min_depth: Minimum depth of events to return.
         """
 
         room_id = pdu.room_id
@@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
         origin: str,
         event: EventBase,
         state: Optional[Iterable[EventBase]],
-    ):
+    ) -> None:
         """Called when we have a new pdu. We need to do auth checks and put it
         through the StateHandler.
 
@@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
             logger.exception("Failed to resync device for %s", sender)
 
     @log_function
-    async def backfill(self, dest, room_id, limit, extremities):
+    async def backfill(
+        self, dest: str, room_id: str, limit: int, extremities: List[str]
+    ) -> List[EventBase]:
         """Trigger a backfill request to `dest` for the given `room_id`
 
         This will attempt to get more events from the remote. If the other side
@@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
 
         curr_state = await self.state_handler.get_current_state(room_id)
 
-        def get_domains_from_state(state):
+        def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
             """Get joined domains from state
 
             Args:
-                state (dict[tuple, FrozenEvent]): State map from type/state
-                    key to event.
+                state: State map from type/state key to event.
 
             Returns:
-                list[tuple[str, int]]: Returns a list of servers with the
-                lowest depth of their joins. Sorted by lowest depth first.
+                Returns a list of servers with the lowest depth of their joins.
+                 Sorted by lowest depth first.
             """
             joined_users = [
                 (state_key, int(event.depth))
@@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
             domain for domain, depth in curr_domains if domain != self.server_name
         ]
 
-        async def try_backfill(domains):
+        async def try_backfill(domains: List[str]) -> bool:
             # TODO: Should we try multiple of these at a time?
             for dom in domains:
                 try:
@@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
         }
 
         for e_id, _ in sorted_extremeties_tuple:
-            likely_domains = get_domains_from_state(states[e_id])
+            likely_extremeties_domains = get_domains_from_state(states[e_id])
 
             success = await try_backfill(
-                [dom for dom, _ in likely_domains if dom not in tried_domains]
+                [
+                    dom
+                    for dom, _ in likely_extremeties_domains
+                    if dom not in tried_domains
+                ]
             )
             if success:
                 return True
 
-            tried_domains.update(dom for dom, _ in likely_domains)
+            tried_domains.update(dom for dom, _ in likely_extremeties_domains)
 
         return False
 
     async def _get_events_and_persist(
         self, destination: str, room_id: str, events: Iterable[str]
-    ):
+    ) -> None:
         """Fetch the given events from a server, and persist them as outliers.
 
         This function *does not* recursively get missing auth events of the
@@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
             event_infos,
         )
 
-    def _sanity_check_event(self, ev):
+    def _sanity_check_event(self, ev: EventBase) -> None:
         """
         Do some early sanity checks of a received event
 
@@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
         or cascade of event fetches.
 
         Args:
-            ev (synapse.events.EventBase): event to be checked
-
-        Returns: None
+            ev: event to be checked
 
         Raises:
             SynapseError if the event does not pass muster
@@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
             )
             raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
 
-    async def send_invite(self, target_host, event):
+    async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
         """Sends the invite to the remote server for signing.
 
         Invites must be signed by the invitee's server before distribution.
@@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
 
             run_in_background(self._handle_queued_pdus, room_queue)
 
-    async def _handle_queued_pdus(self, room_queue):
+    async def _handle_queued_pdus(
+        self, room_queue: List[Tuple[EventBase, str]]
+    ) -> None:
         """Process PDUs which got queued up while we were busy send_joining.
 
         Args:
-            room_queue (list[FrozenEvent, str]): list of PDUs to be processed
-                and the servers that sent them
+            room_queue: list of PDUs to be processed and the servers that sent them
         """
         for p, origin in room_queue:
             try:
@@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
 
         return event
 
-    async def on_send_join_request(self, origin, pdu):
+    async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
         """We have received a join event for a room. Fully process it and
         respond with the current state and auth chains.
         """
@@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
 
     async def on_invite_request(
         self, origin: str, event: EventBase, room_version: RoomVersion
-    ):
+    ) -> EventBase:
         """We've got an invite event. Process and persist it. Sign it.
 
         Respond with the now signed event.
@@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
 
         return event
 
-    async def on_send_leave_request(self, origin, pdu):
+    async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
         """ We have received a leave event for a room. Fully process it."""
         event = pdu
 
@@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
         else:
             return None
 
-    async def get_min_depth_for_context(self, context):
+    async def get_min_depth_for_context(self, context: str) -> int:
         return await self.store.get_min_depth(context)
 
     async def _handle_new_event(
-        self, origin, event, state=None, auth_events=None, backfilled=False
-    ):
+        self,
+        origin: str,
+        event: EventBase,
+        state: Optional[Iterable[EventBase]] = None,
+        auth_events: Optional[MutableStateMap[EventBase]] = None,
+        backfilled: bool = False,
+    ) -> EventContext:
         context = await self._prep_event(
             origin, event, state=state, auth_events=auth_events, backfilled=backfilled
         )
@@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
             logger.warning("Soft-failing %r because %s", event, e)
             event.internal_metadata.soft_failed = True
 
-    async def on_query_auth(
-        self, origin, event_id, room_id, remote_auth_chain, rejects, missing
-    ):
-        in_room = await self.auth.check_host_in_room(room_id, origin)
-        if not in_room:
-            raise AuthError(403, "Host not in room.")
-
-        event = await self.store.get_event(event_id, check_room_id=room_id)
-
-        # Just go through and process each event in `remote_auth_chain`. We
-        # don't want to fall into the trap of `missing` being wrong.
-        for e in remote_auth_chain:
-            try:
-                await self._handle_new_event(origin, e)
-            except AuthError:
-                pass
-
-        # Now get the current auth_chain for the event.
-        local_auth_chain = await self.store.get_auth_chain(
-            room_id, list(event.auth_event_ids()), include_given=True
-        )
-
-        # TODO: Check if we would now reject event_id. If so we need to tell
-        # everyone.
-
-        ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
-
-        logger.debug("on_query_auth returning: %s", ret)
-
-        return ret
-
     async def on_get_missing_events(
-        self, origin, room_id, earliest_events, latest_events, limit
-    ):
+        self,
+        origin: str,
+        room_id: str,
+        earliest_events: List[str],
+        latest_events: List[str],
+        limit: int,
+    ) -> List[EventBase]:
         in_room = await self.auth.check_host_in_room(room_id, origin)
         if not in_room:
             raise AuthError(403, "Host not in room.")
@@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
         assumes that we have already processed all events in remote_auth
 
         Params:
-            local_auth (list)
-            remote_auth (list)
+            local_auth
+            remote_auth
 
         Returns:
             dict
@@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
 
     @log_function
     async def exchange_third_party_invite(
-        self, sender_user_id, target_user_id, room_id, signed
-    ):
+        self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
+    ) -> None:
         third_party_invite = {"signed": signed}
 
         event_dict = {
@@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
         await member_handler.send_membership_event(None, event, context)
 
     async def add_display_name_to_third_party_invite(
-        self, room_version, event_dict, event, context
-    ):
+        self,
+        room_version: str,
+        event_dict: JsonDict,
+        event: EventBase,
+        context: EventContext,
+    ) -> Tuple[EventBase, EventContext]:
         key = (
             EventTypes.ThirdPartyInvite,
             event.content["third_party_invite"]["signed"]["token"],
@@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
         EventValidator().validate_new(event, self.config)
         return (event, context)
 
-    async def _check_signature(self, event, context):
+    async def _check_signature(self, event: EventBase, context: EventContext) -> None:
         """
         Checks that the signature in the event is consistent with its invite.
 
         Args:
-            event (Event): The m.room.member event to check
-            context (EventContext):
+            event: The m.room.member event to check
+            context:
 
         Raises:
             AuthError: if signature didn't match any keys, or key has been
@@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
 
         raise last_exception
 
-    async def _check_key_revocation(self, public_key, url):
+    async def _check_key_revocation(self, public_key: str, url: str) -> None:
         """
         Checks whether public_key has been revoked.
 
         Args:
-            public_key (str): base-64 encoded public key.
-            url (str): Key revocation URL.
+            public_key: base-64 encoded public key.
+            url: Key revocation URL.
 
         Raises:
             AuthError: if they key has been revoked.