summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/client.py19
-rw-r--r--synapse/http/connectproxyclient.py2
-rw-r--r--synapse/http/federation/matrix_federation_agent.py10
-rw-r--r--synapse/http/federation/srv_resolver.py4
-rw-r--r--synapse/http/federation/well_known_resolver.py66
-rw-r--r--synapse/http/matrixfederationclient.py11
-rw-r--r--synapse/http/request_metrics.py2
-rw-r--r--synapse/http/server.py110
-rw-r--r--synapse/http/servlet.py7
9 files changed, 165 insertions, 66 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 8aeb70cdec..13fcab3378 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -19,7 +19,7 @@ import urllib
 from io import BytesIO
 
 import treq
-from canonicaljson import encode_canonical_json, json
+from canonicaljson import encode_canonical_json
 from netaddr import IPAddress
 from prometheus_client import Counter
 from zope.interface import implementer, provider
@@ -47,6 +47,7 @@ from synapse.http import (
 from synapse.http.proxyagent import ProxyAgent
 from synapse.logging.context import make_deferred_yieldable
 from synapse.logging.opentracing import set_tag, start_active_span, tags
+from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 
 logger = logging.getLogger(__name__)
@@ -85,7 +86,7 @@ def _make_scheduler(reactor):
     return _scheduler
 
 
-class IPBlacklistingResolver(object):
+class IPBlacklistingResolver:
     """
     A proxy for reactor.nameResolver which only produces non-blacklisted IP
     addresses, preventing DNS rebinding attacks on URL preview.
@@ -132,7 +133,7 @@ class IPBlacklistingResolver(object):
             r.resolutionComplete()
 
         @provider(IResolutionReceiver)
-        class EndpointReceiver(object):
+        class EndpointReceiver:
             @staticmethod
             def resolutionBegan(resolutionInProgress):
                 pass
@@ -191,7 +192,7 @@ class BlacklistingAgentWrapper(Agent):
         )
 
 
-class SimpleHttpClient(object):
+class SimpleHttpClient:
     """
     A simple, no-frills HTTP client with methods that wrap up common ways of
     using HTTP in Matrix
@@ -243,7 +244,7 @@ class SimpleHttpClient(object):
             )
 
             @implementer(IReactorPluggableNameResolver)
-            class Reactor(object):
+            class Reactor:
                 def __getattr__(_self, attr):
                     if attr == "nameResolver":
                         return nameResolver
@@ -391,7 +392,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -433,7 +434,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
@@ -463,7 +464,7 @@ class SimpleHttpClient(object):
             actual_headers.update(headers)
 
         body = await self.get_raw(uri, args, headers=headers)
-        return json.loads(body.decode("utf-8"))
+        return json_decoder.decode(body.decode("utf-8"))
 
     async def put_json(self, uri, json_body, args={}, headers=None):
         """ Puts some json to the given URI.
@@ -506,7 +507,7 @@ class SimpleHttpClient(object):
         body = await make_deferred_yieldable(readBody(response))
 
         if 200 <= response.code < 300:
-            return json.loads(body.decode("utf-8"))
+            return json_decoder.decode(body.decode("utf-8"))
         else:
             raise HttpResponseException(
                 response.code, response.phrase.decode("ascii", errors="replace"), body
diff --git a/synapse/http/connectproxyclient.py b/synapse/http/connectproxyclient.py
index be7b2ceb8e..856e28454f 100644
--- a/synapse/http/connectproxyclient.py
+++ b/synapse/http/connectproxyclient.py
@@ -31,7 +31,7 @@ class ProxyConnectError(ConnectError):
 
 
 @implementer(IStreamClientEndpoint)
-class HTTPConnectProxyEndpoint(object):
+class HTTPConnectProxyEndpoint:
     """An Endpoint implementation which will send a CONNECT request to an http proxy
 
     Wraps an existing HostnameEndpoint for the proxy.
diff --git a/synapse/http/federation/matrix_federation_agent.py b/synapse/http/federation/matrix_federation_agent.py
index 369bf9c2fc..83d6196d4a 100644
--- a/synapse/http/federation/matrix_federation_agent.py
+++ b/synapse/http/federation/matrix_federation_agent.py
@@ -36,7 +36,7 @@ logger = logging.getLogger(__name__)
 
 
 @implementer(IAgent)
-class MatrixFederationAgent(object):
+class MatrixFederationAgent:
     """An Agent-like thing which provides a `request` method which correctly
     handles resolving matrix server names when using matrix://. Handles standard
     https URIs as normal.
@@ -134,8 +134,8 @@ class MatrixFederationAgent(object):
             and not _is_ip_literal(parsed_uri.hostname)
             and not parsed_uri.port
         ):
-            well_known_result = yield self._well_known_resolver.get_well_known(
-                parsed_uri.hostname
+            well_known_result = yield defer.ensureDeferred(
+                self._well_known_resolver.get_well_known(parsed_uri.hostname)
             )
             delegated_server = well_known_result.delegated_server
 
@@ -175,7 +175,7 @@ class MatrixFederationAgent(object):
 
 
 @implementer(IAgentEndpointFactory)
-class MatrixHostnameEndpointFactory(object):
+class MatrixHostnameEndpointFactory:
     """Factory for MatrixHostnameEndpoint for parsing to an Agent.
     """
 
@@ -198,7 +198,7 @@ class MatrixHostnameEndpointFactory(object):
 
 
 @implementer(IStreamClientEndpoint)
-class MatrixHostnameEndpoint(object):
+class MatrixHostnameEndpoint:
     """An endpoint that resolves matrix:// URLs using Matrix server name
     resolution (i.e. via SRV). Does not check for well-known delegation.
 
diff --git a/synapse/http/federation/srv_resolver.py b/synapse/http/federation/srv_resolver.py
index 2ede90a9b1..d9620032d2 100644
--- a/synapse/http/federation/srv_resolver.py
+++ b/synapse/http/federation/srv_resolver.py
@@ -33,7 +33,7 @@ SERVER_CACHE = {}
 
 
 @attr.s(slots=True, frozen=True)
-class Server(object):
+class Server:
     """
     Our record of an individual server which can be tried to reach a destination.
 
@@ -96,7 +96,7 @@ def _sort_server_list(server_list):
     return results
 
 
-class SrvResolver(object):
+class SrvResolver:
     """Interface to the dns client to do SRV lookups, with result caching.
 
     The default resolver in twisted.names doesn't do any caching (it has a CacheResolver,
diff --git a/synapse/http/federation/well_known_resolver.py b/synapse/http/federation/well_known_resolver.py
index 89a3b041ce..e6f067ca29 100644
--- a/synapse/http/federation/well_known_resolver.py
+++ b/synapse/http/federation/well_known_resolver.py
@@ -13,10 +13,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-import json
 import logging
 import random
 import time
+from typing import Callable, Dict, Optional, Tuple
 
 import attr
 
@@ -24,9 +24,10 @@ from twisted.internet import defer
 from twisted.web.client import RedirectAgent, readBody
 from twisted.web.http import stringToDatetime
 from twisted.web.http_headers import Headers
+from twisted.web.iweb import IResponse
 
 from synapse.logging.context import make_deferred_yieldable
-from synapse.util import Clock
+from synapse.util import Clock, json_decoder
 from synapse.util.caches.ttlcache import TTLCache
 from synapse.util.metrics import Measure
 
@@ -70,11 +71,11 @@ _had_valid_well_known_cache = TTLCache("had-valid-well-known")
 
 
 @attr.s(slots=True, frozen=True)
-class WellKnownLookupResult(object):
+class WellKnownLookupResult:
     delegated_server = attr.ib()
 
 
-class WellKnownResolver(object):
+class WellKnownResolver:
     """Handles well-known lookups for matrix servers.
     """
 
@@ -100,15 +101,14 @@ class WellKnownResolver(object):
         self._well_known_agent = RedirectAgent(agent)
         self.user_agent = user_agent
 
-    @defer.inlineCallbacks
-    def get_well_known(self, server_name):
+    async def get_well_known(self, server_name: bytes) -> WellKnownLookupResult:
         """Attempt to fetch and parse a .well-known file for the given server
 
         Args:
-            server_name (bytes): name of the server, from the requested url
+            server_name: name of the server, from the requested url
 
         Returns:
-            Deferred[WellKnownLookupResult]: The result of the lookup
+            The result of the lookup
         """
         try:
             prev_result, expiry, ttl = self._well_known_cache.get_with_expiry(
@@ -125,7 +125,9 @@ class WellKnownResolver(object):
         # requests for the same server in parallel?
         try:
             with Measure(self._clock, "get_well_known"):
-                result, cache_period = yield self._fetch_well_known(server_name)
+                result, cache_period = await self._fetch_well_known(
+                    server_name
+                )  # type: Tuple[Optional[bytes], float]
 
         except _FetchWellKnownFailure as e:
             if prev_result and e.temporary:
@@ -154,18 +156,17 @@ class WellKnownResolver(object):
 
         return WellKnownLookupResult(delegated_server=result)
 
-    @defer.inlineCallbacks
-    def _fetch_well_known(self, server_name):
+    async def _fetch_well_known(self, server_name: bytes) -> Tuple[bytes, float]:
         """Actually fetch and parse a .well-known, without checking the cache
 
         Args:
-            server_name (bytes): name of the server, from the requested url
+            server_name: name of the server, from the requested url
 
         Raises:
             _FetchWellKnownFailure if we fail to lookup a result
 
         Returns:
-            Deferred[Tuple[bytes,int]]: The lookup result and cache period.
+            The lookup result and cache period.
         """
 
         had_valid_well_known = self._had_valid_well_known_cache.get(server_name, False)
@@ -173,7 +174,7 @@ class WellKnownResolver(object):
         # We do this in two steps to differentiate between possibly transient
         # errors (e.g. can't connect to host, 503 response) and more permenant
         # errors (such as getting a 404 response).
-        response, body = yield self._make_well_known_request(
+        response, body = await self._make_well_known_request(
             server_name, retry=had_valid_well_known
         )
 
@@ -181,7 +182,7 @@ class WellKnownResolver(object):
             if response.code != 200:
                 raise Exception("Non-200 response %s" % (response.code,))
 
-            parsed_body = json.loads(body.decode("utf-8"))
+            parsed_body = json_decoder.decode(body.decode("utf-8"))
             logger.info("Response from .well-known: %s", parsed_body)
 
             result = parsed_body["m.server"].encode("ascii")
@@ -216,20 +217,20 @@ class WellKnownResolver(object):
 
         return result, cache_period
 
-    @defer.inlineCallbacks
-    def _make_well_known_request(self, server_name, retry):
+    async def _make_well_known_request(
+        self, server_name: bytes, retry: bool
+    ) -> Tuple[IResponse, bytes]:
         """Make the well known request.
 
         This will retry the request if requested and it fails (with unable
         to connect or receives a 5xx error).
 
         Args:
-            server_name (bytes)
-            retry (bool): Whether to retry the request if it fails.
+            server_name: name of the server, from the requested url
+            retry: Whether to retry the request if it fails.
 
         Returns:
-            Deferred[tuple[IResponse, bytes]] Returns the response object and
-            body. Response may be a non-200 response.
+            Returns the response object and body. Response may be a non-200 response.
         """
         uri = b"https://%s/.well-known/matrix/server" % (server_name,)
         uri_str = uri.decode("ascii")
@@ -244,12 +245,12 @@ class WellKnownResolver(object):
 
             logger.info("Fetching %s", uri_str)
             try:
-                response = yield make_deferred_yieldable(
+                response = await make_deferred_yieldable(
                     self._well_known_agent.request(
                         b"GET", uri, headers=Headers(headers)
                     )
                 )
-                body = yield make_deferred_yieldable(readBody(response))
+                body = await make_deferred_yieldable(readBody(response))
 
                 if 500 <= response.code < 600:
                     raise Exception("Non-200 response %s" % (response.code,))
@@ -266,21 +267,24 @@ class WellKnownResolver(object):
                 logger.info("Error fetching %s: %s. Retrying", uri_str, e)
 
             # Sleep briefly in the hopes that they come back up
-            yield self._clock.sleep(0.5)
+            await self._clock.sleep(0.5)
 
 
-def _cache_period_from_headers(headers, time_now=time.time):
+def _cache_period_from_headers(
+    headers: Headers, time_now: Callable[[], float] = time.time
+) -> Optional[float]:
     cache_controls = _parse_cache_control(headers)
 
     if b"no-store" in cache_controls:
         return 0
 
     if b"max-age" in cache_controls:
-        try:
-            max_age = int(cache_controls[b"max-age"])
-            return max_age
-        except ValueError:
-            pass
+        max_age = cache_controls[b"max-age"]
+        if max_age:
+            try:
+                return int(max_age)
+            except ValueError:
+                pass
 
     expires = headers.getRawHeaders(b"expires")
     if expires is not None:
@@ -296,7 +300,7 @@ def _cache_period_from_headers(headers, time_now=time.time):
     return None
 
 
-def _parse_cache_control(headers):
+def _parse_cache_control(headers: Headers) -> Dict[bytes, Optional[bytes]]:
     cache_controls = {}
     for hdr in headers.getRawHeaders(b"cache-control", []):
         for directive in hdr.split(b","):
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 738be43f46..5eaf3151ce 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -54,6 +54,7 @@ from synapse.logging.opentracing import (
     start_active_span,
     tags,
 )
+from synapse.util import json_decoder
 from synapse.util.async_helpers import timeout_deferred
 from synapse.util.metrics import Measure
 
@@ -76,7 +77,7 @@ _next_id = 1
 
 
 @attr.s(frozen=True)
-class MatrixFederationRequest(object):
+class MatrixFederationRequest:
     method = attr.ib()
     """HTTP method
     :type: str
@@ -164,7 +165,9 @@ async def _handle_json_response(
     try:
         check_content_type_is_json(response.headers)
 
-        d = treq.json_content(response)
+        # Use the custom JSON decoder (partially re-implements treq.json_content).
+        d = treq.text_content(response, encoding="utf-8")
+        d.addCallback(json_decoder.decode)
         d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
 
         body = await make_deferred_yieldable(d)
@@ -203,7 +206,7 @@ async def _handle_json_response(
     return body
 
 
-class MatrixFederationHttpClient(object):
+class MatrixFederationHttpClient:
     """HTTP client used to talk to other homeservers over the federation
     protocol. Send client certificates and signs requests.
 
@@ -226,7 +229,7 @@ class MatrixFederationHttpClient(object):
         )
 
         @implementer(IReactorPluggableNameResolver)
-        class Reactor(object):
+        class Reactor:
             def __getattr__(_self, attr):
                 if attr == "nameResolver":
                     return nameResolver
diff --git a/synapse/http/request_metrics.py b/synapse/http/request_metrics.py
index b58ae3d9db..cd94e789e8 100644
--- a/synapse/http/request_metrics.py
+++ b/synapse/http/request_metrics.py
@@ -145,7 +145,7 @@ LaterGauge(
 )
 
 
-class RequestMetrics(object):
+class RequestMetrics:
     def start(self, time_sec, name, method):
         self.start = time_sec
         self.start_context = current_context()
diff --git a/synapse/http/server.py b/synapse/http/server.py
index ffe6cfa09e..996a31a9ec 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,12 +22,13 @@ import types
 import urllib
 from http import HTTPStatus
 from io import BytesIO
-from typing import Any, Callable, Dict, Tuple, Union
+from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
 
 import jinja2
-from canonicaljson import encode_canonical_json, encode_pretty_printed_json
+from canonicaljson import iterencode_canonical_json, iterencode_pretty_printed_json
+from zope.interface import implementer
 
-from twisted.internet import defer
+from twisted.internet import defer, interfaces
 from twisted.python import failure
 from twisted.web import resource
 from twisted.web.server import NOT_DONE_YET, Request
@@ -173,7 +174,7 @@ def wrap_async_request_handler(h):
     return preserve_fn(wrapped_async_request_handler)
 
 
-class HttpServer(object):
+class HttpServer:
     """ Interface for registering callbacks on a HTTP server
     """
 
@@ -499,6 +500,90 @@ class RootOptionsRedirectResource(OptionsResource, RootRedirect):
     pass
 
 
+@implementer(interfaces.IPushProducer)
+class _ByteProducer:
+    """
+    Iteratively write bytes to the request.
+    """
+
+    # The minimum number of bytes for each chunk. Note that the last chunk will
+    # usually be smaller than this.
+    min_chunk_size = 1024
+
+    def __init__(
+        self, request: Request, iterator: Iterator[bytes],
+    ):
+        self._request = request
+        self._iterator = iterator
+        self._paused = False
+
+        # Register the producer and start producing data.
+        self._request.registerProducer(self, True)
+        self.resumeProducing()
+
+    def _send_data(self, data: List[bytes]) -> None:
+        """
+        Send a list of bytes as a chunk of a response.
+        """
+        if not data:
+            return
+        self._request.write(b"".join(data))
+
+    def pauseProducing(self) -> None:
+        self._paused = True
+
+    def resumeProducing(self) -> None:
+        # We've stopped producing in the meantime (note that this might be
+        # re-entrant after calling write).
+        if not self._request:
+            return
+
+        self._paused = False
+
+        # Write until there's backpressure telling us to stop.
+        while not self._paused:
+            # Get the next chunk and write it to the request.
+            #
+            # The output of the JSON encoder is buffered and coalesced until
+            # min_chunk_size is reached. This is because JSON encoders produce
+            # very small output per iteration and the Request object converts
+            # each call to write() to a separate chunk. Without this there would
+            # be an explosion in bytes written (e.g. b"{" becoming "1\r\n{\r\n").
+            #
+            # Note that buffer stores a list of bytes (instead of appending to
+            # bytes) to hopefully avoid many allocations.
+            buffer = []
+            buffered_bytes = 0
+            while buffered_bytes < self.min_chunk_size:
+                try:
+                    data = next(self._iterator)
+                    buffer.append(data)
+                    buffered_bytes += len(data)
+                except StopIteration:
+                    # The entire JSON object has been serialized, write any
+                    # remaining data, finalize the producer and the request, and
+                    # clean-up any references.
+                    self._send_data(buffer)
+                    self._request.unregisterProducer()
+                    self._request.finish()
+                    self.stopProducing()
+                    return
+
+            self._send_data(buffer)
+
+    def stopProducing(self) -> None:
+        # Clear a circular reference.
+        self._request = None
+
+
+def _encode_json_bytes(json_object: Any) -> Iterator[bytes]:
+    """
+    Encode an object into JSON. Returns an iterator of bytes.
+    """
+    for chunk in json_encoder.iterencode(json_object):
+        yield chunk.encode("utf-8")
+
+
 def respond_with_json(
     request: Request,
     code: int,
@@ -533,15 +618,22 @@ def respond_with_json(
         return None
 
     if pretty_print:
-        json_bytes = encode_pretty_printed_json(json_object) + b"\n"
+        encoder = iterencode_pretty_printed_json
     else:
         if canonical_json or synapse.events.USE_FROZEN_DICTS:
-            # canonicaljson already encodes to bytes
-            json_bytes = encode_canonical_json(json_object)
+            encoder = iterencode_canonical_json
         else:
-            json_bytes = json_encoder.encode(json_object).encode("utf-8")
+            encoder = _encode_json_bytes
+
+    request.setResponseCode(code)
+    request.setHeader(b"Content-Type", b"application/json")
+    request.setHeader(b"Cache-Control", b"no-cache, no-store, must-revalidate")
 
-    return respond_with_json_bytes(request, code, json_bytes, send_cors=send_cors)
+    if send_cors:
+        set_cors_headers(request)
+
+    _ByteProducer(request, encoder(json_object))
+    return NOT_DONE_YET
 
 
 def respond_with_json_bytes(
diff --git a/synapse/http/servlet.py b/synapse/http/servlet.py
index a34e5ead88..fd90ba7828 100644
--- a/synapse/http/servlet.py
+++ b/synapse/http/servlet.py
@@ -17,9 +17,8 @@
 
 import logging
 
-from canonicaljson import json
-
 from synapse.api.errors import Codes, SynapseError
+from synapse.util import json_decoder
 
 logger = logging.getLogger(__name__)
 
@@ -215,7 +214,7 @@ def parse_json_value_from_request(request, allow_empty_body=False):
         return None
 
     try:
-        content = json.loads(content_bytes.decode("utf-8"))
+        content = json_decoder.decode(content_bytes.decode("utf-8"))
     except Exception as e:
         logger.warning("Unable to parse JSON: %s", e)
         raise SynapseError(400, "Content not JSON.", errcode=Codes.NOT_JSON)
@@ -257,7 +256,7 @@ def assert_params_in_dict(body, required):
         raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM)
 
 
-class RestServlet(object):
+class RestServlet:
 
     """ A Synapse REST Servlet.