summary refs log tree commit diff
path: root/synapse/http/matrixfederationclient.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/http/matrixfederationclient.py')
-rw-r--r--synapse/http/matrixfederationclient.py76
1 files changed, 58 insertions, 18 deletions
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3302d4e48a..634882487c 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,7 +17,6 @@ import codecs
 import logging
 import random
 import sys
-import typing
 import urllib.parse
 from http import HTTPStatus
 from io import BytesIO, StringIO
@@ -30,9 +29,11 @@ from typing import (
     Generic,
     List,
     Optional,
+    TextIO,
     Tuple,
     TypeVar,
     Union,
+    cast,
     overload,
 )
 
@@ -183,20 +184,61 @@ class MatrixFederationRequest:
         return self.json
 
 
-class JsonParser(ByteParser[Union[JsonDict, list]]):
+class _BaseJsonParser(ByteParser[T]):
     """A parser that buffers the response and tries to parse it as JSON."""
 
     CONTENT_TYPE = "application/json"
 
-    def __init__(self) -> None:
+    def __init__(
+        self, validator: Optional[Callable[[Optional[object]], bool]] = None
+    ) -> None:
+        """
+        Args:
+            validator: A callable which takes the parsed JSON value and returns
+                true if the value is valid.
+        """
         self._buffer = StringIO()
         self._binary_wrapper = BinaryIOWrapper(self._buffer)
+        self._validator = validator
 
     def write(self, data: bytes) -> int:
         return self._binary_wrapper.write(data)
 
-    def finish(self) -> Union[JsonDict, list]:
-        return json_decoder.decode(self._buffer.getvalue())
+    def finish(self) -> T:
+        result = json_decoder.decode(self._buffer.getvalue())
+        if self._validator is not None and not self._validator(result):
+            raise ValueError(
+                f"Received incorrect JSON value: {result.__class__.__name__}"
+            )
+        return result
+
+
+class JsonParser(_BaseJsonParser[JsonDict]):
+    """A parser that buffers the response and tries to parse it as a JSON object."""
+
+    def __init__(self) -> None:
+        super().__init__(self._validate)
+
+    @staticmethod
+    def _validate(v: Any) -> bool:
+        return isinstance(v, dict)
+
+
+class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
+    """Ensure the legacy responses of /send_join & /send_leave are correct."""
+
+    def __init__(self) -> None:
+        super().__init__(self._validate)
+
+    @staticmethod
+    def _validate(v: Any) -> bool:
+        # Match [integer, JSON dict]
+        return (
+            isinstance(v, list)
+            and len(v) == 2
+            and type(v[0]) == int
+            and isinstance(v[1], dict)
+        )
 
 
 async def _handle_response(
@@ -313,9 +355,7 @@ async def _handle_response(
 class BinaryIOWrapper:
     """A wrapper for a TextIO which converts from bytes on the fly."""
 
-    def __init__(
-        self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
-    ):
+    def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
         self.decoder = codecs.getincrementaldecoder(encoding)(errors)
         self.file = file
 
@@ -793,7 +833,7 @@ class MatrixFederationHttpClient:
         backoff_on_404: bool = False,
         try_trailing_slash_on_400: bool = False,
         parser: Literal[None] = None,
-    ) -> Union[JsonDict, list]:
+    ) -> JsonDict:
         ...
 
     @overload
@@ -825,8 +865,8 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         backoff_on_404: bool = False,
         try_trailing_slash_on_400: bool = False,
-        parser: Optional[ByteParser] = None,
-    ):
+        parser: Optional[ByteParser[T]] = None,
+    ) -> Union[JsonDict, T]:
         """Sends the specified json data using PUT
 
         Args:
@@ -902,7 +942,7 @@ class MatrixFederationHttpClient:
             _sec_timeout = self.default_timeout
 
         if parser is None:
-            parser = JsonParser()
+            parser = cast(ByteParser[T], JsonParser())
 
         body = await _handle_response(
             self.reactor,
@@ -924,7 +964,7 @@ class MatrixFederationHttpClient:
         timeout: Optional[int] = None,
         ignore_backoff: bool = False,
         args: Optional[QueryParams] = None,
-    ) -> Union[JsonDict, list]:
+    ) -> JsonDict:
         """Sends the specified json data using POST
 
         Args:
@@ -998,7 +1038,7 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         try_trailing_slash_on_400: bool = False,
         parser: Literal[None] = None,
-    ) -> Union[JsonDict, list]:
+    ) -> JsonDict:
         ...
 
     @overload
@@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient:
         timeout: Optional[int] = None,
         ignore_backoff: bool = False,
         try_trailing_slash_on_400: bool = False,
-        parser: Optional[ByteParser] = None,
-    ):
+        parser: Optional[ByteParser[T]] = None,
+    ) -> Union[JsonDict, T]:
         """GETs some json from the given host homeserver and path
 
         Args:
@@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient:
             _sec_timeout = self.default_timeout
 
         if parser is None:
-            parser = JsonParser()
+            parser = cast(ByteParser[T], JsonParser())
 
         body = await _handle_response(
             self.reactor,
@@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient:
         timeout: Optional[int] = None,
         ignore_backoff: bool = False,
         args: Optional[QueryParams] = None,
-    ) -> Union[JsonDict, list]:
+    ) -> JsonDict:
         """Send a DELETE request to the remote expecting some json response
 
         Args: