summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/__init__.py3
-rw-r--r--synapse/http/client.py33
-rw-r--r--synapse/http/federation/matrix_federation_agent.py13
-rw-r--r--synapse/http/federation/well_known_resolver.py3
-rw-r--r--synapse/http/matrixfederationclient.py20
-rw-r--r--synapse/http/request_metrics.py3
-rw-r--r--synapse/http/server.py123
-rw-r--r--synapse/http/servlet.py2
-rw-r--r--synapse/http/site.py9
9 files changed, 136 insertions, 73 deletions
diff --git a/synapse/http/__init__.py b/synapse/http/__init__.py
index 4bc3cb53f0..c658862fe6 100644
--- a/synapse/http/__init__.py
+++ b/synapse/http/__init__.py
@@ -54,8 +54,7 @@ class QuieterFileBodyProducer(FileBodyProducer):
 
 
 def get_request_user_agent(request: IRequest, default: str = "") -> str:
-    """Return the last User-Agent header, or the given default.
-    """
+    """Return the last User-Agent header, or the given default."""
     # There could be raw utf-8 bytes in the User-Agent header.
 
     # N.B. if you don't do this, the logger explodes cryptically
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 37ccf5ab98..e54d9bd213 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -56,7 +56,7 @@ from twisted.web.client import (
 )
 from twisted.web.http import PotentialDataLoss
 from twisted.web.http_headers import Headers
-from twisted.web.iweb import IAgent, IBodyProducer, IResponse
+from twisted.web.iweb import UNKNOWN_LENGTH, IAgent, IBodyProducer, IResponse
 
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
 from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -398,7 +398,8 @@ class SimpleHttpClient:
                 body_producer = None
                 if data is not None:
                     body_producer = QuieterFileBodyProducer(
-                        BytesIO(data), cooperator=self._cooperator,
+                        BytesIO(data),
+                        cooperator=self._cooperator,
                     )
 
                 request_deferred = treq.request(
@@ -407,13 +408,18 @@ class SimpleHttpClient:
                     agent=self.agent,
                     data=body_producer,
                     headers=headers,
+                    # Avoid buffering the body in treq since we do not reuse
+                    # response bodies.
+                    unbuffered=True,
                     **self._extra_treq_args,
                 )  # type: defer.Deferred
 
                 # we use our own timeout mechanism rather than treq's as a workaround
                 # for https://twistedmatrix.com/trac/ticket/9534.
                 request_deferred = timeout_deferred(
-                    request_deferred, 60, self.hs.get_reactor(),
+                    request_deferred,
+                    60,
+                    self.hs.get_reactor(),
                 )
 
                 # turn timeouts into RequestTimedOutErrors
@@ -699,18 +705,6 @@ class SimpleHttpClient:
 
         resp_headers = dict(response.headers.getAllRawHeaders())
 
-        if (
-            b"Content-Length" in resp_headers
-            and max_size
-            and int(resp_headers[b"Content-Length"][0]) > max_size
-        ):
-            logger.warning("Requested URL is too large > %r bytes" % (max_size,))
-            raise SynapseError(
-                502,
-                "Requested file is too large > %r bytes" % (max_size,),
-                Codes.TOO_LARGE,
-            )
-
         if response.code > 299:
             logger.warning("Got %d when downloading %s" % (response.code, url))
             raise SynapseError(502, "Got error %d" % (response.code,), Codes.UNKNOWN)
@@ -777,7 +771,9 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
         # in the meantime.
         if self.max_size is not None and self.length >= self.max_size:
             self.deferred.errback(BodyExceededMaxSize())
-            self.transport.loseConnection()
+            # Close the connection (forcefully) since all the data will get
+            # discarded anyway.
+            self.transport.abortConnection()
 
     def connectionLost(self, reason: Failure) -> None:
         # If the maximum size was already exceeded, there's nothing to do.
@@ -811,6 +807,11 @@ def read_body_with_max_size(
     Returns:
         A Deferred which resolves to the length of the read body.
     """
+    # If the Content-Length header gives a size larger than the maximum allowed
+    # size, do not bother downloading the body.
+    if max_size is not None and response.length != UNKNOWN_LENGTH:
+        if response.length > max_size:
+            return defer.fail(BodyExceededMaxSize())
 
     d = defer.Deferred()
     response.deliverBody(_ReadBodyWithMaxSizeProtocol(stream, d, max_size))
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 4c06a117d3..2e83fa6773 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -195,8 +195,7 @@ class MatrixFederationAgent:
 
 @implementer(IAgentEndpointFactory)
 class MatrixHostnameEndpointFactory:
-    """Factory for MatrixHostnameEndpoint for parsing to an Agent.
-    """
+    """Factory for MatrixHostnameEndpoint for parsing to an Agent."""
 
     def __init__(
         self,
@@ -261,8 +260,7 @@ class MatrixHostnameEndpoint:
         self._srv_resolver = srv_resolver
 
     def connect(self, protocol_factory: IProtocolFactory) -> defer.Deferred:
-        """Implements IStreamClientEndpoint interface
-        """
+        """Implements IStreamClientEndpoint interface"""
 
         return run_in_background(self._do_connect, protocol_factory)
 
@@ -323,12 +321,19 @@ class MatrixHostnameEndpoint:
         if port or _is_ip_literal(host):
             return [Server(host, port or 8448)]
 
+        logger.debug("Looking up SRV record for %s", host.decode(errors="replace"))
         server_list = await self._srv_resolver.resolve_service(b"_matrix._tcp." + host)
 
         if server_list:
+            logger.debug(
+                "Got %s from SRV lookup for %s",
+                ", ".join(map(str, server_list)),
+                host.decode(errors="replace"),
+            )
             return server_list
 
         # No SRV records, so we fallback to host and 8448
+        logger.debug("No SRV records for %s", host.decode(errors="replace"))
         return [Server(host, 8448)]
 
 
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index b3b6dbcab0..4def7d7633 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -81,8 +81,7 @@ class WellKnownLookupResult:
 
 
 class WellKnownResolver:
-    """Handles well-known lookups for matrix servers.
-    """
+    """Handles well-known lookups for matrix servers."""
 
     def __init__(
         self,
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 19293bf673..cde42e9f5e 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -254,7 +254,8 @@ class MatrixFederationHttpClient:
         # Use a BlacklistingAgentWrapper to prevent circumventing the IP
         # blacklist via IP literals in server names
         self.agent = BlacklistingAgentWrapper(
-            self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
+            self.agent,
+            ip_blacklist=hs.config.federation_ip_range_blacklist,
         )
 
         self.clock = hs.get_clock()
@@ -652,7 +653,7 @@ class MatrixFederationHttpClient:
         backoff_on_404: bool = False,
         try_trailing_slash_on_400: bool = False,
     ) -> Union[JsonDict, list]:
-        """ Sends the specified json data using PUT
+        """Sends the specified json data using PUT
 
         Args:
             destination: The remote server to send the HTTP request to.
@@ -740,7 +741,7 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         args: Optional[QueryArgs] = None,
     ) -> Union[JsonDict, list]:
-        """ Sends the specified json data using POST
+        """Sends the specified json data using POST
 
         Args:
             destination: The remote server to send the HTTP request to.
@@ -799,7 +800,11 @@ class MatrixFederationHttpClient:
             _sec_timeout = self.default_timeout
 
         body = await _handle_json_response(
-            self.reactor, _sec_timeout, request, response, start_ms,
+            self.reactor,
+            _sec_timeout,
+            request,
+            response,
+            start_ms,
         )
         return body
 
@@ -813,7 +818,7 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         try_trailing_slash_on_400: bool = False,
     ) -> Union[JsonDict, list]:
-        """ GETs some json from the given host homeserver and path
+        """GETs some json from the given host homeserver and path
 
         Args:
             destination: The remote server to send the HTTP request to.
@@ -994,7 +999,10 @@ class MatrixFederationHttpClient:
         except BodyExceededMaxSize:
             msg = "Requested file is too large > %r bytes" % (max_size,)
             logger.warning(
-                "{%s} [%s] %s", request.txn_id, request.destination, msg,
+                "{%s} [%s] %s",
+                request.txn_id,
+                request.destination,
+                msg,
             )
             raise SynapseError(502, msg, Codes.TOO_LARGE)
         except Exception as e:
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index 7c5defec82..0ec5d941b8 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -213,8 +213,7 @@ class RequestMetrics:
         self.update_metrics()
 
     def update_metrics(self):
-        """Updates the in flight metrics with values from this request.
-        """
+        """Updates the in flight metrics with values from this request."""
         new_stats = self.start_context.get_resource_usage()
 
         diff = new_stats - self._request_stats
diff --git a/synapse/http/server.py b/synapse/http/server.py
index e464bfe6c7..845db9b78d 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,10 +22,22 @@ import types
 import urllib
 from http import HTTPStatus
 from io import BytesIO
-from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
+from typing import (
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Pattern,
+    Tuple,
+    Union,
+)
 
 import jinja2
 from canonicaljson import iterencode_canonical_json
+from typing_extensions import Protocol
 from zope.interface import implementer
 
 from twisted.internet import defer, interfaces
@@ -64,8 +76,7 @@ HTML_ERROR_TEMPLATE = """<!DOCTYPE html>
 
 
 def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
-    """Sends a JSON error response to clients.
-    """
+    """Sends a JSON error response to clients."""
 
     if f.check(SynapseError):
         error_code = f.value.code
@@ -94,12 +105,17 @@ def return_json_error(f: failure.Failure, request: SynapseRequest) -> None:
                 pass
     else:
         respond_with_json(
-            request, error_code, error_dict, send_cors=True,
+            request,
+            error_code,
+            error_dict,
+            send_cors=True,
         )
 
 
 def return_html_error(
-    f: failure.Failure, request: Request, error_template: Union[str, jinja2.Template],
+    f: failure.Failure,
+    request: Request,
+    error_template: Union[str, jinja2.Template],
 ) -> None:
     """Sends an HTML error page corresponding to the given failure.
 
@@ -168,24 +184,39 @@ def wrap_async_request_handler(h):
     return preserve_fn(wrapped_async_request_handler)
 
 
-class HttpServer:
-    """ Interface for registering callbacks on a HTTP server
-    """
+# Type of a callback method for processing requests
+# it is actually called with a SynapseRequest and a kwargs dict for the params,
+# but I can't figure out how to represent that.
+ServletCallback = Callable[
+    ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
+]
 
-    def register_paths(self, method, path_patterns, callback):
-        """ Register a callback that gets fired if we receive a http request
+
+class HttpServer(Protocol):
+    """Interface for registering callbacks on a HTTP server"""
+
+    def register_paths(
+        self,
+        method: str,
+        path_patterns: Iterable[Pattern],
+        callback: ServletCallback,
+        servlet_classname: str,
+    ) -> None:
+        """Register a callback that gets fired if we receive a http request
         with the given method for a path that matches the given regex.
 
         If the regex contains groups these gets passed to the callback via
         an unpacked tuple.
 
         Args:
-            method (str): The method to listen to.
-            path_patterns (list<SRE_Pattern>): The regex used to match requests.
-            callback (function): The function to fire if we receive a matched
+            method: The HTTP method to listen to.
+            path_patterns: The regex used to match requests.
+            callback: The function to fire if we receive a matched
                 request. The first argument will be the request object and
                 subsequent arguments will be any matched groups from the regex.
-                This should return a tuple of (code, response).
+                This should return either tuple of (code, response), or None.
+            servlet_classname (str): The name of the handler to be used in prometheus
+                and opentracing logs.
         """
         pass
 
@@ -207,8 +238,7 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
         self._extract_context = extract_context
 
     def render(self, request):
-        """ This gets called by twisted every time someone sends us a request.
-        """
+        """This gets called by twisted every time someone sends us a request."""
         defer.ensureDeferred(self._async_render_wrapper(request))
         return NOT_DONE_YET
 
@@ -259,13 +289,18 @@ class _AsyncResource(resource.Resource, metaclass=abc.ABCMeta):
 
     @abc.abstractmethod
     def _send_response(
-        self, request: SynapseRequest, code: int, response_object: Any,
+        self,
+        request: SynapseRequest,
+        code: int,
+        response_object: Any,
     ) -> None:
         raise NotImplementedError()
 
     @abc.abstractmethod
     def _send_error_response(
-        self, f: failure.Failure, request: SynapseRequest,
+        self,
+        f: failure.Failure,
+        request: SynapseRequest,
     ) -> None:
         raise NotImplementedError()
 
@@ -280,10 +315,12 @@ class DirectServeJsonResource(_AsyncResource):
         self.canonical_json = canonical_json
 
     def _send_response(
-        self, request: Request, code: int, response_object: Any,
+        self,
+        request: Request,
+        code: int,
+        response_object: Any,
     ):
-        """Implements _AsyncResource._send_response
-        """
+        """Implements _AsyncResource._send_response"""
         # TODO: Only enable CORS for the requests that need it.
         respond_with_json(
             request,
@@ -294,15 +331,16 @@ class DirectServeJsonResource(_AsyncResource):
         )
 
     def _send_error_response(
-        self, f: failure.Failure, request: SynapseRequest,
+        self,
+        f: failure.Failure,
+        request: SynapseRequest,
     ) -> None:
-        """Implements _AsyncResource._send_error_response
-        """
+        """Implements _AsyncResource._send_error_response"""
         return_json_error(f, request)
 
 
 class JsonResource(DirectServeJsonResource):
-    """ This implements the HttpServer interface and provides JSON support for
+    """This implements the HttpServer interface and provides JSON support for
     Resources.
 
     Register callbacks via register_paths()
@@ -354,7 +392,7 @@ class JsonResource(DirectServeJsonResource):
 
     def _get_handler_for_request(
         self, request: SynapseRequest
-    ) -> Tuple[Callable, str, Dict[str, str]]:
+    ) -> Tuple[ServletCallback, str, Dict[str, str]]:
         """Finds a callback method to handle the given request.
 
         Returns:
@@ -415,10 +453,12 @@ class DirectServeHtmlResource(_AsyncResource):
     ERROR_TEMPLATE = HTML_ERROR_TEMPLATE
 
     def _send_response(
-        self, request: SynapseRequest, code: int, response_object: Any,
+        self,
+        request: SynapseRequest,
+        code: int,
+        response_object: Any,
     ):
-        """Implements _AsyncResource._send_response
-        """
+        """Implements _AsyncResource._send_response"""
         # We expect to get bytes for us to write
         assert isinstance(response_object, bytes)
         html_bytes = response_object
@@ -426,10 +466,11 @@ class DirectServeHtmlResource(_AsyncResource):
         respond_with_html_bytes(request, 200, html_bytes)
 
     def _send_error_response(
-        self, f: failure.Failure, request: SynapseRequest,
+        self,
+        f: failure.Failure,
+        request: SynapseRequest,
     ) -> None:
-        """Implements _AsyncResource._send_error_response
-        """
+        """Implements _AsyncResource._send_error_response"""
         return_html_error(f, request, self.ERROR_TEMPLATE)
 
 
@@ -506,7 +547,9 @@ class _ByteProducer:
     min_chunk_size = 1024
 
     def __init__(
-        self, request: Request, iterator: Iterator[bytes],
+        self,
+        request: Request,
+        iterator: Iterator[bytes],
     ):
         self._request = request
         self._iterator = iterator
@@ -626,7 +669,10 @@ def respond_with_json(
 
 
 def respond_with_json_bytes(
-    request: Request, code: int, json_bytes: bytes, send_cors: bool = False,
+    request: Request,
+    code: int,
+    json_bytes: bytes,
+    send_cors: bool = False,
 ):
     """Sends encoded JSON in response to the given request.
 
@@ -733,8 +779,15 @@ def set_clickjacking_protection_headers(request: Request):
     request.setHeader(b"Content-Security-Policy", b"frame-ancestors 'none';")
 
 
+def respond_with_redirect(request: Request, url: bytes) -> None:
+    """Write a 302 response to the request, if it is still alive."""
+    logger.debug("Redirect to %s", url.decode("utf-8"))
+    request.redirect(url)
+    finish_request(request)
+
+
 def finish_request(request: Request):
-    """ Finish writing the response to the request.
+    """Finish writing the response to the request.
 
     Twisted throws a RuntimeException if the connection closed before the
     response was written but doesn't provide a convenient or reliable way to
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index b361b7cbaf..0e637f4701 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -258,7 +258,7 @@ def assert_params_in_dict(body, required):
 
 class RestServlet:
 
-    """ A Synapse REST Servlet.
+    """A Synapse REST Servlet.
 
     An implementing class can either provide its own custom 'register' method,
     or use the automatic pattern handling provided by the base class.
diff --git a/synapse/http/site.py b/synapse/http/site.py
index 12ec3f851f..4a4fb5ef26 100644
--- a/synapse/http/site.py
+++ b/synapse/http/site.py
@@ -249,8 +249,7 @@ class SynapseRequest(Request):
         )
 
     def _finished_processing(self):
-        """Log the completion of this request and update the metrics
-        """
+        """Log the completion of this request and update the metrics"""
         assert self.logcontext is not None
         usage = self.logcontext.get_resource_usage()
 
@@ -276,7 +275,8 @@ class SynapseRequest(Request):
             # authenticated (e.g. and admin is puppetting a user) then we log both.
             if self.requester.user.to_string() != authenticated_entity:
                 authenticated_entity = "{},{}".format(
-                    authenticated_entity, self.requester.user.to_string(),
+                    authenticated_entity,
+                    self.requester.user.to_string(),
                 )
         elif self.requester is not None:
             # This shouldn't happen, but we log it so we don't lose information
@@ -322,8 +322,7 @@ class SynapseRequest(Request):
             logger.warning("Failed to stop metrics: %r", e)
 
     def _should_log_request(self) -> bool:
-        """Whether we should log at INFO that we processed the request.
-        """
+        """Whether we should log at INFO that we processed the request."""
         if self.path == b"/health":
             return False