summary refs log tree commit diff
path: root/synapse/federation/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r--synapse/federation/federation_client.py104
1 files changed, 88 insertions, 16 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py

index c767d30627..b7a10da15a 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py
@@ -19,10 +19,10 @@ import itertools import logging from typing import ( TYPE_CHECKING, - Any, Awaitable, Callable, Collection, + Container, Dict, Iterable, List, @@ -79,7 +79,15 @@ class InvalidResponseError(RuntimeError): we couldn't parse """ - pass + +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SendJoinResult: + # The event to persist. + event: EventBase + # A string giving the server the event was sent to. + origin: str + state: List[EventBase] + auth_chain: List[EventBase] class FederationClient(FederationBase): @@ -506,6 +514,7 @@ class FederationClient(FederationBase): description: str, destinations: Iterable[str], callback: Callable[[str], Awaitable[T]], + failover_errcodes: Optional[Container[str]] = None, failover_on_unknown_endpoint: bool = False, ) -> T: """Try an operation on a series of servers, until it succeeds @@ -526,6 +535,9 @@ class FederationClient(FederationBase): next server tried. Normally the stacktrace is logged but this is suppressed if the exception is an InvalidResponseError. + failover_errcodes: Error codes (specific to this endpoint) which should + cause a failover when received as part of an HTTP 400 error. + failover_on_unknown_endpoint: if True, we will try other servers if it looks like a server doesn't support the endpoint. This is typically useful if the endpoint in question is new or experimental. @@ -537,6 +549,9 @@ class FederationClient(FederationBase): SynapseError if the chosen remote server returns a 300/400 code, or no servers were reachable. """ + if failover_errcodes is None: + failover_errcodes = () + for destination in destinations: if destination == self.server_name: continue @@ -551,11 +566,17 @@ class FederationClient(FederationBase): synapse_error = e.to_synapse_error() failover = False - # Failover on an internal server error, or if the destination - # doesn't implemented the endpoint for some reason. + # Failover should occur: + # + # * On internal server errors. + # * If the destination responds that it cannot complete the request. + # * If the destination doesn't implemented the endpoint for some reason. if 500 <= e.code < 600: failover = True + elif e.code == 400 and synapse_error.errcode in failover_errcodes: + failover = True + elif failover_on_unknown_endpoint and self._is_unknown_endpoint( e, synapse_error ): @@ -671,13 +692,25 @@ class FederationClient(FederationBase): return destination, ev, room_version + # MSC3083 defines additional error codes for room joins. Unfortunately + # we do not yet know the room version, assume these will only be returned + # by valid room versions. + failover_errcodes = ( + (Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN) + if membership == Membership.JOIN + else None + ) + return await self._try_destination_list( - "make_" + membership, destinations, send_request + "make_" + membership, + destinations, + send_request, + failover_errcodes=failover_errcodes, ) async def send_join( self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion - ) -> Dict[str, Any]: + ) -> SendJoinResult: """Sends a join event to one of a list of homeservers. Doing so will cause the remote server to add the event to the graph, @@ -691,18 +724,38 @@ class FederationClient(FederationBase): did the make_join) Returns: - a dict with members ``origin`` (a string - giving the server the event was sent to, ``state`` (?) and - ``auth_chain``. + The result of the send join request. Raises: SynapseError: if the chosen remote server returns a 300/400 code, or no servers successfully handle the request. """ - async def send_request(destination) -> Dict[str, Any]: + async def send_request(destination) -> SendJoinResult: response = await self._do_send_join(room_version, destination, pdu) + # If an event was returned (and expected to be returned): + # + # * Ensure it has the same event ID (note that the event ID is a hash + # of the event fields for versions which support MSC3083). + # * Ensure the signatures are good. + # + # Otherwise, fallback to the provided event. + if room_version.msc3083_join_rules and response.event: + event = response.event + + valid_pdu = await self._check_sigs_and_hash_and_fetch_one( + pdu=event, + origin=destination, + outlier=True, + room_version=room_version, + ) + + if valid_pdu is None or event.event_id != pdu.event_id: + raise InvalidResponseError("Returned an invalid join event") + else: + event = pdu + state = response.state auth_chain = response.auth_events @@ -784,13 +837,32 @@ class FederationClient(FederationBase): % (auth_chain_create_events,) ) - return { - "state": signed_state, - "auth_chain": signed_auth, - "origin": destination, - } + return SendJoinResult( + event=event, + state=signed_state, + auth_chain=signed_auth, + origin=destination, + ) - return await self._try_destination_list("send_join", destinations, send_request) + # MSC3083 defines additional error codes for room joins. + failover_errcodes = None + if room_version.msc3083_join_rules: + failover_errcodes = ( + Codes.UNABLE_AUTHORISE_JOIN, + Codes.UNABLE_TO_GRANT_JOIN, + ) + + # If the join is being authorised via allow rules, we need to send + # the /send_join back to the same server that was originally used + # with /make_join. + if "join_authorised_via_users_server" in pdu.content: + destinations = [ + get_domain_from_id(pdu.content["join_authorised_via_users_server"]) + ] + + return await self._try_destination_list( + "send_join", destinations, send_request, failover_errcodes=failover_errcodes + ) async def _do_send_join( self, room_version: RoomVersion, destination: str, pdu: EventBase