#
# This file is licensed under the Affero General Public License (AGPL) version 3.
#
# Copyright 2014-2021 The Matrix.org Foundation C.I.C.
# Copyright (C) 2023 New Vector, Ltd
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# See the GNU Affero General Public License for more details:
# <https://www.gnu.org/licenses/agpl-3.0.html>.
#
# Originally licensed under the Apache License, Version 2.0:
# <http://www.apache.org/licenses/LICENSE-2.0>.
#
# [This file includes modifications made by New Vector Limited]
#
#
import abc
import cgi
import codecs
import logging
import random
import sys
import urllib.parse
from http import HTTPStatus
from io import BytesIO, StringIO
from typing import (
    TYPE_CHECKING,
    Any,
    BinaryIO,
    Callable,
    Dict,
    Generic,
    List,
    Optional,
    TextIO,
    Tuple,
    TypeVar,
    Union,
    cast,
    overload,
)

import attr
import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
from typing_extensions import Literal

from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorTime
from twisted.internet.task import Cooperator
from twisted.web.client import ResponseFailed
from twisted.web.http_headers import Headers
from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse

import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
    Codes,
    FederationDeniedError,
    HttpResponseException,
    RequestSendFailed,
    SynapseError,
)
from synapse.api.ratelimiting import Ratelimiter
from synapse.crypto.context_factory import FederationPolicyForHTTPS
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import (
    BlocklistingAgentWrapper,
    BodyExceededMaxSize,
    ByteWriteable,
    _make_scheduler,
    encode_query_args,
    read_body_with_max_size,
)
from synapse.http.connectproxyclient import BearerProxyCredentials
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.http.proxyagent import ProxyAgent
from synapse.http.types import QueryParams
from synapse.logging import opentracing
from synapse.logging.context import make_deferred_yieldable, run_in_background
from synapse.logging.opentracing import set_tag, start_active_span, tags
from synapse.types import JsonDict
from synapse.util import json_decoder
from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred
from synapse.util.metrics import Measure
from synapse.util.stringutils import parse_and_validate_server_name

if TYPE_CHECKING:
    from synapse.server import HomeServer

logger = logging.getLogger(__name__)

outgoing_requests_counter = Counter(
    "synapse_http_matrixfederationclient_requests", "", ["method"]
)
incoming_responses_counter = Counter(
    "synapse_http_matrixfederationclient_responses", "", ["method", "code"]
)


MAXINT = sys.maxsize


_next_id = 1

T = TypeVar("T")


class ByteParser(ByteWriteable, Generic[T], abc.ABC):
    """A `ByteWriteable` that has an additional `finish` function that returns
    the parsed data.
    """

    CONTENT_TYPE: str = abc.abstractproperty()  # type: ignore
    """The expected content type of the response, e.g. `application/json`. If
    the content type doesn't match we fail the request.
    """

    # a federation response can be rather large (eg a big state_ids is 50M or so), so we
    # need a generous limit here.
    MAX_RESPONSE_SIZE: int = 100 * 1024 * 1024
    """The largest response this parser will accept."""

    @abc.abstractmethod
    def finish(self) -> T:
        """Called when response has finished streaming and the parser should
        return the final result (or error).
        """


@attr.s(slots=True, frozen=True, auto_attribs=True)
class MatrixFederationRequest:
    method: str
    """HTTP method
    """

    path: str
    """HTTP path
    """

    destination: str
    """The remote server to send the HTTP request to.
    """

    json: Optional[JsonDict] = None
    """JSON to send in the body.
    """

    json_callback: Optional[Callable[[], JsonDict]] = None
    """A callback to generate the JSON.
    """

    query: Optional[QueryParams] = None
    """Query arguments.
    """

    txn_id: str = attr.ib(init=False)
    """Unique ID for this request (for logging), this is autogenerated.
    """

    uri: bytes = b""
    """The URI of this request, usually generated from the above information.
    """

    _generate_uri: bool = True
    """True to automatically generate the uri field based on the above information.

    Set to False if manually configuring the URI.
    """

    def __attrs_post_init__(self) -> None:
        global _next_id
        txn_id = "%s-O-%s" % (self.method, _next_id)
        _next_id = (_next_id + 1) % (MAXINT - 1)

        object.__setattr__(self, "txn_id", txn_id)

        if self._generate_uri:
            destination_bytes = self.destination.encode("ascii")
            path_bytes = self.path.encode("ascii")
            query_bytes = encode_query_args(self.query)

            # The object is frozen so we can pre-compute this.
            uri = urllib.parse.urlunparse(
                (
                    b"matrix-federation",
                    destination_bytes,
                    path_bytes,
                    None,
                    query_bytes,
                    b"",
                )
            )
            object.__setattr__(self, "uri", uri)

    def get_json(self) -> Optional[JsonDict]:
        if self.json_callback:
            return self.json_callback()
        return self.json


class _BaseJsonParser(ByteParser[T]):
    """A parser that buffers the response and tries to parse it as JSON."""

    CONTENT_TYPE = "application/json"

    def __init__(
        self, validator: Optional[Callable[[Optional[object]], bool]] = None
    ) -> None:
        """
        Args:
            validator: A callable which takes the parsed JSON value and returns
                true if the value is valid.
        """
        self._buffer = StringIO()
        self._binary_wrapper = BinaryIOWrapper(self._buffer)
        self._validator = validator

    def write(self, data: bytes) -> int:
        return self._binary_wrapper.write(data)

    def finish(self) -> T:
        result = json_decoder.decode(self._buffer.getvalue())
        if self._validator is not None and not self._validator(result):
            raise ValueError(
                f"Received incorrect JSON value: {result.__class__.__name__}"
            )
        return result


class JsonParser(_BaseJsonParser[JsonDict]):
    """A parser that buffers the response and tries to parse it as a JSON object."""

    def __init__(self) -> None:
        super().__init__(self._validate)

    @staticmethod
    def _validate(v: Any) -> bool:
        return isinstance(v, dict)


class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
    """Ensure the legacy responses of /send_join & /send_leave are correct."""

    def __init__(self) -> None:
        super().__init__(self._validate)

    @staticmethod
    def _validate(v: Any) -> bool:
        # Match [integer, JSON dict]
        return (
            isinstance(v, list)
            and len(v) == 2
            and type(v[0]) == int  # noqa: E721
            and isinstance(v[1], dict)
        )


async def _handle_response(
    reactor: IReactorTime,
    timeout_sec: float,
    request: MatrixFederationRequest,
    response: IResponse,
    start_ms: int,
    parser: ByteParser[T],
) -> T:
    """
    Reads the body of a response with a timeout and sends it to a parser

    Args:
        reactor: twisted reactor, for the timeout
        timeout_sec: number of seconds to wait for response to complete
        request: the request that triggered the response
        response: response to the request
        start_ms: Timestamp when request was made
        parser: The parser for the response

    Returns:
        The parsed response
    """

    max_response_size = parser.MAX_RESPONSE_SIZE

    finished = False
    try:
        check_content_type_is(response.headers, parser.CONTENT_TYPE)

        d = read_body_with_max_size(response, parser, max_response_size)
        d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)

        length = await make_deferred_yieldable(d)

        finished = True
        value = parser.finish()
    except BodyExceededMaxSize as e:
        # The response was too big.
        logger.warning(
            "{%s} [%s] JSON response exceeded max size %i - %s %s",
            request.txn_id,
            request.destination,
            max_response_size,
            request.method,
            request.uri.decode("ascii"),
        )
        raise RequestSendFailed(e, can_retry=False) from e
    except ValueError as e:
        # The content was invalid.
        logger.warning(
            "{%s} [%s] Failed to parse response - %s %s",
            request.txn_id,
            request.destination,
            request.method,
            request.uri.decode("ascii"),
        )
        raise RequestSendFailed(e, can_retry=False) from e
    except defer.TimeoutError as e:
        logger.warning(
            "{%s} [%s] Timed out reading response - %s %s",
            request.txn_id,
            request.destination,
            request.method,
            request.uri.decode("ascii"),
        )
        raise RequestSendFailed(e, can_retry=True) from e
    except ResponseFailed as e:
        logger.warning(
            "{%s} [%s] Failed to read response - %s %s",
            request.txn_id,
            request.destination,
            request.method,
            request.uri.decode("ascii"),
        )
        raise RequestSendFailed(e, can_retry=True) from e
    except Exception as e:
        logger.warning(
            "{%s} [%s] Error reading response %s %s: %s",
            request.txn_id,
            request.destination,
            request.method,
            request.uri.decode("ascii"),
            e,
        )
        raise
    finally:
        if not finished:
            # There was an exception and we didn't `finish()` the parse.
            # Let the parser know that it can free up any resources.
            try:
                parser.finish()
            except Exception:
                # Ignore any additional exceptions.
                pass

    time_taken_secs = reactor.seconds() - start_ms / 1000

    logger.info(
        "{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s",
        request.txn_id,
        request.destination,
        response.code,
        response.phrase.decode("ascii", errors="replace"),
        time_taken_secs,
        length,
        request.method,
        request.uri.decode("ascii"),
    )
    return value


class BinaryIOWrapper:
    """A wrapper for a TextIO which converts from bytes on the fly."""

    def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
        self.decoder = codecs.getincrementaldecoder(encoding)(errors)
        self.file = file

    def write(self, b: Union[bytes, bytearray]) -> int:
        self.file.write(self.decoder.decode(b))
        return len(b)


class MatrixFederationHttpClient:
    """HTTP client used to talk to other homeservers over the federation
    protocol. Send client certificates and signs requests.

    Attributes:
        agent (twisted.web.client.Agent): The twisted Agent used to send the
            requests.
    """

    def __init__(
        self,
        hs: "HomeServer",
        tls_client_options_factory: Optional[FederationPolicyForHTTPS],
    ):
        self.hs = hs
        self.signing_key = hs.signing_key
        self.server_name = hs.hostname

        self.reactor = hs.get_reactor()

        user_agent = hs.version_string
        if hs.config.server.user_agent_suffix:
            user_agent = "%s %s" % (user_agent, hs.config.server.user_agent_suffix)

        outbound_federation_restricted_to = (
            hs.config.worker.outbound_federation_restricted_to
        )
        if hs.get_instance_name() in outbound_federation_restricted_to:
            # Talk to federation directly
            federation_agent: IAgent = MatrixFederationAgent(
                self.reactor,
                tls_client_options_factory,
                user_agent.encode("ascii"),
                hs.config.server.federation_ip_range_allowlist,
                hs.config.server.federation_ip_range_blocklist,
            )
        else:
            proxy_authorization_secret = hs.config.worker.worker_replication_secret
            assert (
                proxy_authorization_secret is not None
            ), "`worker_replication_secret` must be set when using `outbound_federation_restricted_to` (used to authenticate requests across workers)"
            federation_proxy_credentials = BearerProxyCredentials(
                proxy_authorization_secret.encode("ascii")
            )

            # We need to talk to federation via the proxy via one of the configured
            # locations
            federation_proxy_locations = outbound_federation_restricted_to.locations
            federation_agent = ProxyAgent(
                self.reactor,
                self.reactor,
                tls_client_options_factory,
                federation_proxy_locations=federation_proxy_locations,
                federation_proxy_credentials=federation_proxy_credentials,
            )

        # Use a BlocklistingAgentWrapper to prevent circumventing the IP
        # blocking via IP literals in server names
        self.agent: IAgent = BlocklistingAgentWrapper(
            federation_agent,
            ip_blocklist=hs.config.server.federation_ip_range_blocklist,
        )

        self.clock = hs.get_clock()
        self._store = hs.get_datastores().main
        self.version_string_bytes = hs.version_string.encode("ascii")
        self.default_timeout_seconds = hs.config.federation.client_timeout_ms / 1000
        self.max_long_retry_delay_seconds = (
            hs.config.federation.max_long_retry_delay_ms / 1000
        )
        self.max_short_retry_delay_seconds = (
            hs.config.federation.max_short_retry_delay_ms / 1000
        )
        self.max_long_retries = hs.config.federation.max_long_retries
        self.max_short_retries = hs.config.federation.max_short_retries

        self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))

        self._sleeper = AwakenableSleeper(self.reactor)

    def wake_destination(self, destination: str) -> None:
        """Called when the remote server may have come back online."""

        self._sleeper.wake(destination)

    async def _send_request_with_optional_trailing_slash(
        self,
        request: MatrixFederationRequest,
        try_trailing_slash_on_400: bool = False,
        **send_request_args: Any,
    ) -> IResponse:
        """Wrapper for _send_request which can optionally retry the request
        upon receiving a combination of a 400 HTTP response code and a
        'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3
        due to https://github.com/matrix-org/synapse/issues/3622.

        Args:
            request: details of request to be sent
            try_trailing_slash_on_400: Whether on receiving a 400
                'M_UNRECOGNIZED' from the server to retry the request with a
                trailing slash appended to the request path.
            send_request_args: A dictionary of arguments to pass to `_send_request()`.

        Raises:
            HttpResponseException: If we get an HTTP response code >= 300
                (except 429).

        Returns:
            Parsed JSON response body.
        """
        try:
            response = await self._send_request(request, **send_request_args)
        except HttpResponseException as e:
            # Received an HTTP error > 300. Check if it meets the requirements
            # to retry with a trailing slash
            if not try_trailing_slash_on_400:
                raise

            if e.code != 400 or e.to_synapse_error().errcode != "M_UNRECOGNIZED":
                raise

            # Retry with a trailing slash if we received a 400 with
            # 'M_UNRECOGNIZED' which some endpoints can return when omitting a
            # trailing slash on Synapse <= v0.99.3.
            logger.info("Retrying request with trailing slash")

            # Request is frozen so we create a new instance
            request = attr.evolve(request, path=request.path + "/")

            response = await self._send_request(request, **send_request_args)

        return response

    async def _send_request(
        self,
        request: MatrixFederationRequest,
        retry_on_dns_fail: bool = True,
        timeout: Optional[int] = None,
        long_retries: bool = False,
        ignore_backoff: bool = False,
        backoff_on_404: bool = False,
        backoff_on_all_error_codes: bool = False,
        follow_redirects: bool = False,
    ) -> IResponse:
        """
        Sends a request to the given server.

        Args:
            request: details of request to be sent

            retry_on_dns_fail: true if the request should be retried on DNS failures

            timeout: number of milliseconds to wait for the response headers
                (including connecting to the server), *for each attempt*.
                60s by default.

            long_retries: whether to use the long retry algorithm.

                The regular retry algorithm makes 4 attempts, with intervals
                [0.5s, 1s, 2s].

                The long retry algorithm makes 11 attempts, with intervals
                [4s, 16s, 60s, 60s, ...]

                Both algorithms add -20%/+40% jitter to the retry intervals.

                Note that the above intervals are *in addition* to the time spent
                waiting for the request to complete (up to `timeout` ms).

                NB: the long retry algorithm takes over 20 minutes to complete, with a
                default timeout of 60s! It's best not to use the `long_retries` option
                for something that is blocking a client so we don't make them wait for
                aaaaages, whereas some things like sending transactions (server to
                server) we can be a lot more lenient but its very fuzzy / hand-wavey.

                In the future, we could be more intelligent about doing this sort of
                thing by looking at things with the bigger picture in mind,
                https://github.com/matrix-org/synapse/issues/8917

            ignore_backoff: true to ignore the historical backoff data
                and try the request anyway.

            backoff_on_404: Back off if we get a 404
            backoff_on_all_error_codes: Back off if we get any error response

            follow_redirects: True to follow the Location header of 307/308 redirect
                responses. This does not recurse.

        Returns:
            Resolves with the HTTP response object on success.

        Raises:
            HttpResponseException: If we get an HTTP response code >= 300
                (except 429).
            NotRetryingDestination: If we are not yet ready to retry this
                server.
            FederationDeniedError: If this destination is not on our
                federation whitelist
            RequestSendFailed: If there were problems connecting to the
                remote, due to e.g. DNS failures, connection timeouts etc.
        """
        # Validate server name and log if it is an invalid destination, this is
        # partially to help track down code paths where we haven't validated before here
        try:
            parse_and_validate_server_name(request.destination)
        except ValueError:
            logger.exception(f"Invalid destination: {request.destination}.")
            raise FederationDeniedError(request.destination)

        if timeout is not None:
            _sec_timeout = timeout / 1000
        else:
            _sec_timeout = self.default_timeout_seconds

        if (
            self.hs.config.federation.federation_domain_whitelist is not None
            and request.destination
            not in self.hs.config.federation.federation_domain_whitelist
        ):
            raise FederationDeniedError(request.destination)

        limiter = await synapse.util.retryutils.get_retry_limiter(
            request.destination,
            self.clock,
            self._store,
            backoff_on_404=backoff_on_404,
            ignore_backoff=ignore_backoff,
            notifier=self.hs.get_notifier(),
            replication_client=self.hs.get_replication_command_handler(),
            backoff_on_all_error_codes=backoff_on_all_error_codes,
        )

        method_bytes = request.method.encode("ascii")
        destination_bytes = request.destination.encode("ascii")
        path_bytes = request.path.encode("ascii")
        query_bytes = encode_query_args(request.query)

        scope = start_active_span(
            "outgoing-federation-request",
            tags={
                tags.SPAN_KIND: tags.SPAN_KIND_RPC_CLIENT,
                tags.PEER_ADDRESS: request.destination,
                tags.HTTP_METHOD: request.method,
                tags.HTTP_URL: request.path,
            },
            finish_on_close=True,
        )

        # Inject the span into the headers
        headers_dict: Dict[bytes, List[bytes]] = {}
        opentracing.inject_header_dict(headers_dict, request.destination)

        headers_dict[b"User-Agent"] = [self.version_string_bytes]

        with limiter, scope:
            # XXX: Would be much nicer to retry only at the transaction-layer
            # (once we have reliable transactions in place)
            if long_retries:
                retries_left = self.max_long_retries
            else:
                retries_left = self.max_short_retries

            url_bytes = request.uri
            url_str = url_bytes.decode("ascii")

            url_to_sign_bytes = urllib.parse.urlunparse(
                (b"", b"", path_bytes, None, query_bytes, b"")
            )

            while True:
                try:
                    json = request.get_json()
                    if json:
                        headers_dict[b"Content-Type"] = [b"application/json"]
                        auth_headers = self.build_auth_headers(
                            destination_bytes, method_bytes, url_to_sign_bytes, json
                        )
                        data = encode_canonical_json(json)
                        producer: Optional[IBodyProducer] = QuieterFileBodyProducer(
                            BytesIO(data), cooperator=self._cooperator
                        )
                    else:
                        producer = None
                        auth_headers = self.build_auth_headers(
                            destination_bytes, method_bytes, url_to_sign_bytes
                        )

                    headers_dict[b"Authorization"] = auth_headers

                    logger.debug(
                        "{%s} [%s] Sending request: %s %s; timeout %fs",
                        request.txn_id,
                        request.destination,
                        request.method,
                        url_str,
                        _sec_timeout,
                    )

                    outgoing_requests_counter.labels(request.method).inc()

                    try:
                        with Measure(self.clock, "outbound_request"):
                            # we don't want all the fancy cookie and redirect handling
                            # that treq.request gives: just use the raw Agent.

                            # To preserve the logging context, the timeout is treated
                            # in a similar way to `defer.gatherResults`:
                            # * Each logging context-preserving fork is wrapped in
                            #   `run_in_background`. In this case there is only one,
                            #   since the timeout fork is not logging-context aware.
                            # * The `Deferred` that joins the forks back together is
                            #   wrapped in `make_deferred_yieldable` to restore the
                            #   logging context regardless of the path taken.
                            request_deferred = run_in_background(
                                self.agent.request,
                                method_bytes,
                                url_bytes,
                                headers=Headers(headers_dict),
                                bodyProducer=producer,
                            )
                            request_deferred = timeout_deferred(
                                request_deferred,
                                timeout=_sec_timeout,
                                reactor=self.reactor,
                            )

                            response = await make_deferred_yieldable(request_deferred)
                    except DNSLookupError as e:
                        raise RequestSendFailed(e, can_retry=retry_on_dns_fail) from e
                    except Exception as e:
                        raise RequestSendFailed(e, can_retry=True) from e

                    incoming_responses_counter.labels(
                        request.method, response.code
                    ).inc()

                    set_tag(tags.HTTP_STATUS_CODE, response.code)
                    response_phrase = response.phrase.decode("ascii", errors="replace")

                    if 200 <= response.code < 300:
                        logger.debug(
                            "{%s} [%s] Got response headers: %d %s",
                            request.txn_id,
                            request.destination,
                            response.code,
                            response_phrase,
                        )
                    elif (
                        response.code in (307, 308)
                        and follow_redirects
                        and response.headers.hasHeader("Location")
                    ):
                        # The Location header *might* be relative so resolve it.
                        location = response.headers.getRawHeaders(b"Location")[0]
                        new_uri = urllib.parse.urljoin(request.uri, location)

                        return await self._send_request(
                            attr.evolve(request, uri=new_uri, generate_uri=False),
                            retry_on_dns_fail,
                            timeout,
                            long_retries,
                            ignore_backoff,
                            backoff_on_404,
                            backoff_on_all_error_codes,
                            # Do not continue following redirects.
                            follow_redirects=False,
                        )
                    else:
                        logger.info(
                            "{%s} [%s] Got response headers: %d %s",
                            request.txn_id,
                            request.destination,
                            response.code,
                            response_phrase,
                        )
                        # :'(
                        # Update transactions table?
                        d = treq.content(response)
                        d = timeout_deferred(
                            d, timeout=_sec_timeout, reactor=self.reactor
                        )

                        try:
                            body = await make_deferred_yieldable(d)
                        except Exception as e:
                            # Eh, we're already going to raise an exception so lets
                            # ignore if this fails.
                            logger.warning(
                                "{%s} [%s] Failed to get error response: %s %s: %s",
                                request.txn_id,
                                request.destination,
                                request.method,
                                url_str,
                                _flatten_response_never_received(e),
                            )
                            body = None

                        exc = HttpResponseException(
                            response.code, response_phrase, body
                        )

                        # Retry if the error is a 5xx or a 429 (Too Many
                        # Requests), otherwise just raise a standard
                        # `HttpResponseException`
                        if 500 <= response.code < 600 or response.code == 429:
                            raise RequestSendFailed(exc, can_retry=True) from exc
                        else:
                            raise exc

                    break
                except RequestSendFailed as e:
                    logger.info(
                        "{%s} [%s] Request failed: %s %s: %s",
                        request.txn_id,
                        request.destination,
                        request.method,
                        url_str,
                        _flatten_response_never_received(e.inner_exception),
                    )

                    if not e.can_retry:
                        raise

                    if retries_left and not timeout:
                        if long_retries:
                            delay_seconds = 4 ** (
                                self.max_long_retries + 1 - retries_left
                            )
                            delay_seconds = min(
                                delay_seconds, self.max_long_retry_delay_seconds
                            )
                            delay_seconds *= random.uniform(0.8, 1.4)
                        else:
                            delay_seconds = 0.5 * 2 ** (
                                self.max_short_retries - retries_left
                            )
                            delay_seconds = min(
                                delay_seconds, self.max_short_retry_delay_seconds
                            )
                            delay_seconds *= random.uniform(0.8, 1.4)

                        logger.debug(
                            "{%s} [%s] Waiting %ss before re-sending...",
                            request.txn_id,
                            request.destination,
                            delay_seconds,
                        )

                        # Sleep for the calculated delay, or wake up immediately
                        # if we get notified that the server is back up.
                        await self._sleeper.sleep(
                            request.destination, delay_seconds * 1000
                        )
                        retries_left -= 1
                    else:
                        raise

                except Exception as e:
                    logger.warning(
                        "{%s} [%s] Request failed: %s %s: %s",
                        request.txn_id,
                        request.destination,
                        request.method,
                        url_str,
                        _flatten_response_never_received(e),
                    )
                    raise
        return response

    def build_auth_headers(
        self,
        destination: Optional[bytes],
        method: bytes,
        url_bytes: bytes,
        content: Optional[JsonDict] = None,
        destination_is: Optional[bytes] = None,
    ) -> List[bytes]:
        """
        Builds the Authorization headers for a federation request
        Args:
            destination: The destination homeserver of the request.
                May be None if the destination is an identity server, in which case
                destination_is must be non-None.
            method: The HTTP method of the request
            url_bytes: The URI path of the request
            content: The body of the request
            destination_is: As 'destination', but if the destination is an
                identity server

        Returns:
            A list of headers to be added as "Authorization:" headers
        """
        if not destination and not destination_is:
            raise ValueError(
                "At least one of the arguments destination and destination_is "
                "must be a nonempty bytestring."
            )

        request: JsonDict = {
            "method": method.decode("ascii"),
            "uri": url_bytes.decode("ascii"),
            "origin": self.server_name,
        }

        if destination is not None:
            request["destination"] = destination.decode("ascii")

        if destination_is is not None:
            request["destination_is"] = destination_is.decode("ascii")

        if content is not None:
            request["content"] = content

        request = sign_json(request, self.server_name, self.signing_key)

        auth_headers = []

        for key, sig in request["signatures"][self.server_name].items():
            auth_headers.append(
                (
                    'X-Matrix origin="%s",key="%s",sig="%s",destination="%s"'
                    % (
                        self.server_name,
                        key,
                        sig,
                        request.get("destination") or request["destination_is"],
                    )
                ).encode("ascii")
            )
        return auth_headers

    @overload
    async def put_json(
        self,
        destination: str,
        path: str,
        args: Optional[QueryParams] = None,
        data: Optional[JsonDict] = None,
        json_data_callback: Optional[Callable[[], JsonDict]] = None,
        long_retries: bool = False,
        timeout: Optional[int] = None,
        ignore_backoff: bool = False,
        backoff_on_404: bool = False,
        try_trailing_slash_on_400: bool = False,
        parser: Literal[None] = None,
        backoff_on_all_error_codes: bool = False,
    ) -> JsonDict: ...

    @overload
    async def put_json(
        self,
        destination: str,
        path: str,
        args: Optional[QueryParams] = None,
        data: Optional[JsonDict] = None,
        json_data_callback: Optional[Callable[[], JsonDict]] = None,
        long_retries: bool = False,
        timeout: Optional[int] = None,
        ignore_backoff: bool = False,
        backoff_on_404: bool = False,
        try_trailing_slash_on_400: bool = False,
        parser: Optional[ByteParser[T]] = None,
        backoff_on_all_error_codes: bool = False,
    ) -> T: ...

    async def put_json(
        self,
        destination: str,
        path: str,
        args: Optional[QueryParams] = None,
        data: Optional[JsonDict] = None,
        json_data_callback: Optional[Callable[[], JsonDict]] = None,
        long_retries: bool = False,
        timeout: Optional[int] = None,
        ignore_backoff: bool = False,
        backoff_on_404: bool = False,
        try_trailing_slash_on_400: bool = False,
        parser: Optional[ByteParser[T]] = None,
        backoff_on_all_error_codes: bool = False,
    ) -> Union[JsonDict, T]:
        """Sends the specified json data using PUT

        Args:
            destination: The remote server to send the HTTP request to.
            path: The HTTP path.
            args: query params
            data: A dict containing the data that will be used as
                the request body. This will be encoded as JSON.
            json_data_callback: A callable returning the dict to
                use as the request body.

            long_retries: whether to use the long retry algorithm. See
                docs on _send_request for details.

            timeout: number of milliseconds to wait for the response.
                self._default_timeout (60s) by default.

                Note that we may make several attempts to send the request; this
                timeout applies to the time spent waiting for response headers for
                *each* attempt (including connection time) as well as the time spent
                reading the response body after a 200 response.

            ignore_backoff: true to ignore the historical backoff data
                and try the request anyway.
            backoff_on_404: True if we should count a 404 response as
                a failure of the server (and should therefore back off future
                requests).
            try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
                response we should try appending a trailing slash to the end
                of the request. Workaround for https://github.com/matrix-org/synapse/issues/3622
                in Synapse <= v0.99.3. This will be attempted before backing off if
                backing off has been enabled.
            parser: The parser to use to decode the response. Defaults to
                parsing as JSON.
            backoff_on_all_error_codes: Back off if we get any error response

        Returns:
            Succeeds when we get a 2xx HTTP response. The
            result will be the decoded JSON body.

        Raises:
            HttpResponseException: If we get an HTTP response code >= 300
                (except 429).
            NotRetryingDestination: If we are not yet ready to retry this
                server.
            FederationDeniedError: If this destination is not on our
                federation whitelist
            RequestSendFailed: If there were problems connecting to the
                remote, due to e.g. DNS failures, connection timeouts etc.
        """
        request = MatrixFederationRequest(
            method="PUT",
            destination=destination,
            path=path,
            query=args,
            json_callback=json_data_callback,
            json=data,
        )

        start_ms = self.clock.time_msec()

        response = await self._send_request_with_optional_trailing_slash(
            request,
            try_trailing_slash_on_400,
            backoff_on_404=backoff_on_404,
            ignore_backoff=ignore_backoff,
            long_retries=long_retries,
            timeout=timeout,
            backoff_on_all_error_codes=backoff_on_all_error_codes,
        )

        if timeout is not None:
            _sec_timeout = timeout / 1000
        else:
            _sec_timeout = self.default_timeout_seconds

        if parser is None:
            parser = cast(ByteParser[T], JsonParser())

        body = await _handle_response(
            self.reactor,
            _sec_timeout,
            request,
            response,
            start_ms,
            parser=parser,
        )

        return body

    async def post_json(
        self,
        destination: str,
        path: str,
        data: Optional[JsonDict] = None,
        long_retries: bool = False,
        timeout: Optional[int] = None,
        ignore_backoff: bool = False,
        args: Optional[QueryParams] = None,
    ) -> JsonDict:
        """Sends the specified json data using POST

        Args:
            destination: The remote server to send the HTTP request to.

            path: The HTTP path.

            data: A dict containing the data that will be used as
                the request body. This will be encoded as JSON.

            long_retries: whether to use the long retry algorithm. See
                docs on _send_request for details.

            timeout: number of milliseconds to wait for the response.
                self._default_timeout (60s) by default.

                Note that we may make several attempts to send the request; this
                timeout applies to the time spent waiting for response headers for
                *each* attempt (including connection time) as well as the time spent
                reading the response body after a 200 response.

            ignore_backoff: true to ignore the historical backoff data and
                try the request anyway.

            args: query params
        Returns:
            Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body.

        Raises:
            HttpResponseException: If we get an HTTP response code >= 300
                (except 429).
            NotRetryingDestination: If we are not yet ready to retry this
                server.
            FederationDeniedError: If this destination is not on our
                federation whitelist
            RequestSendFailed: If there were problems connecting to the
                remote, due to e.g. DNS failures, connection timeouts etc.
        """

        request = MatrixFederationRequest(
            method="POST", destination=destination, path=path, query=args, json=data
        )

        start_ms = self.clock.time_msec()

        response = await self._send_request(
            request,
            long_retries=long_retries,
            timeout=timeout,
            ignore_backoff=ignore_backoff,
        )

        if timeout is not None:
            _sec_timeout = timeout / 1000
        else:
            _sec_timeout = self.default_timeout_seconds

        body = await _handle_response(
            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
        )
        return body

    @overload
    async def get_json(
        self,
        destination: str,
        path: str,
        args: Optional[QueryParams] = None,
        retry_on_dns_fail: bool = True,
        timeout: Optional[int] = None,
        ignore_backoff: bool = False,
        try_trailing_slash_on_400: bool = False,
        parser: Literal[None] = None,
    ) -> JsonDict: ...

    @overload
    async def get_json(
        self,
        destination: str,
        path: str,
        args: Optional[QueryParams] = ...,
        retry_on_dns_fail: bool = ...,
        timeout: Optional[int] = ...,
        ignore_backoff: bool = ...,
        try_trailing_slash_on_400: bool = ...,
        parser: ByteParser[T] = ...,
    ) -> T: ...

    async def get_json(
        self,
        destination: str,
        path: str,
        args: Optional[QueryParams] = None,
        retry_on_dns_fail: bool = True,
        timeout: Optional[int] = None,
        ignore_backoff: bool = False,
        try_trailing_slash_on_400: bool = False,
        parser: Optional[ByteParser[T]] = None,
    ) -> Union[JsonDict, T]:
        """GETs some json from the given host homeserver and path

        Args:
            destination: The remote server to send the HTTP request to.

            path: The HTTP path.

            args: A dictionary used to create query strings, defaults to
                None.

            retry_on_dns_fail: true if the request should be retried on DNS failures

            timeout: number of milliseconds to wait for the response.
                self._default_timeout (60s) by default.

                Note that we may make several attempts to send the request; this
                timeout applies to the time spent waiting for response headers for
                *each* attempt (including connection time) as well as the time spent
                reading the response body after a 200 response.

            ignore_backoff: true to ignore the historical backoff data
                and try the request anyway.

            try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
                response we should try appending a trailing slash to the end of
                the request. Workaround for https://github.com/matrix-org/synapse/issues/3622
                in Synapse <= v0.99.3.

            parser: The parser to use to decode the response. Defaults to
                parsing as JSON.

        Returns:
            Succeeds when we get a 2xx HTTP response. The
            result will be the decoded JSON body.

        Raises:
            HttpResponseException: If we get an HTTP response code >= 300
                (except 429).
            NotRetryingDestination: If we are not yet ready to retry this
                server.
            FederationDeniedError: If this destination is not on our
                federation whitelist
            RequestSendFailed: If there were problems connecting to the
                remote, due to e.g. DNS failures, connection timeouts etc.
        """
        json_dict, _ = await self.get_json_with_headers(
            destination=destination,
            path=path,
            args=args,
            retry_on_dns_fail=retry_on_dns_fail,
            timeout=timeout,
            ignore_backoff=ignore_backoff,
            try_trailing_slash_on_400=try_trailing_slash_on_400,
            parser=parser,
        )
        return json_dict

    @overload
    async def get_json_with_headers(
        self,
        destination: str,
        path: str,
        args: Optional[QueryParams] = None,
        retry_on_dns_fail: bool = True,
        timeout: Optional[int] = None,
        ignore_backoff: bool = False,
        try_trailing_slash_on_400: bool = False,
        parser: Literal[None] = None,
    ) -> Tuple[JsonDict, Dict[bytes, List[bytes]]]: ...

    @overload
    async def get_json_with_headers(
        self,
        destination: str,
        path: str,
        args: Optional[QueryParams] = ...,
        retry_on_dns_fail: bool = ...,
        timeout: Optional[int] = ...,
        ignore_backoff: bool = ...,
        try_trailing_slash_on_400: bool = ...,
        parser: ByteParser[T] = ...,
    ) -> Tuple[T, Dict[bytes, List[bytes]]]: ...

    async def get_json_with_headers(
        self,
        destination: str,
        path: str,
        args: Optional[QueryParams] = None,
        retry_on_dns_fail: bool = True,
        timeout: Optional[int] = None,
        ignore_backoff: bool = False,
        try_trailing_slash_on_400: bool = False,
        parser: Optional[ByteParser[T]] = None,
    ) -> Tuple[Union[JsonDict, T], Dict[bytes, List[bytes]]]:
        """GETs some json from the given host homeserver and path

        Args:
            destination: The remote server to send the HTTP request to.

            path: The HTTP path.

            args: A dictionary used to create query strings, defaults to
                None.

            retry_on_dns_fail: true if the request should be retried on DNS failures

            timeout: number of milliseconds to wait for the response.
                self._default_timeout (60s) by default.

                Note that we may make several attempts to send the request; this
                timeout applies to the time spent waiting for response headers for
                *each* attempt (including connection time) as well as the time spent
                reading the response body after a 200 response.

            ignore_backoff: true to ignore the historical backoff data
                and try the request anyway.

            try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED
                response we should try appending a trailing slash to the end of
                the request. Workaround for https://github.com/matrix-org/synapse/issues/3622
                in Synapse <= v0.99.3.

            parser: The parser to use to decode the response. Defaults to
                parsing as JSON.

        Returns:
            Succeeds when we get a 2xx HTTP response. The result will be a tuple of the
            decoded JSON body and a dict of the response headers.

        Raises:
            HttpResponseException: If we get an HTTP response code >= 300
                (except 429).
            NotRetryingDestination: If we are not yet ready to retry this
                server.
            FederationDeniedError: If this destination is not on our
                federation whitelist
            RequestSendFailed: If there were problems connecting to the
                remote, due to e.g. DNS failures, connection timeouts etc.
        """
        request = MatrixFederationRequest(
            method="GET", destination=destination, path=path, query=args
        )

        start_ms = self.clock.time_msec()

        response = await self._send_request_with_optional_trailing_slash(
            request,
            try_trailing_slash_on_400,
            backoff_on_404=False,
            ignore_backoff=ignore_backoff,
            retry_on_dns_fail=retry_on_dns_fail,
            timeout=timeout,
        )

        headers = dict(response.headers.getAllRawHeaders())

        if timeout is not None:
            _sec_timeout = timeout / 1000
        else:
            _sec_timeout = self.default_timeout_seconds

        if parser is None:
            parser = cast(ByteParser[T], JsonParser())

        body = await _handle_response(
            self.reactor,
            _sec_timeout,
            request,
            response,
            start_ms,
            parser=parser,
        )

        return body, headers

    async def delete_json(
        self,
        destination: str,
        path: str,
        long_retries: bool = False,
        timeout: Optional[int] = None,
        ignore_backoff: bool = False,
        args: Optional[QueryParams] = None,
    ) -> JsonDict:
        """Send a DELETE request to the remote expecting some json response

        Args:
            destination: The remote server to send the HTTP request to.
            path: The HTTP path.

            long_retries: whether to use the long retry algorithm. See
                docs on _send_request for details.

            timeout: number of milliseconds to wait for the response.
                self._default_timeout (60s) by default.

                Note that we may make several attempts to send the request; this
                timeout applies to the time spent waiting for response headers for
                *each* attempt (including connection time) as well as the time spent
                reading the response body after a 200 response.

            ignore_backoff: true to ignore the historical backoff data and
                try the request anyway.

            args: query params
        Returns:
            Succeeds when we get a 2xx HTTP response. The
            result will be the decoded JSON body.

        Raises:
            HttpResponseException: If we get an HTTP response code >= 300
                (except 429).
            NotRetryingDestination: If we are not yet ready to retry this
                server.
            FederationDeniedError: If this destination is not on our
                federation whitelist
            RequestSendFailed: If there were problems connecting to the
                remote, due to e.g. DNS failures, connection timeouts etc.
        """
        request = MatrixFederationRequest(
            method="DELETE", destination=destination, path=path, query=args
        )

        start_ms = self.clock.time_msec()

        response = await self._send_request(
            request,
            long_retries=long_retries,
            timeout=timeout,
            ignore_backoff=ignore_backoff,
        )

        if timeout is not None:
            _sec_timeout = timeout / 1000
        else:
            _sec_timeout = self.default_timeout_seconds

        body = await _handle_response(
            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
        )
        return body

    async def get_file(
        self,
        destination: str,
        path: str,
        output_stream: BinaryIO,
        download_ratelimiter: Ratelimiter,
        ip_address: str,
        max_size: int,
        args: Optional[QueryParams] = None,
        retry_on_dns_fail: bool = True,
        ignore_backoff: bool = False,
        follow_redirects: bool = False,
    ) -> Tuple[int, Dict[bytes, List[bytes]]]:
        """GETs a file from a given homeserver
        Args:
            destination: The remote server to send the HTTP request to.
            path: The HTTP path to GET.
            output_stream: File to write the response body to.
            download_ratelimiter: a ratelimiter to limit remote media downloads, keyed to
                requester IP
            ip_address: IP address of the requester
            max_size: maximum allowable size in bytes of the file
            args: Optional dictionary used to create the query string.
            ignore_backoff: true to ignore the historical backoff data
                and try the request anyway.
            follow_redirects: True to follow the Location header of 307/308 redirect
                responses. This does not recurse.

        Returns:
            Resolves with an (int,dict) tuple of
            the file length and a dict of the response headers.

        Raises:
            HttpResponseException: If we get an HTTP response code >= 300
                (except 429).
            NotRetryingDestination: If we are not yet ready to retry this
                server.
            FederationDeniedError: If this destination is not on our
                federation whitelist
            RequestSendFailed: If there were problems connecting to the
                remote, due to e.g. DNS failures, connection timeouts etc.
            SynapseError: If the requested file exceeds ratelimits
        """
        request = MatrixFederationRequest(
            method="GET", destination=destination, path=path, query=args
        )

        # check for a minimum balance of 1MiB in ratelimiter before initiating request
        send_req, _ = await download_ratelimiter.can_do_action(
            requester=None, key=ip_address, n_actions=1048576, update=False
        )

        if not send_req:
            msg = "Requested file size exceeds ratelimits"
            logger.warning(
                "{%s} [%s] %s",
                request.txn_id,
                request.destination,
                msg,
            )
            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)

        response = await self._send_request(
            request,
            retry_on_dns_fail=retry_on_dns_fail,
            ignore_backoff=ignore_backoff,
            follow_redirects=follow_redirects,
        )

        headers = dict(response.headers.getAllRawHeaders())

        expected_size = response.length
        # if we don't get an expected length then use the max length
        if expected_size == UNKNOWN_LENGTH:
            expected_size = max_size
            logger.debug(
                f"File size unknown, assuming file is max allowable size: {max_size}"
            )

        read_body, _ = await download_ratelimiter.can_do_action(
            requester=None,
            key=ip_address,
            n_actions=expected_size,
        )
        if not read_body:
            msg = "Requested file size exceeds ratelimits"
            logger.warning(
                "{%s} [%s] %s",
                request.txn_id,
                request.destination,
                msg,
            )
            raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED)

        try:
            # add a byte of headroom to max size as function errs at >=
            d = read_body_with_max_size(response, output_stream, expected_size + 1)
            d.addTimeout(self.default_timeout_seconds, self.reactor)
            length = await make_deferred_yieldable(d)
        except BodyExceededMaxSize:
            msg = "Requested file is too large > %r bytes" % (expected_size,)
            logger.warning(
                "{%s} [%s] %s",
                request.txn_id,
                request.destination,
                msg,
            )
            raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE)
        except defer.TimeoutError as e:
            logger.warning(
                "{%s} [%s] Timed out reading response - %s %s",
                request.txn_id,
                request.destination,
                request.method,
                request.uri.decode("ascii"),
            )
            raise RequestSendFailed(e, can_retry=True) from e
        except ResponseFailed as e:
            logger.warning(
                "{%s} [%s] Failed to read response - %s %s",
                request.txn_id,
                request.destination,
                request.method,
                request.uri.decode("ascii"),
            )
            raise RequestSendFailed(e, can_retry=True) from e
        except Exception as e:
            logger.warning(
                "{%s} [%s] Error reading response: %s",
                request.txn_id,
                request.destination,
                e,
            )
            raise
        logger.info(
            "{%s} [%s] Completed: %d %s [%d bytes] %s %s",
            request.txn_id,
            request.destination,
            response.code,
            response.phrase.decode("ascii", errors="replace"),
            length,
            request.method,
            request.uri.decode("ascii"),
        )
        return length, headers


def _flatten_response_never_received(e: BaseException) -> str:
    if hasattr(e, "reasons"):
        reasons = ", ".join(
            _flatten_response_never_received(f.value) for f in e.reasons
        )

        return "%s:[%s]" % (type(e).__name__, reasons)
    else:
        return repr(e)


def check_content_type_is(headers: Headers, expected_content_type: str) -> None:
    """
    Check that a set of HTTP headers have a Content-Type header, and that it
    is the expected value..

    Args:
        headers: headers to check

    Raises:
        RequestSendFailed: if the Content-Type header is missing or doesn't match

    """
    content_type_headers = headers.getRawHeaders(b"Content-Type")
    if content_type_headers is None:
        raise RequestSendFailed(
            RuntimeError("No Content-Type header received from remote server"),
            can_retry=False,
        )

    c_type = content_type_headers[0].decode("ascii")  # only the first header
    val, options = cgi.parse_header(c_type)
    if val != expected_content_type:
        raise RequestSendFailed(
            RuntimeError(
                f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'",
            ),
            can_retry=False,
        )