diff options
Diffstat (limited to 'synapse/http/matrixfederationclient.py')
-rw-r--r-- | synapse/http/matrixfederationclient.py | 160 |
1 files changed, 125 insertions, 35 deletions
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index bb837b7b19..f5503b394b 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc import cgi import codecs import logging @@ -19,13 +20,24 @@ import sys import typing import urllib.parse from io import BytesIO, StringIO -from typing import Callable, Dict, List, Optional, Tuple, Union +from typing import ( + Callable, + Dict, + Generic, + List, + Optional, + Tuple, + TypeVar, + Union, + overload, +) import attr import treq from canonicaljson import encode_canonical_json from prometheus_client import Counter from signedjson.sign import sign_json +from typing_extensions import Literal from twisted.internet import defer from twisted.internet.error import DNSLookupError @@ -48,6 +60,7 @@ from synapse.http.client import ( BlacklistingAgentWrapper, BlacklistingReactorWrapper, BodyExceededMaxSize, + ByteWriteable, encode_query_args, read_body_with_max_size, ) @@ -88,6 +101,27 @@ _next_id = 1 QueryArgs = Dict[str, Union[str, List[str]]] +T = TypeVar("T") + + +class ByteParser(ByteWriteable, Generic[T], abc.ABC): + """A `ByteWriteable` that has an additional `finish` function that returns + the parsed data. + """ + + CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore + """The expected content type of the response, e.g. `application/json`. If + the content type doesn't match we fail the request. + """ + + @abc.abstractmethod + def finish(self) -> T: + """Called when response has finished streaming and the parser should + return the final result (or error). + """ + pass + + @attr.s(slots=True, frozen=True) class MatrixFederationRequest: method = attr.ib(type=str) @@ -148,15 +182,32 @@ class MatrixFederationRequest: return self.json -async def _handle_json_response( +class JsonParser(ByteParser[Union[JsonDict, list]]): + """A parser that buffers the response and tries to parse it as JSON.""" + + CONTENT_TYPE = "application/json" + + def __init__(self): + self._buffer = StringIO() + self._binary_wrapper = BinaryIOWrapper(self._buffer) + + def write(self, data: bytes) -> int: + return self._binary_wrapper.write(data) + + def finish(self) -> Union[JsonDict, list]: + return json_decoder.decode(self._buffer.getvalue()) + + +async def _handle_response( reactor: IReactorTime, timeout_sec: float, request: MatrixFederationRequest, response: IResponse, start_ms: int, -) -> JsonDict: + parser: ByteParser[T], +) -> T: """ - Reads the JSON body of a response, with a timeout + Reads the body of a response with a timeout and sends it to a parser Args: reactor: twisted reactor, for the timeout @@ -164,23 +215,21 @@ async def _handle_json_response( request: the request that triggered the response response: response to the request start_ms: Timestamp when request was made + parser: The parser for the response Returns: - The parsed JSON response + The parsed response """ + try: - check_content_type_is_json(response.headers) + check_content_type_is(response.headers, parser.CONTENT_TYPE) - buf = StringIO() - d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE) + d = read_body_with_max_size(response, parser, MAX_RESPONSE_SIZE) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) - def parse(_len: int): - return json_decoder.decode(buf.getvalue()) - - d.addCallback(parse) + length = await make_deferred_yieldable(d) - body = await make_deferred_yieldable(d) + value = parser.finish() except BodyExceededMaxSize as e: # The response was too big. logger.warning( @@ -193,9 +242,9 @@ async def _handle_json_response( ) raise RequestSendFailed(e, can_retry=False) from e except ValueError as e: - # The JSON content was invalid. + # The content was invalid. logger.warning( - "{%s} [%s] Failed to parse JSON response - %s %s", + "{%s} [%s] Failed to parse response - %s %s", request.txn_id, request.destination, request.method, @@ -225,16 +274,17 @@ async def _handle_json_response( time_taken_secs = reactor.seconds() - start_ms / 1000 logger.info( - "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s", + "{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s", request.txn_id, request.destination, response.code, response.phrase.decode("ascii", errors="replace"), time_taken_secs, + length, request.method, request.uri.decode("ascii"), ) - return body + return value class BinaryIOWrapper: @@ -671,6 +721,7 @@ class MatrixFederationHttpClient: ) return auth_headers + @overload async def put_json( self, destination: str, @@ -683,7 +734,41 @@ class MatrixFederationHttpClient: ignore_backoff: bool = False, backoff_on_404: bool = False, try_trailing_slash_on_400: bool = False, + parser: Literal[None] = None, ) -> Union[JsonDict, list]: + ... + + @overload + async def put_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser[T]] = None, + ) -> T: + ... + + async def put_json( + self, + destination: str, + path: str, + args: Optional[QueryArgs] = None, + data: Optional[JsonDict] = None, + json_data_callback: Optional[Callable[[], JsonDict]] = None, + long_retries: bool = False, + timeout: Optional[int] = None, + ignore_backoff: bool = False, + backoff_on_404: bool = False, + try_trailing_slash_on_400: bool = False, + parser: Optional[ByteParser] = None, + ): """Sends the specified json data using PUT Args: @@ -716,6 +801,8 @@ class MatrixFederationHttpClient: of the request. Workaround for #3622 in Synapse <= v0.99.3. This will be attempted before backing off if backing off has been enabled. + parser: The parser to use to decode the response. Defaults to + parsing as JSON. Returns: Succeeds when we get a 2xx HTTP response. The @@ -756,8 +843,16 @@ class MatrixFederationHttpClient: else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + if parser is None: + parser = JsonParser() + + body = await _handle_response( + self.reactor, + _sec_timeout, + request, + response, + start_ms, + parser=parser, ) return body @@ -830,12 +925,8 @@ class MatrixFederationHttpClient: else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, - _sec_timeout, - request, - response, - start_ms, + body = await _handle_response( + self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() ) return body @@ -907,8 +998,8 @@ class MatrixFederationHttpClient: else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + body = await _handle_response( + self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() ) return body @@ -975,8 +1066,8 @@ class MatrixFederationHttpClient: else: _sec_timeout = self.default_timeout - body = await _handle_json_response( - self.reactor, _sec_timeout, request, response, start_ms + body = await _handle_response( + self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser() ) return body @@ -1068,16 +1159,16 @@ def _flatten_response_never_received(e): return repr(e) -def check_content_type_is_json(headers: Headers) -> None: +def check_content_type_is(headers: Headers, expected_content_type: str) -> None: """ Check that a set of HTTP headers have a Content-Type header, and that it - is application/json. + is the expected value.. Args: headers: headers to check Raises: - RequestSendFailed: if the Content-Type header is missing or isn't JSON + RequestSendFailed: if the Content-Type header is missing or doesn't match """ content_type_headers = headers.getRawHeaders(b"Content-Type") @@ -1089,11 +1180,10 @@ def check_content_type_is_json(headers: Headers) -> None: c_type = content_type_headers[0].decode("ascii") # only the first header val, options = cgi.parse_header(c_type) - if val != "application/json": + if val != expected_content_type: raise RequestSendFailed( RuntimeError( - "Remote server sent Content-Type header of '%s', not 'application/json'" - % c_type, + f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'", ), can_retry=False, ) |