diff --git a/changelog.d/9743.misc b/changelog.d/9743.misc
new file mode 100644
index 0000000000..c2f75c1df9
--- /dev/null
+++ b/changelog.d/9743.misc
@@ -0,0 +1 @@
+Add missing type hints to federation handler and server.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 71cb120ef7..b9f8d966a6 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -739,22 +739,20 @@ class FederationServer(FederationBase):
await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True)
- def __str__(self):
+ def __str__(self) -> str:
return "<ReplicationLayer(%s)>" % self.server_name
async def exchange_third_party_invite(
self, sender_user_id: str, target_user_id: str, room_id: str, signed: Dict
- ):
- ret = await self.handler.exchange_third_party_invite(
+ ) -> None:
+ await self.handler.exchange_third_party_invite(
sender_user_id, target_user_id, room_id, signed
)
- return ret
- async def on_exchange_third_party_invite_request(self, event_dict: Dict):
- ret = await self.handler.on_exchange_third_party_invite_request(event_dict)
- return ret
+ async def on_exchange_third_party_invite_request(self, event_dict: Dict) -> None:
+ await self.handler.on_exchange_third_party_invite_request(event_dict)
- async def check_server_matches_acl(self, server_name: str, room_id: str):
+ async def check_server_matches_acl(self, server_name: str, room_id: str) -> None:
"""Check if the given server is allowed by the server ACLs in the room
Args:
@@ -878,7 +876,7 @@ class FederationHandlerRegistry:
def register_edu_handler(
self, edu_type: str, handler: Callable[[str, JsonDict], Awaitable[None]]
- ):
+ ) -> None:
"""Sets the handler callable that will be used to handle an incoming
federation EDU of the given type.
@@ -897,7 +895,7 @@ class FederationHandlerRegistry:
def register_query_handler(
self, query_type: str, handler: Callable[[dict], Awaitable[JsonDict]]
- ):
+ ) -> None:
"""Sets the handler callable that will be used to handle an incoming
federation query of the given type.
@@ -915,15 +913,17 @@ class FederationHandlerRegistry:
self.query_handlers[query_type] = handler
- def register_instance_for_edu(self, edu_type: str, instance_name: str):
+ def register_instance_for_edu(self, edu_type: str, instance_name: str) -> None:
"""Register that the EDU handler is on a different instance than master."""
self._edu_type_to_instance[edu_type] = [instance_name]
- def register_instances_for_edu(self, edu_type: str, instance_names: List[str]):
+ def register_instances_for_edu(
+ self, edu_type: str, instance_names: List[str]
+ ) -> None:
"""Register that the EDU handler is on multiple instances."""
self._edu_type_to_instance[edu_type] = instance_names
- async def on_edu(self, edu_type: str, origin: str, content: dict):
+ async def on_edu(self, edu_type: str, origin: str, content: dict) -> None:
if not self.config.use_presence and edu_type == EduTypes.Presence:
return
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index 84e39c5a46..5ef0556ef7 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -620,8 +620,8 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServlet):
PATH = "/exchange_third_party_invite/(?P<room_id>[^/]*)"
async def on_PUT(self, origin, content, query, room_id):
- content = await self.handler.on_exchange_third_party_invite_request(content)
- return 200, content
+ await self.handler.on_exchange_third_party_invite_request(content)
+ return 200, {}
class FederationClientKeysQueryServlet(BaseFederationServlet):
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index 3ebee38ebe..5ea8a7b603 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -21,7 +21,17 @@ import itertools
import logging
from collections.abc import Container
from http import HTTPStatus
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Sequence, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ Union,
+)
import attr
from signedjson.key import decode_verify_key_bytes
@@ -171,15 +181,17 @@ class FederationHandler(BaseHandler):
self._ephemeral_messages_enabled = hs.config.enable_ephemeral_messages
- async def on_receive_pdu(self, origin, pdu, sent_to_us_directly=False) -> None:
+ async def on_receive_pdu(
+ self, origin: str, pdu: EventBase, sent_to_us_directly: bool = False
+ ) -> None:
"""Process a PDU received via a federation /send/ transaction, or
via backfill of missing prev_events
Args:
- origin (str): server which initiated the /send/ transaction. Will
+ origin: server which initiated the /send/ transaction. Will
be used to fetch missing events or state.
- pdu (FrozenEvent): received PDU
- sent_to_us_directly (bool): True if this event was pushed to us; False if
+ pdu: received PDU
+ sent_to_us_directly: True if this event was pushed to us; False if
we pulled it as the result of a missing prev_event.
"""
@@ -411,13 +423,15 @@ class FederationHandler(BaseHandler):
await self._process_received_pdu(origin, pdu, state=state)
- async def _get_missing_events_for_pdu(self, origin, pdu, prevs, min_depth):
+ async def _get_missing_events_for_pdu(
+ self, origin: str, pdu: EventBase, prevs: Set[str], min_depth: int
+ ) -> None:
"""
Args:
- origin (str): Origin of the pdu. Will be called to get the missing events
+ origin: Origin of the pdu. Will be called to get the missing events
pdu: received pdu
- prevs (set(str)): List of event ids which we are missing
- min_depth (int): Minimum depth of events to return.
+ prevs: List of event ids which we are missing
+ min_depth: Minimum depth of events to return.
"""
room_id = pdu.room_id
@@ -778,7 +792,7 @@ class FederationHandler(BaseHandler):
origin: str,
event: EventBase,
state: Optional[Iterable[EventBase]],
- ):
+ ) -> None:
"""Called when we have a new pdu. We need to do auth checks and put it
through the StateHandler.
@@ -887,7 +901,9 @@ class FederationHandler(BaseHandler):
logger.exception("Failed to resync device for %s", sender)
@log_function
- async def backfill(self, dest, room_id, limit, extremities):
+ async def backfill(
+ self, dest: str, room_id: str, limit: int, extremities: List[str]
+ ) -> List[EventBase]:
"""Trigger a backfill request to `dest` for the given `room_id`
This will attempt to get more events from the remote. If the other side
@@ -1142,16 +1158,15 @@ class FederationHandler(BaseHandler):
curr_state = await self.state_handler.get_current_state(room_id)
- def get_domains_from_state(state):
+ def get_domains_from_state(state: StateMap[EventBase]) -> List[Tuple[str, int]]:
"""Get joined domains from state
Args:
- state (dict[tuple, FrozenEvent]): State map from type/state
- key to event.
+ state: State map from type/state key to event.
Returns:
- list[tuple[str, int]]: Returns a list of servers with the
- lowest depth of their joins. Sorted by lowest depth first.
+ Returns a list of servers with the lowest depth of their joins.
+ Sorted by lowest depth first.
"""
joined_users = [
(state_key, int(event.depth))
@@ -1179,7 +1194,7 @@ class FederationHandler(BaseHandler):
domain for domain, depth in curr_domains if domain != self.server_name
]
- async def try_backfill(domains):
+ async def try_backfill(domains: List[str]) -> bool:
# TODO: Should we try multiple of these at a time?
for dom in domains:
try:
@@ -1258,21 +1273,25 @@ class FederationHandler(BaseHandler):
}
for e_id, _ in sorted_extremeties_tuple:
- likely_domains = get_domains_from_state(states[e_id])
+ likely_extremeties_domains = get_domains_from_state(states[e_id])
success = await try_backfill(
- [dom for dom, _ in likely_domains if dom not in tried_domains]
+ [
+ dom
+ for dom, _ in likely_extremeties_domains
+ if dom not in tried_domains
+ ]
)
if success:
return True
- tried_domains.update(dom for dom, _ in likely_domains)
+ tried_domains.update(dom for dom, _ in likely_extremeties_domains)
return False
async def _get_events_and_persist(
self, destination: str, room_id: str, events: Iterable[str]
- ):
+ ) -> None:
"""Fetch the given events from a server, and persist them as outliers.
This function *does not* recursively get missing auth events of the
@@ -1348,7 +1367,7 @@ class FederationHandler(BaseHandler):
event_infos,
)
- def _sanity_check_event(self, ev):
+ def _sanity_check_event(self, ev: EventBase) -> None:
"""
Do some early sanity checks of a received event
@@ -1357,9 +1376,7 @@ class FederationHandler(BaseHandler):
or cascade of event fetches.
Args:
- ev (synapse.events.EventBase): event to be checked
-
- Returns: None
+ ev: event to be checked
Raises:
SynapseError if the event does not pass muster
@@ -1380,7 +1397,7 @@ class FederationHandler(BaseHandler):
)
raise SynapseError(HTTPStatus.BAD_REQUEST, "Too many auth_events")
- async def send_invite(self, target_host, event):
+ async def send_invite(self, target_host: str, event: EventBase) -> EventBase:
"""Sends the invite to the remote server for signing.
Invites must be signed by the invitee's server before distribution.
@@ -1528,12 +1545,13 @@ class FederationHandler(BaseHandler):
run_in_background(self._handle_queued_pdus, room_queue)
- async def _handle_queued_pdus(self, room_queue):
+ async def _handle_queued_pdus(
+ self, room_queue: List[Tuple[EventBase, str]]
+ ) -> None:
"""Process PDUs which got queued up while we were busy send_joining.
Args:
- room_queue (list[FrozenEvent, str]): list of PDUs to be processed
- and the servers that sent them
+ room_queue: list of PDUs to be processed and the servers that sent them
"""
for p, origin in room_queue:
try:
@@ -1612,7 +1630,7 @@ class FederationHandler(BaseHandler):
return event
- async def on_send_join_request(self, origin, pdu):
+ async def on_send_join_request(self, origin: str, pdu: EventBase) -> JsonDict:
"""We have received a join event for a room. Fully process it and
respond with the current state and auth chains.
"""
@@ -1668,7 +1686,7 @@ class FederationHandler(BaseHandler):
async def on_invite_request(
self, origin: str, event: EventBase, room_version: RoomVersion
- ):
+ ) -> EventBase:
"""We've got an invite event. Process and persist it. Sign it.
Respond with the now signed event.
@@ -1841,7 +1859,7 @@ class FederationHandler(BaseHandler):
return event
- async def on_send_leave_request(self, origin, pdu):
+ async def on_send_leave_request(self, origin: str, pdu: EventBase) -> None:
""" We have received a leave event for a room. Fully process it."""
event = pdu
@@ -1969,12 +1987,17 @@ class FederationHandler(BaseHandler):
else:
return None
- async def get_min_depth_for_context(self, context):
+ async def get_min_depth_for_context(self, context: str) -> int:
return await self.store.get_min_depth(context)
async def _handle_new_event(
- self, origin, event, state=None, auth_events=None, backfilled=False
- ):
+ self,
+ origin: str,
+ event: EventBase,
+ state: Optional[Iterable[EventBase]] = None,
+ auth_events: Optional[MutableStateMap[EventBase]] = None,
+ backfilled: bool = False,
+ ) -> EventContext:
context = await self._prep_event(
origin, event, state=state, auth_events=auth_events, backfilled=backfilled
)
@@ -2280,40 +2303,14 @@ class FederationHandler(BaseHandler):
logger.warning("Soft-failing %r because %s", event, e)
event.internal_metadata.soft_failed = True
- async def on_query_auth(
- self, origin, event_id, room_id, remote_auth_chain, rejects, missing
- ):
- in_room = await self.auth.check_host_in_room(room_id, origin)
- if not in_room:
- raise AuthError(403, "Host not in room.")
-
- event = await self.store.get_event(event_id, check_room_id=room_id)
-
- # Just go through and process each event in `remote_auth_chain`. We
- # don't want to fall into the trap of `missing` being wrong.
- for e in remote_auth_chain:
- try:
- await self._handle_new_event(origin, e)
- except AuthError:
- pass
-
- # Now get the current auth_chain for the event.
- local_auth_chain = await self.store.get_auth_chain(
- room_id, list(event.auth_event_ids()), include_given=True
- )
-
- # TODO: Check if we would now reject event_id. If so we need to tell
- # everyone.
-
- ret = await self.construct_auth_difference(local_auth_chain, remote_auth_chain)
-
- logger.debug("on_query_auth returning: %s", ret)
-
- return ret
-
async def on_get_missing_events(
- self, origin, room_id, earliest_events, latest_events, limit
- ):
+ self,
+ origin: str,
+ room_id: str,
+ earliest_events: List[str],
+ latest_events: List[str],
+ limit: int,
+ ) -> List[EventBase]:
in_room = await self.auth.check_host_in_room(room_id, origin)
if not in_room:
raise AuthError(403, "Host not in room.")
@@ -2617,8 +2614,8 @@ class FederationHandler(BaseHandler):
assumes that we have already processed all events in remote_auth
Params:
- local_auth (list)
- remote_auth (list)
+ local_auth
+ remote_auth
Returns:
dict
@@ -2742,8 +2739,8 @@ class FederationHandler(BaseHandler):
@log_function
async def exchange_third_party_invite(
- self, sender_user_id, target_user_id, room_id, signed
- ):
+ self, sender_user_id: str, target_user_id: str, room_id: str, signed: JsonDict
+ ) -> None:
third_party_invite = {"signed": signed}
event_dict = {
@@ -2835,8 +2832,12 @@ class FederationHandler(BaseHandler):
await member_handler.send_membership_event(None, event, context)
async def add_display_name_to_third_party_invite(
- self, room_version, event_dict, event, context
- ):
+ self,
+ room_version: str,
+ event_dict: JsonDict,
+ event: EventBase,
+ context: EventContext,
+ ) -> Tuple[EventBase, EventContext]:
key = (
EventTypes.ThirdPartyInvite,
event.content["third_party_invite"]["signed"]["token"],
@@ -2872,13 +2873,13 @@ class FederationHandler(BaseHandler):
EventValidator().validate_new(event, self.config)
return (event, context)
- async def _check_signature(self, event, context):
+ async def _check_signature(self, event: EventBase, context: EventContext) -> None:
"""
Checks that the signature in the event is consistent with its invite.
Args:
- event (Event): The m.room.member event to check
- context (EventContext):
+ event: The m.room.member event to check
+ context:
Raises:
AuthError: if signature didn't match any keys, or key has been
@@ -2964,13 +2965,13 @@ class FederationHandler(BaseHandler):
raise last_exception
- async def _check_key_revocation(self, public_key, url):
+ async def _check_key_revocation(self, public_key: str, url: str) -> None:
"""
Checks whether public_key has been revoked.
Args:
- public_key (str): base-64 encoded public key.
- url (str): Key revocation URL.
+ public_key: base-64 encoded public key.
+ url: Key revocation URL.
Raises:
AuthError: if they key has been revoked.
|