summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/http/client.py213
-rw-r--r--synapse/http/federation/matrix_federation_agent.py100
-rw-r--r--synapse/http/federation/well_known_resolver.py18
-rw-r--r--synapse/http/matrixfederationclient.py364
-rw-r--r--synapse/http/request_metrics.py2
-rw-r--r--synapse/http/server.py48
-rw-r--r--synapse/http/servlet.py3
-rw-r--r--synapse/http/site.py50
8 files changed, 415 insertions, 383 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py

index 8324632cb6..e5b13593f2 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -14,9 +14,10 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -import urllib +import urllib.parse from io import BytesIO from typing import ( + TYPE_CHECKING, Any, BinaryIO, Dict, @@ -31,7 +32,7 @@ from typing import ( import treq from canonicaljson import encode_canonical_json -from netaddr import IPAddress +from netaddr import IPAddress, IPSet from prometheus_client import Counter from zope.interface import implementer, provider @@ -39,6 +40,8 @@ from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE from twisted.internet import defer, error as twisted_error, protocol, ssl from twisted.internet.interfaces import ( + IAddress, + IHostResolution, IReactorPluggableNameResolver, IResolutionReceiver, ) @@ -53,7 +56,7 @@ from twisted.web.client import ( ) from twisted.web.http import PotentialDataLoss from twisted.web.http_headers import Headers -from twisted.web.iweb import IResponse +from twisted.web.iweb import IAgent, IBodyProducer, IResponse from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri @@ -63,6 +66,9 @@ from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred +if TYPE_CHECKING: + from synapse.app.homeserver import HomeServer + logger = logging.getLogger(__name__) outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) @@ -84,12 +90,19 @@ QueryParamValue = Union[str, bytes, Iterable[Union[str, bytes]]] QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]] -def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist): +def check_against_blacklist( + ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet +) -> bool: """ + Compares an IP address to allowed and disallowed IP sets. + Args: - ip_address (netaddr.IPAddress) - ip_whitelist (netaddr.IPSet) - ip_blacklist (netaddr.IPSet) + ip_address: The IP address to check + ip_whitelist: Allowed IP addresses. + ip_blacklist: Disallowed IP addresses. + + Returns: + True if the IP address is in the blacklist and not in the whitelist. """ if ip_address in ip_blacklist: if ip_whitelist is None or ip_address not in ip_whitelist: @@ -118,23 +131,30 @@ class IPBlacklistingResolver: addresses, preventing DNS rebinding attacks on URL preview. """ - def __init__(self, reactor, ip_whitelist, ip_blacklist): + def __init__( + self, + reactor: IReactorPluggableNameResolver, + ip_whitelist: Optional[IPSet], + ip_blacklist: IPSet, + ): """ Args: - reactor (twisted.internet.reactor) - ip_whitelist (netaddr.IPSet) - ip_blacklist (netaddr.IPSet) + reactor: The twisted reactor. + ip_whitelist: IP addresses to allow. + ip_blacklist: IP addresses to disallow. """ self._reactor = reactor self._ip_whitelist = ip_whitelist self._ip_blacklist = ip_blacklist - def resolveHostName(self, recv, hostname, portNumber=0): + def resolveHostName( + self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0 + ) -> IResolutionReceiver: r = recv() - addresses = [] + addresses = [] # type: List[IAddress] - def _callback(): + def _callback() -> None: r.resolutionBegan(None) has_bad_ip = False @@ -161,15 +181,15 @@ class IPBlacklistingResolver: @provider(IResolutionReceiver) class EndpointReceiver: @staticmethod - def resolutionBegan(resolutionInProgress): + def resolutionBegan(resolutionInProgress: IHostResolution) -> None: pass @staticmethod - def addressResolved(address): + def addressResolved(address: IAddress) -> None: addresses.append(address) @staticmethod - def resolutionComplete(): + def resolutionComplete() -> None: _callback() self._reactor.nameResolver.resolveHostName( @@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent): directly (without an IP address lookup). """ - def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None): + def __init__( + self, + agent: IAgent, + ip_whitelist: Optional[IPSet] = None, + ip_blacklist: Optional[IPSet] = None, + ): """ Args: - agent (twisted.web.client.Agent): The Agent to wrap. - reactor (twisted.internet.reactor) - ip_whitelist (netaddr.IPSet) - ip_blacklist (netaddr.IPSet) + agent: The Agent to wrap. + ip_whitelist: IP addresses to allow. + ip_blacklist: IP addresses to disallow. """ self._agent = agent self._ip_whitelist = ip_whitelist self._ip_blacklist = ip_blacklist - def request(self, method, uri, headers=None, bodyProducer=None): + def request( + self, + method: bytes, + uri: bytes, + headers: Optional[Headers] = None, + bodyProducer: Optional[IBodyProducer] = None, + ) -> defer.Deferred: h = urllib.parse.urlparse(uri.decode("ascii")) try: @@ -226,23 +256,23 @@ class SimpleHttpClient: def __init__( self, - hs, - treq_args={}, - ip_whitelist=None, - ip_blacklist=None, - http_proxy=None, - https_proxy=None, + hs: "HomeServer", + treq_args: Dict[str, Any] = {}, + ip_whitelist: Optional[IPSet] = None, + ip_blacklist: Optional[IPSet] = None, + http_proxy: Optional[bytes] = None, + https_proxy: Optional[bytes] = None, ): """ Args: - hs (synapse.server.HomeServer) - treq_args (dict): Extra keyword arguments to be given to treq.request. - ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that + hs + treq_args: Extra keyword arguments to be given to treq.request. + ip_blacklist: The IP addresses that are blacklisted that we may not request. - ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can + ip_whitelist: The whitelisted IP addresses, that we can request if it were otherwise caught in a blacklist. - http_proxy (bytes): proxy server to use for http connections. host[:port] - https_proxy (bytes): proxy server to use for https connections. host[:port] + http_proxy: proxy server to use for http connections. host[:port] + https_proxy: proxy server to use for https connections. host[:port] """ self.hs = hs @@ -306,7 +336,6 @@ class SimpleHttpClient: # by the DNS resolution. self.agent = BlacklistingAgentWrapper( self.agent, - self.reactor, ip_whitelist=self._ip_whitelist, ip_blacklist=self._ip_blacklist, ) @@ -359,7 +388,7 @@ class SimpleHttpClient: agent=self.agent, data=body_producer, headers=headers, - **self._extra_treq_args + **self._extra_treq_args, ) # type: defer.Deferred # we use our own timeout mechanism rather than treq's as a workaround @@ -397,7 +426,7 @@ class SimpleHttpClient: async def post_urlencoded_get_json( self, uri: str, - args: Mapping[str, Union[str, List[str]]] = {}, + args: Optional[Mapping[str, Union[str, List[str]]]] = None, headers: Optional[RawHeaders] = None, ) -> Any: """ @@ -422,9 +451,7 @@ class SimpleHttpClient: # TODO: Do we ever want to log message contents? logger.debug("post_urlencoded_get_json args: %s", args) - query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode( - "utf8" - ) + query_bytes = encode_query_args(args) actual_headers = { b"Content-Type": [b"application/x-www-form-urlencoded"], @@ -432,7 +459,7 @@ class SimpleHttpClient: b"Accept": [b"application/json"], } if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request( "POST", uri, headers=Headers(actual_headers), data=query_bytes @@ -479,7 +506,7 @@ class SimpleHttpClient: b"Accept": [b"application/json"], } if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request( "POST", uri, headers=Headers(actual_headers), data=json_str @@ -495,7 +522,10 @@ class SimpleHttpClient: ) async def get_json( - self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None, + self, + uri: str, + args: Optional[QueryParams] = None, + headers: Optional[RawHeaders] = None, ) -> Any: """Gets some json from the given URI. @@ -516,7 +546,7 @@ class SimpleHttpClient: """ actual_headers = {b"Accept": [b"application/json"]} if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore body = await self.get_raw(uri, args, headers=headers) return json_decoder.decode(body.decode("utf-8")) @@ -525,7 +555,7 @@ class SimpleHttpClient: self, uri: str, json_body: Any, - args: QueryParams = {}, + args: Optional[QueryParams] = None, headers: RawHeaders = None, ) -> Any: """Puts some json to the given URI. @@ -546,9 +576,9 @@ class SimpleHttpClient: ValueError: if the response was not JSON """ - if len(args): - query_bytes = urllib.parse.urlencode(args, True) - uri = "%s?%s" % (uri, query_bytes) + if args: + query_str = urllib.parse.urlencode(args, True) + uri = "%s?%s" % (uri, query_str) json_str = encode_canonical_json(json_body) @@ -558,7 +588,7 @@ class SimpleHttpClient: b"Accept": [b"application/json"], } if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request( "PUT", uri, headers=Headers(actual_headers), data=json_str @@ -574,7 +604,10 @@ class SimpleHttpClient: ) async def get_raw( - self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None + self, + uri: str, + args: Optional[QueryParams] = None, + headers: Optional[RawHeaders] = None, ) -> bytes: """Gets raw text from the given URI. @@ -592,13 +625,13 @@ class SimpleHttpClient: HttpResponseException on a non-2xx HTTP response. """ - if len(args): - query_bytes = urllib.parse.urlencode(args, True) - uri = "%s?%s" % (uri, query_bytes) + if args: + query_str = urllib.parse.urlencode(args, True) + uri = "%s?%s" % (uri, query_str) actual_headers = {b"User-Agent": [self.user_agent]} if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request("GET", uri, headers=Headers(actual_headers)) @@ -641,7 +674,7 @@ class SimpleHttpClient: actual_headers = {b"User-Agent": [self.user_agent]} if headers: - actual_headers.update(headers) + actual_headers.update(headers) # type: ignore response = await self.request("GET", url, headers=Headers(actual_headers)) @@ -649,12 +682,13 @@ class SimpleHttpClient: if ( b"Content-Length" in resp_headers + and max_size and int(resp_headers[b"Content-Length"][0]) > max_size ): - logger.warning("Requested URL is too large > %r bytes" % (self.max_size,)) + logger.warning("Requested URL is too large > %r bytes" % (max_size,)) raise SynapseError( 502, - "Requested file is too large > %r bytes" % (self.max_size,), + "Requested file is too large > %r bytes" % (max_size,), Codes.TOO_LARGE, ) @@ -668,7 +702,7 @@ class SimpleHttpClient: try: length = await make_deferred_yieldable( - _readBodyToFile(response, output_stream, max_size) + readBodyToFile(response, output_stream, max_size) ) except SynapseError: # This can happen e.g. because the body is too large. @@ -696,18 +730,16 @@ def _timeout_to_request_timed_out_error(f: Failure): return f -# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. -# The two should be factored out. - - class _ReadBodyToFileProtocol(protocol.Protocol): - def __init__(self, stream, deferred, max_size): + def __init__( + self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int] + ): self.stream = stream self.deferred = deferred self.length = 0 self.max_size = max_size - def dataReceived(self, data): + def dataReceived(self, data: bytes) -> None: self.stream.write(data) self.length += len(data) if self.max_size is not None and self.length >= self.max_size: @@ -721,7 +753,7 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.deferred = defer.Deferred() self.transport.loseConnection() - def connectionLost(self, reason): + def connectionLost(self, reason: Failure) -> None: if reason.check(ResponseDone): self.deferred.callback(self.length) elif reason.check(PotentialDataLoss): @@ -732,35 +764,48 @@ class _ReadBodyToFileProtocol(protocol.Protocol): self.deferred.errback(reason) -# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient. -# The two should be factored out. +def readBodyToFile( + response: IResponse, stream: BinaryIO, max_size: Optional[int] +) -> defer.Deferred: + """ + Read a HTTP response body to a file-object. Optionally enforcing a maximum file size. + Args: + response: The HTTP response to read from. + stream: The file-object to write to. + max_size: The maximum file size to allow. + + Returns: + A Deferred which resolves to the length of the read body. + """ -def _readBodyToFile(response, stream, max_size): d = defer.Deferred() response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) return d -def encode_urlencode_args(args): - return {k: encode_urlencode_arg(v) for k, v in args.items()} +def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes: + """ + Encodes a map of query arguments to bytes which can be appended to a URL. + Args: + args: The query arguments, a mapping of string to string or list of strings. + + Returns: + The query arguments encoded as bytes. + """ + if args is None: + return b"" -def encode_urlencode_arg(arg): - if isinstance(arg, str): - return arg.encode("utf-8") - elif isinstance(arg, list): - return [encode_urlencode_arg(i) for i in arg] - else: - return arg + encoded_args = {} + for k, vs in args.items(): + if isinstance(vs, str): + vs = [vs] + encoded_args[k] = [v.encode("utf8") for v in vs] + query_str = urllib.parse.urlencode(encoded_args, True) -def _print_ex(e): - if hasattr(e, "reasons") and e.reasons: - for ex in e.reasons: - _print_ex(ex) - else: - logger.exception(e) + return query_str.encode("utf8") class InsecureInterceptableContextFactory(ssl.ContextFactory): diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 83d6196d4a..e77f9587d0 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py
@@ -12,21 +12,25 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging -import urllib -from typing import List +import urllib.parse +from typing import List, Optional from netaddr import AddrFormatError, IPAddress from zope.interface import implementer from twisted.internet import defer from twisted.internet.endpoints import HostnameEndpoint, wrapClientTLS -from twisted.internet.interfaces import IStreamClientEndpoint -from twisted.web.client import Agent, HTTPConnectionPool +from twisted.internet.interfaces import ( + IProtocolFactory, + IReactorCore, + IStreamClientEndpoint, +) +from twisted.web.client import URI, Agent, HTTPConnectionPool from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IAgentEndpointFactory +from twisted.web.iweb import IAgent, IAgentEndpointFactory, IBodyProducer +from synapse.crypto.context_factory import FederationPolicyForHTTPS from synapse.http.federation.srv_resolver import Server, SrvResolver from synapse.http.federation.well_known_resolver import WellKnownResolver from synapse.logging.context import make_deferred_yieldable, run_in_background @@ -44,30 +48,30 @@ class MatrixFederationAgent: Doesn't implement any retries. (Those are done in MatrixFederationHttpClient.) Args: - reactor (IReactor): twisted reactor to use for underlying requests + reactor: twisted reactor to use for underlying requests - tls_client_options_factory (FederationPolicyForHTTPS|None): + tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. - user_agent (bytes): + user_agent: The user agent header to use for federation requests. - _srv_resolver (SrvResolver|None): - SRVResolver impl to use for looking up SRV records. None to use a default - implementation. + _srv_resolver: + SrvResolver implementation to use for looking up SRV records. None + to use a default implementation. - _well_known_resolver (WellKnownResolver|None): + _well_known_resolver: WellKnownResolver to use to perform well-known lookups. None to use a default implementation. """ def __init__( self, - reactor, - tls_client_options_factory, - user_agent, - _srv_resolver=None, - _well_known_resolver=None, + reactor: IReactorCore, + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + user_agent: bytes, + _srv_resolver: Optional[SrvResolver] = None, + _well_known_resolver: Optional[WellKnownResolver] = None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -99,15 +103,20 @@ class MatrixFederationAgent: self._well_known_resolver = _well_known_resolver @defer.inlineCallbacks - def request(self, method, uri, headers=None, bodyProducer=None): + def request( + self, + method: bytes, + uri: bytes, + headers: Optional[Headers] = None, + bodyProducer: Optional[IBodyProducer] = None, + ) -> defer.Deferred: """ Args: - method (bytes): HTTP method: GET/POST/etc - uri (bytes): Absolute URI to be retrieved - headers (twisted.web.http_headers.Headers|None): - HTTP headers to send with the request, or None to - send no extra headers. - bodyProducer (twisted.web.iweb.IBodyProducer|None): + method: HTTP method: GET/POST/etc + uri: Absolute URI to be retrieved + headers: + HTTP headers to send with the request, or None to send no extra headers. + bodyProducer: An object which can generate bytes to make up the body of this request (for example, the properly encoded contents of a file for a file upload). Or None if the request is to have @@ -123,6 +132,9 @@ class MatrixFederationAgent: # explicit port. parsed_uri = urllib.parse.urlparse(uri) + # There must be a valid hostname. + assert parsed_uri.hostname + # If this is a matrix:// URI check if the server has delegated matrix # traffic using well-known delegation. # @@ -179,7 +191,12 @@ class MatrixHostnameEndpointFactory: """Factory for MatrixHostnameEndpoint for parsing to an Agent. """ - def __init__(self, reactor, tls_client_options_factory, srv_resolver): + def __init__( + self, + reactor: IReactorCore, + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + srv_resolver: Optional[SrvResolver], + ): self._reactor = reactor self._tls_client_options_factory = tls_client_options_factory @@ -203,15 +220,20 @@ class MatrixHostnameEndpoint: resolution (i.e. via SRV). Does not check for well-known delegation. Args: - reactor (IReactor) - tls_client_options_factory (ClientTLSOptionsFactory|None): + reactor: twisted reactor to use for underlying requests + tls_client_options_factory: factory to use for fetching client tls options, or none to disable TLS. - srv_resolver (SrvResolver): The SRV resolver to use - parsed_uri (twisted.web.client.URI): The parsed URI that we're wanting - to connect to. + srv_resolver: The SRV resolver to use + parsed_uri: The parsed URI that we're wanting to connect to. """ - def __init__(self, reactor, tls_client_options_factory, srv_resolver, parsed_uri): + def __init__( + self, + reactor: IReactorCore, + tls_client_options_factory: Optional[FederationPolicyForHTTPS], + srv_resolver: SrvResolver, + parsed_uri: URI, + ): self._reactor = reactor self._parsed_uri = parsed_uri @@ -231,13 +253,13 @@ class MatrixHostnameEndpoint: self._srv_resolver = srv_resolver - def connect(self, protocol_factory): + def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred: """Implements IStreamClientEndpoint interface """ return run_in_background(self._do_connect, protocol_factory) - async def _do_connect(self, protocol_factory): + async def _do_connect(self, protocol_factory: IProtocolFactory) -> None: first_exception = None server_list = await self._resolve_server() @@ -303,20 +325,20 @@ class MatrixHostnameEndpoint: return [Server(host, 8448)] -def _is_ip_literal(host): +def _is_ip_literal(host: bytes) -> bool: """Test if the given host name is either an IPv4 or IPv6 literal. Args: - host (bytes) + host: The host name to check Returns: - bool + True if the hostname is an IP address literal. """ - host = host.decode("ascii") + host_str = host.decode("ascii") try: - IPAddress(host) + IPAddress(host_str) return True except AddrFormatError: return False diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index a306faa267..5e08ef1664 100644 --- a/synapse/http/federation/well_known_resolver.py +++ b/synapse/http/federation/well_known_resolver.py
@@ -12,7 +12,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging import random import time @@ -21,10 +20,11 @@ from typing import Callable, Dict, Optional, Tuple import attr from twisted.internet import defer +from twisted.internet.interfaces import IReactorTime from twisted.web.client import RedirectAgent, readBody from twisted.web.http import stringToDatetime from twisted.web.http_headers import Headers -from twisted.web.iweb import IResponse +from twisted.web.iweb import IAgent, IResponse from synapse.logging.context import make_deferred_yieldable from synapse.util import Clock, json_decoder @@ -81,11 +81,11 @@ class WellKnownResolver: def __init__( self, - reactor, - agent, - user_agent, - well_known_cache=None, - had_well_known_cache=None, + reactor: IReactorTime, + agent: IAgent, + user_agent: bytes, + well_known_cache: Optional[TTLCache] = None, + had_well_known_cache: Optional[TTLCache] = None, ): self._reactor = reactor self._clock = Clock(reactor) @@ -127,7 +127,7 @@ class WellKnownResolver: with Measure(self._clock, "get_well_known"): result, cache_period = await self._fetch_well_known( server_name - ) # type: Tuple[Optional[bytes], float] + ) # type: Optional[bytes], float except _FetchWellKnownFailure as e: if prev_result and e.temporary: @@ -172,7 +172,7 @@ class WellKnownResolver: had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False) # We do this in two steps to differentiate between possibly transient - # errors (e.g. can't connect to host, 503 response) and more permenant + # errors (e.g. can't connect to host, 503 response) and more permanent # errors (such as getting a 404 response). response, body = await self._make_well_known_request( server_name, retry=had_valid_well_known diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c23a4d7c0c..4e27f93b7a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -17,8 +17,9 @@ import cgi import logging import random import sys -import urllib +import urllib.parse from io import BytesIO +from typing import Callable, Dict, List, Optional, Tuple, Union import attr import treq @@ -27,25 +28,27 @@ from prometheus_client import Counter from signedjson.sign import sign_json from zope.interface import implementer -from twisted.internet import defer, protocol +from twisted.internet import defer from twisted.internet.error import DNSLookupError from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime from twisted.internet.task import _EPSILON, Cooperator -from twisted.web._newclient import ResponseDone from twisted.web.http_headers import Headers -from twisted.web.iweb import IResponse +from twisted.web.iweb import IBodyProducer, IResponse import synapse.metrics import synapse.util.retryutils from synapse.api.errors import ( - Codes, FederationDeniedError, HttpResponseException, RequestSendFailed, - SynapseError, ) from synapse.http import QuieterFileBodyProducer -from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver +from synapse.http.client import ( + BlacklistingAgentWrapper, + IPBlacklistingResolver, + encode_query_args, + readBodyToFile, +) from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent from synapse.logging.context import make_deferred_yieldable from synapse.logging.opentracing import ( @@ -54,6 +57,7 @@ from synapse.logging.opentracing import ( start_active_span, tags, ) +from synapse.types import JsonDict from synapse.util import json_decoder from synapse.util.async_helpers import timeout_deferred from synapse.util.metrics import Measure @@ -76,47 +80,44 @@ MAXINT = sys.maxsize _next_id = 1 +QueryArgs = Dict[str, Union[str, List[str]]] + + @attr.s(slots=True, frozen=True) class MatrixFederationRequest: - method = attr.ib() + method = attr.ib(type=str) """HTTP method - :type: str """ - path = attr.ib() + path = attr.ib(type=str) """HTTP path - :type: str """ - destination = attr.ib() + destination = attr.ib(type=str) """The remote server to send the HTTP request to. - :type: str""" + """ - json = attr.ib(default=None) + json = attr.ib(default=None, type=Optional[JsonDict]) """JSON to send in the body. - :type: dict|None """ - json_callback = attr.ib(default=None) + json_callback = attr.ib(default=None, type=Optional[Callable[[], JsonDict]]) """A callback to generate the JSON. - :type: func|None """ - query = attr.ib(default=None) + query = attr.ib(default=None, type=Optional[dict]) """Query arguments. - :type: dict|None """ - txn_id = attr.ib(default=None) + txn_id = attr.ib(default=None, type=Optional[str]) """Unique ID for this request (for logging) - :type: str|None """ uri = attr.ib(init=False, type=bytes) """The URI of this request """ - def __attrs_post_init__(self): + def __attrs_post_init__(self) -> None: global _next_id txn_id = "%s-O-%s" % (self.method, _next_id) _next_id = (_next_id + 1) % (MAXINT - 1) @@ -136,7 +137,7 @@ class MatrixFederationRequest: ) object.__setattr__(self, "uri", uri) - def get_json(self): + def get_json(self) -> Optional[JsonDict]: if self.json_callback: return self.json_callback() return self.json @@ -148,7 +149,7 @@ async def _handle_json_response( request: MatrixFederationRequest, response: IResponse, start_ms: int, -): +) -> JsonDict: """ Reads the JSON body of a response, with a timeout @@ -160,7 +161,7 @@ async def _handle_json_response( start_ms: Timestamp when request was made Returns: - dict: parsed JSON response + The parsed JSON response """ try: check_content_type_is_json(response.headers) @@ -250,9 +251,7 @@ class MatrixFederationHttpClient: # Use a BlacklistingAgentWrapper to prevent circumventing the IP # blacklist via IP literals in server names self.agent = BlacklistingAgentWrapper( - self.agent, - self.reactor, - ip_blacklist=hs.config.federation_ip_range_blacklist, + self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist, ) self.clock = hs.get_clock() @@ -266,27 +265,29 @@ class MatrixFederationHttpClient: self._cooperator = Cooperator(scheduler=schedule) async def _send_request_with_optional_trailing_slash( - self, request, try_trailing_slash_on_400=False, **send_request_args - ): + self, + request: MatrixFederationRequest, + try_trailing_slash_on_400: bool = False, + **send_request_args + ) -> IResponse: """Wrapper for _send_request which can optionally retry the request upon receiving a combination of a 400 HTTP response code and a 'M_UNRECOGNIZED' errcode. This is a workaround for Synapse <= v0.99.3 due to #3622. Args: - request (MatrixFederationRequest): details of request to be sent - try_trailing_slash_on_400 (bool): Whether on receiving a 400 + request: details of request to be sent + try_trailing_slash_on_400: Whether on receiving a 400 'M_UNRECOGNIZED' from the server to retry the request with a trailing slash appended to the request path. - send_request_args (Dict): A dictionary of arguments to pass to - `_send_request()`. + send_request_args: A dictionary of arguments to pass to `_send_request()`. Raises: HttpResponseException: If we get an HTTP response code >= 300 (except 429). Returns: - Dict: Parsed JSON response body. + Parsed JSON response body. """ try: response = await self._send_request(request, **send_request_args) @@ -313,24 +314,26 @@ class MatrixFederationHttpClient: async def _send_request( self, - request, - retry_on_dns_fail=True, - timeout=None, - long_retries=False, - ignore_backoff=False, - backoff_on_404=False, - ): + request: MatrixFederationRequest, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + long_retries: bool = False, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + ) -> IResponse: """ Sends a request to the given server. Args: - request (MatrixFederationRequest): details of request to be sent + request: details of request to be sent + + retry_on_dns_fail: true if the request should be retied on DNS failures - timeout (int|None): number of milliseconds to wait for the response headers + timeout: number of milliseconds to wait for the response headers (including connecting to the server), *for each attempt*. 60s by default. - long_retries (bool): whether to use the long retry algorithm. + long_retries: whether to use the long retry algorithm. The regular retry algorithm makes 4 attempts, with intervals [0.5s, 1s, 2s]. @@ -346,14 +349,13 @@ class MatrixFederationHttpClient: NB: the long retry algorithm takes over 20 minutes to complete, with a default timeout of 60s! - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - backoff_on_404 (bool): Back off if we get a 404 + backoff_on_404: Back off if we get a 404 Returns: - twisted.web.client.Response: resolves with the HTTP - response object on success. + Resolves with the HTTP response object on success. Raises: HttpResponseException: If we get an HTTP response code >= 300 @@ -404,7 +406,7 @@ class MatrixFederationHttpClient: ) # Inject the span into the headers - headers_dict = {} + headers_dict = {} # type: Dict[bytes, List[bytes]] inject_active_span_byte_dict(headers_dict, request.destination) headers_dict[b"User-Agent"] = [self.version_string_bytes] @@ -435,7 +437,7 @@ class MatrixFederationHttpClient: data = encode_canonical_json(json) producer = QuieterFileBodyProducer( BytesIO(data), cooperator=self._cooperator - ) + ) # type: Optional[IBodyProducer] else: producer = None auth_headers = self.build_auth_headers( @@ -524,14 +526,16 @@ class MatrixFederationHttpClient: ) body = None - e = HttpResponseException(response.code, response_phrase, body) + exc = HttpResponseException( + response.code, response_phrase, body + ) # Retry if the error is a 429 (Too Many Requests), # otherwise just raise a standard HttpResponseException if response.code == 429: - raise RequestSendFailed(e, can_retry=True) from e + raise RequestSendFailed(exc, can_retry=True) from exc else: - raise e + raise exc break except RequestSendFailed as e: @@ -582,22 +586,27 @@ class MatrixFederationHttpClient: return response def build_auth_headers( - self, destination, method, url_bytes, content=None, destination_is=None - ): + self, + destination: Optional[bytes], + method: bytes, + url_bytes: bytes, + content: Optional[JsonDict] = None, + destination_is: Optional[bytes] = None, + ) -> List[bytes]: """ Builds the Authorization headers for a federation request Args: - destination (bytes|None): The desination homeserver of the request. + destination: The destination homeserver of the request. May be None if the destination is an identity server, in which case destination_is must be non-None. - method (bytes): The HTTP method of the request - url_bytes (bytes): The URI path of the request - content (object): The body of the request - destination_is (bytes): As 'destination', but if the destination is an + method: The HTTP method of the request + url_bytes: The URI path of the request + content: The body of the request + destination_is: As 'destination', but if the destination is an identity server Returns: - list[bytes]: a list of headers to be added as "Authorization:" headers + A list of headers to be added as "Authorization:" headers """ request = { "method": method.decode("ascii"), @@ -629,33 +638,32 @@ class MatrixFederationHttpClient: async def put_json( self, - destination, - path, - args={}, - data={}, - json_data_callback=None, - long_retries=False, - timeout=None, - ignore_backoff=False, - backoff_on_404=False, - try_trailing_slash_on_400=False, - ): - """ Sends the specifed json data using PUT + destination: str, + path: str, + args: Optional[QueryArgs] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + ) -> Union[JsonDict, list]: + """ Sends the specified json data using PUT Args: - destination (str): The remote server to send the HTTP request - to. - path (str): The HTTP path. - args (dict): query params - data (dict): A dict containing the data that will be used as + destination: The remote server to send the HTTP request to. + path: The HTTP path. + args: query params + data: A dict containing the data that will be used as the request body. This will be encoded as JSON. - json_data_callback (callable): A callable returning the dict to + json_data_callback: A callable returning the dict to use as the request body. - long_retries (bool): whether to use the long retry algorithm. See + long_retries: whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -663,19 +671,19 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - backoff_on_404 (bool): True if we should count a 404 response as + backoff_on_404: True if we should count a 404 response as a failure of the server (and should therefore back off future requests). - try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. This will be attempted before backing off if backing off has been enabled. Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -721,29 +729,28 @@ class MatrixFederationHttpClient: async def post_json( self, - destination, - path, - data={}, - long_retries=False, - timeout=None, - ignore_backoff=False, - args={}, - ): - """ Sends the specifed json data using POST + destination: str, + path: str, + data: Optional[JsonDict] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + args: Optional[QueryArgs] = None, + ) -> Union[JsonDict, list]: + """ Sends the specified json data using POST Args: - destination (str): The remote server to send the HTTP request - to. + destination: The remote server to send the HTTP request to. - path (str): The HTTP path. + path: The HTTP path. - data (dict): A dict containing the data that will be used as + data: A dict containing the data that will be used as the request body. This will be encoded as JSON. - long_retries (bool): whether to use the long retry algorithm. See + long_retries: whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -751,10 +758,10 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data and + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - args (dict): query params + args: query params Returns: dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. @@ -795,26 +802,25 @@ class MatrixFederationHttpClient: async def get_json( self, - destination, - path, - args=None, - retry_on_dns_fail=True, - timeout=None, - ignore_backoff=False, - try_trailing_slash_on_400=False, - ): + destination: str, + path: str, + args: Optional[QueryArgs] = None, + retry_on_dns_fail: bool = True, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + try_trailing_slash_on_400: bool = False, + ) -> Union[JsonDict, list]: """ GETs some json from the given host homeserver and path Args: - destination (str): The remote server to send the HTTP request - to. + destination: The remote server to send the HTTP request to. - path (str): The HTTP path. + path: The HTTP path. - args (dict|None): A dictionary used to create query strings, defaults to + args: A dictionary used to create query strings, defaults to None. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -822,14 +828,14 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - try_trailing_slash_on_400 (bool): True if on a 400 M_UNRECOGNIZED + try_trailing_slash_on_400: True if on a 400 M_UNRECOGNIZED response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -870,24 +876,23 @@ class MatrixFederationHttpClient: async def delete_json( self, - destination, - path, - long_retries=False, - timeout=None, - ignore_backoff=False, - args={}, - ): + destination: str, + path: str, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + args: Optional[QueryArgs] = None, + ) -> Union[JsonDict, list]: """Send a DELETE request to the remote expecting some json response Args: - destination (str): The remote server to send the HTTP request - to. - path (str): The HTTP path. + destination: The remote server to send the HTTP request to. + path: The HTTP path. - long_retries (bool): whether to use the long retry algorithm. See + long_retries: whether to use the long retry algorithm. See docs on _send_request for details. - timeout (int|None): number of milliseconds to wait for the response. + timeout: number of milliseconds to wait for the response. self._default_timeout (60s) by default. Note that we may make several attempts to send the request; this @@ -895,12 +900,12 @@ class MatrixFederationHttpClient: *each* attempt (including connection time) as well as the time spent reading the response body after a 200 response. - ignore_backoff (bool): true to ignore the historical backoff data and + ignore_backoff: true to ignore the historical backoff data and try the request anyway. - args (dict): query params + args: query params Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -938,25 +943,25 @@ class MatrixFederationHttpClient: async def get_file( self, - destination, - path, + destination: str, + path: str, output_stream, - args={}, - retry_on_dns_fail=True, - max_size=None, - ignore_backoff=False, - ): + args: Optional[QueryArgs] = None, + retry_on_dns_fail: bool = True, + max_size: Optional[int] = None, + ignore_backoff: bool = False, + ) -> Tuple[int, Dict[bytes, List[bytes]]]: """GETs a file from a given homeserver Args: - destination (str): The remote server to send the HTTP request to. - path (str): The HTTP path to GET. - output_stream (file): File to write the response body to. - args (dict): Optional dictionary used to create the query string. - ignore_backoff (bool): true to ignore the historical backoff data + destination: The remote server to send the HTTP request to. + path: The HTTP path to GET. + output_stream: File to write the response body to. + args: Optional dictionary used to create the query string. + ignore_backoff: true to ignore the historical backoff data and try the request anyway. Returns: - tuple[int, dict]: Resolves with an (int,dict) tuple of + Resolves with an (int,dict) tuple of the file length and a dict of the response headers. Raises: @@ -980,7 +985,7 @@ class MatrixFederationHttpClient: headers = dict(response.headers.getAllRawHeaders()) try: - d = _readBodyToFile(response, output_stream, max_size) + d = readBodyToFile(response, output_stream, max_size) d.addTimeout(self.default_timeout, self.reactor) length = await make_deferred_yieldable(d) except Exception as e: @@ -1004,40 +1009,6 @@ class MatrixFederationHttpClient: return (length, headers) -class _ReadBodyToFileProtocol(protocol.Protocol): - def __init__(self, stream, deferred, max_size): - self.stream = stream - self.deferred = deferred - self.length = 0 - self.max_size = max_size - - def dataReceived(self, data): - self.stream.write(data) - self.length += len(data) - if self.max_size is not None and self.length >= self.max_size: - self.deferred.errback( - SynapseError( - 502, - "Requested file is too large > %r bytes" % (self.max_size,), - Codes.TOO_LARGE, - ) - ) - self.deferred = defer.Deferred() - self.transport.loseConnection() - - def connectionLost(self, reason): - if reason.check(ResponseDone): - self.deferred.callback(self.length) - else: - self.deferred.errback(reason) - - -def _readBodyToFile(response, stream, max_size): - d = defer.Deferred() - response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size)) - return d - - def _flatten_response_never_received(e): if hasattr(e, "reasons"): reasons = ", ".join( @@ -1049,13 +1020,13 @@ def _flatten_response_never_received(e): return repr(e) -def check_content_type_is_json(headers): +def check_content_type_is_json(headers: Headers) -> None: """ Check that a set of HTTP headers have a Content-Type header, and that it is application/json. Args: - headers (twisted.web.http_headers.Headers): headers to check + headers: headers to check Raises: RequestSendFailed: if the Content-Type header is missing or isn't JSON @@ -1063,27 +1034,18 @@ def check_content_type_is_json(headers): """ c_type = headers.getRawHeaders(b"Content-Type") if c_type is None: - raise RequestSendFailed(RuntimeError("No Content-Type header"), can_retry=False) + raise RequestSendFailed( + RuntimeError("No Content-Type header received from remote server"), + can_retry=False, + ) c_type = c_type[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) if val != "application/json": raise RequestSendFailed( - RuntimeError("Content-Type not application/json: was '%s'" % c_type), + RuntimeError( + "Remote server sent Content-Type header of '%s', not 'application/json'" + % c_type, + ), can_retry=False, ) - - -def encode_query_args(args): - if args is None: - return b"" - - encoded_args = {} - for k, vs in args.items(): - if isinstance(vs, str): - vs = [vs] - encoded_args[k] = [v.encode("UTF-8") for v in vs] - - query_bytes = urllib.parse.urlencode(encoded_args, True) - - return query_bytes.encode("utf8") diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index cd94e789e8..7c5defec82 100644 --- a/synapse/http/request_metrics.py +++ b/synapse/http/request_metrics.py
@@ -109,7 +109,7 @@ in_flight_requests_db_sched_duration = Counter( # The set of all in flight requests, set[RequestMetrics] _in_flight_requests = set() -# Protects the _in_flight_requests set from concurrent accesss +# Protects the _in_flight_requests set from concurrent access _in_flight_requests_lock = threading.Lock() diff --git a/synapse/http/server.py b/synapse/http/server.py
index 996a31a9ec..6a4e429a6c 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -25,7 +25,7 @@ from io import BytesIO from typing import Any, Callable, Dict, Iterator, List, Tuple, Union import jinja2 -from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json +from canonicaljson import iterencode_canonical_json from zope.interface import implementer from twisted.internet import defer, interfaces @@ -35,8 +35,6 @@ from twisted.web.server import NOT_DONE_YET, Request from twisted.web.static import File, NoRangeStaticProducer from twisted.web.util import redirectTo -import synapse.events -import synapse.metrics from synapse.api.errors import ( CodeMessageException, Codes, @@ -96,11 +94,7 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: pass else: respond_with_json( - request, - error_code, - error_dict, - send_cors=True, - pretty_print=_request_user_agent_is_curl(request), + request, error_code, error_dict, send_cors=True, ) @@ -182,7 +176,7 @@ class HttpServer: """ Register a callback that gets fired if we receive a http request with the given method for a path that matches the given regex. - If the regex contains groups these gets passed to the calback via + If the regex contains groups these gets passed to the callback via an unpacked tuple. Args: @@ -241,7 +235,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): async def _async_render(self, request: Request): """Delegates to `_async_render_<METHOD>` methods, or returns a 400 if - no appropriate method exists. Can be overriden in sub classes for + no appropriate method exists. Can be overridden in sub classes for different routing. """ # Treat HEAD requests as GET requests. @@ -257,7 +251,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): callback_return = await raw_callback_return else: - callback_return = raw_callback_return + callback_return = raw_callback_return # type: ignore return callback_return @@ -292,7 +286,6 @@ class DirectServeJsonResource(_AsyncResource): code, response_object, send_cors=True, - pretty_print=_request_user_agent_is_curl(request), canonical_json=self.canonical_json, ) @@ -386,7 +379,7 @@ class JsonResource(DirectServeJsonResource): async def _async_render(self, request): callback, servlet_classname, group_dict = self._get_handler_for_request(request) - # Make sure we have an appopriate name for this handler in prometheus + # Make sure we have an appropriate name for this handler in prometheus # (rather than the default of JsonResource). request.request_metrics.name = servlet_classname @@ -406,7 +399,7 @@ class JsonResource(DirectServeJsonResource): if isinstance(raw_callback_return, (defer.Deferred, types.CoroutineType)): callback_return = await raw_callback_return else: - callback_return = raw_callback_return + callback_return = raw_callback_return # type: ignore return callback_return @@ -589,7 +582,6 @@ def respond_with_json( code: int, json_object: Any, send_cors: bool = False, - pretty_print: bool = False, canonical_json: bool = True, ): """Sends encoded JSON in response to the given request. @@ -600,8 +592,6 @@ def respond_with_json( json_object: The object to serialize to JSON. send_cors: Whether to send Cross-Origin Resource Sharing headers https://fetch.spec.whatwg.org/#http-cors-protocol - pretty_print: Whether to include indentation and line-breaks in the - resulting JSON bytes. canonical_json: Whether to use the canonicaljson algorithm when encoding the JSON bytes. @@ -617,13 +607,10 @@ def respond_with_json( ) return None - if pretty_print: - encoder = iterencode_pretty_printed_json + if canonical_json: + encoder = iterencode_canonical_json else: - if canonical_json or synapse.events.USE_FROZEN_DICTS: - encoder = iterencode_canonical_json - else: - encoder = _encode_json_bytes + encoder = _encode_json_bytes request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") @@ -651,6 +638,11 @@ def respond_with_json_bytes( Returns: twisted.web.server.NOT_DONE_YET if the request is still active. """ + if request._disconnected: + logger.warning( + "Not sending response to request %s, already disconnected.", request + ) + return request.setResponseCode(code) request.setHeader(b"Content-Type", b"application/json") @@ -682,7 +674,7 @@ def set_cors_headers(request: Request): ) request.setHeader( b"Access-Control-Allow-Headers", - b"Origin, X-Requested-With, Content-Type, Accept, Authorization", + b"Origin, X-Requested-With, Content-Type, Accept, Authorization, Date", ) @@ -756,11 +748,3 @@ def finish_request(request: Request): request.finish() except RuntimeError as e: logger.info("Connection disconnected before response was written: %r", e) - - -def _request_user_agent_is_curl(request: Request) -> bool: - user_agents = request.requestHeaders.getRawHeaders(b"User-Agent", default=[]) - for user_agent in user_agents: - if b"curl" in user_agent: - return True - return False diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index fd90ba7828..b361b7cbaf 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -272,7 +272,6 @@ class RestServlet: on_PUT on_POST on_DELETE - on_OPTIONS Automatically handles turning CodeMessageExceptions thrown by these methods into the appropriate HTTP response. @@ -283,7 +282,7 @@ class RestServlet: if hasattr(self, "PATTERNS"): patterns = self.PATTERNS - for method in ("GET", "PUT", "POST", "OPTIONS", "DELETE"): + for method in ("GET", "PUT", "POST", "DELETE"): if hasattr(self, "on_%s" % (method,)): servlet_classname = self.__class__.__name__ method_handler = getattr(self, "on_%s" % (method,)) diff --git a/synapse/http/site.py b/synapse/http/site.py
index 6e79b47828..5f0581dc3f 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -14,7 +14,7 @@ import contextlib import logging import time -from typing import Optional +from typing import Optional, Union from twisted.python.failure import Failure from twisted.web.server import Request, Site @@ -23,6 +23,7 @@ from synapse.config.server import ListenerConfig from synapse.http import redact_uri from synapse.http.request_metrics import RequestMetrics, requests_counter from synapse.logging.context import LoggingContext, PreserveLoggingContext +from synapse.types import Requester logger = logging.getLogger(__name__) @@ -54,9 +55,12 @@ class SynapseRequest(Request): Request.__init__(self, channel, *args, **kw) self.site = channel.site self._channel = channel # this is used by the tests - self.authenticated_entity = None self.start_time = 0.0 + # The requester, if authenticated. For federation requests this is the + # server name, for client requests this is the Requester object. + self.requester = None # type: Optional[Union[Requester, str]] + # we can't yet create the logcontext, as we don't know the method. self.logcontext = None # type: Optional[LoggingContext] @@ -109,8 +113,14 @@ class SynapseRequest(Request): method = self.method.decode("ascii") return method - def get_user_agent(self): - return self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] + def get_user_agent(self, default: str) -> str: + """Return the last User-Agent header, or the given default. + """ + user_agent = self.requestHeaders.getRawHeaders(b"User-Agent", [None])[-1] + if user_agent is None: + return default + + return user_agent.decode("ascii", "replace") def render(self, resrc): # this is called once a Resource has been found to serve the request; in our @@ -161,7 +171,9 @@ class SynapseRequest(Request): yield except Exception: # this should already have been caught, and sent back to the client as a 500. - logger.exception("Asynchronous messge handler raised an uncaught exception") + logger.exception( + "Asynchronous message handler raised an uncaught exception" + ) finally: # the request handler has finished its work and either sent the whole response # back, or handed over responsibility to a Producer. @@ -263,22 +275,30 @@ class SynapseRequest(Request): # to the client (nb may be negative) response_send_time = self.finish_time - self._processing_finished_time - # need to decode as it could be raw utf-8 bytes - # from a IDN servname in an auth header - authenticated_entity = self.authenticated_entity - if authenticated_entity is not None and isinstance(authenticated_entity, bytes): - authenticated_entity = authenticated_entity.decode("utf-8", "replace") + # Convert the requester into a string that we can log + authenticated_entity = None + if isinstance(self.requester, str): + authenticated_entity = self.requester + elif isinstance(self.requester, Requester): + authenticated_entity = self.requester.authenticated_entity + + # If this is a request where the target user doesn't match the user who + # authenticated (e.g. and admin is puppetting a user) then we log both. + if self.requester.user.to_string() != authenticated_entity: + authenticated_entity = "{},{}".format( + authenticated_entity, self.requester.user.to_string(), + ) + elif self.requester is not None: + # This shouldn't happen, but we log it so we don't lose information + # and can see that we're doing something wrong. + authenticated_entity = repr(self.requester) # type: ignore[unreachable] # ...or could be raw utf-8 bytes in the User-Agent header. # N.B. if you don't do this, the logger explodes cryptically # with maximum recursion trying to log errors about # the charset problem. # c.f. https://github.com/matrix-org/synapse/issues/3471 - user_agent = self.get_user_agent() - if user_agent is not None: - user_agent = user_agent.decode("utf-8", "replace") - else: - user_agent = "-" + user_agent = self.get_user_agent("-") code = str(self.code) if not self.finished: