diff --git a/changelog.d/9958.feature b/changelog.d/9958.feature
new file mode 100644
index 0000000000..d86ba36519
--- /dev/null
+++ b/changelog.d/9958.feature
@@ -0,0 +1 @@
+Reduce memory usage when joining very large rooms over federation.
diff --git a/mypy.ini b/mypy.ini
index ea655a0d4d..1d1d1ea0f2 100644
--- a/mypy.ini
+++ b/mypy.ini
@@ -174,3 +174,6 @@ ignore_missing_imports = True
[mypy-pympler.*]
ignore_missing_imports = True
+
+[mypy-ijson.*]
+ignore_missing_imports = True
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index a5b6a61195..e0e9f5d0be 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -55,6 +55,7 @@ from synapse.api.room_versions import (
)
from synapse.events import EventBase, builder
from synapse.federation.federation_base import FederationBase, event_from_pdu_json
+from synapse.federation.transport.client import SendJoinResponse
from synapse.logging.context import make_deferred_yieldable, preserve_fn
from synapse.logging.utils import log_function
from synapse.types import JsonDict, get_domain_from_id
@@ -665,19 +666,10 @@ class FederationClient(FederationBase):
"""
async def send_request(destination) -> Dict[str, Any]:
- content = await self._do_send_join(destination, pdu)
+ response = await self._do_send_join(room_version, destination, pdu)
- logger.debug("Got content: %s", content)
-
- state = [
- event_from_pdu_json(p, room_version, outlier=True)
- for p in content.get("state", [])
- ]
-
- auth_chain = [
- event_from_pdu_json(p, room_version, outlier=True)
- for p in content.get("auth_chain", [])
- ]
+ state = response.state
+ auth_chain = response.auth_events
pdus = {p.event_id: p for p in itertools.chain(state, auth_chain)}
@@ -752,11 +744,14 @@ class FederationClient(FederationBase):
return await self._try_destination_list("send_join", destinations, send_request)
- async def _do_send_join(self, destination: str, pdu: EventBase) -> JsonDict:
+ async def _do_send_join(
+ self, room_version: RoomVersion, destination: str, pdu: EventBase
+ ) -> SendJoinResponse:
time_now = self._clock.time_msec()
try:
return await self.transport_layer.send_join_v2(
+ room_version=room_version,
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
@@ -771,17 +766,14 @@ class FederationClient(FederationBase):
logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API")
- resp = await self.transport_layer.send_join_v1(
+ return await self.transport_layer.send_join_v1(
+ room_version=room_version,
destination=destination,
room_id=pdu.room_id,
event_id=pdu.event_id,
content=pdu.get_pdu_json(time_now),
)
- # We expect the v1 API to respond with [200, content], so we only return the
- # content.
- return resp[1]
-
async def send_invite(
self,
destination: str,
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 497848a2b7..e93ab83f7f 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -17,13 +17,19 @@ import logging
import urllib
from typing import Any, Dict, List, Optional
+import attr
+import ijson
+
from synapse.api.constants import Membership
from synapse.api.errors import Codes, HttpResponseException, SynapseError
+from synapse.api.room_versions import RoomVersion
from synapse.api.urls import (
FEDERATION_UNSTABLE_PREFIX,
FEDERATION_V1_PREFIX,
FEDERATION_V2_PREFIX,
)
+from synapse.events import EventBase, make_event_from_dict
+from synapse.http.matrixfederationclient import ByteParser
from synapse.logging.utils import log_function
from synapse.types import JsonDict
@@ -240,21 +246,36 @@ class TransportLayerClient:
return content
@log_function
- async def send_join_v1(self, destination, room_id, event_id, content):
+ async def send_join_v1(
+ self,
+ room_version,
+ destination,
+ room_id,
+ event_id,
+ content,
+ ) -> "SendJoinResponse":
path = _create_v1_path("/send_join/%s/%s", room_id, event_id)
response = await self.client.put_json(
- destination=destination, path=path, data=content
+ destination=destination,
+ path=path,
+ data=content,
+ parser=SendJoinParser(room_version, v1_api=True),
)
return response
@log_function
- async def send_join_v2(self, destination, room_id, event_id, content):
+ async def send_join_v2(
+ self, room_version, destination, room_id, event_id, content
+ ) -> "SendJoinResponse":
path = _create_v2_path("/send_join/%s/%s", room_id, event_id)
response = await self.client.put_json(
- destination=destination, path=path, data=content
+ destination=destination,
+ path=path,
+ data=content,
+ parser=SendJoinParser(room_version, v1_api=False),
)
return response
@@ -1053,3 +1074,59 @@ def _create_v2_path(path, *args):
str
"""
return _create_path(FEDERATION_V2_PREFIX, path, *args)
+
+
+@attr.s(slots=True, auto_attribs=True)
+class SendJoinResponse:
+ """The parsed response of a `/send_join` request."""
+
+ auth_events: List[EventBase]
+ state: List[EventBase]
+
+
+@ijson.coroutine
+def _event_list_parser(room_version: RoomVersion, events: List[EventBase]):
+ """Helper function for use with `ijson.items_coro` to parse an array of
+ events and add them to the given list.
+ """
+
+ while True:
+ obj = yield
+ event = make_event_from_dict(obj, room_version)
+ events.append(event)
+
+
+class SendJoinParser(ByteParser[SendJoinResponse]):
+ """A parser for the response to `/send_join` requests.
+
+ Args:
+ room_version: The version of the room.
+ v1_api: Whether the response is in the v1 format.
+ """
+
+ CONTENT_TYPE = "application/json"
+
+ def __init__(self, room_version: RoomVersion, v1_api: bool):
+ self._response = SendJoinResponse([], [])
+
+ # The V1 API has the shape of `[200, {...}]`, which we handle by
+ # prefixing with `item.*`.
+ prefix = "item." if v1_api else ""
+
+ self._coro_state = ijson.items_coro(
+ _event_list_parser(room_version, self._response.state),
+ prefix + "state.item",
+ )
+ self._coro_auth = ijson.items_coro(
+ _event_list_parser(room_version, self._response.auth_events),
+ prefix + "auth_chain.item",
+ )
+
+ def write(self, data: bytes) -> int:
+ self._coro_state.send(data)
+ self._coro_auth.send(data)
+
+ return len(data)
+
+ def finish(self) -> SendJoinResponse:
+ return self._response
diff --git a/synapse/http/client.py b/synapse/http/client.py
index 5f40f16e24..1ca6624fd5 100644
--- a/synapse/http/client.py
+++ b/synapse/http/client.py
@@ -813,7 +813,12 @@ class _ReadBodyWithMaxSizeProtocol(protocol.Protocol):
if self.deferred.called:
return
- self.stream.write(data)
+ try:
+ self.stream.write(data)
+ except Exception:
+ self.deferred.errback()
+ return
+
self.length += len(data)
# The first time the maximum size is exceeded, error and cancel the
# connection. dataReceived might be called again if data was received
diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py
index bb837b7b19..f5503b394b 100644
--- a/synapse/http/matrixfederationclient.py
+++ b/synapse/http/matrixfederationclient.py
@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import cgi
import codecs
import logging
@@ -19,13 +20,24 @@ import sys
import typing
import urllib.parse
from io import BytesIO, StringIO
-from typing import Callable, Dict, List, Optional, Tuple, Union
+from typing import (
+ Callable,
+ Dict,
+ Generic,
+ List,
+ Optional,
+ Tuple,
+ TypeVar,
+ Union,
+ overload,
+)
import attr
import treq
from canonicaljson import encode_canonical_json
from prometheus_client import Counter
from signedjson.sign import sign_json
+from typing_extensions import Literal
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
@@ -48,6 +60,7 @@ from synapse.http.client import (
BlacklistingAgentWrapper,
BlacklistingReactorWrapper,
BodyExceededMaxSize,
+ ByteWriteable,
encode_query_args,
read_body_with_max_size,
)
@@ -88,6 +101,27 @@ _next_id = 1
QueryArgs = Dict[str, Union[str, List[str]]]
+T = TypeVar("T")
+
+
+class ByteParser(ByteWriteable, Generic[T], abc.ABC):
+ """A `ByteWriteable` that has an additional `finish` function that returns
+ the parsed data.
+ """
+
+ CONTENT_TYPE = abc.abstractproperty() # type: str # type: ignore
+ """The expected content type of the response, e.g. `application/json`. If
+ the content type doesn't match we fail the request.
+ """
+
+ @abc.abstractmethod
+ def finish(self) -> T:
+ """Called when response has finished streaming and the parser should
+ return the final result (or error).
+ """
+ pass
+
+
@attr.s(slots=True, frozen=True)
class MatrixFederationRequest:
method = attr.ib(type=str)
@@ -148,15 +182,32 @@ class MatrixFederationRequest:
return self.json
-async def _handle_json_response(
+class JsonParser(ByteParser[Union[JsonDict, list]]):
+ """A parser that buffers the response and tries to parse it as JSON."""
+
+ CONTENT_TYPE = "application/json"
+
+ def __init__(self):
+ self._buffer = StringIO()
+ self._binary_wrapper = BinaryIOWrapper(self._buffer)
+
+ def write(self, data: bytes) -> int:
+ return self._binary_wrapper.write(data)
+
+ def finish(self) -> Union[JsonDict, list]:
+ return json_decoder.decode(self._buffer.getvalue())
+
+
+async def _handle_response(
reactor: IReactorTime,
timeout_sec: float,
request: MatrixFederationRequest,
response: IResponse,
start_ms: int,
-) -> JsonDict:
+ parser: ByteParser[T],
+) -> T:
"""
- Reads the JSON body of a response, with a timeout
+ Reads the body of a response with a timeout and sends it to a parser
Args:
reactor: twisted reactor, for the timeout
@@ -164,23 +215,21 @@ async def _handle_json_response(
request: the request that triggered the response
response: response to the request
start_ms: Timestamp when request was made
+ parser: The parser for the response
Returns:
- The parsed JSON response
+ The parsed response
"""
+
try:
- check_content_type_is_json(response.headers)
+ check_content_type_is(response.headers, parser.CONTENT_TYPE)
- buf = StringIO()
- d = read_body_with_max_size(response, BinaryIOWrapper(buf), MAX_RESPONSE_SIZE)
+ d = read_body_with_max_size(response, parser, MAX_RESPONSE_SIZE)
d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor)
- def parse(_len: int):
- return json_decoder.decode(buf.getvalue())
-
- d.addCallback(parse)
+ length = await make_deferred_yieldable(d)
- body = await make_deferred_yieldable(d)
+ value = parser.finish()
except BodyExceededMaxSize as e:
# The response was too big.
logger.warning(
@@ -193,9 +242,9 @@ async def _handle_json_response(
)
raise RequestSendFailed(e, can_retry=False) from e
except ValueError as e:
- # The JSON content was invalid.
+ # The content was invalid.
logger.warning(
- "{%s} [%s] Failed to parse JSON response - %s %s",
+ "{%s} [%s] Failed to parse response - %s %s",
request.txn_id,
request.destination,
request.method,
@@ -225,16 +274,17 @@ async def _handle_json_response(
time_taken_secs = reactor.seconds() - start_ms / 1000
logger.info(
- "{%s} [%s] Completed request: %d %s in %.2f secs - %s %s",
+ "{%s} [%s] Completed request: %d %s in %.2f secs, got %d bytes - %s %s",
request.txn_id,
request.destination,
response.code,
response.phrase.decode("ascii", errors="replace"),
time_taken_secs,
+ length,
request.method,
request.uri.decode("ascii"),
)
- return body
+ return value
class BinaryIOWrapper:
@@ -671,6 +721,7 @@ class MatrixFederationHttpClient:
)
return auth_headers
+ @overload
async def put_json(
self,
destination: str,
@@ -683,7 +734,41 @@ class MatrixFederationHttpClient:
ignore_backoff: bool = False,
backoff_on_404: bool = False,
try_trailing_slash_on_400: bool = False,
+ parser: Literal[None] = None,
) -> Union[JsonDict, list]:
+ ...
+
+ @overload
+ async def put_json(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ data: Optional[JsonDict] = None,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Optional[ByteParser[T]] = None,
+ ) -> T:
+ ...
+
+ async def put_json(
+ self,
+ destination: str,
+ path: str,
+ args: Optional[QueryArgs] = None,
+ data: Optional[JsonDict] = None,
+ json_data_callback: Optional[Callable[[], JsonDict]] = None,
+ long_retries: bool = False,
+ timeout: Optional[int] = None,
+ ignore_backoff: bool = False,
+ backoff_on_404: bool = False,
+ try_trailing_slash_on_400: bool = False,
+ parser: Optional[ByteParser] = None,
+ ):
"""Sends the specified json data using PUT
Args:
@@ -716,6 +801,8 @@ class MatrixFederationHttpClient:
of the request. Workaround for #3622 in Synapse <= v0.99.3. This
will be attempted before backing off if backing off has been
enabled.
+ parser: The parser to use to decode the response. Defaults to
+ parsing as JSON.
Returns:
Succeeds when we get a 2xx HTTP response. The
@@ -756,8 +843,16 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
- body = await _handle_json_response(
- self.reactor, _sec_timeout, request, response, start_ms
+ if parser is None:
+ parser = JsonParser()
+
+ body = await _handle_response(
+ self.reactor,
+ _sec_timeout,
+ request,
+ response,
+ start_ms,
+ parser=parser,
)
return body
@@ -830,12 +925,8 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
- body = await _handle_json_response(
- self.reactor,
- _sec_timeout,
- request,
- response,
- start_ms,
+ body = await _handle_response(
+ self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
)
return body
@@ -907,8 +998,8 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
- body = await _handle_json_response(
- self.reactor, _sec_timeout, request, response, start_ms
+ body = await _handle_response(
+ self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
)
return body
@@ -975,8 +1066,8 @@ class MatrixFederationHttpClient:
else:
_sec_timeout = self.default_timeout
- body = await _handle_json_response(
- self.reactor, _sec_timeout, request, response, start_ms
+ body = await _handle_response(
+ self.reactor, _sec_timeout, request, response, start_ms, parser=JsonParser()
)
return body
@@ -1068,16 +1159,16 @@ def _flatten_response_never_received(e):
return repr(e)
-def check_content_type_is_json(headers: Headers) -> None:
+def check_content_type_is(headers: Headers, expected_content_type: str) -> None:
"""
Check that a set of HTTP headers have a Content-Type header, and that it
- is application/json.
+ is the expected value..
Args:
headers: headers to check
Raises:
- RequestSendFailed: if the Content-Type header is missing or isn't JSON
+ RequestSendFailed: if the Content-Type header is missing or doesn't match
"""
content_type_headers = headers.getRawHeaders(b"Content-Type")
@@ -1089,11 +1180,10 @@ def check_content_type_is_json(headers: Headers) -> None:
c_type = content_type_headers[0].decode("ascii") # only the first header
val, options = cgi.parse_header(c_type)
- if val != "application/json":
+ if val != expected_content_type:
raise RequestSendFailed(
RuntimeError(
- "Remote server sent Content-Type header of '%s', not 'application/json'"
- % c_type,
+ f"Remote server sent Content-Type header of '{c_type}', not '{expected_content_type}'",
),
can_retry=False,
)
diff --git a/synapse/python_dependencies.py b/synapse/python_dependencies.py
index 989523c823..546231bec0 100644
--- a/synapse/python_dependencies.py
+++ b/synapse/python_dependencies.py
@@ -87,6 +87,7 @@ REQUIREMENTS = [
# We enforce that we have a `cryptography` version that bundles an `openssl`
# with the latest security patches.
"cryptography>=3.4.7",
+ "ijson>=3.0",
]
CONDITIONAL_REQUIREMENTS = {
|