summary refs log tree commit diff
path: root/synapse/http
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http')
-rw-r--r--synapse/http/client.py7
-rw-r--r--synapse/http/matrixfederationclient.py160
2 files changed, 131 insertions, 36 deletions
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 5f40f16e24..1ca6624fd5 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -813,7 +813,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
         if self.deferred.called:
             return
 
-        self.stream.write(data)
+        try:
+            self.stream.write(data)
+        except Exception:
+            self.deferred.errback()
+            return
+
         self.length += len(data)
         # The first time the maximum size is exceeded, error and cancel the
         # connection. dataReceived might be called again if data was received
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index bb837b7b19..f5503b394b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import cgi
 import codecs
 import logging
@@ -19,13 +20,24 @@ import sys
 import typing
 import urllib.parse
 from io import BytesIO, StringIO
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import (
+    Callable,
+    Dict,
+    Generic,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+    Union,
+    overload,
+)
 
 import attr
 import treq
 from canonicaljson import encode_canonical_json
 from prometheus_client import Counter
 from signedjson.sign import sign_json
+from typing_extensions import Literal
 
 from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
@@ -48,6 +60,7 @@ from synapse.http.client import (
     BlacklistingAgentWrapper,
     BlacklistingReactorWrapper,
     BodyExceededMaxSize,
+    ByteWriteable,
     encode_query_args,
     read_body_with_max_size,
 )
@@ -88,6 +101,27 @@ _next_id = 1
 QueryArgs = Dict[str, Union[str, List[str]]]
 
 
+T = TypeVar("T")
+
+
+class ByteParser(ByteWriteable, Generic[T], abc.ABC):
+    """A `ByteWriteable` that has an additional `finish` function that returns
+    the parsed data.
+    """
+
+    CONTENT_TYPE = abc.abstractproperty()  # type: str  # type: ignore
+    """The expected content type of the response, e.g. `application/json`. If
+    the content type doesn't match we fail the request.
+    """
+
+    @abc.abstractmethod
+    def finish(self) -> T:
+        """Called when response has finished streaming and the parser should
+        return the final result (or error).
+        """
+        pass
+
+
 @attr.s(slots=True, frozen=True)
 class MatrixFederationRequest:
     method = attr.ib(type=str)
@@ -148,15 +182,32 @@ class MatrixFederationRequest:
         return self.json
 
 
-async def _handle_json_response(
+class JsonParser(ByteParser[Union[JsonDict, list]]):
+    """A parser that buffers the response and tries to parse it as JSON."""
+
+    CONTENT_TYPE = "application/json"
+
+    def __init__(self):
+        self._buffer = StringIO()
+        self._binary_wrapper = BinaryIOWrapper(self._buffer)
+
+    def write(self, data: bytes) -> int:
+        return self._binary_wrapper.write(data)
+
+    def finish(self) -> Union[JsonDict, list]:
+        return json_decoder.decode(self._buffer.getvalue())
+
+
+async def _handle_response(
     reactor: IReactorTime,
     timeout_sec: float,
     request: MatrixFederationRequest,
     response: IResponse,
     start_ms: int,
-) -> JsonDict:
+    parser: ByteParser[T],
+) -> T:
     """
-    Reads the JSON body of a response, with a timeout
+    Reads the body of a response with a timeout and sends it to a parser
 
     Args:
         reactor: twisted reactor, for the timeout
@@ -164,23 +215,21 @@ async def _handle_json_response(
         request: the request that triggered the response
         response: response to the request
         start_ms: Timestamp when request was made
+        parser: The parser for the response
 
     Returns:
-        The parsed JSON response
+        The parsed response
     """
+
     try:
-        check_content_type_is_json(response.headers)
+        check_content_type_is(response.headers, parser.CONTENT_TYPE)
 
-        buf = StringIO()
-        d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
+        d = read_body_with_max_size(response, parser, MAX_RESPONSE_SIZE)
         d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
 
-        def parse(_len: int):
-            return json_decoder.decode(buf.getvalue())
-
-        d.addCallback(parse)
+        length = await make_deferred_yieldable(d)
 
-        body = await make_deferred_yieldable(d)
+        value = parser.finish()
     except BodyExceededMaxSize as e:
         # The response was too big.
         logger.warning(
@@ -193,9 +242,9 @@ async def _handle_json_response(
         )
         raise RequestSendFailed(e, can_retry=False) from e
     except ValueError as e:
-        # The JSON content was invalid.
+        # The content was invalid.
         logger.warning(
-            "{%s} [%s] Failed to parse JSON response - %s %s",
+            "{%s} [%s] Failed to parse response - %s %s",
             request.txn_id,
             request.destination,
             request.method,
@@ -225,16 +274,17 @@ async def _handle_json_response(
     time_taken_secs = reactor.seconds() - start_ms / 1000
 
     logger.info(
-        "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
+        "{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s",
         request.txn_id,
         request.destination,
         response.code,
         response.phrase.decode("ascii", errors="replace"),
         time_taken_secs,
+        length,
         request.method,
         request.uri.decode("ascii"),
     )
-    return body
+    return value
 
 
 class BinaryIOWrapper:
@@ -671,6 +721,7 @@ class MatrixFederationHttpClient:
             )
         return auth_headers
 
+    @overload
     async def put_json(
         self,
         destination: str,
@@ -683,7 +734,41 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         backoff_on_404: bool = False,
         try_trailing_slash_on_400: bool = False,
+        parser: Literal[None] = None,
     ) -> Union[JsonDict, list]:
+        ...
+
+    @overload
+    async def put_json(
+        self,
+        destination: str,
+        path: str,
+        args: Optional[QueryArgs] = None,
+        data: Optional[JsonDict] = None,
+        json_data_callback: Optional[Callable[[], JsonDict]] = None,
+        long_retries: bool = False,
+        timeout: Optional[int] = None,
+        ignore_backoff: bool = False,
+        backoff_on_404: bool = False,
+        try_trailing_slash_on_400: bool = False,
+        parser: Optional[ByteParser[T]] = None,
+    ) -> T:
+        ...
+
+    async def put_json(
+        self,
+        destination: str,
+        path: str,
+        args: Optional[QueryArgs] = None,
+        data: Optional[JsonDict] = None,
+        json_data_callback: Optional[Callable[[], JsonDict]] = None,
+        long_retries: bool = False,
+        timeout: Optional[int] = None,
+        ignore_backoff: bool = False,
+        backoff_on_404: bool = False,
+        try_trailing_slash_on_400: bool = False,
+        parser: Optional[ByteParser] = None,
+    ):
         """Sends the specified json data using PUT
 
         Args:
@@ -716,6 +801,8 @@ class MatrixFederationHttpClient:
                 of the request. Workaround for #3622 in Synapse <= v0.99.3. This
                 will be attempted before backing off if backing off has been
                 enabled.
+            parser: The parser to use to decode the response. Defaults to
+                parsing as JSON.
 
         Returns:
             Succeeds when we get a 2xx HTTP response. The
@@ -756,8 +843,16 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
-        body = await _handle_json_response(
-            self.reactor, _sec_timeout, request, response, start_ms
+        if parser is None:
+            parser = JsonParser()
+
+        body = await _handle_response(
+            self.reactor,
+            _sec_timeout,
+            request,
+            response,
+            start_ms,
+            parser=parser,
         )
 
         return body
@@ -830,12 +925,8 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
-        body = await _handle_json_response(
-            self.reactor,
-            _sec_timeout,
-            request,
-            response,
-            start_ms,
+        body = await _handle_response(
+            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
         )
         return body
 
@@ -907,8 +998,8 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
-        body = await _handle_json_response(
-            self.reactor, _sec_timeout, request, response, start_ms
+        body = await _handle_response(
+            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
         )
 
         return body
@@ -975,8 +1066,8 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
-        body = await _handle_json_response(
-            self.reactor, _sec_timeout, request, response, start_ms
+        body = await _handle_response(
+            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
         )
         return body
 
@@ -1068,16 +1159,16 @@ def _flatten_response_never_received(e):
         return repr(e)
 
 
-def check_content_type_is_json(headers: Headers) -> None:
+def check_content_type_is(headers: Headers, expected_content_type: str) -> None:
     """
     Check that a set of HTTP headers have a Content-Type header, and that it
-    is application/json.
+    is the expected value..
 
     Args:
         headers: headers to check
 
     Raises:
-        RequestSendFailed: if the Content-Type header is missing or isn't JSON
+        RequestSendFailed: if the Content-Type header is missing or doesn't match
 
     """
     content_type_headers = headers.getRawHeaders(b"Content-Type")
@@ -1089,11 +1180,10 @@ def check_content_type_is_json(headers: Headers) -> None:
 
     c_type = content_type_headers[0].decode("ascii")  # only the first header
     val, options = cgi.parse_header(c_type)
-    if val != "application/json":
+    if val != expected_content_type:
         raise RequestSendFailed(
             RuntimeError(
-                "Remote server sent Content-Type header of '%s', not 'application/json'"
-                % c_type,
+                f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'",
             ),
             can_retry=False,
         )