diff options
Diffstat (limited to 'synapse/http/server.py')
-rw-r--r-- | synapse/http/server.py | 74 |
1 files changed, 73 insertions, 1 deletions
diff --git a/synapse/http/server.py b/synapse/http/server.py index 657bffcddd..e3dcc3f3dd 100644 --- a/synapse/http/server.py +++ b/synapse/http/server.py @@ -33,6 +33,7 @@ from typing import ( Optional, Pattern, Tuple, + TypeVar, Union, ) @@ -92,6 +93,68 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html> HTTP_STATUS_REQUEST_CANCELLED = 499 +F = TypeVar("F", bound=Callable[..., Any]) + + +_cancellable_method_names = frozenset( + { + # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet` + # methods + "on_GET", + "on_PUT", + "on_POST", + "on_DELETE", + # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource` + # methods + "_async_render_GET", + "_async_render_PUT", + "_async_render_POST", + "_async_render_DELETE", + "_async_render_OPTIONS", + # `ReplicationEndpoint` methods + "_handle_request", + } +) + + +def cancellable(method: F) -> F: + """Marks a servlet method as cancellable. + + Methods with this decorator will be cancelled if the client disconnects before we + finish processing the request. + + During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping + the method. The `cancel()` call will propagate down to the `Deferred` that is + currently being waited on. That `Deferred` will raise a `CancelledError`, which will + propagate up, as per normal exception handling. + + Before applying this decorator to a new endpoint, you MUST recursively check + that all `await`s in the function are on `async` functions or `Deferred`s that + handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from + premature logging context closure, to stuck requests, to database corruption. + + Usage: + class SomeServlet(RestServlet): + @cancellable + async def on_GET(self, request: SynapseRequest) -> ...: + ... + """ + if method.__name__ not in _cancellable_method_names and not any( + method.__name__.startswith(prefix) for prefix in _cancellable_method_names + ): + raise ValueError( + "@cancellable decorator can only be applied to servlet methods." + ) + + method.cancellable = True # type: ignore[attr-defined] + return method + + +def is_method_cancellable(method: Callable[..., Any]) -> bool: + """Checks whether a servlet method has the `@cancellable` flag.""" + return getattr(method, "cancellable", False) + + def return_json_error(f: failure.Failure, request: SynapseRequest) -> None: """Sends a JSON error response to clients.""" @@ -253,6 +316,9 @@ class HttpServer(Protocol): If the regex contains groups these gets passed to the callback via an unpacked tuple. + The callback may be marked with the `@cancellable` decorator, which will + cause request processing to be cancelled when clients disconnect early. + Args: method: The HTTP method to listen to. path_patterns: The regex used to match requests. @@ -283,7 +349,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): def render(self, request: SynapseRequest) -> int: """This gets called by twisted every time someone sends us a request.""" - defer.ensureDeferred(self._async_render_wrapper(request)) + request.render_deferred = defer.ensureDeferred( + self._async_render_wrapper(request) + ) return NOT_DONE_YET @wrap_async_request_handler @@ -319,6 +387,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta): method_handler = getattr(self, "_async_render_%s" % (request_method,), None) if method_handler: + request.is_render_cancellable = is_method_cancellable(method_handler) + raw_callback_return = method_handler(request) # Is it synchronous? We'll allow this for now. @@ -479,6 +549,8 @@ class JsonResource(DirectServeJsonResource): async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]: callback, servlet_classname, group_dict = self._get_handler_for_request(request) + request.is_render_cancellable = is_method_cancellable(callback) + # Make sure we have an appropriate name for this handler in prometheus # (rather than the default of JsonResource). request.request_metrics.name = servlet_classname |