diff --git a/synapse/http/additional_resource.py b/synapse/http/additional_resource.py
index 6a9f6635d2..8729630581 100644
--- a/synapse/http/additional_resource.py
+++ b/synapse/http/additional_resource.py
@@ -45,8 +45,7 @@ class AdditionalResource(DirectServeJsonResource):
Args:
hs: homeserver
- handler ((twisted.web.server.Request) -> twisted.internet.defer.Deferred):
- function to be called to handle the request.
+ handler: function to be called to handle the request.
"""
super().__init__()
self._handler = handler
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 084d0a5b84..4eb740c040 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -25,7 +25,6 @@ from typing import (
List,
Mapping,
Optional,
- Sequence,
Tuple,
Union,
)
@@ -90,14 +89,29 @@ incoming_responses_counter = Counter(
"synapse_http_client_responses", "", ["method", "code"]
)
-# the type of the headers list, to be passed to the t.w.h.Headers.
-# Actually we can mix str and bytes keys, but Mapping treats 'key' as invariant so
-# we simplify.
+# the type of the headers map, to be passed to the t.w.h.Headers.
+#
+# The actual type accepted by Twisted is
+# Mapping[Union[str, bytes], Sequence[Union[str, bytes]] ,
+# allowing us to mix and match str and bytes freely. However: any str is also a
+# Sequence[str]; passing a header string value which is a
+# standalone str is interpreted as a sequence of 1-codepoint strings. This is a disastrous footgun.
+# We use a narrower value type (RawHeaderValue) to avoid this footgun.
+#
+# We also simplify the keys to be either all str or all bytes. This helps because
+# Dict[K, V] is invariant in K (and indeed V).
RawHeaders = Union[Mapping[str, "RawHeaderValue"], Mapping[bytes, "RawHeaderValue"]]
# the value actually has to be a List, but List is invariant so we can't specify that
# the entries can either be Lists or bytes.
-RawHeaderValue = Sequence[Union[str, bytes]]
+RawHeaderValue = Union[
+ List[str],
+ List[bytes],
+ List[Union[str, bytes]],
+ Tuple[str, ...],
+ Tuple[bytes, ...],
+ Tuple[Union[str, bytes], ...],
+]
def check_against_blacklist(
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 2f0177f1e2..0359231e7d 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -155,11 +155,10 @@ class MatrixFederationAgent:
a file for a file upload). Or None if the request is to have
no body.
Returns:
- Deferred[twisted.web.iweb.IResponse]:
- fires when the header of the response has been received (regardless of the
- response status code). Fails if there is any problem which prevents that
- response from being received (including problems that prevent the request
- from being sent).
+ A deferred which fires when the header of the response has been received
+ (regardless of the response status code). Fails if there is any problem
+ which prevents that response from being received (including problems that
+ prevent the request from being sent).
"""
# We use urlparse as that will set `port` to None if there is no
# explicit port.
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3c35b1d2c7..b92f1d3d1a 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -951,8 +951,7 @@ class MatrixFederationHttpClient:
args: query params
Returns:
- dict|list: Succeeds when we get a 2xx HTTP response. The
- result will be the decoded JSON body.
+ Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body.
Raises:
HttpResponseException: If we get an HTTP response code >= 300
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index b2a50c9105..18899bc6d1 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -34,8 +34,9 @@ from twisted.web.client import (
)
from twisted.web.error import SchemeNotSupported
from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS
+from twisted.web.iweb import IAgent, IBodyProducer, IPolicyForHTTPS, IResponse
+from synapse.http import redact_uri
from synapse.http.connectproxyclient import HTTPConnectProxyEndpoint, ProxyCredentials
from synapse.types import ISynapseReactor
@@ -133,7 +134,7 @@ class ProxyAgent(_AgentBase):
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
- ) -> defer.Deferred:
+ ) -> "defer.Deferred[IResponse]":
"""
Issue a request to the server indicated by the given uri.
@@ -156,17 +157,17 @@ class ProxyAgent(_AgentBase):
a file upload). Or, None if the request is to have no body.
Returns:
- Deferred[IResponse]: completes when the header of the response has
- been received (regardless of the response status code).
+ A deferred which completes when the header of the response has
+ been received (regardless of the response status code).
- Can fail with:
- SchemeNotSupported: if the uri is not http or https
+ Can fail with:
+ SchemeNotSupported: if the uri is not http or https
- twisted.internet.error.TimeoutError if the server we are connecting
- to (proxy or destination) does not accept a connection before
- connectTimeout.
+ twisted.internet.error.TimeoutError if the server we are connecting
+ to (proxy or destination) does not accept a connection before
+ connectTimeout.
- ... other things too.
+ ... other things too.
"""
uri = uri.strip()
if not _VALID_URI.match(uri):
@@ -220,7 +221,11 @@ class ProxyAgent(_AgentBase):
self._reactor, parsed_uri.host, parsed_uri.port, **self._endpoint_kwargs
)
- logger.debug("Requesting %s via %s", uri, endpoint)
+ logger.debug(
+ "Requesting %s via %s",
+ redact_uri(uri.decode("ascii", errors="replace")),
+ endpoint,
+ )
if parsed_uri.scheme == b"https":
tls_connection_creator = self._policy_for_https.creatorForNetloc(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 19f42159b8..051a1899a0 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -19,6 +19,7 @@ import logging
import types
import urllib
from http import HTTPStatus
+from http.client import FOUND
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
@@ -33,7 +34,6 @@ from typing import (
Optional,
Pattern,
Tuple,
- TypeVar,
Union,
)
@@ -64,6 +64,7 @@ from synapse.logging.context import defer_to_thread, preserve_fn, run_in_backgro
from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
+from synapse.util.cancellation import is_function_cancellable
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -94,68 +95,6 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
HTTP_STATUS_REQUEST_CANCELLED = 499
-F = TypeVar("F", bound=Callable[..., Any])
-
-
-_cancellable_method_names = frozenset(
- {
- # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
- # methods
- "on_GET",
- "on_PUT",
- "on_POST",
- "on_DELETE",
- # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
- # methods
- "_async_render_GET",
- "_async_render_PUT",
- "_async_render_POST",
- "_async_render_DELETE",
- "_async_render_OPTIONS",
- # `ReplicationEndpoint` methods
- "_handle_request",
- }
-)
-
-
-def cancellable(method: F) -> F:
- """Marks a servlet method as cancellable.
-
- Methods with this decorator will be cancelled if the client disconnects before we
- finish processing the request.
-
- During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
- the method. The `cancel()` call will propagate down to the `Deferred` that is
- currently being waited on. That `Deferred` will raise a `CancelledError`, which will
- propagate up, as per normal exception handling.
-
- Before applying this decorator to a new endpoint, you MUST recursively check
- that all `await`s in the function are on `async` functions or `Deferred`s that
- handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
- premature logging context closure, to stuck requests, to database corruption.
-
- Usage:
- class SomeServlet(RestServlet):
- @cancellable
- async def on_GET(self, request: SynapseRequest) -> ...:
- ...
- """
- if method.__name__ not in _cancellable_method_names and not any(
- method.__name__.startswith(prefix) for prefix in _cancellable_method_names
- ):
- raise ValueError(
- "@cancellable decorator can only be applied to servlet methods."
- )
-
- method.cancellable = True # type: ignore[attr-defined]
- return method
-
-
-def is_method_cancellable(method: Callable[..., Any]) -> bool:
- """Checks whether a servlet method has the `@cancellable` flag."""
- return getattr(method, "cancellable", False)
-
-
def return_json_error(
f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig]
) -> None:
@@ -328,7 +267,7 @@ class HttpServer(Protocol):
request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex.
This should return either tuple of (code, response), or None.
- servlet_classname (str): The name of the handler to be used in prometheus
+ servlet_classname: The name of the handler to be used in prometheus
and opentracing logs.
"""
@@ -389,7 +328,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
if method_handler:
- request.is_render_cancellable = is_method_cancellable(method_handler)
+ request.is_render_cancellable = is_function_cancellable(method_handler)
raw_callback_return = method_handler(request)
@@ -401,7 +340,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
return callback_return
- _unrecognised_request_handler(request)
+ return _unrecognised_request_handler(request)
@abc.abstractmethod
def _send_response(
@@ -551,7 +490,7 @@ class JsonResource(DirectServeJsonResource):
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
- request.is_render_cancellable = is_method_cancellable(callback)
+ request.is_render_cancellable = is_function_cancellable(callback)
# Make sure we have an appropriate name for this handler in prometheus
# (rather than the default of JsonResource).
@@ -660,7 +599,7 @@ class RootRedirect(resource.Resource):
class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
- def render_OPTIONS(self, request: Request) -> bytes:
+ def render_OPTIONS(self, request: SynapseRequest) -> bytes:
request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0")
@@ -767,7 +706,7 @@ class _ByteProducer:
self._request = None
-def _encode_json_bytes(json_object: Any) -> bytes:
+def _encode_json_bytes(json_object: object) -> bytes:
"""
Encode an object into JSON. Returns an iterator of bytes.
"""
@@ -808,7 +747,7 @@ def respond_with_json(
return None
if canonical_json:
- encoder = encode_canonical_json
+ encoder: Callable[[object], bytes] = encode_canonical_json
else:
encoder = _encode_json_bytes
@@ -825,7 +764,7 @@ def respond_with_json(
def respond_with_json_bytes(
- request: Request,
+ request: SynapseRequest,
code: int,
json_bytes: bytes,
send_cors: bool = False,
@@ -921,7 +860,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator)
-def set_cors_headers(request: Request) -> None:
+def set_cors_headers(request: SynapseRequest) -> None:
"""Set the CORS headers so that javascript running in a web browsers can
use this API
@@ -932,10 +871,20 @@ def set_cors_headers(request: Request) -> None:
request.setHeader(
b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS"
)
- request.setHeader(
- b"Access-Control-Allow-Headers",
- b"X-Requested-With, Content-Type, Authorization, Date",
- )
+ if request.experimental_cors_msc3886:
+ request.setHeader(
+ b"Access-Control-Allow-Headers",
+ b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match",
+ )
+ request.setHeader(
+ b"Access-Control-Expose-Headers",
+ b"ETag, Location, X-Max-Bytes",
+ )
+ else:
+ request.setHeader(
+ b"Access-Control-Allow-Headers",
+ b"X-Requested-With, Content-Type, Authorization, Date",
+ )
def set_corp_headers(request: Request) -> None:
@@ -1004,10 +953,25 @@ def set_clickjacking_protection_headers(request: Request) -> None:
request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
-def respond_with_redirect(request: Request, url: bytes) -> None:
- """Write a 302 response to the request, if it is still alive."""
+def respond_with_redirect(
+ request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False
+) -> None:
+ """
+ Write a 302 (or other specified status code) response to the request, if it is still alive.
+
+ Args:
+ request: The http request to respond to.
+ url: The URL to redirect to.
+ statusCode: The HTTP status code to use for the redirect (defaults to 302).
+ cors: Whether to set CORS headers on the response.
+ """
logger.debug("Redirect to %s", url.decode("utf-8"))
- request.redirect(url)
+
+ if cors:
+ set_cors_headers(request)
+
+ request.setResponseCode(statusCode)
+ request.setHeader(b"location", url)
finish_request(request)
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index 4ff840ca0e..dead02cd5c 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -23,14 +23,19 @@ from typing import (
Optional,
Sequence,
Tuple,
+ Type,
+ TypeVar,
overload,
)
+from pydantic import BaseModel, MissingError, PydanticValueError, ValidationError
+from pydantic.error_wrappers import ErrorWrapper
from typing_extensions import Literal
from twisted.web.server import Request
from synapse.api.errors import Codes, SynapseError
+from synapse.http import redact_uri
from synapse.http.server import HttpServer
from synapse.types import JsonDict, RoomAlias, RoomID
from synapse.util import json_decoder
@@ -660,7 +665,13 @@ def parse_json_value_from_request(
try:
content = json_decoder.decode(content_bytes.decode("utf-8"))
except Exception as e:
- logger.warning("Unable to parse JSON: %s (%s)", e, content_bytes)
+ logger.warning(
+ "Unable to parse JSON from %s %s response: %s (%s)",
+ request.method.decode("ascii", errors="replace"),
+ redact_uri(request.uri.decode("ascii", errors="replace")),
+ e,
+ content_bytes,
+ )
raise SynapseError(
HTTPStatus.BAD_REQUEST, "Content not JSON.", errcode=Codes.NOT_JSON
)
@@ -694,6 +705,42 @@ def parse_json_object_from_request(
return content
+Model = TypeVar("Model", bound=BaseModel)
+
+
+def parse_and_validate_json_object_from_request(
+ request: Request, model_type: Type[Model]
+) -> Model:
+ """Parse a JSON object from the body of a twisted HTTP request, then deserialise and
+ validate using the given pydantic model.
+
+ Raises:
+ SynapseError if the request body couldn't be decoded as JSON or
+ if it wasn't a JSON object.
+ """
+ content = parse_json_object_from_request(request, allow_empty_body=False)
+ try:
+ instance = model_type.parse_obj(content)
+ except ValidationError as e:
+ # Choose a matrix error code. The catch-all is BAD_JSON, but we try to find a
+ # more specific error if possible (which occasionally helps us to be spec-
+ # compliant) This is a bit awkward because the spec's error codes aren't very
+ # clear-cut: BAD_JSON arguably overlaps with MISSING_PARAM and INVALID_PARAM.
+ errcode = Codes.BAD_JSON
+
+ raw_errors = e.raw_errors
+ if len(raw_errors) == 1 and isinstance(raw_errors[0], ErrorWrapper):
+ raw_error = raw_errors[0].exc
+ if isinstance(raw_error, MissingError):
+ errcode = Codes.MISSING_PARAM
+ elif isinstance(raw_error, PydanticValueError):
+ errcode = Codes.INVALID_PARAM
+
+ raise SynapseError(HTTPStatus.BAD_REQUEST, str(e), errcode=errcode)
+
+ return instance
+
+
def assert_params_in_dict(body: JsonDict, required: Iterable[str]) -> None:
absent = []
for k in required:
diff --git a/synapse/http/site.py b/synapse/http/site.py
index eeec74b78a..6a1dbf7f33 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -72,14 +72,17 @@ class SynapseRequest(Request):
site: "SynapseSite",
*args: Any,
max_request_body_size: int = 1024,
+ request_id_header: Optional[str] = None,
**kw: Any,
):
super().__init__(channel, *args, **kw)
self._max_request_body_size = max_request_body_size
+ self.request_id_header = request_id_header
self.synapse_site = site
self.reactor = site.reactor
self._channel = channel # this is used by the tests
self.start_time = 0.0
+ self.experimental_cors_msc3886 = site.experimental_cors_msc3886
# The requester, if authenticated. For federation requests this is the
# server name, for client requests this is the Requester object.
@@ -172,7 +175,14 @@ class SynapseRequest(Request):
self._opentracing_span = span
def get_request_id(self) -> str:
- return "%s-%i" % (self.get_method(), self.request_seq)
+ request_id_value = None
+ if self.request_id_header:
+ request_id_value = self.getHeader(self.request_id_header)
+
+ if request_id_value is None:
+ request_id_value = str(self.request_seq)
+
+ return "%s-%s" % (self.get_method(), request_id_value)
def get_redacted_uri(self) -> str:
"""Gets the redacted URI associated with the request (or placeholder if the URI
@@ -226,7 +236,7 @@ class SynapseRequest(Request):
# If this is a request where the target user doesn't match the user who
# authenticated (e.g. and admin is puppetting a user) then we return both.
- if self._requester.user.to_string() != authenticated_entity:
+ if requester != authenticated_entity:
return requester, authenticated_entity
return requester, None
@@ -390,7 +400,7 @@ class SynapseRequest(Request):
be sure to call finished_processing.
Args:
- servlet_name (str): the name of the servlet which will be
+ servlet_name: the name of the servlet which will be
processing this request. This is used in the metrics.
It is possible to update this afterwards by updating
@@ -611,12 +621,17 @@ class SynapseSite(Site):
proxied = config.http_options.x_forwarded
request_class = XForwardedForRequest if proxied else SynapseRequest
+ request_id_header = config.http_options.request_id_header
+
+ self.experimental_cors_msc3886 = config.http_options.experimental_cors_msc3886
+
def request_factory(channel: HTTPChannel, queued: bool) -> Request:
return request_class(
channel,
self,
max_request_body_size=max_request_body_size,
queued=queued,
+ request_id_header=request_id_header,
)
self.requestFactory = request_factory # type: ignore
|