summary refs log tree commit diff
path: root/synapse/federation/federation_client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_client.py')
-rw-r--r--synapse/federation/federation_client.py61
1 files changed, 49 insertions, 12 deletions
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index c767d30627..dbadf102f2 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -19,7 +19,6 @@ import itertools
 import logging
 from typing import (
     TYPE_CHECKING,
-    Any,
     Awaitable,
     Callable,
     Collection,
@@ -79,7 +78,15 @@ class InvalidResponseError(RuntimeError):
     we couldn't parse
     """
 
-    pass
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SendJoinResult:
+    # The event to persist.
+    event: EventBase
+    # A string giving the server the event was sent to.
+    origin: str
+    state: List[EventBase]
+    auth_chain: List[EventBase]
 
 
 class FederationClient(FederationBase):
@@ -677,7 +684,7 @@ class FederationClient(FederationBase):
 
     async def send_join(
         self, destinations: Iterable[str], pdu: EventBase, room_version: RoomVersion
-    ) -> Dict[str, Any]:
+    ) -> SendJoinResult:
         """Sends a join event to one of a list of homeservers.
 
         Doing so will cause the remote server to add the event to the graph,
@@ -691,18 +698,38 @@ class FederationClient(FederationBase):
                 did the make_join)
 
         Returns:
-            a dict with members ``origin`` (a string
-            giving the server the event was sent to, ``state`` (?) and
-            ``auth_chain``.
+            The result of the send join request.
 
         Raises:
             SynapseError: if the chosen remote server returns a 300/400 code, or
                 no servers successfully handle the request.
         """
 
-        async def send_request(destination) -> Dict[str, Any]:
+        async def send_request(destination) -> SendJoinResult:
             response = await self._do_send_join(room_version, destination, pdu)
 
+            # If an event was returned (and expected to be returned):
+            #
+            # * Ensure it has the same event ID (note that the event ID is a hash
+            #   of the event fields for versions which support MSC3083).
+            # * Ensure the signatures are good.
+            #
+            # Otherwise, fallback to the provided event.
+            if room_version.msc3083_join_rules and response.event:
+                event = response.event
+
+                valid_pdu = await self._check_sigs_and_hash_and_fetch_one(
+                    pdu=event,
+                    origin=destination,
+                    outlier=True,
+                    room_version=room_version,
+                )
+
+                if valid_pdu is None or event.event_id != pdu.event_id:
+                    raise InvalidResponseError("Returned an invalid join event")
+            else:
+                event = pdu
+
             state = response.state
             auth_chain = response.auth_events
 
@@ -784,11 +811,21 @@ class FederationClient(FederationBase):
                     % (auth_chain_create_events,)
                 )
 
-            return {
-                "state": signed_state,
-                "auth_chain": signed_auth,
-                "origin": destination,
-            }
+            return SendJoinResult(
+                event=event,
+                state=signed_state,
+                auth_chain=signed_auth,
+                origin=destination,
+            )
+
+        if room_version.msc3083_join_rules:
+            # If the join is being authorised via allow rules, we need to send
+            # the /send_join back to the same server that was originally used
+            # with /make_join.
+            if "join_authorised_via_users_server" in pdu.content:
+                destinations = [
+                    get_domain_from_id(pdu.content["join_authorised_via_users_server"])
+                ]
 
         return await self._try_destination_list("send_join", destinations, send_request)