summary refs log tree commit diff
path: root/synapse/http/server.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/server.py')
-rw-r--r--synapse/http/server.py74
1 files changed, 73 insertions, 1 deletions
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 657bffcddd..e3dcc3f3dd 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -33,6 +33,7 @@ from typing import (
     Optional,
     Pattern,
     Tuple,
+    TypeVar,
     Union,
 )
 
@@ -92,6 +93,68 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 HTTP_STATUS_REQUEST_CANCELLED = 499
 
 
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+_cancellable_method_names = frozenset(
+    {
+        # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
+        # methods
+        "on_GET",
+        "on_PUT",
+        "on_POST",
+        "on_DELETE",
+        # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
+        # methods
+        "_async_render_GET",
+        "_async_render_PUT",
+        "_async_render_POST",
+        "_async_render_DELETE",
+        "_async_render_OPTIONS",
+        # `ReplicationEndpoint` methods
+        "_handle_request",
+    }
+)
+
+
+def cancellable(method: F) -> F:
+    """Marks a servlet method as cancellable.
+
+    Methods with this decorator will be cancelled if the client disconnects before we
+    finish processing the request.
+
+    During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
+    the method. The `cancel()` call will propagate down to the `Deferred` that is
+    currently being waited on. That `Deferred` will raise a `CancelledError`, which will
+    propagate up, as per normal exception handling.
+
+    Before applying this decorator to a new endpoint, you MUST recursively check
+    that all `await`s in the function are on `async` functions or `Deferred`s that
+    handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
+    premature logging context closure, to stuck requests, to database corruption.
+
+    Usage:
+        class SomeServlet(RestServlet):
+            @cancellable
+            async def on_GET(self, request: SynapseRequest) -> ...:
+                ...
+    """
+    if method.__name__ not in _cancellable_method_names and not any(
+        method.__name__.startswith(prefix) for prefix in _cancellable_method_names
+    ):
+        raise ValueError(
+            "@cancellable decorator can only be applied to servlet methods."
+        )
+
+    method.cancellable = True  # type: ignore[attr-defined]
+    return method
+
+
+def is_method_cancellable(method: Callable[..., Any]) -> bool:
+    """Checks whether a servlet method has the `@cancellable` flag."""
+    return getattr(method, "cancellable", False)
+
+
 def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
     """Sends a JSON error response to clients."""
 
@@ -253,6 +316,9 @@ class HttpServer(Protocol):
         If the regex contains groups these gets passed to the callback via
         an unpacked tuple.
 
+        The callback may be marked with the `@cancellable` decorator, which will
+        cause request processing to be cancelled when clients disconnect early.
+
         Args:
             method: The HTTP method to listen to.
             path_patterns: The regex used to match requests.
@@ -283,7 +349,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
     def render(self, request: SynapseRequest) -> int:
         """This gets called by twisted every time someone sends us a request."""
-        defer.ensureDeferred(self._async_render_wrapper(request))
+        request.render_deferred = defer.ensureDeferred(
+            self._async_render_wrapper(request)
+        )
         return NOT_DONE_YET
 
     @wrap_async_request_handler
@@ -319,6 +387,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
         method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
         if method_handler:
+            request.is_render_cancellable = is_method_cancellable(method_handler)
+
             raw_callback_return = method_handler(request)
 
             # Is it synchronous? We'll allow this for now.
@@ -479,6 +549,8 @@ class JsonResource(DirectServeJsonResource):
     async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
+        request.is_render_cancellable = is_method_cancellable(callback)
+
         # Make sure we have an appropriate name for this handler in prometheus
         # (rather than the default of JsonResource).
         request.request_metrics.name = servlet_classname