summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/client.py18
-rw-r--r--synapse/http/connectproxyclient.py39
-rw-r--r--synapse/http/federation/matrix_federation_agent.py2
-rw-r--r--synapse/http/federation/srv_resolver.py4
-rw-r--r--synapse/http/federation/well_known_resolver.py6
-rw-r--r--synapse/http/matrixfederationclient.py31
-rw-r--r--synapse/http/proxyagent.py2
-rw-r--r--synapse/http/request_metrics.py10
-rw-r--r--synapse/http/server.py74
-rw-r--r--synapse/http/site.py25
10 files changed, 162 insertions, 49 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8310fb466a..084d0a5b84 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -43,8 +43,10 @@ from twisted.internet import defer, error as twisted_error, protocol, ssl
 from twisted.internet.address import IPv4Address, IPv6Address
 from twisted.internet.interfaces import (
     IAddress,
+    IDelayedCall,
     IHostResolution,
     IReactorPluggableNameResolver,
+    IReactorTime,
     IResolutionReceiver,
     ITCPTransport,
 )
@@ -121,13 +123,15 @@ def check_against_blacklist(
 _EPSILON = 0.00000001
 
 
-def _make_scheduler(reactor):
+def _make_scheduler(
+    reactor: IReactorTime,
+) -> Callable[[Callable[[], object]], IDelayedCall]:
     """Makes a schedular suitable for a Cooperator using the given reactor.
 
     (This is effectively just a copy from `twisted.internet.task`)
     """
 
-    def _scheduler(x):
+    def _scheduler(x: Callable[[], object]) -> IDelayedCall:
         return reactor.callLater(_EPSILON, x)
 
     return _scheduler
@@ -348,7 +352,7 @@ class SimpleHttpClient:
         # XXX: The justification for using the cache factor here is that larger instances
         # will need both more cache and more connections.
         # Still, this should probably be a separate dial
-        pool.maxPersistentPerHost = max((100 * hs.config.caches.global_factor, 5))
+        pool.maxPersistentPerHost = max(int(100 * hs.config.caches.global_factor), 5)
         pool.cachedConnectionTimeout = 2 * 60
 
         self.agent: IAgent = ProxyAgent(
@@ -775,7 +779,7 @@ class SimpleHttpClient:
         )
 
 
-def _timeout_to_request_timed_out_error(f: Failure):
+def _timeout_to_request_timed_out_error(f: Failure) -> Failure:
     if f.check(twisted_error.TimeoutError, twisted_error.ConnectingCancelledError):
         # The TCP connection has its own timeout (set by the 'connectTimeout' param
         # on the Agent), which raises twisted_error.TimeoutError exception.
@@ -809,7 +813,7 @@ class _DiscardBodyWithMaxSizeProtocol(protocol.Protocol):
     def __init__(self, deferred: defer.Deferred):
         self.deferred = deferred
 
-    def _maybe_fail(self):
+    def _maybe_fail(self) -> None:
         """
         Report a max size exceed error and disconnect the first time this is called.
         """
@@ -933,12 +937,12 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory):
     Do not use this since it allows an attacker to intercept your communications.
     """
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._context = SSL.Context(SSL.SSLv23_METHOD)
         self._context.set_verify(VERIFY_NONE, lambda *_: False)
 
     def getContext(self, hostname=None, port=None):
         return self._context
 
-    def creatorForNetloc(self, hostname, port):
+    def creatorForNetloc(self, hostname: bytes, port: int):
         return self
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index 203e995bb7..23a60af171 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -14,15 +14,22 @@
 
 import base64
 import logging
-from typing import Optional
+from typing import Optional, Union
 
 import attr
 from zope.interface import implementer
 
 from twisted.internet import defer, protocol
 from twisted.internet.error import ConnectError
-from twisted.internet.interfaces import IReactorCore, IStreamClientEndpoint
+from twisted.internet.interfaces import (
+    IAddress,
+    IConnector,
+    IProtocol,
+    IReactorCore,
+    IStreamClientEndpoint,
+)
 from twisted.internet.protocol import ClientFactory, Protocol, connectionDone
+from twisted.python.failure import Failure
 from twisted.web import http
 
 logger = logging.getLogger(__name__)
@@ -81,14 +88,14 @@ class HTTPConnectProxyEndpoint:
         self._port = port
         self._proxy_creds = proxy_creds
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         return "<HTTPConnectProxyEndpoint %s>" % (self._proxy_endpoint,)
 
     # Mypy encounters a false positive here: it complains that ClientFactory
     # is incompatible with IProtocolFactory. But ClientFactory inherits from
     # Factory, which implements IProtocolFactory. So I think this is a bug
     # in mypy-zope.
-    def connect(self, protocolFactory: ClientFactory):  # type: ignore[override]
+    def connect(self, protocolFactory: ClientFactory) -> "defer.Deferred[IProtocol]":  # type: ignore[override]
         f = HTTPProxiedClientFactory(
             self._host, self._port, protocolFactory, self._proxy_creds
         )
@@ -125,10 +132,10 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
         self.proxy_creds = proxy_creds
         self.on_connection: "defer.Deferred[None]" = defer.Deferred()
 
-    def startedConnecting(self, connector):
+    def startedConnecting(self, connector: IConnector) -> None:
         return self.wrapped_factory.startedConnecting(connector)
 
-    def buildProtocol(self, addr):
+    def buildProtocol(self, addr: IAddress) -> "HTTPConnectProtocol":
         wrapped_protocol = self.wrapped_factory.buildProtocol(addr)
         if wrapped_protocol is None:
             raise TypeError("buildProtocol produced None instead of a Protocol")
@@ -141,13 +148,13 @@ class HTTPProxiedClientFactory(protocol.ClientFactory):
             self.proxy_creds,
         )
 
-    def clientConnectionFailed(self, connector, reason):
+    def clientConnectionFailed(self, connector: IConnector, reason: Failure) -> None:
         logger.debug("Connection to proxy failed: %s", reason)
         if not self.on_connection.called:
             self.on_connection.errback(reason)
         return self.wrapped_factory.clientConnectionFailed(connector, reason)
 
-    def clientConnectionLost(self, connector, reason):
+    def clientConnectionLost(self, connector: IConnector, reason: Failure) -> None:
         logger.debug("Connection to proxy lost: %s", reason)
         if not self.on_connection.called:
             self.on_connection.errback(reason)
@@ -191,10 +198,10 @@ class HTTPConnectProtocol(protocol.Protocol):
         )
         self.http_setup_client.on_connected.addCallback(self.proxyConnected)
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         self.http_setup_client.makeConnection(self.transport)
 
-    def connectionLost(self, reason=connectionDone):
+    def connectionLost(self, reason: Failure = connectionDone) -> None:
         if self.wrapped_protocol.connected:
             self.wrapped_protocol.connectionLost(reason)
 
@@ -203,7 +210,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if not self.connected_deferred.called:
             self.connected_deferred.errback(reason)
 
-    def proxyConnected(self, _):
+    def proxyConnected(self, _: Union[None, "defer.Deferred[None]"]) -> None:
         self.wrapped_protocol.makeConnection(self.transport)
 
         self.connected_deferred.callback(self.wrapped_protocol)
@@ -213,7 +220,7 @@ class HTTPConnectProtocol(protocol.Protocol):
         if buf:
             self.wrapped_protocol.dataReceived(buf)
 
-    def dataReceived(self, data: bytes):
+    def dataReceived(self, data: bytes) -> None:
         # if we've set up the HTTP protocol, we can send the data there
         if self.wrapped_protocol.connected:
             return self.wrapped_protocol.dataReceived(data)
@@ -243,7 +250,7 @@ class HTTPConnectSetupClient(http.HTTPClient):
         self.proxy_creds = proxy_creds
         self.on_connected: "defer.Deferred[None]" = defer.Deferred()
 
-    def connectionMade(self):
+    def connectionMade(self) -> None:
         logger.debug("Connected to proxy, sending CONNECT")
         self.sendCommand(b"CONNECT", b"%s:%d" % (self.host, self.port))
 
@@ -257,14 +264,14 @@ class HTTPConnectSetupClient(http.HTTPClient):
 
         self.endHeaders()
 
-    def handleStatus(self, version: bytes, status: bytes, message: bytes):
+    def handleStatus(self, version: bytes, status: bytes, message: bytes) -> None:
         logger.debug("Got Status: %s %s %s", status, message, version)
         if status != b"200":
             raise ProxyConnectError(f"Unexpected status on CONNECT: {status!s}")
 
-    def handleEndHeaders(self):
+    def handleEndHeaders(self) -> None:
         logger.debug("End Headers")
         self.on_connected.callback(None)
 
-    def handleResponse(self, body):
+    def handleResponse(self, body: bytes) -> None:
         pass
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index a8a520f809..2f0177f1e2 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -239,7 +239,7 @@ class MatrixHostnameEndpointFactory:
 
         self._srv_resolver = srv_resolver
 
-    def endpointForURI(self, parsed_uri: URI):
+    def endpointForURI(self, parsed_uri: URI) -> "MatrixHostnameEndpoint":
         return MatrixHostnameEndpoint(
             self._reactor,
             self._proxy_reactor,
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index f68646fd0d..de0e882b33 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -16,7 +16,7 @@
 import logging
 import random
 import time
-from typing import Callable, Dict, List
+from typing import Any, Callable, Dict, List
 
 import attr
 
@@ -109,7 +109,7 @@ class SrvResolver:
 
     def __init__(
         self,
-        dns_client=client,
+        dns_client: Any = client,
         cache: Dict[bytes, List[Server]] = SERVER_CACHE,
         get_time: Callable[[], float] = time.time,
     ):
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 43f2140429..71b685fade 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -74,9 +74,9 @@ _well_known_cache: TTLCache[bytes, Optional[bytes]] = TTLCache("well-known")
 _had_valid_well_known_cache: TTLCache[bytes, bool] = TTLCache("had-valid-well-known")
 
 
-@attr.s(slots=True, frozen=True)
+@attr.s(slots=True, frozen=True, auto_attribs=True)
 class WellKnownLookupResult:
-    delegated_server = attr.ib()
+    delegated_server: Optional[bytes]
 
 
 class WellKnownResolver:
@@ -336,4 +336,4 @@ def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
 class _FetchWellKnownFailure(Exception):
     # True if we didn't get a non-5xx HTTP response, i.e. this may or may not be
     # a temporary failure.
-    temporary = attr.ib()
+    temporary: bool = attr.ib()
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index c2ec3caa0e..725b5c33b8 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -23,6 +23,8 @@ from http import HTTPStatus
 from io import BytesIO, StringIO
 from typing import (
     TYPE_CHECKING,
+    Any,
+    BinaryIO,
     Callable,
     Dict,
     Generic,
@@ -44,7 +46,7 @@ from typing_extensions import Literal
 from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
 from twisted.internet.interfaces import IReactorTime
-from twisted.internet.task import _EPSILON, Cooperator
+from twisted.internet.task import Cooperator
 from twisted.web.client import ResponseFailed
 from twisted.web.http_headers import Headers
 from twisted.web.iweb import IBodyProducer, IResponse
@@ -58,11 +60,13 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.crypto.context_factory import FederationPolicyForHTTPS
 from synapse.http import QuieterFileBodyProducer
 from synapse.http.client import (
     BlacklistingAgentWrapper,
     BodyExceededMaxSize,
     ByteWriteable,
+    _make_scheduler,
     encode_query_args,
     read_body_with_max_size,
 )
@@ -181,7 +185,7 @@ class JsonParser(ByteParser[Union[JsonDict, list]]):
 
     CONTENT_TYPE = "application/json"
 
-    def __init__(self):
+    def __init__(self) -> None:
         self._buffer = StringIO()
         self._binary_wrapper = BinaryIOWrapper(self._buffer)
 
@@ -299,7 +303,9 @@ async def _handle_response(
 class BinaryIOWrapper:
     """A wrapper for a TextIO which converts from bytes on the fly."""
 
-    def __init__(self, file: typing.TextIO, encoding="utf-8", errors="strict"):
+    def __init__(
+        self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
+    ):
         self.decoder = codecs.getincrementaldecoder(encoding)(errors)
         self.file = file
 
@@ -317,7 +323,11 @@ class MatrixFederationHttpClient:
             requests.
     """
 
-    def __init__(self, hs: "HomeServer", tls_client_options_factory):
+    def __init__(
+        self,
+        hs: "HomeServer",
+        tls_client_options_factory: Optional[FederationPolicyForHTTPS],
+    ):
         self.hs = hs
         self.signing_key = hs.signing_key
         self.server_name = hs.hostname
@@ -348,10 +358,7 @@ class MatrixFederationHttpClient:
         self.version_string_bytes = hs.version_string.encode("ascii")
         self.default_timeout = 60
 
-        def schedule(x):
-            self.reactor.callLater(_EPSILON, x)
-
-        self._cooperator = Cooperator(scheduler=schedule)
+        self._cooperator = Cooperator(scheduler=_make_scheduler(self.reactor))
 
         self._sleeper = AwakenableSleeper(self.reactor)
 
@@ -364,7 +371,7 @@ class MatrixFederationHttpClient:
         self,
         request: MatrixFederationRequest,
         try_trailing_slash_on_400: bool = False,
-        **send_request_args,
+        **send_request_args: Any,
     ) -> IResponse:
         """Wrapper for _send_request which can optionally retry the request
         upon receiving a combination of a 400 HTTP response code and a
@@ -1159,7 +1166,7 @@ class MatrixFederationHttpClient:
         self,
         destination: str,
         path: str,
-        output_stream,
+        output_stream: BinaryIO,
         args: Optional[QueryParams] = None,
         retry_on_dns_fail: bool = True,
         max_size: Optional[int] = None,
@@ -1250,10 +1257,10 @@ class MatrixFederationHttpClient:
         return length, headers
 
 
-def _flatten_response_never_received(e):
+def _flatten_response_never_received(e: BaseException) -> str:
     if hasattr(e, "reasons"):
         reasons = ", ".join(
-            _flatten_response_never_received(f.value) for f in e.reasons
+            _flatten_response_never_received(f.value) for f in e.reasons  # type: ignore[attr-defined]
         )
 
         return "%s:[%s]" % (type(e).__name__, reasons)
diff --git a/synapse/http/proxyagent.py b/synapse/http/proxyagent.py
index a16dde2380..b2a50c9105 100644
--- a/synapse/http/proxyagent.py
+++ b/synapse/http/proxyagent.py
@@ -245,7 +245,7 @@ def http_proxy_endpoint(
     proxy: Optional[bytes],
     reactor: IReactorCore,
     tls_options_factory: Optional[IPolicyForHTTPS],
-    **kwargs,
+    **kwargs: object,
 ) -> Tuple[Optional[IStreamClientEndpoint], Optional[ProxyCredentials]]:
     """Parses an http proxy setting and returns an endpoint for the proxy
 
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 4886626d50..2b6d113544 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -162,7 +162,7 @@ class RequestMetrics:
         with _in_flight_requests_lock:
             _in_flight_requests.add(self)
 
-    def stop(self, time_sec, response_code, sent_bytes):
+    def stop(self, time_sec: float, response_code: int, sent_bytes: int) -> None:
         with _in_flight_requests_lock:
             _in_flight_requests.discard(self)
 
@@ -186,13 +186,13 @@ class RequestMetrics:
             )
             return
 
-        response_code = str(response_code)
+        response_code_str = str(response_code)
 
-        outgoing_responses_counter.labels(self.method, response_code).inc()
+        outgoing_responses_counter.labels(self.method, response_code_str).inc()
 
         response_count.labels(self.method, self.name, tag).inc()
 
-        response_timer.labels(self.method, self.name, tag, response_code).observe(
+        response_timer.labels(self.method, self.name, tag, response_code_str).observe(
             time_sec - self.start_ts
         )
 
@@ -221,7 +221,7 @@ class RequestMetrics:
         # flight.
         self.update_metrics()
 
-    def update_metrics(self):
+    def update_metrics(self) -> None:
         """Updates the in flight metrics with values from this request."""
         if not self.start_context:
             logger.error(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index 657bffcddd..e3dcc3f3dd 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -33,6 +33,7 @@ from typing import (
     Optional,
     Pattern,
     Tuple,
+    TypeVar,
     Union,
 )
 
@@ -92,6 +93,68 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 HTTP_STATUS_REQUEST_CANCELLED = 499
 
 
+F = TypeVar("F", bound=Callable[..., Any])
+
+
+_cancellable_method_names = frozenset(
+    {
+        # `RestServlet`, `BaseFederationServlet` and `BaseFederationServerServlet`
+        # methods
+        "on_GET",
+        "on_PUT",
+        "on_POST",
+        "on_DELETE",
+        # `_AsyncResource`, `DirectServeHtmlResource` and `DirectServeJsonResource`
+        # methods
+        "_async_render_GET",
+        "_async_render_PUT",
+        "_async_render_POST",
+        "_async_render_DELETE",
+        "_async_render_OPTIONS",
+        # `ReplicationEndpoint` methods
+        "_handle_request",
+    }
+)
+
+
+def cancellable(method: F) -> F:
+    """Marks a servlet method as cancellable.
+
+    Methods with this decorator will be cancelled if the client disconnects before we
+    finish processing the request.
+
+    During cancellation, `Deferred.cancel()` will be invoked on the `Deferred` wrapping
+    the method. The `cancel()` call will propagate down to the `Deferred` that is
+    currently being waited on. That `Deferred` will raise a `CancelledError`, which will
+    propagate up, as per normal exception handling.
+
+    Before applying this decorator to a new endpoint, you MUST recursively check
+    that all `await`s in the function are on `async` functions or `Deferred`s that
+    handle cancellation cleanly, otherwise a variety of bugs may occur, ranging from
+    premature logging context closure, to stuck requests, to database corruption.
+
+    Usage:
+        class SomeServlet(RestServlet):
+            @cancellable
+            async def on_GET(self, request: SynapseRequest) -> ...:
+                ...
+    """
+    if method.__name__ not in _cancellable_method_names and not any(
+        method.__name__.startswith(prefix) for prefix in _cancellable_method_names
+    ):
+        raise ValueError(
+            "@cancellable decorator can only be applied to servlet methods."
+        )
+
+    method.cancellable = True  # type: ignore[attr-defined]
+    return method
+
+
+def is_method_cancellable(method: Callable[..., Any]) -> bool:
+    """Checks whether a servlet method has the `@cancellable` flag."""
+    return getattr(method, "cancellable", False)
+
+
 def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
     """Sends a JSON error response to clients."""
 
@@ -253,6 +316,9 @@ class HttpServer(Protocol):
         If the regex contains groups these gets passed to the callback via
         an unpacked tuple.
 
+        The callback may be marked with the `@cancellable` decorator, which will
+        cause request processing to be cancelled when clients disconnect early.
+
         Args:
             method: The HTTP method to listen to.
             path_patterns: The regex used to match requests.
@@ -283,7 +349,9 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
     def render(self, request: SynapseRequest) -> int:
         """This gets called by twisted every time someone sends us a request."""
-        defer.ensureDeferred(self._async_render_wrapper(request))
+        request.render_deferred = defer.ensureDeferred(
+            self._async_render_wrapper(request)
+        )
         return NOT_DONE_YET
 
     @wrap_async_request_handler
@@ -319,6 +387,8 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
         method_handler = getattr(self, "_async_render_%s" % (request_method,), None)
         if method_handler:
+            request.is_render_cancellable = is_method_cancellable(method_handler)
+
             raw_callback_return = method_handler(request)
 
             # Is it synchronous? We'll allow this for now.
@@ -479,6 +549,8 @@ class JsonResource(DirectServeJsonResource):
     async def _async_render(self, request: SynapseRequest) -> Tuple[int, Any]:
         callback, servlet_classname, group_dict = self._get_handler_for_request(request)
 
+        request.is_render_cancellable = is_method_cancellable(callback)
+
         # Make sure we have an appropriate name for this handler in prometheus
         # (rather than the default of JsonResource).
         request.request_metrics.name = servlet_classname
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 0b85a57d77..eeec74b78a 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, Any, Generator, Optional, Tuple, Union
 import attr
 from zope.interface import implementer
 
+from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IAddress, IReactorTime
 from twisted.python.failure import Failure
 from twisted.web.http import HTTPChannel
@@ -91,6 +92,14 @@ class SynapseRequest(Request):
         # we can't yet create the logcontext, as we don't know the method.
         self.logcontext: Optional[LoggingContext] = None
 
+        # The `Deferred` to cancel if the client disconnects early and
+        # `is_render_cancellable` is set. Expected to be set by `Resource.render`.
+        self.render_deferred: Optional["Deferred[None]"] = None
+        # A boolean indicating whether `render_deferred` should be cancelled if the
+        # client disconnects early. Expected to be set by the coroutine started by
+        # `Resource.render`, if rendering is asynchronous.
+        self.is_render_cancellable = False
+
         global _next_request_seq
         self.request_seq = _next_request_seq
         _next_request_seq += 1
@@ -357,7 +366,21 @@ class SynapseRequest(Request):
                     {"event": "client connection lost", "reason": str(reason.value)}
                 )
 
-            if not self._is_processing:
+            if self._is_processing:
+                if self.is_render_cancellable:
+                    if self.render_deferred is not None:
+                        # Throw a cancellation into the request processing, in the hope
+                        # that it will finish up sooner than it normally would.
+                        # The `self.processing()` context manager will call
+                        # `_finished_processing()` when done.
+                        with PreserveLoggingContext():
+                            self.render_deferred.cancel()
+                    else:
+                        logger.error(
+                            "Connection from client lost, but have no Deferred to "
+                            "cancel even though the request is marked as cancellable."
+                        )
+            else:
                 self._finished_processing()
 
     def _started_processing(self, servlet_name: str) -> None: