diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 4cf4957a42..ba34573d46 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -280,15 +280,11 @@ class FederationClient(FederationBase):
logger.debug("backfill transaction_data=%r", transaction_data)
if not isinstance(transaction_data, dict):
- # TODO we probably want an exception type specific to federation
- # client validation.
- raise TypeError("Backfill transaction_data is not a dict.")
+ raise InvalidResponseError("Backfill transaction_data is not a dict.")
transaction_data_pdus = transaction_data.get("pdus")
if not isinstance(transaction_data_pdus, list):
- # TODO we probably want an exception type specific to federation
- # client validation.
- raise TypeError("transaction_data.pdus is not a list.")
+ raise InvalidResponseError("transaction_data.pdus is not a list.")
room_version = await self.store.get_room_version(room_id)
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index c05d598b70..bedbd23ded 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -16,6 +16,7 @@
import logging
import urllib
from typing import (
+ TYPE_CHECKING,
Any,
Callable,
Collection,
@@ -42,18 +43,21 @@ from synapse.api.urls import (
)
from synapse.events import EventBase, make_event_from_dict
from synapse.federation.units import Transaction
-from synapse.http.matrixfederationclient import ByteParser
+from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
from synapse.http.types import QueryParams
from synapse.types import JsonDict
from synapse.util import ExceptionBundle
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
class TransportLayerClient:
"""Sends federation HTTP requests to other servers"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.server_name = hs.hostname
self.client = hs.get_federation_http_client()
self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
@@ -133,7 +137,7 @@ class TransportLayerClient:
async def backfill(
self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
- ) -> Optional[JsonDict]:
+ ) -> Optional[Union[JsonDict, list]]:
"""Requests `limit` previous PDUs in a given context before list of
PDUs.
@@ -388,6 +392,7 @@ class TransportLayerClient:
# server was just having a momentary blip, the room will be out of
# sync.
ignore_backoff=True,
+ parser=LegacyJsonSendParser(),
)
async def send_leave_v2(
@@ -445,7 +450,11 @@ class TransportLayerClient:
path = _create_v1_path("/invite/%s/%s", room_id, event_id)
return await self.client.put_json(
- destination=destination, path=path, data=content, ignore_backoff=True
+ destination=destination,
+ path=path,
+ data=content,
+ ignore_backoff=True,
+ parser=LegacyJsonSendParser(),
)
async def send_invite_v2(
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3302d4e48a..634882487c 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,7 +17,6 @@ import codecs
import logging
import random
import sys
-import typing
import urllib.parse
from http import HTTPStatus
from io import BytesIO, StringIO
@@ -30,9 +29,11 @@ from typing import (
Generic,
List,
Optional,
+ TextIO,
Tuple,
TypeVar,
Union,
+ cast,
overload,
)
@@ -183,20 +184,61 @@ class MatrixFederationRequest:
return self.json
-class JsonParser(ByteParser[Union[JsonDict, list]]):
+class _BaseJsonParser(ByteParser[T]):
"""A parser that buffers the response and tries to parse it as JSON."""
CONTENT_TYPE = "application/json"
- def __init__(self) -> None:
+ def __init__(
+ self, validator: Optional[Callable[[Optional[object]], bool]] = None
+ ) -> None:
+ """
+ Args:
+ validator: A callable which takes the parsed JSON value and returns
+ true if the value is valid.
+ """
self._buffer = StringIO()
self._binary_wrapper = BinaryIOWrapper(self._buffer)
+ self._validator = validator
def write(self, data: bytes) -> int:
return self._binary_wrapper.write(data)
- def finish(self) -> Union[JsonDict, list]:
- return json_decoder.decode(self._buffer.getvalue())
+ def finish(self) -> T:
+ result = json_decoder.decode(self._buffer.getvalue())
+ if self._validator is not None and not self._validator(result):
+ raise ValueError(
+ f"Received incorrect JSON value: {result.__class__.__name__}"
+ )
+ return result
+
+
+class JsonParser(_BaseJsonParser[JsonDict]):
+ """A parser that buffers the response and tries to parse it as a JSON object."""
+
+ def __init__(self) -> None:
+ super().__init__(self._validate)
+
+ @staticmethod
+ def _validate(v: Any) -> bool:
+ return isinstance(v, dict)
+
+
+class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
+ """Ensure the legacy responses of /send_join & /send_leave are correct."""
+
+ def __init__(self) -> None:
+ super().__init__(self._validate)
+
+ @staticmethod
+ def _validate(v: Any) -> bool:
+ # Match [integer, JSON dict]
+ return (
+ isinstance(v, list)
+ and len(v) == 2
+ and type(v[0]) == int
+ and isinstance(v[1], dict)
+ )
async def _handle_response(
@@ -313,9 +355,7 @@ async def _handle_response(
class BinaryIOWrapper:
"""A wrapper for a TextIO which converts from bytes on the fly."""
- def __init__(
- self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
- ):
+ def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
self.decoder = codecs.getincrementaldecoder(encoding)(errors)
self.file = file
@@ -793,7 +833,7 @@ class MatrixFederationHttpClient:
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
- ) -> Union[JsonDict, list]:
+ ) -> JsonDict:
...
@overload
@@ -825,8 +865,8 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
- parser: Optional[ByteParser] = None,
- ):
+ parser: Optional[ByteParser[T]] = None,
+ ) -> Union[JsonDict, T]:
"""Sends the specified json data using PUT
Args:
@@ -902,7 +942,7 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
if parser is None:
- parser = JsonParser()
+ parser = cast(ByteParser[T], JsonParser())
body = await _handle_response(
self.reactor,
@@ -924,7 +964,7 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryParams] = None,
- ) -> Union[JsonDict, list]:
+ ) -> JsonDict:
"""Sends the specified json data using POST
Args:
@@ -998,7 +1038,7 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
parser: Literal[None] = None,
- ) -> Union[JsonDict, list]:
+ ) -> JsonDict:
...
@overload
@@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
try_trailing_slash_on_400: bool = False,
- parser: Optional[ByteParser] = None,
- ):
+ parser: Optional[ByteParser[T]] = None,
+ ) -> Union[JsonDict, T]:
"""GETs some json from the given host homeserver and path
Args:
@@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient:
_sec_timeout = self.default_timeout
if parser is None:
- parser = JsonParser()
+ parser = cast(ByteParser[T], JsonParser())
body = await _handle_response(
self.reactor,
@@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient:
timeout: Optional[int] = None,
ignore_backoff: bool = False,
args: Optional[QueryParams] = None,
- ) -> Union[JsonDict, list]:
+ ) -> JsonDict:
"""Send a DELETE request to the remote expecting some json response
Args:
|