From 6e8fb42be7657f9d4958c02d87cff865225714d2 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Thu, 24 Jun 2021 15:30:49 +0100 Subject: Improve validation for `send_{join,leave,knock}` (#10225) The idea here is to stop people sending things that aren't joins/leaves/knocks through these endpoints: previously you could send anything you liked through them. I wasn't able to find any security holes from doing so, but it doesn't sound like a good thing. --- synapse/federation/federation_server.py | 121 +++++++++++++++++++------------- synapse/federation/transport/server.py | 12 ++-- 2 files changed, 78 insertions(+), 55 deletions(-) (limited to 'synapse/federation') diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2b07f18529..341965047a 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -34,7 +34,7 @@ from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure -from synapse.api.constants import EduTypes, EventTypes +from synapse.api.constants import EduTypes, EventTypes, Membership from synapse.api.errors import ( AuthError, Codes, @@ -46,6 +46,7 @@ from synapse.api.errors import ( ) from synapse.api.room_versions import KNOWN_ROOM_VERSIONS from synapse.events import EventBase +from synapse.events.snapshot import EventContext from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions from synapse.federation.units import Edu, Transaction @@ -537,26 +538,21 @@ class FederationServer(FederationBase): return {"event": ret_pdu.get_pdu_json(time_now)} async def on_send_join_request( - self, origin: str, content: JsonDict + self, origin: str, content: JsonDict, room_id: str ) -> Dict[str, Any]: - logger.debug("on_send_join_request: content: %s", content) - - assert_params_in_dict(content, ["room_id"]) - room_version = await self.store.get_room_version(content["room_id"]) - pdu = event_from_pdu_json(content, room_version) - - origin_host, _ = parse_server_name(origin) - await self.check_server_matches_acl(origin_host, pdu.room_id) - - logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) + context = await self._on_send_membership_event( + origin, content, Membership.JOIN, room_id + ) - pdu = await self._check_sigs_and_hash(room_version, pdu) + prev_state_ids = await context.get_prev_state_ids() + state_ids = list(prev_state_ids.values()) + auth_chain = await self.store.get_auth_chain(room_id, state_ids) + state = await self.store.get_events(state_ids) - res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() return { - "state": [p.get_pdu_json(time_now) for p in res_pdus["state"]], - "auth_chain": [p.get_pdu_json(time_now) for p in res_pdus["auth_chain"]], + "state": [p.get_pdu_json(time_now) for p in state.values()], + "auth_chain": [p.get_pdu_json(time_now) for p in auth_chain], } async def on_make_leave_request( @@ -571,21 +567,11 @@ class FederationServer(FederationBase): time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - async def on_send_leave_request(self, origin: str, content: JsonDict) -> dict: + async def on_send_leave_request( + self, origin: str, content: JsonDict, room_id: str + ) -> dict: logger.debug("on_send_leave_request: content: %s", content) - - assert_params_in_dict(content, ["room_id"]) - room_version = await self.store.get_room_version(content["room_id"]) - pdu = event_from_pdu_json(content, room_version) - - origin_host, _ = parse_server_name(origin) - await self.check_server_matches_acl(origin_host, pdu.room_id) - - logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - - pdu = await self._check_sigs_and_hash(room_version, pdu) - - await self.handler.on_send_leave_request(origin, pdu) + await self._on_send_membership_event(origin, content, Membership.LEAVE, room_id) return {} async def on_make_knock_request( @@ -651,39 +637,76 @@ class FederationServer(FederationBase): Returns: The stripped room state. """ - logger.debug("on_send_knock_request: content: %s", content) + event_context = await self._on_send_membership_event( + origin, content, Membership.KNOCK, room_id + ) + + # Retrieve stripped state events from the room and send them back to the remote + # server. This will allow the remote server's clients to display information + # related to the room while the knock request is pending. + stripped_room_state = ( + await self.store.get_stripped_room_state_from_event_context( + event_context, self._room_prejoin_state_types + ) + ) + return {"knock_state_events": stripped_room_state} + + async def _on_send_membership_event( + self, origin: str, content: JsonDict, membership_type: str, room_id: str + ) -> EventContext: + """Handle an on_send_{join,leave,knock} request + + Does some preliminary validation before passing the request on to the + federation handler. + + Args: + origin: The (authenticated) requesting server + content: The body of the send_* request - a complete membership event + membership_type: The expected membership type (join or leave, depending + on the endpoint) + room_id: The room_id from the request, to be validated against the room_id + in the event + + Returns: + The context of the event after inserting it into the room graph. + + Raises: + SynapseError if there is a problem with the request, including things like + the room_id not matching or the event not being authorized. + """ + assert_params_in_dict(content, ["room_id"]) + if content["room_id"] != room_id: + raise SynapseError( + 400, + "Room ID in body does not match that in request path", + Codes.BAD_JSON, + ) room_version = await self.store.get_room_version(room_id) - # Check that this room supports knocking as defined by its room version - if not room_version.msc2403_knocking: + if membership_type == Membership.KNOCK and not room_version.msc2403_knocking: raise SynapseError( 403, "This room version does not support knocking", errcode=Codes.FORBIDDEN, ) - pdu = event_from_pdu_json(content, room_version) + event = event_from_pdu_json(content, room_version) - origin_host, _ = parse_server_name(origin) - await self.check_server_matches_acl(origin_host, pdu.room_id) + if event.type != EventTypes.Member or not event.is_state(): + raise SynapseError(400, "Not an m.room.member event", Codes.BAD_JSON) - logger.debug("on_send_knock_request: pdu sigs: %s", pdu.signatures) + if event.content.get("membership") != membership_type: + raise SynapseError(400, "Not a %s event" % membership_type, Codes.BAD_JSON) - pdu = await self._check_sigs_and_hash(room_version, pdu) + origin_host, _ = parse_server_name(origin) + await self.check_server_matches_acl(origin_host, event.room_id) - # Handle the event, and retrieve the EventContext - event_context = await self.handler.on_send_knock_request(origin, pdu) + logger.debug("_on_send_membership_event: pdu sigs: %s", event.signatures) - # Retrieve stripped state events from the room and send them back to the remote - # server. This will allow the remote server's clients to display information - # related to the room while the knock request is pending. - stripped_room_state = ( - await self.store.get_stripped_room_state_from_event_context( - event_context, self._room_prejoin_state_types - ) - ) - return {"knock_state_events": stripped_room_state} + event = await self._check_sigs_and_hash(room_version, event) + + return await self.handler.on_send_membership_event(origin, event) async def on_event_auth( self, origin: str, room_id: str, event_id: str diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index bed47f8abd..676fbd3750 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -553,7 +553,7 @@ class FederationV1SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content) + content = await self.handler.on_send_leave_request(origin, content, room_id) return 200, (200, content) @@ -563,7 +563,7 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content) + content = await self.handler.on_send_leave_request(origin, content, room_id) return 200, content @@ -602,9 +602,9 @@ class FederationV1SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" async def on_PUT(self, origin, content, query, room_id, event_id): - # TODO(paul): assert that room_id/event_id parsed from path actually + # TODO(paul): assert that event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content) + content = await self.handler.on_send_join_request(origin, content, room_id) return 200, (200, content) @@ -614,9 +614,9 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX async def on_PUT(self, origin, content, query, room_id, event_id): - # TODO(paul): assert that room_id/event_id parsed from path actually + # TODO(paul): assert that event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content) + content = await self.handler.on_send_join_request(origin, content, room_id) return 200, content -- cgit 1.5.1 From 0555d7b0dc18fff489a31afccb47b79afa082113 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 28 Jun 2021 07:36:41 -0400 Subject: Add additional types to the federation transport server. (#10213) --- changelog.d/10213.misc | 1 + synapse/federation/transport/server.py | 588 ++++++++++++++++++++++++++------- synapse/http/servlet.py | 50 ++- 3 files changed, 521 insertions(+), 118 deletions(-) create mode 100644 changelog.d/10213.misc (limited to 'synapse/federation') diff --git a/changelog.d/10213.misc b/changelog.d/10213.misc new file mode 100644 index 0000000000..9adb0fbd02 --- /dev/null +++ b/changelog.d/10213.misc @@ -0,0 +1 @@ +Add type hints to the federation servlets. diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py index 676fbd3750..d37d9565fc 100644 --- a/synapse/federation/transport/server.py +++ b/synapse/federation/transport/server.py @@ -15,7 +15,19 @@ import functools import logging import re -from typing import Container, Mapping, Optional, Sequence, Tuple, Type +from typing import ( + Container, + Dict, + List, + Mapping, + Optional, + Sequence, + Tuple, + Type, + Union, +) + +from typing_extensions import Literal import synapse from synapse.api.constants import MAX_GROUP_CATEGORYID_LENGTH, MAX_GROUP_ROLEID_LENGTH @@ -56,15 +68,15 @@ logger = logging.getLogger(__name__) class TransportLayerServer(JsonResource): """Handles incoming federation HTTP requests""" - def __init__(self, hs, servlet_groups=None): + def __init__(self, hs: HomeServer, servlet_groups: Optional[List[str]] = None): """Initialize the TransportLayerServer Will by default register all servlets. For custom behaviour, pass in a list of servlet_groups to register. Args: - hs (synapse.server.HomeServer): homeserver - servlet_groups (list[str], optional): List of servlet groups to register. + hs: homeserver + servlet_groups: List of servlet groups to register. Defaults to ``DEFAULT_SERVLET_GROUPS``. """ self.hs = hs @@ -78,7 +90,7 @@ class TransportLayerServer(JsonResource): self.register_servlets() - def register_servlets(self): + def register_servlets(self) -> None: register_servlets( self.hs, resource=self, @@ -91,14 +103,10 @@ class TransportLayerServer(JsonResource): class AuthenticationError(SynapseError): """There was a problem authenticating the request""" - pass - class NoAuthenticationError(AuthenticationError): """The request had no authentication information""" - pass - class Authenticator: def __init__(self, hs: HomeServer): @@ -410,13 +418,18 @@ class FederationSendServlet(BaseFederationServerServlet): RATELIMIT = False # This is when someone is trying to send us a bunch of data. - async def on_PUT(self, origin, content, query, transaction_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + transaction_id: str, + ) -> Tuple[int, JsonDict]: """Called on PUT /send// Args: - request (twisted.web.http.Request): The HTTP request. - transaction_id (str): The transaction_id associated with this - request. This is *not* None. + transaction_id: The transaction_id associated with this request. This + is *not* None. Returns: Tuple of `(code, response)`, where @@ -461,7 +474,13 @@ class FederationEventServlet(BaseFederationServerServlet): PATH = "/event/(?P[^/]*)/?" # This is when someone asks for a data item for a given server data_id pair. - async def on_GET(self, origin, content, query, event_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + event_id: str, + ) -> Tuple[int, Union[JsonDict, str]]: return await self.handler.on_pdu_request(origin, event_id) @@ -469,7 +488,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet): PATH = "/state/(?P[^/]*)/?" # This is when someone asks for all data for a given room. - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_room_state_request( origin, room_id, @@ -480,7 +505,13 @@ class FederationStateV1Servlet(BaseFederationServerServlet): class FederationStateIdsServlet(BaseFederationServerServlet): PATH = "/state_ids/(?P[^/]*)/?" - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_state_ids_request( origin, room_id, @@ -491,7 +522,13 @@ class FederationStateIdsServlet(BaseFederationServerServlet): class FederationBackfillServlet(BaseFederationServerServlet): PATH = "/backfill/(?P[^/]*)/?" - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: versions = [x.decode("ascii") for x in query[b"v"]] limit = parse_integer_from_args(query, "limit", None) @@ -505,7 +542,13 @@ class FederationQueryServlet(BaseFederationServerServlet): PATH = "/query/(?P[^/]*)" # This is when we receive a server-server Query - async def on_GET(self, origin, content, query, query_type): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + query_type: str, + ) -> Tuple[int, JsonDict]: args = {k.decode("utf8"): v[0].decode("utf-8") for k, v in query.items()} args["origin"] = origin return await self.handler.on_query_request(query_type, args) @@ -514,47 +557,66 @@ class FederationQueryServlet(BaseFederationServerServlet): class FederationMakeJoinServlet(BaseFederationServerServlet): PATH = "/make_join/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, _content, query, room_id, user_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: """ Args: - origin (unicode): The authenticated server_name of the calling server + origin: The authenticated server_name of the calling server - _content (None): (GETs don't have bodies) + content: (GETs don't have bodies) - query (dict[bytes, list[bytes]]): Query params from the request. + query: Query params from the request. - **kwargs (dict[unicode, unicode]): the dict mapping keys to path - components as specified in the path match regexp. + **kwargs: the dict mapping keys to path components as specified in + the path match regexp. Returns: - Tuple[int, object]: (response code, response object) + Tuple of (response code, response object) """ - versions = query.get(b"ver") - if versions is not None: - supported_versions = [v.decode("utf-8") for v in versions] - else: + supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") + if supported_versions is None: supported_versions = ["1"] - content = await self.handler.on_make_join_request( + result = await self.handler.on_make_join_request( origin, room_id, user_id, supported_versions=supported_versions ) - return 200, content + return 200, result class FederationMakeLeaveServlet(BaseFederationServerServlet): PATH = "/make_leave/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, room_id, user_id): - content = await self.handler.on_make_leave_request(origin, room_id, user_id) - return 200, content + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_make_leave_request(origin, room_id, user_id) + return 200, result class FederationV1SendLeaveServlet(BaseFederationServerServlet): PATH = "/send_leave/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content, room_id) - return 200, (200, content) + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, (200, result) class FederationV2SendLeaveServlet(BaseFederationServerServlet): @@ -562,50 +624,84 @@ class FederationV2SendLeaveServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_leave_request(origin, content, room_id) - return 200, content + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_leave_request(origin, content, room_id) + return 200, result class FederationMakeKnockServlet(BaseFederationServerServlet): PATH = "/make_knock/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, room_id, user_id): - try: - # Retrieve the room versions the remote homeserver claims to support - supported_versions = parse_strings_from_args(query, "ver", encoding="utf-8") - except KeyError: - raise SynapseError(400, "Missing required query parameter 'ver'") + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: + # Retrieve the room versions the remote homeserver claims to support + supported_versions = parse_strings_from_args( + query, "ver", required=True, encoding="utf-8" + ) - content = await self.handler.on_make_knock_request( + result = await self.handler.on_make_knock_request( origin, room_id, user_id, supported_versions=supported_versions ) - return 200, content + return 200, result class FederationV1SendKnockServlet(BaseFederationServerServlet): PATH = "/send_knock/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): - content = await self.handler.on_send_knock_request(origin, content, room_id) - return 200, content + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: + result = await self.handler.on_send_knock_request(origin, content, room_id) + return 200, result class FederationEventAuthServlet(BaseFederationServerServlet): PATH = "/event_auth/(?P[^/]*)/(?P[^/]*)" - async def on_GET(self, origin, content, query, room_id, event_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_event_auth(origin, room_id, event_id) class FederationV1SendJoinServlet(BaseFederationServerServlet): PATH = "/send_join/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: # TODO(paul): assert that event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content, room_id) - return 200, (200, content) + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, (200, result) class FederationV2SendJoinServlet(BaseFederationServerServlet): @@ -613,28 +709,42 @@ class FederationV2SendJoinServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: # TODO(paul): assert that event_id parsed from path actually # match those given in content - content = await self.handler.on_send_join_request(origin, content, room_id) - return 200, content + result = await self.handler.on_send_join_request(origin, content, room_id) + return 200, result class FederationV1InviteServlet(BaseFederationServerServlet): PATH = "/invite/(?P[^/]*)/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, Tuple[int, JsonDict]]: # We don't get a room version, so we have to assume its EITHER v1 or # v2. This is "fine" as the only difference between V1 and V2 is the # state resolution algorithm, and we don't use that for processing # invites - content = await self.handler.on_invite_request( + result = await self.handler.on_invite_request( origin, content, room_version_id=RoomVersions.V1.identifier ) # V1 federation API is defined to return a content of `[200, {...}]` # due to a historical bug. - return 200, (200, content) + return 200, (200, result) class FederationV2InviteServlet(BaseFederationServerServlet): @@ -642,7 +752,14 @@ class FederationV2InviteServlet(BaseFederationServerServlet): PREFIX = FEDERATION_V2_PREFIX - async def on_PUT(self, origin, content, query, room_id, event_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + event_id: str, + ) -> Tuple[int, JsonDict]: # TODO(paul): assert that room_id/event_id parsed from path actually # match those given in content @@ -655,16 +772,22 @@ class FederationV2InviteServlet(BaseFederationServerServlet): event.setdefault("unsigned", {})["invite_room_state"] = invite_room_state - content = await self.handler.on_invite_request( + result = await self.handler.on_invite_request( origin, event, room_version_id=room_version ) - return 200, content + return 200, result class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): PATH = "/exchange_third_party_invite/(?P[^/]*)" - async def on_PUT(self, origin, content, query, room_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: await self.handler.on_exchange_third_party_invite_request(content) return 200, {} @@ -672,21 +795,31 @@ class FederationThirdPartyInviteExchangeServlet(BaseFederationServerServlet): class FederationClientKeysQueryServlet(BaseFederationServerServlet): PATH = "/user/keys/query" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: return await self.handler.on_query_client_keys(origin, content) class FederationUserDevicesQueryServlet(BaseFederationServerServlet): PATH = "/user/devices/(?P[^/]*)" - async def on_GET(self, origin, content, query, user_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + user_id: str, + ) -> Tuple[int, JsonDict]: return await self.handler.on_query_user_devices(origin, user_id) class FederationClientKeysClaimServlet(BaseFederationServerServlet): PATH = "/user/keys/claim" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: response = await self.handler.on_claim_client_keys(origin, content) return 200, response @@ -695,12 +828,18 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet): # TODO(paul): Why does this path alone end with "/?" optional? PATH = "/get_missing_events/(?P[^/]*)/?" - async def on_POST(self, origin, content, query, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: limit = int(content.get("limit", 10)) earliest_events = content.get("earliest_events", []) latest_events = content.get("latest_events", []) - content = await self.handler.on_get_missing_events( + result = await self.handler.on_get_missing_events( origin, room_id=room_id, earliest_events=earliest_events, @@ -708,7 +847,7 @@ class FederationGetMissingEventsServlet(BaseFederationServerServlet): limit=limit, ) - return 200, content + return 200, result class On3pidBindServlet(BaseFederationServerServlet): @@ -716,7 +855,9 @@ class On3pidBindServlet(BaseFederationServerServlet): REQUIRE_AUTH = False - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: Optional[str], content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: if "invites" in content: last_exception = None for invite in content["invites"]: @@ -762,15 +903,20 @@ class OpenIdUserInfo(BaseFederationServerServlet): REQUIRE_AUTH = False - async def on_GET(self, origin, content, query): - token = query.get(b"access_token", [None])[0] + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: + token = parse_string_from_args(query, "access_token") if token is None: return ( 401, {"errcode": "M_MISSING_TOKEN", "error": "Access Token required"}, ) - user_id = await self.handler.on_openid_userinfo(token.decode("ascii")) + user_id = await self.handler.on_openid_userinfo(token) if user_id is None: return ( @@ -829,7 +975,9 @@ class PublicRoomList(BaseFederationServlet): self.handler = hs.get_room_list_handler() self.allow_access = allow_access - async def on_GET(self, origin, content, query): + async def on_GET( + self, origin: str, content: Literal[None], query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: if not self.allow_access: raise FederationDeniedError(origin) @@ -858,7 +1006,9 @@ class PublicRoomList(BaseFederationServlet): ) return 200, data - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: # This implements MSC2197 (Search Filtering over Federation) if not self.allow_access: raise FederationDeniedError(origin) @@ -904,7 +1054,12 @@ class FederationVersionServlet(BaseFederationServlet): REQUIRE_AUTH = False - async def on_GET(self, origin, content, query): + async def on_GET( + self, + origin: Optional[str], + content: Literal[None], + query: Dict[bytes, List[bytes]], + ) -> Tuple[int, JsonDict]: return ( 200, {"server": {"name": "Synapse", "version": get_version_string(synapse)}}, @@ -933,7 +1088,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/profile" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -942,7 +1103,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet): return 200, new_content - async def on_POST(self, origin, content, query, group_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -957,7 +1124,13 @@ class FederationGroupsProfileServlet(BaseGroupsServerServlet): class FederationGroupsSummaryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/summary" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -972,7 +1145,13 @@ class FederationGroupsRoomsServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/rooms" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -987,7 +1166,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/room/(?P[^/]*)" - async def on_POST(self, origin, content, query, group_id, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -998,7 +1184,14 @@ class FederationGroupsAddRoomsServlet(BaseGroupsServerServlet): return 200, new_content - async def on_DELETE(self, origin, content, query, group_id, room_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1018,7 +1211,15 @@ class FederationGroupsAddRoomsConfigServlet(BaseGroupsServerServlet): "/config/(?P[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, room_id, config_key): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + room_id: str, + config_key: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1035,7 +1236,13 @@ class FederationGroupsUsersServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1050,7 +1257,13 @@ class FederationGroupsInvitedUsersServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/invited_users" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1067,7 +1280,14 @@ class FederationGroupsInviteServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1084,7 +1304,14 @@ class FederationGroupsAcceptInviteServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/accept_invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1098,7 +1325,14 @@ class FederationGroupsJoinServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/join" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(user_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1112,7 +1346,14 @@ class FederationGroupsRemoveUserServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/users/(?P[^/]*)/remove" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1146,7 +1387,14 @@ class FederationGroupsLocalInviteServlet(BaseGroupsLocalServlet): PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/invite" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: if get_domain_from_id(group_id) != origin: raise SynapseError(403, "group_id doesn't match origin") @@ -1164,7 +1412,14 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): PATH = "/groups/local/(?P[^/]*)/users/(?P[^/]*)/remove" - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, None]: if get_domain_from_id(group_id) != origin: raise SynapseError(403, "user_id doesn't match origin") @@ -1172,11 +1427,9 @@ class FederationGroupsRemoveLocalUserServlet(BaseGroupsLocalServlet): self.handler, GroupsLocalHandler ), "Workers cannot handle group removals." - new_content = await self.handler.user_removed_from_group( - group_id, user_id, content - ) + await self.handler.user_removed_from_group(group_id, user_id, content) - return 200, new_content + return 200, None class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): @@ -1194,7 +1447,14 @@ class FederationGroupsRenewAttestaionServlet(BaseFederationServlet): super().__init__(hs, authenticator, ratelimiter, server_name) self.handler = hs.get_groups_attestation_renewer() - async def on_POST(self, origin, content, query, group_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: # We don't need to check auth here as we check the attestation signatures new_content = await self.handler.on_renew_attestation( @@ -1218,7 +1478,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): "/rooms/(?P[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, category_id, room_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1246,7 +1514,15 @@ class FederationGroupsSummaryRoomsServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, category_id, room_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + room_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1266,7 +1542,13 @@ class FederationGroupsCategoriesServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/categories/?" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1281,7 +1563,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/categories/(?P[^/]+)" - async def on_GET(self, origin, content, query, group_id, category_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1292,7 +1581,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet): return 200, resp - async def on_POST(self, origin, content, query, group_id, category_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1314,7 +1610,14 @@ class FederationGroupsCategoryServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, category_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + category_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1334,7 +1637,13 @@ class FederationGroupsRolesServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/roles/?" - async def on_GET(self, origin, content, query, group_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1349,7 +1658,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/roles/(?P[^/]+)" - async def on_GET(self, origin, content, query, group_id, role_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1358,7 +1674,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet): return 200, resp - async def on_POST(self, origin, content, query, group_id, role_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1382,7 +1705,14 @@ class FederationGroupsRoleServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, role_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1411,7 +1741,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): "/users/(?P[^/]*)" ) - async def on_POST(self, origin, content, query, group_id, role_id, user_id): + async def on_POST( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1437,7 +1775,15 @@ class FederationGroupsSummaryUsersServlet(BaseGroupsServerServlet): return 200, resp - async def on_DELETE(self, origin, content, query, group_id, role_id, user_id): + async def on_DELETE( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + group_id: str, + role_id: str, + user_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1457,7 +1803,9 @@ class FederationGroupsBulkPublicisedServlet(BaseGroupsLocalServlet): PATH = "/get_groups_publicised" - async def on_POST(self, origin, content, query): + async def on_POST( + self, origin: str, content: JsonDict, query: Dict[bytes, List[bytes]] + ) -> Tuple[int, JsonDict]: resp = await self.handler.bulk_get_publicised_groups( content["user_ids"], proxy=False ) @@ -1470,7 +1818,13 @@ class FederationGroupsSettingJoinPolicyServlet(BaseGroupsServerServlet): PATH = "/groups/(?P[^/]*)/settings/m.join_policy" - async def on_PUT(self, origin, content, query, group_id): + async def on_PUT( + self, + origin: str, + content: JsonDict, + query: Dict[bytes, List[bytes]], + group_id: str, + ) -> Tuple[int, JsonDict]: requester_user_id = parse_string_from_args(query, "requester_user_id") if get_domain_from_id(requester_user_id) != origin: raise SynapseError(403, "requester_user_id doesn't match origin") @@ -1499,7 +1853,7 @@ class FederationSpaceSummaryServlet(BaseFederationServlet): async def on_GET( self, origin: str, - content: JsonDict, + content: Literal[None], query: Mapping[bytes, Sequence[bytes]], room_id: str, ) -> Tuple[int, JsonDict]: @@ -1571,7 +1925,13 @@ class RoomComplexityServlet(BaseFederationServlet): super().__init__(hs, authenticator, ratelimiter, server_name) self._store = self.hs.get_datastore() - async def on_GET(self, origin, content, query, room_id): + async def on_GET( + self, + origin: str, + content: Literal[None], + query: Dict[bytes, List[bytes]], + room_id: str, + ) -> Tuple[int, JsonDict]: is_public = await self._store.is_room_world_readable_or_publicly_joinable( room_id ) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py index fda8da21b7..6ba2ce1e53 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py @@ -109,12 +109,22 @@ def parse_boolean_from_args(args, name, default=None, required=False): return default +@overload +def parse_bytes_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[bytes] = None, +) -> Optional[bytes]: + ... + + @overload def parse_bytes_from_args( args: Dict[bytes, List[bytes]], name: str, default: Literal[None] = None, - required: Literal[True] = True, + *, + required: Literal[True], ) -> bytes: ... @@ -197,7 +207,12 @@ def parse_string( """ args = request.args # type: Dict[bytes, List[bytes]] # type: ignore return parse_string_from_args( - args, name, default, required, allowed_values, encoding + args, + name, + default, + required=required, + allowed_values=allowed_values, + encoding=encoding, ) @@ -227,7 +242,20 @@ def parse_strings_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[List[str]] = None, - required: Literal[True] = True, + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[List[str]]: + ... + + +@overload +def parse_strings_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[List[str]] = None, + *, + required: Literal[True], allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", ) -> List[str]: @@ -239,6 +267,7 @@ def parse_strings_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[List[str]] = None, + *, required: bool = False, allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", @@ -299,7 +328,20 @@ def parse_string_from_args( args: Dict[bytes, List[bytes]], name: str, default: Optional[str] = None, - required: Literal[True] = True, + *, + allowed_values: Optional[Iterable[str]] = None, + encoding: str = "ascii", +) -> Optional[str]: + ... + + +@overload +def parse_string_from_args( + args: Dict[bytes, List[bytes]], + name: str, + default: Optional[str] = None, + *, + required: Literal[True], allowed_values: Optional[Iterable[str]] = None, encoding: str = "ascii", ) -> str: -- cgit 1.5.1 From a0ed0f363eb84f273b2cc706fcc5542d77a94463 Mon Sep 17 00:00:00 2001 From: Richard van der Hoff <1389908+richvdh@users.noreply.github.com> Date: Tue, 29 Jun 2021 11:08:06 +0100 Subject: Soft-fail spammy events received over federation (#10263) --- changelog.d/10263.feature | 1 + synapse/federation/federation_base.py | 12 ++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) create mode 100644 changelog.d/10263.feature (limited to 'synapse/federation') diff --git a/changelog.d/10263.feature b/changelog.d/10263.feature new file mode 100644 index 0000000000..7b1d2fe60f --- /dev/null +++ b/changelog.d/10263.feature @@ -0,0 +1 @@ +Mark events received over federation which fail a spam check as "soft-failed". diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py index c066617b92..2bfe6a3d37 100644 --- a/synapse/federation/federation_base.py +++ b/synapse/federation/federation_base.py @@ -89,12 +89,12 @@ class FederationBase: result = await self.spam_checker.check_event_for_spam(pdu) if result: - logger.warning( - "Event contains spam, redacting %s: %s", - pdu.event_id, - pdu.get_pdu_json(), - ) - return prune_event(pdu) + logger.warning("Event contains spam, soft-failing %s", pdu.event_id) + # we redact (to save disk space) as well as soft-failing (to stop + # using the event in prev_events). + redacted_event = prune_event(pdu) + redacted_event.internal_metadata.soft_failed = True + return redacted_event return pdu -- cgit 1.5.1 From c54db67d0ea5b5967b7ea918c66a222a75b8ced1 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 29 Jun 2021 19:55:22 +0100 Subject: Handle inbound events from federation asynchronously (#10272) Fixes #9490 This will break a couple of SyTest that are expecting failures to be added to the response of a federation /send, which obviously doesn't happen now that things are asynchronous. Two drawbacks: Currently there is no logic to handle any events left in the staging area after restart, and so they'll only be handled on the next incoming event in that room. That can be fixed separately. We now only process one event per room at a time. This can be fixed up further down the line. --- changelog.d/10272.bugfix | 1 + synapse/federation/federation_server.py | 98 +++++++++++++++++- synapse/storage/databases/main/event_federation.py | 109 ++++++++++++++++++++- .../main/delta/59/16federation_inbound_staging.sql | 32 ++++++ sytest-blacklist | 6 ++ 5 files changed, 241 insertions(+), 5 deletions(-) create mode 100644 changelog.d/10272.bugfix create mode 100644 synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql (limited to 'synapse/federation') diff --git a/changelog.d/10272.bugfix b/changelog.d/10272.bugfix new file mode 100644 index 0000000000..3cefa05788 --- /dev/null +++ b/changelog.d/10272.bugfix @@ -0,0 +1 @@ +Handle inbound events from federation asynchronously. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 2b07f18529..1d050e54e2 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -44,7 +44,7 @@ from synapse.api.errors import ( SynapseError, UnsupportedRoomVersionError, ) -from synapse.api.room_versions import KNOWN_ROOM_VERSIONS +from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase from synapse.federation.federation_base import FederationBase, event_from_pdu_json from synapse.federation.persistence import TransactionActions @@ -57,10 +57,12 @@ from synapse.logging.context import ( ) from synapse.logging.opentracing import log_kv, start_active_span_from_edu, trace from synapse.logging.utils import log_function +from synapse.metrics.background_process_metrics import wrap_as_background_process from synapse.replication.http.federation import ( ReplicationFederationSendEduRestServlet, ReplicationGetQueryRestServlet, ) +from synapse.storage.databases.main.lock import Lock from synapse.types import JsonDict from synapse.util import glob_to_regex, json_decoder, unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute @@ -96,6 +98,11 @@ last_pdu_ts_metric = Gauge( ) +# The name of the lock to use when process events in a room received over +# federation. +_INBOUND_EVENT_HANDLING_LOCK_NAME = "federation_inbound_pdu" + + class FederationServer(FederationBase): def __init__(self, hs: "HomeServer"): super().__init__(hs) @@ -834,7 +841,94 @@ class FederationServer(FederationBase): except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + # Add the event to our staging area + await self.store.insert_received_event_to_staging(origin, pdu) + + # Try and acquire the processing lock for the room, if we get it start a + # background process for handling the events in the room. + lock = await self.store.try_acquire_lock( + _INBOUND_EVENT_HANDLING_LOCK_NAME, pdu.room_id + ) + if lock: + self._process_incoming_pdus_in_room_inner( + pdu.room_id, room_version, lock, origin, pdu + ) + + @wrap_as_background_process("_process_incoming_pdus_in_room_inner") + async def _process_incoming_pdus_in_room_inner( + self, + room_id: str, + room_version: RoomVersion, + lock: Lock, + latest_origin: str, + latest_event: EventBase, + ) -> None: + """Process events in the staging area for the given room. + + The latest_origin and latest_event args are the latest origin and event + received. + """ + + # The common path is for the event we just received be the only event in + # the room, so instead of pulling the event out of the DB and parsing + # the event we just pull out the next event ID and check if that matches. + next_origin, next_event_id = await self.store.get_next_staged_event_id_for_room( + room_id + ) + if next_origin == latest_origin and next_event_id == latest_event.event_id: + origin = latest_origin + event = latest_event + else: + next = await self.store.get_next_staged_event_for_room( + room_id, room_version + ) + if not next: + return + + origin, event = next + + # We loop round until there are no more events in the room in the + # staging area, or we fail to get the lock (which means another process + # has started processing). + while True: + async with lock: + try: + await self.handler.on_receive_pdu( + origin, event, sent_to_us_directly=True + ) + except FederationError as e: + # XXX: Ideally we'd inform the remote we failed to process + # the event, but we can't return an error in the transaction + # response (as we've already responded). + logger.warning("Error handling PDU %s: %s", event.event_id, e) + except Exception: + f = failure.Failure() + logger.error( + "Failed to handle PDU %s", + event.event_id, + exc_info=(f.type, f.value, f.getTracebackObject()), # type: ignore + ) + + await self.store.remove_received_event_from_staging( + origin, event.event_id + ) + + # We need to do this check outside the lock to avoid a race between + # a new event being inserted by another instance and it attempting + # to acquire the lock. + next = await self.store.get_next_staged_event_for_room( + room_id, room_version + ) + if not next: + break + + origin, event = next + + lock = await self.store.try_acquire_lock( + _INBOUND_EVENT_HANDLING_LOCK_NAME, room_id + ) + if not lock: + return def __str__(self) -> str: return "" % self.server_name diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index c0ea445550..f23f8c6ecf 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -14,18 +14,20 @@ import itertools import logging from queue import Empty, PriorityQueue -from typing import Collection, Dict, Iterable, List, Set, Tuple +from typing import Collection, Dict, Iterable, List, Optional, Set, Tuple from synapse.api.constants import MAX_DEPTH from synapse.api.errors import StoreError -from synapse.events import EventBase +from synapse.api.room_versions import RoomVersion +from synapse.events import EventBase, make_event_from_dict from synapse.metrics.background_process_metrics import wrap_as_background_process -from synapse.storage._base import SQLBaseStore, make_in_list_sql_clause +from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause from synapse.storage.database import DatabasePool, LoggingTransaction from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.databases.main.signatures import SignatureWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.types import Cursor +from synapse.util import json_encoder from synapse.util.caches.descriptors import cached from synapse.util.caches.lrucache import LruCache from synapse.util.iterutils import batch_iter @@ -1044,6 +1046,107 @@ class EventFederationWorkerStore(EventsWorkerStore, SignatureWorkerStore, SQLBas _delete_old_forward_extrem_cache_txn, ) + async def insert_received_event_to_staging( + self, origin: str, event: EventBase + ) -> None: + """Insert a newly received event from federation into the staging area.""" + + # We use an upsert here to handle the case where we see the same event + # from the same server multiple times. + await self.db_pool.simple_upsert( + table="federation_inbound_events_staging", + keyvalues={ + "origin": origin, + "event_id": event.event_id, + }, + values={}, + insertion_values={ + "room_id": event.room_id, + "received_ts": self._clock.time_msec(), + "event_json": json_encoder.encode(event.get_dict()), + "internal_metadata": json_encoder.encode( + event.internal_metadata.get_dict() + ), + }, + desc="insert_received_event_to_staging", + ) + + async def remove_received_event_from_staging( + self, + origin: str, + event_id: str, + ) -> None: + """Remove the given event from the staging area""" + await self.db_pool.simple_delete( + table="federation_inbound_events_staging", + keyvalues={ + "origin": origin, + "event_id": event_id, + }, + desc="remove_received_event_from_staging", + ) + + async def get_next_staged_event_id_for_room( + self, + room_id: str, + ) -> Optional[Tuple[str, str]]: + """Get the next event ID in the staging area for the given room.""" + + def _get_next_staged_event_id_for_room_txn(txn): + sql = """ + SELECT origin, event_id + FROM federation_inbound_events_staging + WHERE room_id = ? + ORDER BY received_ts ASC + LIMIT 1 + """ + + txn.execute(sql, (room_id,)) + + return txn.fetchone() + + return await self.db_pool.runInteraction( + "get_next_staged_event_id_for_room", _get_next_staged_event_id_for_room_txn + ) + + async def get_next_staged_event_for_room( + self, + room_id: str, + room_version: RoomVersion, + ) -> Optional[Tuple[str, EventBase]]: + """Get the next event in the staging area for the given room.""" + + def _get_next_staged_event_for_room_txn(txn): + sql = """ + SELECT event_json, internal_metadata, origin + FROM federation_inbound_events_staging + WHERE room_id = ? + ORDER BY received_ts ASC + LIMIT 1 + """ + txn.execute(sql, (room_id,)) + + return txn.fetchone() + + row = await self.db_pool.runInteraction( + "get_next_staged_event_for_room", _get_next_staged_event_for_room_txn + ) + + if not row: + return None + + event_d = db_to_json(row[0]) + internal_metadata_d = db_to_json(row[1]) + origin = row[2] + + event = make_event_from_dict( + event_dict=event_d, + room_version=room_version, + internal_metadata_dict=internal_metadata_d, + ) + + return origin, event + class EventFederationStore(EventFederationWorkerStore): """Responsible for storing and serving up the various graphs associated diff --git a/synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql b/synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql new file mode 100644 index 0000000000..43bc5c025f --- /dev/null +++ b/synapse/storage/schema/main/delta/59/16federation_inbound_staging.sql @@ -0,0 +1,32 @@ +/* Copyright 2021 The Matrix.org Foundation C.I.C + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +-- A staging area for newly received events over federation. +-- +-- Note we may store the same event multiple times if it comes from different +-- servers; this is to handle the case if we get a redacted and non-redacted +-- versions of the event. +CREATE TABLE federation_inbound_events_staging ( + origin TEXT NOT NULL, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + received_ts BIGINT NOT NULL, + event_json TEXT NOT NULL, + internal_metadata TEXT NOT NULL +); + +CREATE INDEX federation_inbound_events_staging_room ON federation_inbound_events_staging(room_id, received_ts); +CREATE UNIQUE INDEX federation_inbound_events_staging_instance_event ON federation_inbound_events_staging(origin, event_id); diff --git a/sytest-blacklist b/sytest-blacklist index de9986357b..89c4e828fd 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -41,3 +41,9 @@ We can't peek into rooms with invited history_visibility We can't peek into rooms with joined history_visibility Local users can peek by room alias Peeked rooms only turn up in the sync for the device who peeked them + + +# Blacklisted due to changes made in #10272 +Outbound federation will ignore a missing event with bad JSON for room version 6 +Backfilled events whose prev_events are in a different room do not allow cross-room back-pagination +Federation rejects inbound events where the prev_events cannot be found -- cgit 1.5.1