summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-05-20 16:11:48 +0100
committerGitHub <noreply@github.com>2021-05-20 16:11:48 +0100
commit64887f06fcac63e069364d625d984b4951bf1ffc (patch)
treeeabd30c1ad56d057667271573d3bfc96276d11bd
parentAllow a user who could join a restricted room to see it in spaces summary. (#... (diff)
downloadsynapse-64887f06fcac63e069364d625d984b4951bf1ffc.tar.xz
Use ijson to parse the response to `/send_join`, reducing memory usage. (#9958)
Instead of parsing the full response to `/send_join` into Python objects (which can be huge for large rooms) and *then* parsing that into events, we instead use ijson to stream parse the response directly into `EventBase` objects.
-rw-r--r--changelog.d/9958.feature1
-rw-r--r--mypy.ini3
-rw-r--r--synapse/federation/federation_client.py28
-rw-r--r--synapse/federation/transport/client.py85
-rw-r--r--synapse/http/client.py7
-rw-r--r--synapse/http/matrixfederationclient.py160
-rw-r--r--synapse/python_dependencies.py1
7 files changed, 227 insertions, 58 deletions
diff --git a/changelog.d/9958.feature b/changelog.d/9958.feature
new file mode 100644
index 0000000000..d86ba36519
--- /dev/null
+++ b/changelog.d/9958.feature
@@ -0,0 +1 @@
+Reduce memory usage when joining very large rooms over federation.
diff --git a/mypy.ini b/mypy.ini
index ea655a0d4d..1d1d1ea0f2 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -174,3 +174,6 @@ ignore_missing_imports = True
 
 [mypy-pympler.*]
 ignore_missing_imports = True
+
+[mypy-ijson.*]
+ignore_missing_imports = True
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index a5b6a61195..e0e9f5d0be 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -55,6 +55,7 @@ from synapse.api.room_versions import (
 )
 from synapse.events import EventBase, builder
 from synapse.federation.federation_base import FederationBase, event_from_pdu_json
+from synapse.federation.transport.client import SendJoinResponse
 from synapse.logging.context import make_deferred_yieldable, preserve_fn
 from synapse.logging.utils import log_function
 from synapse.types import JsonDict, get_domain_from_id
@@ -665,19 +666,10 @@ class FederationClient(FederationBase):
         """
 
         async def send_request(destination) -> Dict[str, Any]:
-            content = await self._do_send_join(destination, pdu)
+            response = await self._do_send_join(room_version, destination, pdu)
 
-            logger.debug("Got content: %s", content)
-
-            state = [
-                event_from_pdu_json(p, room_version, outlier=True)
-                for p in content.get("state", [])
-            ]
-
-            auth_chain = [
-                event_from_pdu_json(p, room_version, outlier=True)
-                for p in content.get("auth_chain", [])
-            ]
+            state = response.state
+            auth_chain = response.auth_events
 
             pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
 
@@ -752,11 +744,14 @@ class FederationClient(FederationBase):
 
         return await self._try_destination_list("send_join", destinations, send_request)
 
-    async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
+    async def _do_send_join(
+        self, room_version: RoomVersion, destination: str, pdu: EventBase
+    ) -> SendJoinResponse:
         time_now = self._clock.time_msec()
 
         try:
             return await self.transport_layer.send_join_v2(
+                room_version=room_version,
                 destination=destination,
                 room_id=pdu.room_id,
                 event_id=pdu.event_id,
@@ -771,17 +766,14 @@ class FederationClient(FederationBase):
 
         logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
 
-        resp = await self.transport_layer.send_join_v1(
+        return await self.transport_layer.send_join_v1(
+            room_version=room_version,
             destination=destination,
             room_id=pdu.room_id,
             event_id=pdu.event_id,
             content=pdu.get_pdu_json(time_now),
         )
 
-        # We expect the v1 API to respond with [200, content], so we only return the
-        # content.
-        return resp[1]
-
     async def send_invite(
         self,
         destination: str,
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 497848a2b7..e93ab83f7f 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -17,13 +17,19 @@ import logging
 import urllib
 from typing import Any, Dict, List, Optional
 
+import attr
+import ijson
+
 from synapse.api.constants import Membership
 from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.api.room_versions import RoomVersion
 from synapse.api.urls import (
     FEDERATION_UNSTABLE_PREFIX,
     FEDERATION_V1_PREFIX,
     FEDERATION_V2_PREFIX,
 )
+from synapse.events import EventBase, make_event_from_dict
+from synapse.http.matrixfederationclient import ByteParser
 from synapse.logging.utils import log_function
 from synapse.types import JsonDict
 
@@ -240,21 +246,36 @@ class TransportLayerClient:
         return content
 
     @log_function
-    async def send_join_v1(self, destination, room_id, event_id, content):
+    async def send_join_v1(
+        self,
+        room_version,
+        destination,
+        room_id,
+        event_id,
+        content,
+    ) -> "SendJoinResponse":
         path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
 
         response = await self.client.put_json(
-            destination=destination, path=path, data=content
+            destination=destination,
+            path=path,
+            data=content,
+            parser=SendJoinParser(room_version, v1_api=True),
         )
 
         return response
 
     @log_function
-    async def send_join_v2(self, destination, room_id, event_id, content):
+    async def send_join_v2(
+        self, room_version, destination, room_id, event_id, content
+    ) -> "SendJoinResponse":
         path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
 
         response = await self.client.put_json(
-            destination=destination, path=path, data=content
+            destination=destination,
+            path=path,
+            data=content,
+            parser=SendJoinParser(room_version, v1_api=False),
         )
 
         return response
@@ -1053,3 +1074,59 @@ def _create_v2_path(path, *args):
         str
     """
     return _create_path(FEDERATION_V2_PREFIX, path, *args)
+
+
+@attr.s(slots=True, auto_attribs=True)
+class SendJoinResponse:
+    """The parsed response of a `/send_join` request."""
+
+    auth_events: List[EventBase]
+    state: List[EventBase]
+
+
+@ijson.coroutine
+def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
+    """Helper function for use with `ijson.items_coro` to parse an array of
+    events and add them to the given list.
+    """
+
+    while True:
+        obj = yield
+        event = make_event_from_dict(obj, room_version)
+        events.append(event)
+
+
+class SendJoinParser(ByteParser[SendJoinResponse]):
+    """A parser for the response to `/send_join` requests.
+
+    Args:
+        room_version: The version of the room.
+        v1_api: Whether the response is in the v1 format.
+    """
+
+    CONTENT_TYPE = "application/json"
+
+    def __init__(self, room_version: RoomVersion, v1_api: bool):
+        self._response = SendJoinResponse([], [])
+
+        # The V1 API has the shape of `[200, {...}]`, which we handle by
+        # prefixing with `item.*`.
+        prefix = "item." if v1_api else ""
+
+        self._coro_state = ijson.items_coro(
+            _event_list_parser(room_version, self._response.state),
+            prefix + "state.item",
+        )
+        self._coro_auth = ijson.items_coro(
+            _event_list_parser(room_version, self._response.auth_events),
+            prefix + "auth_chain.item",
+        )
+
+    def write(self, data: bytes) -> int:
+        self._coro_state.send(data)
+        self._coro_auth.send(data)
+
+        return len(data)
+
+    def finish(self) -> SendJoinResponse:
+        return self._response
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 5f40f16e24..1ca6624fd5 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -813,7 +813,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
         if self.deferred.called:
             return
 
-        self.stream.write(data)
+        try:
+            self.stream.write(data)
+        except Exception:
+            self.deferred.errback()
+            return
+
         self.length += len(data)
         # The first time the maximum size is exceeded, error and cancel the
         # connection. dataReceived might be called again if data was received
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index bb837b7b19..f5503b394b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -11,6 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
 import cgi
 import codecs
 import logging
@@ -19,13 +20,24 @@ import sys
 import typing
 import urllib.parse
 from io import BytesIO, StringIO
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import (
+    Callable,
+    Dict,
+    Generic,
+    List,
+    Optional,
+    Tuple,
+    TypeVar,
+    Union,
+    overload,
+)
 
 import attr
 import treq
 from canonicaljson import encode_canonical_json
 from prometheus_client import Counter
 from signedjson.sign import sign_json
+from typing_extensions import Literal
 
 from twisted.internet import defer
 from twisted.internet.error import DNSLookupError
@@ -48,6 +60,7 @@ from synapse.http.client import (
     BlacklistingAgentWrapper,
     BlacklistingReactorWrapper,
     BodyExceededMaxSize,
+    ByteWriteable,
     encode_query_args,
     read_body_with_max_size,
 )
@@ -88,6 +101,27 @@ _next_id = 1
 QueryArgs = Dict[str, Union[str, List[str]]]
 
 
+T = TypeVar("T")
+
+
+class ByteParser(ByteWriteable, Generic[T], abc.ABC):
+    """A `ByteWriteable` that has an additional `finish` function that returns
+    the parsed data.
+    """
+
+    CONTENT_TYPE = abc.abstractproperty()  # type: str  # type: ignore
+    """The expected content type of the response, e.g. `application/json`. If
+    the content type doesn't match we fail the request.
+    """
+
+    @abc.abstractmethod
+    def finish(self) -> T:
+        """Called when response has finished streaming and the parser should
+        return the final result (or error).
+        """
+        pass
+
+
 @attr.s(slots=True, frozen=True)
 class MatrixFederationRequest:
     method = attr.ib(type=str)
@@ -148,15 +182,32 @@ class MatrixFederationRequest:
         return self.json
 
 
-async def _handle_json_response(
+class JsonParser(ByteParser[Union[JsonDict, list]]):
+    """A parser that buffers the response and tries to parse it as JSON."""
+
+    CONTENT_TYPE = "application/json"
+
+    def __init__(self):
+        self._buffer = StringIO()
+        self._binary_wrapper = BinaryIOWrapper(self._buffer)
+
+    def write(self, data: bytes) -> int:
+        return self._binary_wrapper.write(data)
+
+    def finish(self) -> Union[JsonDict, list]:
+        return json_decoder.decode(self._buffer.getvalue())
+
+
+async def _handle_response(
     reactor: IReactorTime,
     timeout_sec: float,
     request: MatrixFederationRequest,
     response: IResponse,
     start_ms: int,
-) -> JsonDict:
+    parser: ByteParser[T],
+) -> T:
     """
-    Reads the JSON body of a response, with a timeout
+    Reads the body of a response with a timeout and sends it to a parser
 
     Args:
         reactor: twisted reactor, for the timeout
@@ -164,23 +215,21 @@ async def _handle_json_response(
         request: the request that triggered the response
         response: response to the request
         start_ms: Timestamp when request was made
+        parser: The parser for the response
 
     Returns:
-        The parsed JSON response
+        The parsed response
     """
+
     try:
-        check_content_type_is_json(response.headers)
+        check_content_type_is(response.headers, parser.CONTENT_TYPE)
 
-        buf = StringIO()
-        d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
+        d = read_body_with_max_size(response, parser, MAX_RESPONSE_SIZE)
         d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
 
-        def parse(_len: int):
-            return json_decoder.decode(buf.getvalue())
-
-        d.addCallback(parse)
+        length = await make_deferred_yieldable(d)
 
-        body = await make_deferred_yieldable(d)
+        value = parser.finish()
     except BodyExceededMaxSize as e:
         # The response was too big.
         logger.warning(
@@ -193,9 +242,9 @@ async def _handle_json_response(
         )
         raise RequestSendFailed(e, can_retry=False) from e
     except ValueError as e:
-        # The JSON content was invalid.
+        # The content was invalid.
         logger.warning(
-            "{%s} [%s] Failed to parse JSON response - %s %s",
+            "{%s} [%s] Failed to parse response - %s %s",
             request.txn_id,
             request.destination,
             request.method,
@@ -225,16 +274,17 @@ async def _handle_json_response(
     time_taken_secs = reactor.seconds() - start_ms / 1000
 
     logger.info(
-        "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
+        "{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s",
         request.txn_id,
         request.destination,
         response.code,
         response.phrase.decode("ascii", errors="replace"),
         time_taken_secs,
+        length,
         request.method,
         request.uri.decode("ascii"),
     )
-    return body
+    return value
 
 
 class BinaryIOWrapper:
@@ -671,6 +721,7 @@ class MatrixFederationHttpClient:
             )
         return auth_headers
 
+    @overload
     async def put_json(
         self,
         destination: str,
@@ -683,7 +734,41 @@ class MatrixFederationHttpClient:
         ignore_backoff: bool = False,
         backoff_on_404: bool = False,
         try_trailing_slash_on_400: bool = False,
+        parser: Literal[None] = None,
     ) -> Union[JsonDict, list]:
+        ...
+
+    @overload
+    async def put_json(
+        self,
+        destination: str,
+        path: str,
+        args: Optional[QueryArgs] = None,
+        data: Optional[JsonDict] = None,
+        json_data_callback: Optional[Callable[[], JsonDict]] = None,
+        long_retries: bool = False,
+        timeout: Optional[int] = None,
+        ignore_backoff: bool = False,
+        backoff_on_404: bool = False,
+        try_trailing_slash_on_400: bool = False,
+        parser: Optional[ByteParser[T]] = None,
+    ) -> T:
+        ...
+
+    async def put_json(
+        self,
+        destination: str,
+        path: str,
+        args: Optional[QueryArgs] = None,
+        data: Optional[JsonDict] = None,
+        json_data_callback: Optional[Callable[[], JsonDict]] = None,
+        long_retries: bool = False,
+        timeout: Optional[int] = None,
+        ignore_backoff: bool = False,
+        backoff_on_404: bool = False,
+        try_trailing_slash_on_400: bool = False,
+        parser: Optional[ByteParser] = None,
+    ):
         """Sends the specified json data using PUT
 
         Args:
@@ -716,6 +801,8 @@ class MatrixFederationHttpClient:
                 of the request. Workaround for #3622 in Synapse <= v0.99.3. This
                 will be attempted before backing off if backing off has been
                 enabled.
+            parser: The parser to use to decode the response. Defaults to
+                parsing as JSON.
 
         Returns:
             Succeeds when we get a 2xx HTTP response. The
@@ -756,8 +843,16 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
-        body = await _handle_json_response(
-            self.reactor, _sec_timeout, request, response, start_ms
+        if parser is None:
+            parser = JsonParser()
+
+        body = await _handle_response(
+            self.reactor,
+            _sec_timeout,
+            request,
+            response,
+            start_ms,
+            parser=parser,
         )
 
         return body
@@ -830,12 +925,8 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
-        body = await _handle_json_response(
-            self.reactor,
-            _sec_timeout,
-            request,
-            response,
-            start_ms,
+        body = await _handle_response(
+            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
         )
         return body
 
@@ -907,8 +998,8 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
-        body = await _handle_json_response(
-            self.reactor, _sec_timeout, request, response, start_ms
+        body = await _handle_response(
+            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
         )
 
         return body
@@ -975,8 +1066,8 @@ class MatrixFederationHttpClient:
         else:
             _sec_timeout = self.default_timeout
 
-        body = await _handle_json_response(
-            self.reactor, _sec_timeout, request, response, start_ms
+        body = await _handle_response(
+            self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
         )
         return body
 
@@ -1068,16 +1159,16 @@ def _flatten_response_never_received(e):
         return repr(e)
 
 
-def check_content_type_is_json(headers: Headers) -> None:
+def check_content_type_is(headers: Headers, expected_content_type: str) -> None:
     """
     Check that a set of HTTP headers have a Content-Type header, and that it
-    is application/json.
+    is the expected value..
 
     Args:
         headers: headers to check
 
     Raises:
-        RequestSendFailed: if the Content-Type header is missing or isn't JSON
+        RequestSendFailed: if the Content-Type header is missing or doesn't match
 
     """
     content_type_headers = headers.getRawHeaders(b"Content-Type")
@@ -1089,11 +1180,10 @@ def check_content_type_is_json(headers: Headers) -> None:
 
     c_type = content_type_headers[0].decode("ascii")  # only the first header
     val, options = cgi.parse_header(c_type)
-    if val != "application/json":
+    if val != expected_content_type:
         raise RequestSendFailed(
             RuntimeError(
-                "Remote server sent Content-Type header of '%s', not 'application/json'"
-                % c_type,
+                f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'",
             ),
             can_retry=False,
         )
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 989523c823..546231bec0 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -87,6 +87,7 @@ REQUIREMENTS = [
     # We enforce that we have a `cryptography` version that bundles an `openssl`
     # with the latest security patches.
     "cryptography>=3.4.7",
+    "ijson>=3.0",
 ]
 
 CONDITIONAL_REQUIREMENTS = {