From da0e9f8efdac1571eab35ad2cc842073ca54769c Mon Sep 17 00:00:00 2001
From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Date: Thu, 17 Feb 2022 16:11:59 +0000
Subject: Faster joins: parse msc3706 fields in send_join response (#12011)

Part of my work on #11249: add code to handle the new fields added in MSC3706.
---
 synapse/federation/transport/client.py | 118 ++++++++++++++++++++++++---------
 1 file changed, 87 insertions(+), 31 deletions(-)

(limited to 'synapse/federation/transport')

diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 8782586cd6..dca6e5c45d 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -1,4 +1,4 @@
-# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
+# Copyright 2014-2022 The Matrix.org Foundation C.I.C.
 # Copyright 2020 Sorunome
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
@@ -60,6 +60,7 @@ class TransportLayerClient:
     def __init__(self, hs):
         self.server_name = hs.hostname
         self.client = hs.get_federation_http_client()
+        self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
 
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
@@ -336,10 +337,15 @@ class TransportLayerClient:
         content: JsonDict,
     ) -> "SendJoinResponse":
         path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
+        query_params: Dict[str, str] = {}
+        if self._faster_joins_enabled:
+            # lazy-load state on join
+            query_params["org.matrix.msc3706.partial_state"] = "true"
 
         return await self.client.put_json(
             destination=destination,
             path=path,
+            args=query_params,
             data=content,
             parser=SendJoinParser(room_version, v1_api=False),
             max_response_size=MAX_RESPONSE_SIZE_SEND_JOIN,
@@ -1271,6 +1277,12 @@ class SendJoinResponse:
     # "event" is not included in the response.
     event: Optional[EventBase] = None
 
+    # The room state is incomplete
+    partial_state: bool = False
+
+    # List of servers in the room
+    servers_in_room: Optional[List[str]] = None
+
 
 @ijson.coroutine
 def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
@@ -1297,6 +1309,32 @@ def _event_list_parser(
         events.append(event)
 
 
+@ijson.coroutine
+def _partial_state_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
+    """Helper function for use with `ijson.items_coro`
+
+    Parses the partial_state field in send_join responses
+    """
+    while True:
+        val = yield
+        if not isinstance(val, bool):
+            raise TypeError("partial_state must be a boolean")
+        response.partial_state = val
+
+
+@ijson.coroutine
+def _servers_in_room_parser(response: SendJoinResponse) -> Generator[None, Any, None]:
+    """Helper function for use with `ijson.items_coro`
+
+    Parses the servers_in_room field in send_join responses
+    """
+    while True:
+        val = yield
+        if not isinstance(val, list) or any(not isinstance(x, str) for x in val):
+            raise TypeError("servers_in_room must be a list of strings")
+        response.servers_in_room = val
+
+
 class SendJoinParser(ByteParser[SendJoinResponse]):
     """A parser for the response to `/send_join` requests.
 
@@ -1308,44 +1346,62 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
     CONTENT_TYPE = "application/json"
 
     def __init__(self, room_version: RoomVersion, v1_api: bool):
-        self._response = SendJoinResponse([], [], {})
+        self._response = SendJoinResponse([], [], event_dict={})
         self._room_version = room_version
+        self._coros = []
 
         # The V1 API has the shape of `[200, {...}]`, which we handle by
         # prefixing with `item.*`.
         prefix = "item." if v1_api else ""
 
-        self._coro_state = ijson.items_coro(
-            _event_list_parser(room_version, self._response.state),
-            prefix + "state.item",
-            use_float=True,
-        )
-        self._coro_auth = ijson.items_coro(
-            _event_list_parser(room_version, self._response.auth_events),
-            prefix + "auth_chain.item",
-            use_float=True,
-        )
-        # TODO Remove the unstable prefix when servers have updated.
-        #
-        # By re-using the same event dictionary this will cause the parsing of
-        # org.matrix.msc3083.v2.event and event to stomp over each other.
-        # Generally this should be fine.
-        self._coro_unstable_event = ijson.kvitems_coro(
-            _event_parser(self._response.event_dict),
-            prefix + "org.matrix.msc3083.v2.event",
-            use_float=True,
-        )
-        self._coro_event = ijson.kvitems_coro(
-            _event_parser(self._response.event_dict),
-            prefix + "event",
-            use_float=True,
-        )
+        self._coros = [
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.state),
+                prefix + "state.item",
+                use_float=True,
+            ),
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.auth_events),
+                prefix + "auth_chain.item",
+                use_float=True,
+            ),
+            # TODO Remove the unstable prefix when servers have updated.
+            #
+            # By re-using the same event dictionary this will cause the parsing of
+            # org.matrix.msc3083.v2.event and event to stomp over each other.
+            # Generally this should be fine.
+            ijson.kvitems_coro(
+                _event_parser(self._response.event_dict),
+                prefix + "org.matrix.msc3083.v2.event",
+                use_float=True,
+            ),
+            ijson.kvitems_coro(
+                _event_parser(self._response.event_dict),
+                prefix + "event",
+                use_float=True,
+            ),
+        ]
+
+        if not v1_api:
+            self._coros.append(
+                ijson.items_coro(
+                    _partial_state_parser(self._response),
+                    "org.matrix.msc3706.partial_state",
+                    use_float="True",
+                )
+            )
+
+            self._coros.append(
+                ijson.items_coro(
+                    _servers_in_room_parser(self._response),
+                    "org.matrix.msc3706.servers_in_room",
+                    use_float="True",
+                )
+            )
 
     def write(self, data: bytes) -> int:
-        self._coro_state.send(data)
-        self._coro_auth.send(data)
-        self._coro_unstable_event.send(data)
-        self._coro_event.send(data)
+        for c in self._coros:
+            c.send(data)
 
         return len(data)
 
-- 
cgit 1.4.1