summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2022-02-12 10:44:16 +0000
committerGitHub <noreply@github.com>2022-02-12 10:44:16 +0000
commit63c46349c41aa967e64a5a4042ef5177f934be47 (patch)
treeec133260cacd1f8b720b61fa7d2b5d222c4646b6
parentEnable cache time-based expiry by default (#11849) (diff)
downloadsynapse-63c46349c41aa967e64a5a4042ef5177f934be47.tar.xz
Implement MSC3706: partial state in `/send_join` response (#11967)
* Make `get_auth_chain_ids` return a Set

It has a set internally, and a set is often useful where it gets used, so let's
avoid converting to an intermediate list.

* Minor refactors in `on_send_join_request`

A little bit of non-functional groundwork

* Implement MSC3706: partial state in /send_join response
-rw-r--r--changelog.d/11967.feature1
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/federation/federation_server.py91
-rw-r--r--synapse/federation/transport/server/federation.py20
-rw-r--r--synapse/storage/databases/main/event_federation.py12
-rw-r--r--tests/federation/test_federation_server.py148
-rw-r--r--tests/storage/test_event_federation.py8
7 files changed, 262 insertions, 21 deletions
diff --git a/changelog.d/11967.feature b/changelog.d/11967.feature
new file mode 100644
index 0000000000..d09320a290
--- /dev/null
+++ b/changelog.d/11967.feature
@@ -0,0 +1 @@
+Experimental implementation of [MSC3706](https://github.com/matrix-org/matrix-doc/pull/3706): extensions to `/send_join` to support reduced response size.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index f05a803a71..09d692d9a1 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -61,3 +61,6 @@ class ExperimentalConfig(Config):
         self.msc2409_to_device_messages_enabled: bool = experimental.get(
             "msc2409_to_device_messages_enabled", False
         )
+
+        # MSC3706 (server-side support for partial state in /send_join responses)
+        self.msc3706_enabled: bool = experimental.get("msc3706_enabled", False)
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index af9cb98f67..482bbdd867 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -20,6 +20,7 @@ from typing import (
     Any,
     Awaitable,
     Callable,
+    Collection,
     Dict,
     Iterable,
     List,
@@ -64,7 +65,7 @@ from synapse.replication.http.federation import (
     ReplicationGetQueryRestServlet,
 )
 from synapse.storage.databases.main.lock import Lock
-from synapse.types import JsonDict, get_domain_from_id
+from synapse.types import JsonDict, StateMap, get_domain_from_id
 from synapse.util import json_decoder, unwrapFirstError
 from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
 from synapse.util.caches.response_cache import ResponseCache
@@ -571,7 +572,7 @@ class FederationServer(FederationBase):
     ) -> JsonDict:
         state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
         auth_chain_ids = await self.store.get_auth_chain_ids(room_id, state_ids)
-        return {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids}
+        return {"pdu_ids": state_ids, "auth_chain_ids": list(auth_chain_ids)}
 
     async def _on_context_state_request_compute(
         self, room_id: str, event_id: Optional[str]
@@ -645,27 +646,61 @@ class FederationServer(FederationBase):
         return {"event": ret_pdu.get_pdu_json(time_now)}
 
     async def on_send_join_request(
-        self, origin: str, content: JsonDict, room_id: str
+        self,
+        origin: str,
+        content: JsonDict,
+        room_id: str,
+        caller_supports_partial_state: bool = False,
     ) -> Dict[str, Any]:
         event, context = await self._on_send_membership_event(
             origin, content, Membership.JOIN, room_id
         )
 
         prev_state_ids = await context.get_prev_state_ids()
-        state_ids = list(prev_state_ids.values())
-        auth_chain = await self.store.get_auth_chain(room_id, state_ids)
-        state = await self.store.get_events(state_ids)
 
+        state_event_ids: Collection[str]
+        servers_in_room: Optional[Collection[str]]
+        if caller_supports_partial_state:
+            state_event_ids = _get_event_ids_for_partial_state_join(
+                event, prev_state_ids
+            )
+            servers_in_room = await self.state.get_hosts_in_room_at_events(
+                room_id, event_ids=event.prev_event_ids()
+            )
+        else:
+            state_event_ids = prev_state_ids.values()
+            servers_in_room = None
+
+        auth_chain_event_ids = await self.store.get_auth_chain_ids(
+            room_id, state_event_ids
+        )
+
+        # if the caller has opted in, we can omit any auth_chain events which are
+        # already in state_event_ids
+        if caller_supports_partial_state:
+            auth_chain_event_ids.difference_update(state_event_ids)
+
+        auth_chain_events = await self.store.get_events_as_list(auth_chain_event_ids)
+        state_events = await self.store.get_events_as_list(state_event_ids)
+
+        # we try to do all the async stuff before this point, so that time_now is as
+        # accurate as possible.
         time_now = self._clock.time_msec()
-        event_json = event.get_pdu_json()
-        return {
+        event_json = event.get_pdu_json(time_now)
+        resp = {
             # TODO Remove the unstable prefix when servers have updated.
             "org.matrix.msc3083.v2.event": event_json,
             "event": event_json,
-            "state": [p.get_pdu_json(time_now) for p in state.values()],
-            "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain],
+            "state": [p.get_pdu_json(time_now) for p in state_events],
+            "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain_events],
+            "org.matrix.msc3706.partial_state": caller_supports_partial_state,
         }
 
+        if servers_in_room is not None:
+            resp["org.matrix.msc3706.servers_in_room"] = list(servers_in_room)
+
+        return resp
+
     async def on_make_leave_request(
         self, origin: str, room_id: str, user_id: str
     ) -> Dict[str, Any]:
@@ -1339,3 +1374,39 @@ class FederationHandlerRegistry:
         # error.
         logger.warning("No handler registered for query type %s", query_type)
         raise NotFoundError("No handler for Query type '%s'" % (query_type,))
+
+
+def _get_event_ids_for_partial_state_join(
+    join_event: EventBase,
+    prev_state_ids: StateMap[str],
+) -> Collection[str]:
+    """Calculate state to be retuned in a partial_state send_join
+
+    Args:
+        join_event: the join event being send_joined
+        prev_state_ids: the event ids of the state before the join
+
+    Returns:
+        the event ids to be returned
+    """
+
+    # return all non-member events
+    state_event_ids = {
+        event_id
+        for (event_type, state_key), event_id in prev_state_ids.items()
+        if event_type != EventTypes.Member
+    }
+
+    # we also need the current state of the current user (it's going to
+    # be an auth event for the new join, so we may as well return it)
+    current_membership_event_id = prev_state_ids.get(
+        (EventTypes.Member, join_event.state_key)
+    )
+    if current_membership_event_id is not None:
+        state_event_ids.add(current_membership_event_id)
+
+    # TODO: return a few more members:
+    #   - those with invites
+    #   - those that are kicked? / banned
+
+    return state_event_ids
diff --git a/synapse/federation/transport/server/federation.py b/synapse/federation/transport/server/federation.py
index d86dfede4e..310733097d 100644
--- a/synapse/federation/transport/server/federation.py
+++ b/synapse/federation/transport/server/federation.py
@@ -412,6 +412,16 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
 
     PREFIX = FEDERATION_V2_PREFIX
 
+    def __init__(
+        self,
+        hs: "HomeServer",
+        authenticator: Authenticator,
+        ratelimiter: FederationRateLimiter,
+        server_name: str,
+    ):
+        super().__init__(hs, authenticator, ratelimiter, server_name)
+        self._msc3706_enabled = hs.config.experimental.msc3706_enabled
+
     async def on_PUT(
         self,
         origin: str,
@@ -422,7 +432,15 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet):
     ) -> Tuple[int, JsonDict]:
         # TODO(paul): assert that event_id parsed from path actually
         #   match those given in content
-        result = await self.handler.on_send_join_request(origin, content, room_id)
+
+        partial_state = False
+        if self._msc3706_enabled:
+            partial_state = parse_boolean_from_args(
+                query, "org.matrix.msc3706.partial_state", default=False
+            )
+        result = await self.handler.on_send_join_request(
+            origin, content, room_id, caller_supports_partial_state=partial_state
+        )
         return 200, result
 
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 22f6474127..277e6422eb 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -121,7 +121,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         room_id: str,
         event_ids: Collection[str],
         include_given: bool = False,
-    ) -> List[str]:
+    ) -> Set[str]:
         """Get auth events for given event_ids. The events *must* be state events.
 
         Args:
@@ -130,7 +130,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             include_given: include the given events in result
 
         Returns:
-            list of event_ids
+            set of event_ids
         """
 
         # Check if we have indexed the room so we can use the chain cover
@@ -159,7 +159,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
     def _get_auth_chain_ids_using_cover_index_txn(
         self, txn: Cursor, room_id: str, event_ids: Collection[str], include_given: bool
-    ) -> List[str]:
+    ) -> Set[str]:
         """Calculates the auth chain IDs using the chain index."""
 
         # First we look up the chain ID/sequence numbers for the given events.
@@ -272,11 +272,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 txn.execute(sql, (chain_id, max_no))
                 results.update(r for r, in txn)
 
-        return list(results)
+        return results
 
     def _get_auth_chain_ids_txn(
         self, txn: LoggingTransaction, event_ids: Collection[str], include_given: bool
-    ) -> List[str]:
+    ) -> Set[str]:
         """Calculates the auth chain IDs.
 
         This is used when we don't have a cover index for the room.
@@ -331,7 +331,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             front = new_front
             results.update(front)
 
-        return list(results)
+        return results
 
     async def get_auth_chain_difference(
         self, room_id: str, state_sets: List[Set[str]]
diff --git a/tests/federation/test_federation_server.py b/tests/federation/test_federation_server.py
index 1af284bd2f..d084919ef7 100644
--- a/tests/federation/test_federation_server.py
+++ b/tests/federation/test_federation_server.py
@@ -16,12 +16,21 @@ import logging
 
 from parameterized import parameterized
 
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.crypto.event_signing import add_hashes_and_signatures
 from synapse.events import make_event_from_dict
 from synapse.federation.federation_server import server_matches_acl_event
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
 
 from tests import unittest
+from tests.unittest import override_config
 
 
 class FederationServerTests(unittest.FederatingHomeserverTestCase):
@@ -152,6 +161,145 @@ class StateQueryTests(unittest.FederatingHomeserverTestCase):
         self.assertEqual(channel.json_body["errcode"], "M_FORBIDDEN")
 
 
+class SendJoinFederationTests(unittest.FederatingHomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer):
+        super().prepare(reactor, clock, hs)
+
+        # create the room
+        creator_user_id = self.register_user("kermit", "test")
+        tok = self.login("kermit", "test")
+        self._room_id = self.helper.create_room_as(
+            room_creator=creator_user_id, tok=tok
+        )
+
+        # a second member on the orgin HS
+        second_member_user_id = self.register_user("fozzie", "bear")
+        tok2 = self.login("fozzie", "bear")
+        self.helper.join(self._room_id, second_member_user_id, tok=tok2)
+
+    def _make_join(self, user_id) -> JsonDict:
+        channel = self.make_signed_federation_request(
+            "GET",
+            f"/_matrix/federation/v1/make_join/{self._room_id}/{user_id}"
+            f"?ver={DEFAULT_ROOM_VERSION}",
+        )
+        self.assertEquals(channel.code, 200, channel.json_body)
+        return channel.json_body
+
+    def test_send_join(self):
+        """happy-path test of send_join"""
+        joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
+        join_result = self._make_join(joining_user)
+
+        join_event_dict = join_result["event"]
+        add_hashes_and_signatures(
+            KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+            join_event_dict,
+            signature_name=self.OTHER_SERVER_NAME,
+            signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+        )
+        channel = self.make_signed_federation_request(
+            "PUT",
+            f"/_matrix/federation/v2/send_join/{self._room_id}/x",
+            content=join_event_dict,
+        )
+        self.assertEquals(channel.code, 200, channel.json_body)
+
+        # we should get complete room state back
+        returned_state = [
+            (ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
+        ]
+        self.assertCountEqual(
+            returned_state,
+            [
+                ("m.room.create", ""),
+                ("m.room.power_levels", ""),
+                ("m.room.join_rules", ""),
+                ("m.room.history_visibility", ""),
+                ("m.room.member", "@kermit:test"),
+                ("m.room.member", "@fozzie:test"),
+                # nb: *not* the joining user
+            ],
+        )
+
+        # also check the auth chain
+        returned_auth_chain_events = [
+            (ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
+        ]
+        self.assertCountEqual(
+            returned_auth_chain_events,
+            [
+                ("m.room.create", ""),
+                ("m.room.member", "@kermit:test"),
+                ("m.room.power_levels", ""),
+                ("m.room.join_rules", ""),
+            ],
+        )
+
+        # the room should show that the new user is a member
+        r = self.get_success(
+            self.hs.get_state_handler().get_current_state(self._room_id)
+        )
+        self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
+
+    @override_config({"experimental_features": {"msc3706_enabled": True}})
+    def test_send_join_partial_state(self):
+        """When MSC3706 support is enabled, /send_join should return partial state"""
+        joining_user = "@misspiggy:" + self.OTHER_SERVER_NAME
+        join_result = self._make_join(joining_user)
+
+        join_event_dict = join_result["event"]
+        add_hashes_and_signatures(
+            KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+            join_event_dict,
+            signature_name=self.OTHER_SERVER_NAME,
+            signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+        )
+        channel = self.make_signed_federation_request(
+            "PUT",
+            f"/_matrix/federation/v2/send_join/{self._room_id}/x?org.matrix.msc3706.partial_state=true",
+            content=join_event_dict,
+        )
+        self.assertEquals(channel.code, 200, channel.json_body)
+
+        # expect a reduced room state
+        returned_state = [
+            (ev["type"], ev["state_key"]) for ev in channel.json_body["state"]
+        ]
+        self.assertCountEqual(
+            returned_state,
+            [
+                ("m.room.create", ""),
+                ("m.room.power_levels", ""),
+                ("m.room.join_rules", ""),
+                ("m.room.history_visibility", ""),
+            ],
+        )
+
+        # the auth chain should not include anything already in "state"
+        returned_auth_chain_events = [
+            (ev["type"], ev["state_key"]) for ev in channel.json_body["auth_chain"]
+        ]
+        self.assertCountEqual(
+            returned_auth_chain_events,
+            [
+                ("m.room.member", "@kermit:test"),
+            ],
+        )
+
+        # the room should show that the new user is a member
+        r = self.get_success(
+            self.hs.get_state_handler().get_current_state(self._room_id)
+        )
+        self.assertEqual(r[("m.room.member", joining_user)].membership, "join")
+
+
 def _create_acl_event(content):
     return make_event_from_dict(
         {
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 2bc89512f8..667ca90a4d 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -260,16 +260,16 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         self.assertCountEqual(auth_chain_ids, ["h", "i", "j", "k"])
 
         auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["h"]))
-        self.assertEqual(auth_chain_ids, ["k"])
+        self.assertEqual(auth_chain_ids, {"k"})
 
         auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["i"]))
-        self.assertEqual(auth_chain_ids, ["j"])
+        self.assertEqual(auth_chain_ids, {"j"})
 
         # j and k have no parents.
         auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["j"]))
-        self.assertEqual(auth_chain_ids, [])
+        self.assertEqual(auth_chain_ids, set())
         auth_chain_ids = self.get_success(self.store.get_auth_chain_ids(room_id, ["k"]))
-        self.assertEqual(auth_chain_ids, [])
+        self.assertEqual(auth_chain_ids, set())
 
         # More complex input sequences.
         auth_chain_ids = self.get_success(