diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3302d4e48a..634882487c 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,7 +17,6 @@ import codecs
import logging
import random
import sys
-import typing
import urllib.parse
from http import HTTPStatus
from io import BytesIO, StringIO
@@ -30,9 +29,11 @@ from typing import (
Generic,
List,
Optional,
+ TextIO,
Tuple,
TypeVar,
Union,
+ cast,
overload,
)
@@ -183,20 +184,61 @@ class MatrixFederationRequest:
return self.json
-class JsonParser(ByteParser[Union[JsonDict, list]]):
+class _BaseJsonParser(ByteParser[T]):
"""A parser that buffers the response and tries to parse it as JSON."""
CONTENT_TYPE = "application/json"
- def __init__(self) -> None:
+ def __init__(
+ self, validator: Optional[Callable[[Optional[object]], bool]] = None
+ ) -> None:
+ """
+ Args:
+ validator: A callable which takes the parsed JSON value and returns
+ true if the value is valid.
+ """
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)
+ self._validator = validator
def write(self, data: bytes) -> int:
return self._binary_wrapper.write(data)
- def finish(self) -> Union[JsonDict, list]:
- return json_decoder.decode(self._buffer.getvalue())
+ def finish(self) -> T:
+ result = json_decoder.decode(self._buffer.getvalue())
+ if self._validator is not None and not self._validator(result):
+ raise ValueError(
+ f"Received incorrect JSON value: {result.__class__.__name__}"
+ )
+ return result
+
+
+class JsonParser(_BaseJsonParser[JsonDict]):
+ """A parser that buffers the response and tries to parse it as a JSON object."""
+
+ def __init__(self) -> None:
+ super().__init__(self._validate)
+
+ @staticmethod
+ def _validate(v: Any) -> bool:
+ return isinstance(v, dict)
+
+
+class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
+ """Ensure the legacy responses of /send_join & /send_leave are correct."""
+
+ def __init__(self) -> None:
+ super().__init__(self._validate)
+
+ @staticmethod
+ def _validate(v: Any) -> bool:
+ # Match [integer, JSON dict]
+ return (
+ isinstance(v, list)
+ and len(v) == 2
+ and type(v[0]) == int
+ and isinstance(v[1], dict)
+ )
async def _handle_response(
@@ -313,9 +355,7 @@ async def _handle_response(
class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly."""
- def __init__(
- self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
- ):
+ def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file
@@ -793,7 +833,7 @@ class MatrixFederationHttpClient:
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
- ) -> Union[JsonDict, list]:
+ ) -> JsonDict:
...
@overload
@@ -825,8 +865,8 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
- parser: Optional[ByteParser] = None,
- ):
+ parser: Optional[ByteParser[T]] = None,
+ ) -> Union[JsonDict, T]:
"""Sends the specified json data using PUT
Args:
@@ -902,7 +942,7 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
if parser is None:
- parser = JsonParser()
+ parser = cast(ByteParser[T], JsonParser())
body = await _handle_response(
self.reactor,
@@ -924,7 +964,7 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryParams] = None,
- ) -> Union[JsonDict, list]:
+ ) -> JsonDict:
"""Sends the specified json data using POST
Args:
@@ -998,7 +1038,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
- ) -> Union[JsonDict, list]:
+ ) -> JsonDict:
...
@overload
@@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
- parser: Optional[ByteParser] = None,
- ):
+ parser: Optional[ByteParser[T]] = None,
+ ) -> Union[JsonDict, T]:
"""GETs some json from the given host homeserver and path
Args:
@@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
if parser is None:
- parser = JsonParser()
+ parser = cast(ByteParser[T], JsonParser())
body = await _handle_response(
self.reactor,
@@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryParams] = None,
- ) -> Union[JsonDict, list]:
+ ) -> JsonDict:
"""Send a DELETE request to the remote expecting some json response
Args:
|