summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/additional_resource.py3
-rw-r--r--synapse/http/client.py24
-rw-r--r--synapse/http/federation/matrix_federation_agent.py9
-rw-r--r--synapse/http/matrixfederationclient.py3
-rw-r--r--synapse/http/proxyagent.py27
-rw-r--r--synapse/http/server.py122
-rw-r--r--synapse/http/servlet.py49
-rw-r--r--synapse/http/site.py21
8 files changed, 150 insertions, 108 deletions
diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py

index 6a9f6635d2..8729630581 100644 --- a/synapse/http/additional_resource.py +++ b/synapse/http/additional_resource.py
@@ -45,8 +45,7 @@ class AdditionalResource(DirectServeJsonResource): Args: hs: homeserver - handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred): - function to be called to handle the request. + handler: function to be called to handle the request. """ super().__init__() self._handler = handler diff --git a/synapse/http/client.py b/synapse/http/client.py
index 084d0a5b84..4eb740c040 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py
@@ -25,7 +25,6 @@ from typing import ( List, Mapping, Optional, - Sequence, Tuple, Union, ) @@ -90,14 +89,29 @@ incoming_responses_counter = Counter( "synapse_http_client_responses", "", ["method", "code"] ) -# the type of the headers list, to be passed to the t.w.h.Headers. -# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so -# we simplify. +# the type of the headers map, to be passed to the t.w.h.Headers. +# +# The actual type accepted by Twisted is +# Mapping[Union[str, bytes], Sequence[Union[str, bytes]] , +# allowing us to mix and match str and bytes freely. However: any str is also a +# Sequence[str]; passing a header string value which is a +# standalone str is interpreted as a sequence of 1-codepoint strings. This is a disastrous footgun. +# We use a narrower value type (RawHeaderValue) to avoid this footgun. +# +# We also simplify the keys to be either all str or all bytes. This helps because +# Dict[K, V] is invariant in K (and indeed V). RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]] # the value actually has to be a List, but List is invariant so we can't specify that # the entries can either be Lists or bytes. -RawHeaderValue = Sequence[Union[str, bytes]] +RawHeaderValue = Union[ + List[str], + List[bytes], + List[Union[str, bytes]], + Tuple[str, ...], + Tuple[bytes, ...], + Tuple[Union[str, bytes], ...], +] def check_against_blacklist( diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 2f0177f1e2..0359231e7d 100644 --- a/synapse/http/federation/matrix_federation_agent.py +++ b/synapse/http/federation/matrix_federation_agent.py
@@ -155,11 +155,10 @@ class MatrixFederationAgent: a file for a file upload). Or None if the request is to have no body. Returns: - Deferred[twisted.web.iweb.IResponse]: - fires when the header of the response has been received (regardless of the - response status code). Fails if there is any problem which prevents that - response from being received (including problems that prevent the request - from being sent). + A deferred which fires when the header of the response has been received + (regardless of the response status code). Fails if there is any problem + which prevents that response from being received (including problems that + prevent the request from being sent). """ # We use urlparse as that will set `port` to None if there is no # explicit port. diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3c35b1d2c7..b92f1d3d1a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py
@@ -951,8 +951,7 @@ class MatrixFederationHttpClient: args: query params Returns: - dict|list: Succeeds when we get a 2xx HTTP response. The - result will be the decoded JSON body. + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: HttpResponseException: If we get an HTTP response code >= 300 diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index b2a50c9105..18899bc6d1 100644 --- a/synapse/http/proxyagent.py +++ b/synapse/http/proxyagent.py
@@ -34,8 +34,9 @@ from twisted.web.client import ( ) from twisted.web.error import SchemeNotSupported from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS +from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse +from synapse.http import redact_uri from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials from synapse.types import ISynapseReactor @@ -133,7 +134,7 @@ class ProxyAgent(_AgentBase): uri: bytes, headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, - ) -> defer.Deferred: + ) -> "defer.Deferred[IResponse]": """ Issue a request to the server indicated by the given uri. @@ -156,17 +157,17 @@ class ProxyAgent(_AgentBase): a file upload). Or, None if the request is to have no body. Returns: - Deferred[IResponse]: completes when the header of the response has - been received (regardless of the response status code). + A deferred which completes when the header of the response has + been received (regardless of the response status code). - Can fail with: - SchemeNotSupported: if the uri is not http or https + Can fail with: + SchemeNotSupported: if the uri is not http or https - twisted.internet.error.TimeoutError if the server we are connecting - to (proxy or destination) does not accept a connection before - connectTimeout. + twisted.internet.error.TimeoutError if the server we are connecting + to (proxy or destination) does not accept a connection before + connectTimeout. - ... other things too. + ... other things too. """ uri = uri.strip() if not _VALID_URI.match(uri): @@ -220,7 +221,11 @@ class ProxyAgent(_AgentBase): self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs ) - logger.debug("Requesting %s via %s", uri, endpoint) + logger.debug( + "Requesting %s via %s", + redact_uri(uri.decode("ascii", errors="replace")), + endpoint, + ) if parsed_uri.scheme == b"https": tls_connection_creator = self._policy_for_https.creatorForNetloc( diff --git a/synapse/http/server.py b/synapse/http/server.py
index 19f42159b8..051a1899a0 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py
@@ -19,6 +19,7 @@ import logging import types import urllib from http import HTTPStatus +from http.client import FOUND from inspect import isawaitable from typing import ( TYPE_CHECKING, @@ -33,7 +34,6 @@ from typing import ( Optional, Pattern, Tuple, - TypeVar, Union, ) @@ -64,6 +64,7 @@ from synapse.logging.context import defer_to_thread, preserve_fn, run_in_backgro from synapse.logging.opentracing import active_span, start_active_span, trace_servlet from synapse.util import json_encoder from synapse.util.caches import intern_dict +from synapse.util.cancellation import is_function_cancellable from synapse.util.iterutils import chunk_seq if TYPE_CHECKING: @@ -94,68 +95,6 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html> HTTP_STATUS_REQUEST_CANCELLED = 499 -F = TypeVar("F", bound=Callable[..., Any]) - - -_cancellable_method_names = frozenset( - { - # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet` - # methods - "on_GET", - "on_PUT", - "on_POST", - "on_DELETE", - # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource` - # methods - "_async_render_GET", - "_async_render_PUT", - "_async_render_POST", - "_async_render_DELETE", - "_async_render_OPTIONS", - # `ReplicationEndpoint` methods - "_handle_request", - } -) - - -def cancellable(method: F) -> F: - """Marks a servlet method as cancellable. - - Methods with this decorator will be cancelled if the client disconnects before we - finish processing the request. - - During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping - the method. The `cancel()` call will propagate down to the `Deferred` that is - currently being waited on. That `Deferred` will raise a `CancelledError`, which will - propagate up, as per normal exception handling. - - Before applying this decorator to a new endpoint, you MUST recursively check - that all `await`s in the function are on `async` functions or `Deferred`s that - handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from - premature logging context closure, to stuck requests, to database corruption. - - Usage: - class SomeServlet(RestServlet): - @cancellable - async def on_GET(self, request: SynapseRequest) -> ...: - ... - """ - if method.__name__ not in _cancellable_method_names and not any( - method.__name__.startswith(prefix) for prefix in _cancellable_method_names - ): - raise ValueError( - "@cancellable decorator can only be applied to servlet methods." - ) - - method.cancellable = True # type: ignore[attr-defined] - return method - - -def is_method_cancellable(method: Callable[..., Any]) -> bool: - """Checks whether a servlet method has the `@cancellable` flag.""" - return getattr(method, "cancellable", False) - - def return_json_error( f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig] ) -> None: @@ -328,7 +267,7 @@ class HttpServer(Protocol): request. The first argument will be the request object and subsequent arguments will be any matched groups from the regex. This should return either tuple of (code, response), or None. - servlet_classname (str): The name of the handler to be used in prometheus + servlet_classname: The name of the handler to be used in prometheus and opentracing logs. """ @@ -389,7 +328,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): method_handler = getattr(self, "_async_render_%s" % (request_method,), None) if method_handler: - request.is_render_cancellable = is_method_cancellable(method_handler) + request.is_render_cancellable = is_function_cancellable(method_handler) raw_callback_return = method_handler(request) @@ -401,7 +340,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): return callback_return - _unrecognised_request_handler(request) + return _unrecognised_request_handler(request) @abc.abstractmethod def _send_response( @@ -551,7 +490,7 @@ class JsonResource(DirectServeJsonResource): async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) - request.is_render_cancellable = is_method_cancellable(callback) + request.is_render_cancellable = is_function_cancellable(callback) # Make sure we have an appropriate name for this handler in prometheus # (rather than the default of JsonResource). @@ -660,7 +599,7 @@ class RootRedirect(resource.Resource): class OptionsResource(resource.Resource): """Responds to OPTION requests for itself and all children.""" - def render_OPTIONS(self, request: Request) -> bytes: + def render_OPTIONS(self, request: SynapseRequest) -> bytes: request.setResponseCode(204) request.setHeader(b"Content-Length", b"0") @@ -767,7 +706,7 @@ class _ByteProducer: self._request = None -def _encode_json_bytes(json_object: Any) -> bytes: +def _encode_json_bytes(json_object: object) -> bytes: """ Encode an object into JSON. Returns an iterator of bytes. """ @@ -808,7 +747,7 @@ def respond_with_json( return None if canonical_json: - encoder = encode_canonical_json + encoder: Callable[[object], bytes] = encode_canonical_json else: encoder = _encode_json_bytes @@ -825,7 +764,7 @@ def respond_with_json( def respond_with_json_bytes( - request: Request, + request: SynapseRequest, code: int, json_bytes: bytes, send_cors: bool = False, @@ -921,7 +860,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None: _ByteProducer(request, bytes_generator) -def set_cors_headers(request: Request) -> None: +def set_cors_headers(request: SynapseRequest) -> None: """Set the CORS headers so that javascript running in a web browsers can use this API @@ -932,10 +871,20 @@ def set_cors_headers(request: Request) -> None: request.setHeader( b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS" ) - request.setHeader( - b"Access-Control-Allow-Headers", - b"X-Requested-With, Content-Type, Authorization, Date", - ) + if request.experimental_cors_msc3886: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match", + ) + request.setHeader( + b"Access-Control-Expose-Headers", + b"ETag, Location, X-Max-Bytes", + ) + else: + request.setHeader( + b"Access-Control-Allow-Headers", + b"X-Requested-With, Content-Type, Authorization, Date", + ) def set_corp_headers(request: Request) -> None: @@ -1004,10 +953,25 @@ def set_clickjacking_protection_headers(request: Request) -> None: request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';") -def respond_with_redirect(request: Request, url: bytes) -> None: - """Write a 302 response to the request, if it is still alive.""" +def respond_with_redirect( + request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False +) -> None: + """ + Write a 302 (or other specified status code) response to the request, if it is still alive. + + Args: + request: The http request to respond to. + url: The URL to redirect to. + statusCode: The HTTP status code to use for the redirect (defaults to 302). + cors: Whether to set CORS headers on the response. + """ logger.debug("Redirect to %s", url.decode("utf-8")) - request.redirect(url) + + if cors: + set_cors_headers(request) + + request.setResponseCode(statusCode) + request.setHeader(b"location", url) finish_request(request) diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 4ff840ca0e..dead02cd5c 100644 --- a/synapse/http/servlet.py +++ b/synapse/http/servlet.py
@@ -23,14 +23,19 @@ from typing import ( Optional, Sequence, Tuple, + Type, + TypeVar, overload, ) +from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError +from pydantic.error_wrappers import ErrorWrapper from typing_extensions import Literal from twisted.web.server import Request from synapse.api.errors import Codes, SynapseError +from synapse.http import redact_uri from synapse.http.server import HttpServer from synapse.types import JsonDict, RoomAlias, RoomID from synapse.util import json_decoder @@ -660,7 +665,13 @@ def parse_json_value_from_request( try: content = json_decoder.decode(content_bytes.decode("utf-8")) except Exception as e: - logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes) + logger.warning( + "Unable to parse JSON from %s %s response: %s (%s)", + request.method.decode("ascii", errors="replace"), + redact_uri(request.uri.decode("ascii", errors="replace")), + e, + content_bytes, + ) raise SynapseError( HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON ) @@ -694,6 +705,42 @@ def parse_json_object_from_request( return content +Model = TypeVar("Model", bound=BaseModel) + + +def parse_and_validate_json_object_from_request( + request: Request, model_type: Type[Model] +) -> Model: + """Parse a JSON object from the body of a twisted HTTP request, then deserialise and + validate using the given pydantic model. + + Raises: + SynapseError if the request body couldn't be decoded as JSON or + if it wasn't a JSON object. + """ + content = parse_json_object_from_request(request, allow_empty_body=False) + try: + instance = model_type.parse_obj(content) + except ValidationError as e: + # Choose a matrix error code. The catch-all is BAD_JSON, but we try to find a + # more specific error if possible (which occasionally helps us to be spec- + # compliant) This is a bit awkward because the spec's error codes aren't very + # clear-cut: BAD_JSON arguably overlaps with MISSING_PARAM and INVALID_PARAM. + errcode = Codes.BAD_JSON + + raw_errors = e.raw_errors + if len(raw_errors) == 1 and isinstance(raw_errors[0], ErrorWrapper): + raw_error = raw_errors[0].exc + if isinstance(raw_error, MissingError): + errcode = Codes.MISSING_PARAM + elif isinstance(raw_error, PydanticValueError): + errcode = Codes.INVALID_PARAM + + raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=errcode) + + return instance + + def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None: absent = [] for k in required: diff --git a/synapse/http/site.py b/synapse/http/site.py
index eeec74b78a..6a1dbf7f33 100644 --- a/synapse/http/site.py +++ b/synapse/http/site.py
@@ -72,14 +72,17 @@ class SynapseRequest(Request): site: "SynapseSite", *args: Any, max_request_body_size: int = 1024, + request_id_header: Optional[str] = None, **kw: Any, ): super().__init__(channel, *args, **kw) self._max_request_body_size = max_request_body_size + self.request_id_header = request_id_header self.synapse_site = site self.reactor = site.reactor self._channel = channel # this is used by the tests self.start_time = 0.0 + self.experimental_cors_msc3886 = site.experimental_cors_msc3886 # The requester, if authenticated. For federation requests this is the # server name, for client requests this is the Requester object. @@ -172,7 +175,14 @@ class SynapseRequest(Request): self._opentracing_span = span def get_request_id(self) -> str: - return "%s-%i" % (self.get_method(), self.request_seq) + request_id_value = None + if self.request_id_header: + request_id_value = self.getHeader(self.request_id_header) + + if request_id_value is None: + request_id_value = str(self.request_seq) + + return "%s-%s" % (self.get_method(), request_id_value) def get_redacted_uri(self) -> str: """Gets the redacted URI associated with the request (or placeholder if the URI @@ -226,7 +236,7 @@ class SynapseRequest(Request): # If this is a request where the target user doesn't match the user who # authenticated (e.g. and admin is puppetting a user) then we return both. - if self._requester.user.to_string() != authenticated_entity: + if requester != authenticated_entity: return requester, authenticated_entity return requester, None @@ -390,7 +400,7 @@ class SynapseRequest(Request): be sure to call finished_processing. Args: - servlet_name (str): the name of the servlet which will be + servlet_name: the name of the servlet which will be processing this request. This is used in the metrics. It is possible to update this afterwards by updating @@ -611,12 +621,17 @@ class SynapseSite(Site): proxied = config.http_options.x_forwarded request_class = XForwardedForRequest if proxied else SynapseRequest + request_id_header = config.http_options.request_id_header + + self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886 + def request_factory(channel: HTTPChannel, queued: bool) -> Request: return request_class( channel, self, max_request_body_size=max_request_body_size, queued=queued, + request_id_header=request_id_header, ) self.requestFactory = request_factory # type: ignore