summary refs log tree commit diff
path: root/synapse/federation/transport/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/federation/transport/client.py')
-rw-r--r--synapse/federation/transport/client.py70
1 files changed, 67 insertions, 3 deletions
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index dca6e5c45d..7e510e224a 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -65,13 +65,12 @@ class TransportLayerClient:
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
     ) -> JsonDict:
-        """Requests all state for a given room from the given server at the
-        given event. Returns the state's event_id's
+        """Requests the IDs of all state for a given room at the given event.
 
         Args:
             destination: The host name of the remote homeserver we want
                 to get the state from.
-            context: The name of the context we want the state of
+            room_id: the room we want the state of
             event_id: The event we want the context at.
 
         Returns:
@@ -87,6 +86,29 @@ class TransportLayerClient:
             try_trailing_slash_on_400=True,
         )
 
+    async def get_room_state(
+        self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
+    ) -> "StateRequestResponse":
+        """Requests the full state for a given room at the given event.
+
+        Args:
+            room_version: the version of the room (required to build the event objects)
+            destination: The host name of the remote homeserver we want
+                to get the state from.
+            room_id: the room we want the state of
+            event_id: The event we want the context at.
+
+        Returns:
+            Results in a dict received from the remote homeserver.
+        """
+        path = _create_v1_path("/state/%s", room_id)
+        return await self.client.get_json(
+            destination,
+            path=path,
+            args={"event_id": event_id},
+            parser=_StateParser(room_version),
+        )
+
     async def get_event(
         self, destination: str, event_id: str, timeout: Optional[int] = None
     ) -> JsonDict:
@@ -1284,6 +1306,14 @@ class SendJoinResponse:
     servers_in_room: Optional[List[str]] = None
 
 
+@attr.s(slots=True, auto_attribs=True)
+class StateRequestResponse:
+    """The parsed response of a `/state` request."""
+
+    auth_events: List[EventBase]
+    state: List[EventBase]
+
+
 @ijson.coroutine
 def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
     """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
@@ -1411,3 +1441,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
                 self._response.event_dict, self._room_version
             )
         return self._response
+
+
+class _StateParser(ByteParser[StateRequestResponse]):
+    """A parser for the response to `/state` requests.
+
+    Args:
+        room_version: The version of the room.
+    """
+
+    CONTENT_TYPE = "application/json"
+
+    def __init__(self, room_version: RoomVersion):
+        self._response = StateRequestResponse([], [])
+        self._room_version = room_version
+        self._coros = [
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.state),
+                "pdus.item",
+                use_float=True,
+            ),
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.auth_events),
+                "auth_chain.item",
+                use_float=True,
+            ),
+        ]
+
+    def write(self, data: bytes) -> int:
+        for c in self._coros:
+            c.send(data)
+        return len(data)
+
+    def finish(self) -> StateRequestResponse:
+        return self._response