diff --git a/synapse/http/server.py b/synapse/http/server.py
index 19f42159b8..051a1899a0 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -19,6 +19,7 @@ import logging
import types
import urllib
from http import HTTPStatus
+from http.client import FOUND
from inspect import isawaitable
from typing import (
TYPE_CHECKING,
@@ -33,7 +34,6 @@ from typing import (
Optional,
Pattern,
Tuple,
- TypeVar,
Union,
)
@@ -64,6 +64,7 @@ from synapse.logging.context import defer_to_thread, preserve_fn, run_in_backgro
from synapse.logging.opentracing import active_span, start_active_span, trace_servlet
from synapse.util import json_encoder
from synapse.util.caches import intern_dict
+from synapse.util.cancellation import is_function_cancellable
from synapse.util.iterutils import chunk_seq
if TYPE_CHECKING:
@@ -94,68 +95,6 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
HTTP_STATUS_REQUEST_CANCELLED = 499
-F = TypeVar("F", bound=Callable[..., Any])
-
-
-_cancellable_method_names = frozenset(
- {
- # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
- # methods
- "on_GET",
- "on_PUT",
- "on_POST",
- "on_DELETE",
- # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
- # methods
- "_async_render_GET",
- "_async_render_PUT",
- "_async_render_POST",
- "_async_render_DELETE",
- "_async_render_OPTIONS",
- # `ReplicationEndpoint` methods
- "_handle_request",
- }
-)
-
-
-def cancellable(method: F) -> F:
- """Marks a servlet method as cancellable.
-
- Methods with this decorator will be cancelled if the client disconnects before we
- finish processing the request.
-
- During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
- the method. The `cancel()` call will propagate down to the `Deferred` that is
- currently being waited on. That `Deferred` will raise a `CancelledError`, which will
- propagate up, as per normal exception handling.
-
- Before applying this decorator to a new endpoint, you MUST recursively check
- that all `await`s in the function are on `async` functions or `Deferred`s that
- handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
- premature logging context closure, to stuck requests, to database corruption.
-
- Usage:
- class SomeServlet(RestServlet):
- @cancellable
- async def on_GET(self, request: SynapseRequest) -> ...:
- ...
- """
- if method.__name__ not in _cancellable_method_names and not any(
- method.__name__.startswith(prefix) for prefix in _cancellable_method_names
- ):
- raise ValueError(
- "@cancellable decorator can only be applied to servlet methods."
- )
-
- method.cancellable = True # type: ignore[attr-defined]
- return method
-
-
-def is_method_cancellable(method: Callable[..., Any]) -> bool:
- """Checks whether a servlet method has the `@cancellable` flag."""
- return getattr(method, "cancellable", False)
-
-
def return_json_error(
f: failure.Failure, request: SynapseRequest, config: Optional[HomeServerConfig]
) -> None:
@@ -328,7 +267,7 @@ class HttpServer(Protocol):
request. The first argument will be the request object and
subsequent arguments will be any matched groups from the regex.
This should return either tuple of (code, response), or None.
- servlet_classname (str): The name of the handler to be used in prometheus
+ servlet_classname: The name of the handler to be used in prometheus
and opentracing logs.
"""
@@ -389,7 +328,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
if method_handler:
- request.is_render_cancellable = is_method_cancellable(method_handler)
+ request.is_render_cancellable = is_function_cancellable(method_handler)
raw_callback_return = method_handler(request)
@@ -401,7 +340,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
return callback_return
- _unrecognised_request_handler(request)
+ return _unrecognised_request_handler(request)
@abc.abstractmethod
def _send_response(
@@ -551,7 +490,7 @@ class JsonResource(DirectServeJsonResource):
async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
callback, servlet_classname, group_dict = self._get_handler_for_request(request)
- request.is_render_cancellable = is_method_cancellable(callback)
+ request.is_render_cancellable = is_function_cancellable(callback)
# Make sure we have an appropriate name for this handler in prometheus
# (rather than the default of JsonResource).
@@ -660,7 +599,7 @@ class RootRedirect(resource.Resource):
class OptionsResource(resource.Resource):
"""Responds to OPTION requests for itself and all children."""
- def render_OPTIONS(self, request: Request) -> bytes:
+ def render_OPTIONS(self, request: SynapseRequest) -> bytes:
request.setResponseCode(204)
request.setHeader(b"Content-Length", b"0")
@@ -767,7 +706,7 @@ class _ByteProducer:
self._request = None
-def _encode_json_bytes(json_object: Any) -> bytes:
+def _encode_json_bytes(json_object: object) -> bytes:
"""
Encode an object into JSON. Returns an iterator of bytes.
"""
@@ -808,7 +747,7 @@ def respond_with_json(
return None
if canonical_json:
- encoder = encode_canonical_json
+ encoder: Callable[[object], bytes] = encode_canonical_json
else:
encoder = _encode_json_bytes
@@ -825,7 +764,7 @@ def respond_with_json(
def respond_with_json_bytes(
- request: Request,
+ request: SynapseRequest,
code: int,
json_bytes: bytes,
send_cors: bool = False,
@@ -921,7 +860,7 @@ def _write_bytes_to_request(request: Request, bytes_to_write: bytes) -> None:
_ByteProducer(request, bytes_generator)
-def set_cors_headers(request: Request) -> None:
+def set_cors_headers(request: SynapseRequest) -> None:
"""Set the CORS headers so that javascript running in a web browsers can
use this API
@@ -932,10 +871,20 @@ def set_cors_headers(request: Request) -> None:
request.setHeader(
b"Access-Control-Allow-Methods", b"GET, HEAD, POST, PUT, DELETE, OPTIONS"
)
- request.setHeader(
- b"Access-Control-Allow-Headers",
- b"X-Requested-With, Content-Type, Authorization, Date",
- )
+ if request.experimental_cors_msc3886:
+ request.setHeader(
+ b"Access-Control-Allow-Headers",
+ b"X-Requested-With, Content-Type, Authorization, Date, If-Match, If-None-Match",
+ )
+ request.setHeader(
+ b"Access-Control-Expose-Headers",
+ b"ETag, Location, X-Max-Bytes",
+ )
+ else:
+ request.setHeader(
+ b"Access-Control-Allow-Headers",
+ b"X-Requested-With, Content-Type, Authorization, Date",
+ )
def set_corp_headers(request: Request) -> None:
@@ -1004,10 +953,25 @@ def set_clickjacking_protection_headers(request: Request) -> None:
request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
-def respond_with_redirect(request: Request, url: bytes) -> None:
- """Write a 302 response to the request, if it is still alive."""
+def respond_with_redirect(
+ request: SynapseRequest, url: bytes, statusCode: int = FOUND, cors: bool = False
+) -> None:
+ """
+ Write a 302 (or other specified status code) response to the request, if it is still alive.
+
+ Args:
+ request: The http request to respond to.
+ url: The URL to redirect to.
+ statusCode: The HTTP status code to use for the redirect (defaults to 302).
+ cors: Whether to set CORS headers on the response.
+ """
logger.debug("Redirect to %s", url.decode("utf-8"))
- request.redirect(url)
+
+ if cors:
+ set_cors_headers(request)
+
+ request.setResponseCode(statusCode)
+ request.setHeader(b"location", url)
finish_request(request)
|