summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorRichard van der Hoff <1389908+richvdh@users.noreply.github.com>2022-02-22 12:17:10 +0000
committerGitHub <noreply@github.com>2022-02-22 12:17:10 +0000
commit7273011f60afbb1c9754ec73ee3661b19dca6bbd (patch)
tree4fb48fc09526d15d2e5d0b6ebaff7d9cf95790bb /synapse
parentFetch images when previewing Twitter URLs. (#11985) (diff)
downloadsynapse-7273011f60afbb1c9754ec73ee3661b19dca6bbd.tar.xz
Faster joins: Support for calling `/federation/v1/state` (#12013)
This is an endpoint that we have server-side support for, but no client-side support. It's going to be useful for resyncing partial-stated rooms, so let's introduce it.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/federation/federation_base.py10
-rw-r--r--synapse/federation/federation_client.py93
-rw-r--r--synapse/federation/transport/client.py70
-rw-r--r--synapse/http/matrixfederationclient.py50
4 files changed, 206 insertions, 17 deletions
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 896168c05c..fab6da3c08 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -47,6 +47,11 @@ class FederationBase:
     ) -> EventBase:
         """Checks that event is correctly signed by the sending server.
 
+        Also checks the content hash, and redacts the event if there is a mismatch.
+
+        Also runs the event through the spam checker; if it fails, redacts the event
+        and flags it as soft-failed.
+
         Args:
             room_version: The room version of the PDU
             pdu: the event to be checked
@@ -55,7 +60,10 @@ class FederationBase:
               * the original event if the checks pass
               * a redacted version of the event (if the signature
                 matched but the hash did not)
-              * throws a SynapseError if the signature check failed."""
+
+        Raises:
+              SynapseError if the signature check failed.
+        """
         try:
             await _check_sigs_on_pdu(self.keyring, room_version, pdu)
         except SynapseError as e:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 48c90bf0bb..c2997997da 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -419,26 +419,90 @@ class FederationClient(FederationBase):
 
         return state_event_ids, auth_event_ids
 
+    async def get_room_state(
+        self,
+        destination: str,
+        room_id: str,
+        event_id: str,
+        room_version: RoomVersion,
+    ) -> Tuple[List[EventBase], List[EventBase]]:
+        """Calls the /state endpoint to fetch the state at a particular point
+        in the room.
+
+        Any invalid events (those with incorrect or unverifiable signatures or hashes)
+        are filtered out from the response, and any duplicate events are removed.
+
+        (Size limits and other event-format checks are *not* performed.)
+
+        Note that the result is not ordered, so callers must be careful to process
+        the events in an order that handles dependencies.
+
+        Returns:
+            a tuple of (state events, auth events)
+        """
+        result = await self.transport_layer.get_room_state(
+            room_version,
+            destination,
+            room_id,
+            event_id,
+        )
+        state_events = result.state
+        auth_events = result.auth_events
+
+        # we may as well filter out any duplicates from the response, to save
+        # processing them multiple times. (In particular, events may be present in
+        # `auth_events` as well as `state`, which is redundant).
+        #
+        # We don't rely on the sort order of the events, so we can just stick them
+        # in a dict.
+        state_event_map = {event.event_id: event for event in state_events}
+        auth_event_map = {
+            event.event_id: event
+            for event in auth_events
+            if event.event_id not in state_event_map
+        }
+
+        logger.info(
+            "Processing from /state: %d state events, %d auth events",
+            len(state_event_map),
+            len(auth_event_map),
+        )
+
+        valid_auth_events = await self._check_sigs_and_hash_and_fetch(
+            destination, auth_event_map.values(), room_version
+        )
+
+        valid_state_events = await self._check_sigs_and_hash_and_fetch(
+            destination, state_event_map.values(), room_version
+        )
+
+        return valid_state_events, valid_auth_events
+
     async def _check_sigs_and_hash_and_fetch(
         self,
         origin: str,
         pdus: Collection[EventBase],
         room_version: RoomVersion,
     ) -> List[EventBase]:
-        """Takes a list of PDUs and checks the signatures and hashes of each
-        one. If a PDU fails its signature check then we check if we have it in
-        the database and if not then request if from the originating server of
-        that PDU.
+        """Checks the signatures and hashes of a list of events.
+
+        If a PDU fails its signature check then we check if we have it in
+        the database, and if not then request it from the sender's server (if that
+        is different from `origin`). If that still fails, the event is omitted from
+        the returned list.
 
         If a PDU fails its content hash check then it is redacted.
 
-        The given list of PDUs are not modified, instead the function returns
+        Also runs each event through the spam checker; if it fails, redacts the event
+        and flags it as soft-failed.
+
+        The given list of PDUs are not modified; instead the function returns
         a new list.
 
         Args:
-            origin
-            pdu
-            room_version
+            origin: The server that sent us these events
+            pdus: The events to be checked
+            room_version: the version of the room these events are in
 
         Returns:
             A list of PDUs that have valid signatures and hashes.
@@ -469,11 +533,16 @@ class FederationClient(FederationBase):
         origin: str,
         room_version: RoomVersion,
     ) -> Optional[EventBase]:
-        """Takes a PDU and checks its signatures and hashes. If the PDU fails
-        its signature check then we check if we have it in the database and if
-        not then request if from the originating server of that PDU.
+        """Takes a PDU and checks its signatures and hashes.
+
+        If the PDU fails its signature check then we check if we have it in the
+        database; if not, we then request it from sender's server (if that is not the
+        same as `origin`). If that still fails, we return None.
+
+        If the PDU fails its content hash check, it is redacted.
 
-        If then PDU fails its content hash check then it is redacted.
+        Also runs the event through the spam checker; if it fails, redacts the event
+        and flags it as soft-failed.
 
         Args:
             origin
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index dca6e5c45d..7e510e224a 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -65,13 +65,12 @@ class TransportLayerClient:
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
     ) -> JsonDict:
-        """Requests all state for a given room from the given server at the
-        given event. Returns the state's event_id's
+        """Requests the IDs of all state for a given room at the given event.
 
         Args:
             destination: The host name of the remote homeserver we want
                 to get the state from.
-            context: The name of the context we want the state of
+            room_id: the room we want the state of
             event_id: The event we want the context at.
 
         Returns:
@@ -87,6 +86,29 @@ class TransportLayerClient:
             try_trailing_slash_on_400=True,
         )
 
+    async def get_room_state(
+        self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
+    ) -> "StateRequestResponse":
+        """Requests the full state for a given room at the given event.
+
+        Args:
+            room_version: the version of the room (required to build the event objects)
+            destination: The host name of the remote homeserver we want
+                to get the state from.
+            room_id: the room we want the state of
+            event_id: The event we want the context at.
+
+        Returns:
+            Results in a dict received from the remote homeserver.
+        """
+        path = _create_v1_path("/state/%s", room_id)
+        return await self.client.get_json(
+            destination,
+            path=path,
+            args={"event_id": event_id},
+            parser=_StateParser(room_version),
+        )
+
     async def get_event(
         self, destination: str, event_id: str, timeout: Optional[int] = None
     ) -> JsonDict:
@@ -1284,6 +1306,14 @@ class SendJoinResponse:
     servers_in_room: Optional[List[str]] = None
 
 
+@attr.s(slots=True, auto_attribs=True)
+class StateRequestResponse:
+    """The parsed response of a `/state` request."""
+
+    auth_events: List[EventBase]
+    state: List[EventBase]
+
+
 @ijson.coroutine
 def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
     """Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
@@ -1411,3 +1441,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
                 self._response.event_dict, self._room_version
             )
         return self._response
+
+
+class _StateParser(ByteParser[StateRequestResponse]):
+    """A parser for the response to `/state` requests.
+
+    Args:
+        room_version: The version of the room.
+    """
+
+    CONTENT_TYPE = "application/json"
+
+    def __init__(self, room_version: RoomVersion):
+        self._response = StateRequestResponse([], [])
+        self._room_version = room_version
+        self._coros = [
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.state),
+                "pdus.item",
+                use_float=True,
+            ),
+            ijson.items_coro(
+                _event_list_parser(room_version, self._response.auth_events),
+                "auth_chain.item",
+                use_float=True,
+            ),
+        ]
+
+    def write(self, data: bytes) -> int:
+        for c in self._coros:
+            c.send(data)
+        return len(data)
+
+    def finish(self) -> StateRequestResponse:
+        return self._response
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c5f8fcbb2a..e7656fbb9f 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -958,6 +958,7 @@ class MatrixFederationHttpClient:
         )
         return body
 
+    @overload
     async def get_json(
         self,
         destination: str,
@@ -967,7 +968,38 @@ class MatrixFederationHttpClient:
         timeout: Optional[int] = None,
         ignore_backoff: bool = False,
         try_trailing_slash_on_400: bool = False,
+        parser: Literal[None] = None,
+        max_response_size: Optional[int] = None,
     ) -> Union[JsonDict, list]:
+        ...
+
+    @overload
+    async def get_json(
+        self,
+        destination: str,
+        path: str,
+        args: Optional[QueryArgs] = ...,
+        retry_on_dns_fail: bool = ...,
+        timeout: Optional[int] = ...,
+        ignore_backoff: bool = ...,
+        try_trailing_slash_on_400: bool = ...,
+        parser: ByteParser[T] = ...,
+        max_response_size: Optional[int] = ...,
+    ) -> T:
+        ...
+
+    async def get_json(
+        self,
+        destination: str,
+        path: str,
+        args: Optional[QueryArgs] = None,
+        retry_on_dns_fail: bool = True,
+        timeout: Optional[int] = None,
+        ignore_backoff: bool = False,
+        try_trailing_slash_on_400: bool = False,
+        parser: Optional[ByteParser] = None,
+        max_response_size: Optional[int] = None,
+    ):
         """GETs some json from the given host homeserver and path
 
         Args:
@@ -992,6 +1024,13 @@ class MatrixFederationHttpClient:
             try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
                 response we should try appending a trailing slash to the end of
                 the request. Workaround for #3622 in Synapse <= v0.99.3.
+
+            parser: The parser to use to decode the response. Defaults to
+                parsing as JSON.
+
+            max_response_size: The maximum size to read from the response. If None,
+                uses the default.
+
         Returns:
             Succeeds when we get a 2xx HTTP response. The
             result will be the decoded JSON body.
@@ -1026,8 +1065,17 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
+        if parser is None:
+            parser = JsonParser()
+
         body = await _handle_response(
-            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
+            self.reactor,
+            _sec_timeout,
+            request,
+            response,
+            start_ms,
+            parser=parser,
+            max_response_size=max_response_size,
         )
 
         return body