summary refs log tree commit diff
path: root/synapse/federation/federation_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/federation_server.py')
-rw-r--r--synapse/federation/federation_server.py91
1 files changed, 81 insertions, 10 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index af9cb98f67..482bbdd867 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -20,6 +20,7 @@ from typing import (
     Any,
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Iterable,
     List,
@@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
     ReplicationGetQueryRestServlet,
 )
 from synapse.storage.databases.main.lock import Lock
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
 from synapse.util.caches.response_cache import ResponseCache
@@ -571,7 +572,7 @@ class FederationServer(FederationBase):
     ) -> JsonDict:
         state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
         auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
-        return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
+        return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}
 
     async def _on_context_state_request_compute(
         self, room_id: str, event_id: Optional[str]
@@ -645,27 +646,61 @@ class FederationServer(FederationBase):
         return {"event": ret_pdu.get_pdu_json(time_now)}
 
     async def on_send_join_request(
-        self, origin: str, content: JsonDict, room_id: str
+        self,
+        origin: str,
+        content: JsonDict,
+        room_id: str,
+        caller_supports_partial_state: bool = False,
     ) -> Dict[str, Any]:
         event, context = await self._on_send_membership_event(
             origin, content, Membership.JOIN, room_id
         )
 
         prev_state_ids = await context.get_prev_state_ids()
-        state_ids = list(prev_state_ids.values())
-        auth_chain = await self.store.get_auth_chain(room_id, state_ids)
-        state = await self.store.get_events(state_ids)
 
+        state_event_ids: Collection[str]
+        servers_in_room: Optional[Collection[str]]
+        if caller_supports_partial_state:
+            state_event_ids = _get_event_ids_for_partial_state_join(
+                event, prev_state_ids
+            )
+            servers_in_room = await self.state.get_hosts_in_room_at_events(
+                room_id, event_ids=event.prev_event_ids()
+            )
+        else:
+            state_event_ids = prev_state_ids.values()
+            servers_in_room = None
+
+        auth_chain_event_ids = await self.store.get_auth_chain_ids(
+            room_id, state_event_ids
+        )
+
+        # if the caller has opted in, we can omit any auth_chain events which are
+        # already in state_event_ids
+        if caller_supports_partial_state:
+            auth_chain_event_ids.difference_update(state_event_ids)
+
+        auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
+        state_events = await self.store.get_events_as_list(state_event_ids)
+
+        # we try to do all the async stuff before this point, so that time_now is as
+        # accurate as possible.
         time_now = self._clock.time_msec()
-        event_json = event.get_pdu_json()
-        return {
+        event_json = event.get_pdu_json(time_now)
+        resp = {
             # TODO Remove the unstable prefix when servers have updated.
             "org.matrix.msc3083.v2.event": event_json,
             "event": event_json,
-            "state": [p.get_pdu_json(time_now) for p in state.values()],
-            "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
+            "state": [p.get_pdu_json(time_now) for p in state_events],
+            "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
+            "org.matrix.msc3706.partial_state": caller_supports_partial_state,
         }
 
+        if servers_in_room is not None:
+            resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)
+
+        return resp
+
     async def on_make_leave_request(
         self, origin: str, room_id: str, user_id: str
     ) -> Dict[str, Any]:
@@ -1339,3 +1374,39 @@ class FederationHandlerRegistry:
         # error.
         logger.warning("No handler registered for query type %s", query_type)
         raise NotFoundError("No handler for Query type '%s'" % (query_type,))
+
+
+def _get_event_ids_for_partial_state_join(
+    join_event: EventBase,
+    prev_state_ids: StateMap[str],
+) -> Collection[str]:
+    """Calculate state to be retuned in a partial_state send_join
+
+    Args:
+        join_event: the join event being send_joined
+        prev_state_ids: the event ids of the state before the join
+
+    Returns:
+        the event ids to be returned
+    """
+
+    # return all non-member events
+    state_event_ids = {
+        event_id
+        for (event_type, state_key), event_id in prev_state_ids.items()
+        if event_type != EventTypes.Member
+    }
+
+    # we also need the current state of the current user (it's going to
+    # be an auth event for the new join, so we may as well return it)
+    current_membership_event_id = prev_state_ids.get(
+        (EventTypes.Member, join_event.state_key)
+    )
+    if current_membership_event_id is not None:
+        state_event_ids.add(current_membership_event_id)
+
+    # TODO: return a few more members:
+    #   - those with invites
+    #   - those that are kicked? / banned
+
+    return state_event_ids