summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/http/client.py42
-rw-r--r--synapse/http/matrixfederationclient.py20
-rw-r--r--synapse/http/proxy.py40
-rw-r--r--synapse/http/proxyagent.py31
-rw-r--r--synapse/http/replicationagent.py4
-rw-r--r--synapse/http/server.py19
-rw-r--r--synapse/http/servlet.py25
-rw-r--r--synapse/http/site.py43
8 files changed, 130 insertions, 94 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py

index 56ad28eabf..84a510fb42 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -31,18 +31,17 @@ from typing import ( List, Mapping, Optional, + Protocol, Tuple, Union, ) import attr -import multipart import treq from canonicaljson import encode_canonical_json from netaddr import AddrFormatError, IPAddress, IPSet from prometheus_client import Counter -from typing_extensions import Protocol -from zope.interface import implementer, provider +from zope.interface import implementer from OpenSSL import SSL from OpenSSL.SSL import VERIFY_NONE @@ -93,6 +92,20 @@ from synapse.util.async_helpers import timeout_deferred if TYPE_CHECKING: from synapse.server import HomeServer +# Support both import names for the `python-multipart` (PyPI) library, +# which renamed its package name from `multipart` to `python_multipart` +# in 0.0.13 (though supports the old import name for compatibility). +# Note that the `multipart` package name conflicts with `multipart` (PyPI) +# so we should prefer importing from `python_multipart` when possible. +try: + from python_multipart import MultipartParser + + if TYPE_CHECKING: + from python_multipart import multipart +except ImportError: + from multipart import MultipartParser # type: ignore[no-redef] + + logger = logging.getLogger(__name__) outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"]) @@ -212,7 +225,7 @@ class _IPBlockingResolver: recv.addressResolved(address) recv.resolutionComplete() - @provider(IResolutionReceiver) + @implementer(IResolutionReceiver) class EndpointReceiver: @staticmethod def resolutionBegan(resolutionInProgress: IHostResolution) -> None: @@ -226,8 +239,9 @@ class _IPBlockingResolver: def resolutionComplete() -> None: _callback() + endpoint_receiver_wrapper = EndpointReceiver() self._reactor.nameResolver.resolveHostName( - EndpointReceiver, hostname, portNumber=portNumber + endpoint_receiver_wrapper, hostname, portNumber=portNumber ) return recv @@ -1039,7 +1053,7 @@ class _MultipartParserProtocol(protocol.Protocol): self.deferred = deferred self.boundary = boundary self.max_length = max_length - self.parser = None + self.parser: Optional[MultipartParser] = None self.multipart_response = MultipartResponse() self.has_redirect = False self.in_json = False @@ -1057,11 +1071,11 @@ class _MultipartParserProtocol(protocol.Protocol): if not self.parser: def on_header_field(data: bytes, start: int, end: int) -> None: - if data[start:end] == b"Location": + if data[start:end].lower() == b"location": self.has_redirect = True - if data[start:end] == b"Content-Disposition": + if data[start:end].lower() == b"content-disposition": self.in_disposition = True - if data[start:end] == b"Content-Type": + if data[start:end].lower() == b"content-type": self.in_content_type = True def on_header_value(data: bytes, start: int, end: int) -> None: @@ -1088,7 +1102,6 @@ class _MultipartParserProtocol(protocol.Protocol): return # otherwise we are in the file part else: - logger.info("Writing multipart file data to stream") try: self.stream.write(data[start:end]) except Exception as e: @@ -1098,12 +1111,12 @@ class _MultipartParserProtocol(protocol.Protocol): self.deferred.errback() self.file_length += end - start - callbacks = { + callbacks: "multipart.MultipartCallbacks" = { "on_header_field": on_header_field, "on_header_value": on_header_value, "on_part_data": on_part_data, } - self.parser = multipart.MultipartParser(self.boundary, callbacks) + self.parser = MultipartParser(self.boundary, callbacks) self.total_length += len(incoming_data) if self.max_length is not None and self.total_length >= self.max_length: @@ -1114,7 +1127,7 @@ class _MultipartParserProtocol(protocol.Protocol): self.transport.abortConnection() try: - self.parser.write(incoming_data) # type: ignore[attr-defined] + self.parser.write(incoming_data) except Exception as e: logger.warning(f"Exception writing to multipart parser: {e}") self.deferred.errback() @@ -1314,6 +1327,5 @@ def is_unknown_endpoint( ) ) or ( # Older Synapses returned a 400 error. - e.code == 400 - and synapse_error.errcode == Codes.UNRECOGNIZED + e.code == 400 and synapse_error.errcode == Codes.UNRECOGNIZED ) diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 6fd75fd381..88bf98045c 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -19,7 +19,6 @@ # # import abc -import cgi import codecs import logging import random @@ -35,6 +34,7 @@ from typing import ( Dict, Generic, List, + Literal, Optional, TextIO, Tuple, @@ -49,7 +49,6 @@ import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json -from typing_extensions import Literal from twisted.internet import defer from twisted.internet.error import DNSLookupError @@ -426,9 +425,9 @@ class MatrixFederationHttpClient: ) else: proxy_authorization_secret = hs.config.worker.worker_replication_secret - assert ( - proxy_authorization_secret is not None - ), "`worker_replication_secret` must be set when using `outbound_federation_restricted_to` (used to authenticate requests across workers)" + assert proxy_authorization_secret is not None, ( + "`worker_replication_secret` must be set when using `outbound_federation_restricted_to` (used to authenticate requests across workers)" + ) federation_proxy_credentials = BearerProxyCredentials( proxy_authorization_secret.encode("ascii") ) @@ -792,7 +791,7 @@ class MatrixFederationHttpClient: url_str, _flatten_response_never_received(e), ) - body = None + body = b"" exc = HttpResponseException( response.code, response_phrase, body @@ -1756,8 +1755,10 @@ class MatrixFederationHttpClient: request.destination, str_url, ) + # We don't know how large the response will be upfront, so limit it to + # the `max_size` config value. length, headers, _, _ = await self._simple_http_client.get_file( - str_url, output_stream, expected_size + str_url, output_stream, max_size ) logger.info( @@ -1811,8 +1812,9 @@ def check_content_type_is(headers: Headers, expected_content_type: str) -> None: ) c_type = content_type_headers[0].decode("ascii") # only the first header - val, options = cgi.parse_header(c_type) - if val != expected_content_type: + # Extract the 'essence' of the mimetype, removing any parameter + c_type_parsed = c_type.split(";", 1)[0].strip() + if c_type_parsed != expected_content_type: raise RequestSendFailed( RuntimeError( f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'", diff --git a/synapse/http/proxy.py b/synapse/http/proxy.py
index 97aa429e7d..5cd990b0d0 100644 --- a/synapse/http/proxy.py +++ b/synapse/http/proxy.py
@@ -51,25 +51,17 @@ logger = logging.getLogger(__name__) # "Hop-by-hop" headers (as opposed to "end-to-end" headers) as defined by RFC2616 # section 13.5.1 and referenced in RFC9110 section 7.6.1. These are meant to only be # consumed by the immediate recipient and not be forwarded on. -HOP_BY_HOP_HEADERS = { - "Connection", - "Keep-Alive", - "Proxy-Authenticate", - "Proxy-Authorization", - "TE", - "Trailers", - "Transfer-Encoding", - "Upgrade", +HOP_BY_HOP_HEADERS_LOWERCASE = { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", } - -if hasattr(Headers, "_canonicalNameCaps"): - # Twisted < 24.7.0rc1 - _canonicalHeaderName = Headers()._canonicalNameCaps # type: ignore[attr-defined] -else: - # Twisted >= 24.7.0rc1 - # But note that `_encodeName` still exists on prior versions, - # it just encodes differently - _canonicalHeaderName = Headers()._encodeName +assert all(header.lower() == header for header in HOP_BY_HOP_HEADERS_LOWERCASE) def parse_connection_header_value( @@ -92,12 +84,12 @@ def parse_connection_header_value( Returns: The set of header names that should not be copied over from the remote response. - The keys are capitalized in canonical capitalization. + The keys are lowercased. """ extra_headers_to_remove: Set[str] = set() if connection_header_value: extra_headers_to_remove = { - _canonicalHeaderName(connection_option.strip()).decode("ascii") + connection_option.decode("ascii").strip().lower() for connection_option in connection_header_value.split(b",") } @@ -194,7 +186,7 @@ class ProxyResource(_AsyncResource): # The `Connection` header also defines which headers should not be copied over. connection_header = response_headers.getRawHeaders(b"connection") - extra_headers_to_remove = parse_connection_header_value( + extra_headers_to_remove_lowercase = parse_connection_header_value( connection_header[0] if connection_header else None ) @@ -202,10 +194,10 @@ class ProxyResource(_AsyncResource): for k, v in response_headers.getAllRawHeaders(): # Do not copy over any hop-by-hop headers. These are meant to only be # consumed by the immediate recipient and not be forwarded on. - header_key = k.decode("ascii") + header_key_lowercase = k.decode("ascii").lower() if ( - header_key in HOP_BY_HOP_HEADERS - or header_key in extra_headers_to_remove + header_key_lowercase in HOP_BY_HOP_HEADERS_LOWERCASE + or header_key_lowercase in extra_headers_to_remove_lowercase ): continue diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index f80f67acc6..6817199035 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py
@@ -21,7 +21,7 @@ import logging import random import re -from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple +from typing import Any, Collection, Dict, List, Optional, Sequence, Tuple, Union from urllib.parse import urlparse from urllib.request import ( # type: ignore[attr-defined] getproxies_environment, @@ -150,6 +150,12 @@ class ProxyAgent(_AgentBase): http_proxy = proxies["http"].encode() if "http" in proxies else None https_proxy = proxies["https"].encode() if "https" in proxies else None no_proxy = proxies["no"] if "no" in proxies else None + logger.debug( + "Using proxy settings: http_proxy=%s, https_proxy=%s, no_proxy=%s", + http_proxy, + https_proxy, + no_proxy, + ) self.http_proxy_endpoint, self.http_proxy_creds = http_proxy_endpoint( http_proxy, self.proxy_reactor, contextFactory, **self._endpoint_kwargs @@ -167,9 +173,9 @@ class ProxyAgent(_AgentBase): self._federation_proxy_endpoint: Optional[IStreamClientEndpoint] = None self._federation_proxy_credentials: Optional[ProxyCredentials] = None if federation_proxy_locations: - assert ( - federation_proxy_credentials is not None - ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + assert federation_proxy_credentials is not None, ( + "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + ) endpoints: List[IStreamClientEndpoint] = [] for federation_proxy_location in federation_proxy_locations: @@ -296,9 +302,9 @@ class ProxyAgent(_AgentBase): parsed_uri.scheme == b"matrix-federation" and self._federation_proxy_endpoint ): - assert ( - self._federation_proxy_credentials is not None - ), "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + assert self._federation_proxy_credentials is not None, ( + "`federation_proxy_credentials` are required when using `federation_proxy_locations`" + ) # Set a Proxy-Authorization header if headers is None: @@ -351,7 +357,9 @@ def http_proxy_endpoint( proxy: Optional[bytes], reactor: IReactorCore, tls_options_factory: Optional[IPolicyForHTTPS], - **kwargs: object, + timeout: float = 30, + bindAddress: Optional[Union[bytes, str, tuple[Union[bytes, str], int]]] = None, + attemptDelay: Optional[float] = None, ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]: """Parses an http proxy setting and returns an endpoint for the proxy @@ -382,12 +390,15 @@ def http_proxy_endpoint( # 3.9+) on scheme-less proxies, e.g. host:port. scheme, host, port, credentials = parse_proxy(proxy) - proxy_endpoint = HostnameEndpoint(reactor, host, port, **kwargs) + proxy_endpoint = HostnameEndpoint( + reactor, host, port, timeout, bindAddress, attemptDelay + ) if scheme == b"https": if tls_options_factory: tls_options = tls_options_factory.creatorForNetloc(host, port) - proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) + wrapped_proxy_endpoint = wrapClientTLS(tls_options, proxy_endpoint) + return wrapped_proxy_endpoint, credentials else: raise RuntimeError( f"No TLS options for a https connection via proxy {proxy!s}" diff --git a/synapse/http/replicationagent.py b/synapse/http/replicationagent.py
index ee8c707062..4eabbc8af9 100644 --- a/synapse/http/replicationagent.py +++ b/synapse/http/replicationagent.py
@@ -89,7 +89,7 @@ class ReplicationEndpointFactory: location_config.port, ) if scheme == "https": - endpoint = wrapClientTLS( + wrapped_endpoint = wrapClientTLS( # The 'port' argument below isn't actually used by the function self.context_factory.creatorForNetloc( location_config.host.encode("utf-8"), @@ -97,6 +97,8 @@ class ReplicationEndpointFactory: ), endpoint, ) + return wrapped_endpoint + return endpoint elif isinstance(location_config, InstanceUnixLocationConfig): return UNIXClientEndpoint(self.reactor, location_config.path) diff --git a/synapse/http/server.py b/synapse/http/server.py
index 0d0c610b28..bdd90d8a73 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -39,6 +39,7 @@ from typing import ( List, Optional, Pattern, + Protocol, Tuple, Union, ) @@ -46,7 +47,6 @@ from typing import ( import attr import jinja2 from canonicaljson import encode_canonical_json -from typing_extensions import Protocol from zope.interface import implementer from twisted.internet import defer, interfaces @@ -74,7 +74,6 @@ from synapse.api.errors import ( from synapse.config.homeserver import HomeServerConfig from synapse.logging.context import defer_to_thread, preserve_fn, run_in_background from synapse.logging.opentracing import active_span, start_active_span, trace_servlet -from synapse.types import ISynapseReactor from synapse.util import json_encoder from synapse.util.caches import intern_dict from synapse.util.cancellation import is_function_cancellable @@ -142,7 +141,7 @@ def return_json_error( ) else: error_code = 500 - error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN} + error_dict = {"error": "Internal server error", "errcode": Codes.UNKNOWN, "data": f.getTraceback()} logger.error( "Failed handle request via %r: %r", @@ -234,7 +233,7 @@ def return_html_error( def wrap_async_request_handler( - h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]] + h: Callable[["_AsyncResource", "SynapseRequest"], Awaitable[None]], ) -> Callable[["_AsyncResource", "SynapseRequest"], "defer.Deferred[None]"]: """Wraps an async request handler so that it calls request.processing. @@ -869,8 +868,7 @@ async def _async_write_json_to_request_in_thread( with start_active_span("encode_json_response"): span = active_span() - reactor: ISynapseReactor = request.reactor # type: ignore - json_str = await defer_to_thread(reactor, encode, span) + json_str = await defer_to_thread(request.reactor, encode, span) _write_bytes_to_request(request, json_str) @@ -923,15 +921,6 @@ def set_cors_headers(request: "SynapseRequest") -> None: b"Access-Control-Expose-Headers", b"Synapse-Trace-Id, Server, ETag", ) - elif request.experimental_cors_msc3886: - request.setHeader( - b"Access-Control-Allow-Headers", - b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match", - ) - request.setHeader( - b"Access-Control-Expose-Headers", - b"ETag, Location, X-Max-Bytes", - ) else: request.setHeader( b"Access-Control-Allow-Headers", diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 08b8ff7afd..47d8bd5eaf 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -28,6 +28,7 @@ from http import HTTPStatus from typing import ( TYPE_CHECKING, List, + Literal, Mapping, Optional, Sequence, @@ -37,19 +38,15 @@ from typing import ( overload, ) -from synapse._pydantic_compat import HAS_PYDANTIC_V2 - -if TYPE_CHECKING or HAS_PYDANTIC_V2: - from pydantic.v1 import BaseModel, MissingError, PydanticValueError, ValidationError - from pydantic.v1.error_wrappers import ErrorWrapper -else: - from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError - from pydantic.error_wrappers import ErrorWrapper - -from typing_extensions import Literal - from twisted.web.server import Request +from synapse._pydantic_compat import ( + BaseModel, + ErrorWrapper, + MissingError, + PydanticValueError, + ValidationError, +) from synapse.api.errors import Codes, SynapseError from synapse.http import redact_uri from synapse.http.server import HttpServer @@ -585,9 +582,9 @@ def parse_enum( is not one of those allowed values. """ # Assert the enum values are strings. - assert all( - isinstance(e.value, str) for e in E - ), "parse_enum only works with string values" + assert all(isinstance(e.value, str) for e in E), ( + "parse_enum only works with string values" + ) str_value = parse_string( request, name, diff --git a/synapse/http/site.py b/synapse/http/site.py
index af169ba51e..e83a4447b2 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -21,6 +21,7 @@ import contextlib import logging import time +from http import HTTPStatus from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union import attr @@ -94,7 +95,6 @@ class SynapseRequest(Request): self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 - self.experimental_cors_msc3886 = site.experimental_cors_msc3886 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. @@ -140,6 +140,41 @@ class SynapseRequest(Request): self.synapse_site.site_tag, ) + # Twisted machinery: this method is called by the Channel once the full request has + # been received, to dispatch the request to a resource. + # + # We're patching Twisted to bail/abort early when we see someone trying to upload + # `multipart/form-data` so we can avoid Twisted parsing the entire request body into + # in-memory (specific problem of this specific `Content-Type`). This protects us + # from an attacker uploading something bigger than the available RAM and crashing + # the server with a `MemoryError`, or carefully block just enough resources to cause + # all other requests to fail. + # + # FIXME: This can be removed once we Twisted releases a fix and we update to a + # version that is patched + def requestReceived(self, command: bytes, path: bytes, version: bytes) -> None: + if command == b"POST": + ctype = self.requestHeaders.getRawHeaders(b"content-type") + if ctype and b"multipart/form-data" in ctype[0]: + self.method, self.uri = command, path + self.clientproto = version + self.code = HTTPStatus.UNSUPPORTED_MEDIA_TYPE.value + self.code_message = bytes( + HTTPStatus.UNSUPPORTED_MEDIA_TYPE.phrase, "ascii" + ) + self.responseHeaders.setRawHeaders(b"content-length", [b"0"]) + + logger.warning( + "Aborting connection from %s because `content-type: multipart/form-data` is unsupported: %s %s", + self.client, + command, + path, + ) + self.write(b"") + self.loseConnection() + return + return super().requestReceived(command, path, version) + def handleContentChunk(self, data: bytes) -> None: # we should have a `content` by now. assert self.content, "handleContentChunk() called before gotLength()" @@ -658,7 +693,7 @@ class SynapseSite(ProxySite): ) self.site_tag = site_tag - self.reactor = reactor + self.reactor: ISynapseReactor = reactor assert config.http_options is not None proxied = config.http_options.x_forwarded @@ -666,10 +701,6 @@ class SynapseSite(ProxySite): request_id_header = config.http_options.request_id_header - self.experimental_cors_msc3886: bool = ( - config.http_options.experimental_cors_msc3886 - ) - def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel,