summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15465.misc1
-rw-r--r--mypy.ini6
-rw-r--r--synapse/federation/federation_client.py8
-rw-r--r--synapse/federation/transport/client.py17
-rw-r--r--synapse/http/matrixfederationclient.py76
-rw-r--r--tests/federation/test_complexity.py10
-rw-r--r--tests/http/test_matrixfederationclient.py6
7 files changed, 82 insertions, 42 deletions
diff --git a/changelog.d/15465.misc b/changelog.d/15465.misc
new file mode 100644
index 0000000000..93ceaeafc9
--- /dev/null
+++ b/changelog.d/15465.misc
@@ -0,0 +1 @@
+Improve type hints.
diff --git a/mypy.ini b/mypy.ini
index 945f7925cb..8fb87b9b74 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -33,12 +33,6 @@ exclude = (?x)
    |synapse/storage/schema/
    )$
 
-[mypy-synapse.federation.transport.client]
-disallow_untyped_defs = False
-
-[mypy-synapse.http.matrixfederationclient]
-disallow_untyped_defs = False
-
 [mypy-synapse.metrics._reactor_metrics]
 disallow_untyped_defs = False
 # This module imports select.epoll. That exists on Linux, but doesn't on macOS.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 4cf4957a42..ba34573d46 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -280,15 +280,11 @@ class FederationClient(FederationBase):
         logger.debug("backfill transaction_data=%r", transaction_data)
 
         if not isinstance(transaction_data, dict):
-            # TODO we probably want an exception type specific to federation
-            # client validation.
-            raise TypeError("Backfill transaction_data is not a dict.")
+            raise InvalidResponseError("Backfill transaction_data is not a dict.")
 
         transaction_data_pdus = transaction_data.get("pdus")
         if not isinstance(transaction_data_pdus, list):
-            # TODO we probably want an exception type specific to federation
-            # client validation.
-            raise TypeError("transaction_data.pdus is not a list.")
+            raise InvalidResponseError("transaction_data.pdus is not a list.")
 
         room_version = await self.store.get_room_version(room_id)
 
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index c05d598b70..bedbd23ded 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -16,6 +16,7 @@
 import logging
 import urllib
 from typing import (
+    TYPE_CHECKING,
     Any,
     Callable,
     Collection,
@@ -42,18 +43,21 @@ from synapse.api.urls import (
 )
 from synapse.events import EventBase, make_event_from_dict
 from synapse.federation.units import Transaction
-from synapse.http.matrixfederationclient import ByteParser
+from synapse.http.matrixfederationclient import ByteParser, LegacyJsonSendParser
 from synapse.http.types import QueryParams
 from synapse.types import JsonDict
 from synapse.util import ExceptionBundle
 
+if TYPE_CHECKING:
+    from synapse.app.homeserver import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class TransportLayerClient:
     """Sends federation HTTP requests to other servers"""
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         self.server_name = hs.hostname
         self.client = hs.get_federation_http_client()
         self._faster_joins_enabled = hs.config.experimental.faster_joins_enabled
@@ -133,7 +137,7 @@ class TransportLayerClient:
 
     async def backfill(
         self, destination: str, room_id: str, event_tuples: Collection[str], limit: int
-    ) -> Optional[JsonDict]:
+    ) -> Optional[Union[JsonDict, list]]:
         """Requests `limit` previous PDUs in a given context before list of
         PDUs.
 
@@ -388,6 +392,7 @@ class TransportLayerClient:
             # server was just having a momentary blip, the room will be out of
             # sync.
             ignore_backoff=True,
+            parser=LegacyJsonSendParser(),
         )
 
     async def send_leave_v2(
@@ -445,7 +450,11 @@ class TransportLayerClient:
         path = _create_v1_path("/invite/%s/%s", room_id, event_id)
 
         return await self.client.put_json(
-            destination=destination, path=path, data=content, ignore_backoff=True
+            destination=destination,
+            path=path,
+            data=content,
+            ignore_backoff=True,
+            parser=LegacyJsonSendParser(),
         )
 
     async def send_invite_v2(
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index 3302d4e48a..634882487c 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -17,7 +17,6 @@ import codecs
 import logging
 import random
 import sys
-import typing
 import urllib.parse
 from http import HTTPStatus
 from io import BytesIO, StringIO
@@ -30,9 +29,11 @@ from typing import (
     Generic,
     List,
     Optional,
+    TextIO,
     Tuple,
     TypeVar,
     Union,
+    cast,
     overload,
 )
 
@@ -183,20 +184,61 @@ class MatrixFederationRequest:
         return self.json
 
 
-class JsonParser(ByteParser[Union[JsonDict, list]]):
+class _BaseJsonParser(ByteParser[T]):
     """A parser that buffers the response and tries to parse it as JSON."""
 
     CONTENT_TYPE = "application/json"
 
-    def __init__(self) -> None:
+    def __init__(
+        self, validator: Optional[Callable[[Optional[object]], bool]] = None
+    ) -> None:
+        """
+        Args:
+            validator: A callable which takes the parsed JSON value and returns
+                true if the value is valid.
+        """
         self._buffer = StringIO()
         self._binary_wrapper = BinaryIOWrapper(self._buffer)
+        self._validator = validator
 
     def write(self, data: bytes) -> int:
         return self._binary_wrapper.write(data)
 
-    def finish(self) -> Union[JsonDict, list]:
-        return json_decoder.decode(self._buffer.getvalue())
+    def finish(self) -> T:
+        result = json_decoder.decode(self._buffer.getvalue())
+        if self._validator is not None and not self._validator(result):
+            raise ValueError(
+                f"Received incorrect JSON value: {result.__class__.__name__}"
+            )
+        return result
+
+
+class JsonParser(_BaseJsonParser[JsonDict]):
+    """A parser that buffers the response and tries to parse it as a JSON object."""
+
+    def __init__(self) -> None:
+        super().__init__(self._validate)
+
+    @staticmethod
+    def _validate(v: Any) -> bool:
+        return isinstance(v, dict)
+
+
+class LegacyJsonSendParser(_BaseJsonParser[Tuple[int, JsonDict]]):
+    """Ensure the legacy responses of /send_join & /send_leave are correct."""
+
+    def __init__(self) -> None:
+        super().__init__(self._validate)
+
+    @staticmethod
+    def _validate(v: Any) -> bool:
+        # Match [integer, JSON dict]
+        return (
+            isinstance(v, list)
+            and len(v) == 2
+            and type(v[0]) == int
+            and isinstance(v[1], dict)
+        )
 
 
 async def _handle_response(
@@ -313,9 +355,7 @@ async def _handle_response(
 class BinaryIOWrapper:
     """A wrapper for a TextIO which converts from bytes on the fly."""
 
-    def __init__(
-        self, file: typing.TextIO, encoding: str = "utf-8", errors: str = "strict"
-    ):
+    def __init__(self, file: TextIO, encoding: str = "utf-8", errors: str = "strict"):
         self.decoder = codecs.getincrementaldecoder(encoding)(errors)
         self.file = file
 
@@ -793,7 +833,7 @@ class MatrixFederationHttpClient:
         backoff_on_404: bool = False,
         try_trailing_slash_on_400: bool = False,
         parser: Literal[None] = None,
-    ) -> Union[JsonDict, list]:
+    ) -> JsonDict:
         ...
 
     @overload
@@ -825,8 +865,8 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         backoff_on_404: bool = False,
         try_trailing_slash_on_400: bool = False,
-        parser: Optional[ByteParser] = None,
-    ):
+        parser: Optional[ByteParser[T]] = None,
+    ) -> Union[JsonDict, T]:
         """Sends the specified json data using PUT
 
         Args:
@@ -902,7 +942,7 @@ class MatrixFederationHttpClient:
             _sec_timeout = self.default_timeout
 
         if parser is None:
-            parser = JsonParser()
+            parser = cast(ByteParser[T], JsonParser())
 
         body = await _handle_response(
             self.reactor,
@@ -924,7 +964,7 @@ class MatrixFederationHttpClient:
         timeout: Optional[int] = None,
         ignore_backoff: bool = False,
         args: Optional[QueryParams] = None,
-    ) -> Union[JsonDict, list]:
+    ) -> JsonDict:
         """Sends the specified json data using POST
 
         Args:
@@ -998,7 +1038,7 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         try_trailing_slash_on_400: bool = False,
         parser: Literal[None] = None,
-    ) -> Union[JsonDict, list]:
+    ) -> JsonDict:
         ...
 
     @overload
@@ -1024,8 +1064,8 @@ class MatrixFederationHttpClient:
         timeout: Optional[int] = None,
         ignore_backoff: bool = False,
         try_trailing_slash_on_400: bool = False,
-        parser: Optional[ByteParser] = None,
-    ):
+        parser: Optional[ByteParser[T]] = None,
+    ) -> Union[JsonDict, T]:
         """GETs some json from the given host homeserver and path
 
         Args:
@@ -1091,7 +1131,7 @@ class MatrixFederationHttpClient:
             _sec_timeout = self.default_timeout
 
         if parser is None:
-            parser = JsonParser()
+            parser = cast(ByteParser[T], JsonParser())
 
         body = await _handle_response(
             self.reactor,
@@ -1112,7 +1152,7 @@ class MatrixFederationHttpClient:
         timeout: Optional[int] = None,
         ignore_backoff: bool = False,
         args: Optional[QueryParams] = None,
-    ) -> Union[JsonDict, list]:
+    ) -> JsonDict:
         """Send a DELETE request to the remote expecting some json response
 
         Args:
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 33af8770fd..129d7cfd93 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -75,7 +75,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
         handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(("", 1))
         )
@@ -106,7 +106,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
         handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(("", 1))
         )
@@ -143,7 +143,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable(None))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
         handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(("", 1))
         )
@@ -200,7 +200,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
         handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(("", 1))
         )
@@ -230,7 +230,7 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))
+        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
         handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
             return_value=make_awaitable(("", 1))
         )
diff --git a/tests/http/test_matrixfederationclient.py b/tests/http/test_matrixfederationclient.py
index fdd22a8e94..d89a91c59d 100644
--- a/tests/http/test_matrixfederationclient.py
+++ b/tests/http/test_matrixfederationclient.py
@@ -26,7 +26,7 @@ from twisted.web.http import HTTPChannel
 
 from synapse.api.errors import RequestSendFailed
 from synapse.http.matrixfederationclient import (
-    JsonParser,
+    ByteParser,
     MatrixFederationHttpClient,
     MatrixFederationRequest,
 )
@@ -618,9 +618,9 @@ class FederationClientTests(HomeserverTestCase):
         while not test_d.called:
             protocol.dataReceived(b"a" * chunk_size)
             sent += chunk_size
-            self.assertLessEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
+            self.assertLessEqual(sent, ByteParser.MAX_RESPONSE_SIZE)
 
-        self.assertEqual(sent, JsonParser.MAX_RESPONSE_SIZE)
+        self.assertEqual(sent, ByteParser.MAX_RESPONSE_SIZE)
 
         f = self.failureResultOf(test_d)
         self.assertIsInstance(f.value, RequestSendFailed)