diff --git a/changelog.d/12013.misc b/changelog.d/12013.misc
new file mode 100644
index 0000000000..c0fca8dccb
--- /dev/null
+++ b/changelog.d/12013.misc
@@ -0,0 +1 @@
+Preparation for faster-room-join work: Support for calling `/federation/v1/state` on a remote server.
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 896168c05c..fab6da3c08 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -47,6 +47,11 @@ class FederationBase:
) -> EventBase:
"""Checks that event is correctly signed by the sending server.
+ Also checks the content hash, and redacts the event if there is a mismatch.
+
+ Also runs the event through the spam checker; if it fails, redacts the event
+ and flags it as soft-failed.
+
Args:
room_version: The room version of the PDU
pdu: the event to be checked
@@ -55,7 +60,10 @@ class FederationBase:
* the original event if the checks pass
* a redacted version of the event (if the signature
matched but the hash did not)
- * throws a SynapseError if the signature check failed."""
+
+ Raises:
+ SynapseError if the signature check failed.
+ """
try:
await _check_sigs_on_pdu(self.keyring, room_version, pdu)
except SynapseError as e:
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 48c90bf0bb..c2997997da 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -419,26 +419,90 @@ class FederationClient(FederationBase):
return state_event_ids, auth_event_ids
+ async def get_room_state(
+ self,
+ destination: str,
+ room_id: str,
+ event_id: str,
+ room_version: RoomVersion,
+ ) -> Tuple[List[EventBase], List[EventBase]]:
+ """Calls the /state endpoint to fetch the state at a particular point
+ in the room.
+
+ Any invalid events (those with incorrect or unverifiable signatures or hashes)
+ are filtered out from the response, and any duplicate events are removed.
+
+ (Size limits and other event-format checks are *not* performed.)
+
+ Note that the result is not ordered, so callers must be careful to process
+ the events in an order that handles dependencies.
+
+ Returns:
+ a tuple of (state events, auth events)
+ """
+ result = await self.transport_layer.get_room_state(
+ room_version,
+ destination,
+ room_id,
+ event_id,
+ )
+ state_events = result.state
+ auth_events = result.auth_events
+
+ # we may as well filter out any duplicates from the response, to save
+ # processing them multiple times. (In particular, events may be present in
+ # `auth_events` as well as `state`, which is redundant).
+ #
+ # We don't rely on the sort order of the events, so we can just stick them
+ # in a dict.
+ state_event_map = {event.event_id: event for event in state_events}
+ auth_event_map = {
+ event.event_id: event
+ for event in auth_events
+ if event.event_id not in state_event_map
+ }
+
+ logger.info(
+ "Processing from /state: %d state events, %d auth events",
+ len(state_event_map),
+ len(auth_event_map),
+ )
+
+ valid_auth_events = await self._check_sigs_and_hash_and_fetch(
+ destination, auth_event_map.values(), room_version
+ )
+
+ valid_state_events = await self._check_sigs_and_hash_and_fetch(
+ destination, state_event_map.values(), room_version
+ )
+
+ return valid_state_events, valid_auth_events
+
async def _check_sigs_and_hash_and_fetch(
self,
origin: str,
pdus: Collection[EventBase],
room_version: RoomVersion,
) -> List[EventBase]:
- """Takes a list of PDUs and checks the signatures and hashes of each
- one. If a PDU fails its signature check then we check if we have it in
- the database and if not then request if from the originating server of
- that PDU.
+ """Checks the signatures and hashes of a list of events.
+
+ If a PDU fails its signature check then we check if we have it in
+ the database, and if not then request it from the sender's server (if that
+ is different from `origin`). If that still fails, the event is omitted from
+ the returned list.
If a PDU fails its content hash check then it is redacted.
- The given list of PDUs are not modified, instead the function returns
+ Also runs each event through the spam checker; if it fails, redacts the event
+ and flags it as soft-failed.
+
+ The given list of PDUs are not modified; instead the function returns
a new list.
Args:
- origin
- pdu
- room_version
+ origin: The server that sent us these events
+ pdus: The events to be checked
+ room_version: the version of the room these events are in
Returns:
A list of PDUs that have valid signatures and hashes.
@@ -469,11 +533,16 @@ class FederationClient(FederationBase):
origin: str,
room_version: RoomVersion,
) -> Optional[EventBase]:
- """Takes a PDU and checks its signatures and hashes. If the PDU fails
- its signature check then we check if we have it in the database and if
- not then request if from the originating server of that PDU.
+ """Takes a PDU and checks its signatures and hashes.
+
+ If the PDU fails its signature check then we check if we have it in the
+ database; if not, we then request it from sender's server (if that is not the
+ same as `origin`). If that still fails, we return None.
+
+ If the PDU fails its content hash check, it is redacted.
- If then PDU fails its content hash check then it is redacted.
+ Also runs the event through the spam checker; if it fails, redacts the event
+ and flags it as soft-failed.
Args:
origin
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index dca6e5c45d..7e510e224a 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -65,13 +65,12 @@ class TransportLayerClient:
async def get_room_state_ids(
self, destination: str, room_id: str, event_id: str
) -> JsonDict:
- """Requests all state for a given room from the given server at the
- given event. Returns the state's event_id's
+ """Requests the IDs of all state for a given room at the given event.
Args:
destination: The host name of the remote homeserver we want
to get the state from.
- context: The name of the context we want the state of
+ room_id: the room we want the state of
event_id: The event we want the context at.
Returns:
@@ -87,6 +86,29 @@ class TransportLayerClient:
try_trailing_slash_on_400=True,
)
+ async def get_room_state(
+ self, room_version: RoomVersion, destination: str, room_id: str, event_id: str
+ ) -> "StateRequestResponse":
+ """Requests the full state for a given room at the given event.
+
+ Args:
+ room_version: the version of the room (required to build the event objects)
+ destination: The host name of the remote homeserver we want
+ to get the state from.
+ room_id: the room we want the state of
+ event_id: The event we want the context at.
+
+ Returns:
+ Results in a dict received from the remote homeserver.
+ """
+ path = _create_v1_path("/state/%s", room_id)
+ return await self.client.get_json(
+ destination,
+ path=path,
+ args={"event_id": event_id},
+ parser=_StateParser(room_version),
+ )
+
async def get_event(
self, destination: str, event_id: str, timeout: Optional[int] = None
) -> JsonDict:
@@ -1284,6 +1306,14 @@ class SendJoinResponse:
servers_in_room: Optional[List[str]] = None
+@attr.s(slots=True, auto_attribs=True)
+class StateRequestResponse:
+ """The parsed response of a `/state` request."""
+
+ auth_events: List[EventBase]
+ state: List[EventBase]
+
+
@ijson.coroutine
def _event_parser(event_dict: JsonDict) -> Generator[None, Tuple[str, Any], None]:
"""Helper function for use with `ijson.kvitems_coro` to parse key-value pairs
@@ -1411,3 +1441,37 @@ class SendJoinParser(ByteParser[SendJoinResponse]):
self._response.event_dict, self._room_version
)
return self._response
+
+
+class _StateParser(ByteParser[StateRequestResponse]):
+ """A parser for the response to `/state` requests.
+
+ Args:
+ room_version: The version of the room.
+ """
+
+ CONTENT_TYPE = "application/json"
+
+ def __init__(self, room_version: RoomVersion):
+ self._response = StateRequestResponse([], [])
+ self._room_version = room_version
+ self._coros = [
+ ijson.items_coro(
+ _event_list_parser(room_version, self._response.state),
+ "pdus.item",
+ use_float=True,
+ ),
+ ijson.items_coro(
+ _event_list_parser(room_version, self._response.auth_events),
+ "auth_chain.item",
+ use_float=True,
+ ),
+ ]
+
+ def write(self, data: bytes) -> int:
+ for c in self._coros:
+ c.send(data)
+ return len(data)
+
+ def finish(self) -> StateRequestResponse:
+ return self._response
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c5f8fcbb2a..e7656fbb9f 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -958,6 +958,7 @@ class MatrixFederationHttpClient:
)
return body
+ @overload
async def get_json(
self,
destination: str,
@@ -967,7 +968,38 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
+ parser: Literal[None] = None,
+ max_response_size: Optional[int] = None,
) -> Union[JsonDict, list]:
+ ...
+
+ @overload
+ async def get_json(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = ...,
+ retry_on_dns_fail: bool = ...,
+ timeout: Optional[int] = ...,
+ ignore_backoff: bool = ...,
+ try_trailing_slash_on_400: bool = ...,
+ parser: ByteParser[T] = ...,
+ max_response_size: Optional[int] = ...,
+ ) -> T:
+ ...
+
+ async def get_json(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ retry_on_dns_fail: bool = True,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Optional[ByteParser] = None,
+ max_response_size: Optional[int] = None,
+ ):
"""GETs some json from the given host homeserver and path
Args:
@@ -992,6 +1024,13 @@ class MatrixFederationHttpClient:
try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
response we should try appending a trailing slash to the end of
the request. Workaround for #3622 in Synapse <= v0.99.3.
+
+ parser: The parser to use to decode the response. Defaults to
+ parsing as JSON.
+
+ max_response_size: The maximum size to read from the response. If None,
+ uses the default.
+
Returns:
Succeeds when we get a 2xx HTTP response. The
result will be the decoded JSON body.
@@ -1026,8 +1065,17 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
+ if parser is None:
+ parser = JsonParser()
+
body = await _handle_response(
- self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
+ self.reactor,
+ _sec_timeout,
+ request,
+ response,
+ start_ms,
+ parser=parser,
+ max_response_size=max_response_size,
)
return body
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
new file mode 100644
index 0000000000..ec8864dafe
--- /dev/null
+++ b/tests/federation/test_federation_client.py
@@ -0,0 +1,149 @@
+# Copyright 2022 Matrix.org Federation C.I.C
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+from unittest import mock
+
+import twisted.web.client
+from twisted.internet import defer
+from twisted.internet.protocol import Protocol
+from twisted.python.failure import Failure
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.api.room_versions import RoomVersions
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import FederatingHomeserverTestCase
+
+
+class FederationClientTest(FederatingHomeserverTestCase):
+ def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
+ super().prepare(reactor, clock, homeserver)
+
+ # mock out the Agent used by the federation client, which is easier than
+ # catching the HTTPS connection and do the TLS stuff.
+ self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True)
+ homeserver.get_federation_http_client().agent = self._mock_agent
+
+ def test_get_room_state(self):
+ creator = f"@creator:{self.OTHER_SERVER_NAME}"
+ test_room_id = "!room_id"
+
+ # mock up some events to use in the response.
+ # In real life, these would have things in `prev_events` and `auth_events`, but that's
+ # a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
+ create_event_dict = self.add_hashes_and_signatures(
+ {
+ "room_id": test_room_id,
+ "type": "m.room.create",
+ "state_key": "",
+ "sender": creator,
+ "content": {"creator": creator},
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 500,
+ }
+ )
+ member_event_dict = self.add_hashes_and_signatures(
+ {
+ "room_id": test_room_id,
+ "type": "m.room.member",
+ "sender": creator,
+ "state_key": creator,
+ "content": {"membership": "join"},
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 600,
+ }
+ )
+ pl_event_dict = self.add_hashes_and_signatures(
+ {
+ "room_id": test_room_id,
+ "type": "m.room.power_levels",
+ "sender": creator,
+ "state_key": "",
+ "content": {},
+ "prev_events": [],
+ "auth_events": [],
+ "origin_server_ts": 700,
+ }
+ )
+
+ # mock up the response, and have the agent return it
+ self._mock_agent.request.return_value = defer.succeed(
+ _mock_response(
+ {
+ "pdus": [
+ create_event_dict,
+ member_event_dict,
+ pl_event_dict,
+ ],
+ "auth_chain": [
+ create_event_dict,
+ member_event_dict,
+ ],
+ }
+ )
+ )
+
+ # now fire off the request
+ state_resp, auth_resp = self.get_success(
+ self.hs.get_federation_client().get_room_state(
+ "yet_another_server",
+ test_room_id,
+ "event_id",
+ RoomVersions.V9,
+ )
+ )
+
+ # check the right call got made to the agent
+ self._mock_agent.request.assert_called_once_with(
+ b"GET",
+ b"matrix://yet_another_server/_matrix/federation/v1/state/%21room_id?event_id=event_id",
+ headers=mock.ANY,
+ bodyProducer=None,
+ )
+
+ # ... and that the response is correct.
+
+ # the auth_resp should be empty because all the events are also in state
+ self.assertEqual(auth_resp, [])
+
+ # all of the events should be returned in state_resp, though not necessarily
+ # in the same order. We just check the type on the assumption that if the type
+ # is right, so is the rest of the event.
+ self.assertCountEqual(
+ [e.type for e in state_resp],
+ ["m.room.create", "m.room.member", "m.room.power_levels"],
+ )
+
+
+def _mock_response(resp: JsonDict):
+ body = json.dumps(resp).encode("utf-8")
+
+ def deliver_body(p: Protocol):
+ p.dataReceived(body)
+ p.connectionLost(Failure(twisted.web.client.ResponseDone()))
+
+ response = mock.Mock(
+ code=200,
+ phrase=b"OK",
+ headers=twisted.web.client.Headers({"content-Type": ["application/json"]}),
+ length=len(body),
+ deliverBody=deliver_body,
+ )
+ mock.seal(response)
+ return response
diff --git a/tests/unittest.py b/tests/unittest.py
index a71892cb9d..7983c1e8b8 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -51,7 +51,10 @@ from twisted.web.server import Request
from synapse import events
from synapse.api.constants import EventTypes, Membership
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
from synapse.config.homeserver import HomeServerConfig
+from synapse.config.server import DEFAULT_ROOM_VERSION
+from synapse.crypto.event_signing import add_hashes_and_signatures
from synapse.federation.transport.server import TransportLayerServer
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
@@ -839,6 +842,24 @@ class FederatingHomeserverTestCase(HomeserverTestCase):
client_ip=client_ip,
)
+ def add_hashes_and_signatures(
+ self,
+ event_dict: JsonDict,
+ room_version: RoomVersion = KNOWN_ROOM_VERSIONS[DEFAULT_ROOM_VERSION],
+ ) -> JsonDict:
+ """Adds hashes and signatures to the given event dict
+
+ Returns:
+ The modified event dict, for convenience
+ """
+ add_hashes_and_signatures(
+ room_version,
+ event_dict,
+ signature_name=self.OTHER_SERVER_NAME,
+ signing_key=self.OTHER_SERVER_SIGNATURE_KEY,
+ )
+ return event_dict
+
def _auth_header_for_request(
origin: str,
|