diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c767d30627..b7a10da15a 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,10 +19,10 @@ import itertools
import logging
from typing import (
TYPE_CHECKING,
- Any,
Awaitable,
Callable,
Collection,
+ Container,
Dict,
Iterable,
List,
@@ -79,7 +79,15 @@ class InvalidResponseError(RuntimeError):
we couldn't parse
"""
- pass
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SendJoinResult:
+ # The event to persist.
+ event: EventBase
+ # A string giving the server the event was sent to.
+ origin: str
+ state: List[EventBase]
+ auth_chain: List[EventBase]
class FederationClient(FederationBase):
@@ -506,6 +514,7 @@ class FederationClient(FederationBase):
description: str,
destinations: Iterable[str],
callback: Callable[[str], Awaitable[T]],
+ failover_errcodes: Optional[Container[str]] = None,
failover_on_unknown_endpoint: bool = False,
) -> T:
"""Try an operation on a series of servers, until it succeeds
@@ -526,6 +535,9 @@ class FederationClient(FederationBase):
next server tried. Normally the stacktrace is logged but this is
suppressed if the exception is an InvalidResponseError.
+ failover_errcodes: Error codes (specific to this endpoint) which should
+ cause a failover when received as part of an HTTP 400 error.
+
failover_on_unknown_endpoint: if True, we will try other servers if it looks
like a server doesn't support the endpoint. This is typically useful
if the endpoint in question is new or experimental.
@@ -537,6 +549,9 @@ class FederationClient(FederationBase):
SynapseError if the chosen remote server returns a 300/400 code, or
no servers were reachable.
"""
+ if failover_errcodes is None:
+ failover_errcodes = ()
+
for destination in destinations:
if destination == self.server_name:
continue
@@ -551,11 +566,17 @@ class FederationClient(FederationBase):
synapse_error = e.to_synapse_error()
failover = False
- # Failover on an internal server error, or if the destination
- # doesn't implemented the endpoint for some reason.
+ # Failover should occur:
+ #
+ # * On internal server errors.
+ # * If the destination responds that it cannot complete the request.
+ # * If the destination doesn't implemented the endpoint for some reason.
if 500 <= e.code < 600:
failover = True
+ elif e.code == 400 and synapse_error.errcode in failover_errcodes:
+ failover = True
+
elif failover_on_unknown_endpoint and self._is_unknown_endpoint(
e, synapse_error
):
@@ -671,13 +692,25 @@ class FederationClient(FederationBase):
return destination, ev, room_version
+ # MSC3083 defines additional error codes for room joins. Unfortunately
+ # we do not yet know the room version, assume these will only be returned
+ # by valid room versions.
+ failover_errcodes = (
+ (Codes.UNABLE_AUTHORISE_JOIN, Codes.UNABLE_TO_GRANT_JOIN)
+ if membership == Membership.JOIN
+ else None
+ )
+
return await self._try_destination_list(
- "make_" + membership, destinations, send_request
+ "make_" + membership,
+ destinations,
+ send_request,
+ failover_errcodes=failover_errcodes,
)
async def send_join(
self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
- ) -> Dict[str, Any]:
+ ) -> SendJoinResult:
"""Sends a join event to one of a list of homeservers.
Doing so will cause the remote server to add the event to the graph,
@@ -691,18 +724,38 @@ class FederationClient(FederationBase):
did the make_join)
Returns:
- a dict with members ``origin`` (a string
- giving the server the event was sent to, ``state`` (?) and
- ``auth_chain``.
+ The result of the send join request.
Raises:
SynapseError: if the chosen remote server returns a 300/400 code, or
no servers successfully handle the request.
"""
- async def send_request(destination) -> Dict[str, Any]:
+ async def send_request(destination) -> SendJoinResult:
response = await self._do_send_join(room_version, destination, pdu)
+ # If an event was returned (and expected to be returned):
+ #
+ # * Ensure it has the same event ID (note that the event ID is a hash
+ # of the event fields for versions which support MSC3083).
+ # * Ensure the signatures are good.
+ #
+ # Otherwise, fallback to the provided event.
+ if room_version.msc3083_join_rules and response.event:
+ event = response.event
+
+ valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+ pdu=event,
+ origin=destination,
+ outlier=True,
+ room_version=room_version,
+ )
+
+ if valid_pdu is None or event.event_id != pdu.event_id:
+ raise InvalidResponseError("Returned an invalid join event")
+ else:
+ event = pdu
+
state = response.state
auth_chain = response.auth_events
@@ -784,13 +837,32 @@ class FederationClient(FederationBase):
% (auth_chain_create_events,)
)
- return {
- "state": signed_state,
- "auth_chain": signed_auth,
- "origin": destination,
- }
+ return SendJoinResult(
+ event=event,
+ state=signed_state,
+ auth_chain=signed_auth,
+ origin=destination,
+ )
- return await self._try_destination_list("send_join", destinations, send_request)
+ # MSC3083 defines additional error codes for room joins.
+ failover_errcodes = None
+ if room_version.msc3083_join_rules:
+ failover_errcodes = (
+ Codes.UNABLE_AUTHORISE_JOIN,
+ Codes.UNABLE_TO_GRANT_JOIN,
+ )
+
+ # If the join is being authorised via allow rules, we need to send
+ # the /send_join back to the same server that was originally used
+ # with /make_join.
+ if "join_authorised_via_users_server" in pdu.content:
+ destinations = [
+ get_domain_from_id(pdu.content["join_authorised_via_users_server"])
+ ]
+
+ return await self._try_destination_list(
+ "send_join", destinations, send_request, failover_errcodes=failover_errcodes
+ )
async def _do_send_join(
self, room_version: RoomVersion, destination: str, pdu: EventBase
|