From 4f6194492a4f16c37e55d15d0434d5357266fed0 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 15 Jul 2024 11:42:59 +0200 Subject: Make sure we use the right logic for enabling the media repo. (#17424) This removes the `enable_media_repo` attribute on the server config in favour of always using the `can_load_media_repo` in the media config. This should avoid issues like in #17420 in the future --- synapse/app/homeserver.py | 2 +- synapse/config/repository.py | 2 +- synapse/config/server.py | 6 ------ 3 files changed, 2 insertions(+), 8 deletions(-) (limited to 'synapse') diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py index 2b111847b7..e114ab7ec4 100644 --- a/synapse/app/homeserver.py +++ b/synapse/app/homeserver.py @@ -217,7 +217,7 @@ class SynapseHomeServer(HomeServer): ) if name in ["media", "federation", "client"]: - if self.config.server.enable_media_repo: + if self.config.media.can_load_media_repo: media_repo = self.get_media_repository_resource() resources.update( { diff --git a/synapse/config/repository.py b/synapse/config/repository.py index 1645470499..dc0e93ffa1 100644 --- a/synapse/config/repository.py +++ b/synapse/config/repository.py @@ -126,7 +126,7 @@ class ContentRepositoryConfig(Config): # Only enable the media repo if either the media repo is enabled or the # current worker app is the media repo. if ( - self.root.server.enable_media_repo is False + config.get("enable_media_repo", True) is False and config.get("worker_app") != "synapse.app.media_repository" ): self.can_load_media_repo = False diff --git a/synapse/config/server.py b/synapse/config/server.py index a2b2305776..8bb97df175 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -395,12 +395,6 @@ class ServerConfig(Config): self.presence_router_config, ) = load_module(presence_router_config, ("presence", "presence_router")) - # whether to enable the media repository endpoints. This should be set - # to false if the media repository is running as a separate endpoint; - # doing so ensures that we will not run cache cleanup jobs on the - # master, potentially causing inconsistency. - self.enable_media_repo = config.get("enable_media_repo", True) - # Whether to require authentication to retrieve profile data (avatars, # display names) of other users through the client API. self.require_auth_for_profile_requests = config.get( -- cgit 1.5.1 From df11af14dbd2faad916924cab96e75bd7c95a66a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Jul 2024 16:13:04 +0100 Subject: Fix bug where sync could get stuck when using workers (#17438) This is because we serialized the token wrong if the instance map contained entries from before the minimum token. --- changelog.d/17438.bugfix | 1 + synapse/handlers/sliding_sync.py | 11 +++++-- synapse/types/__init__.py | 65 +++++++++++++++++++++++++++++++----- tests/test_types.py | 71 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 138 insertions(+), 10 deletions(-) create mode 100644 changelog.d/17438.bugfix (limited to 'synapse') diff --git a/changelog.d/17438.bugfix b/changelog.d/17438.bugfix new file mode 100644 index 0000000000..cff6eecd48 --- /dev/null +++ b/changelog.d/17438.bugfix @@ -0,0 +1 @@ +Fix rare bug where `/sync` would break for a user when using workers with multiple stream writers. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index be98b379eb..1b5262d667 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -699,10 +699,17 @@ class SlidingSyncHandler: instance_to_max_stream_ordering_map[instance_name] = stream_ordering # Then assemble the `RoomStreamToken` + min_stream_pos = min(instance_to_max_stream_ordering_map.values()) membership_snapshot_token = RoomStreamToken( # Minimum position in the `instance_map` - stream=min(instance_to_max_stream_ordering_map.values()), - instance_map=immutabledict(instance_to_max_stream_ordering_map), + stream=min_stream_pos, + instance_map=immutabledict( + { + instance_name: stream_pos + for instance_name, stream_pos in instance_to_max_stream_ordering_map.items() + if stream_pos > min_stream_pos + } + ), ) # Since we fetched the users room list at some point in time after the from/to diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index b22a13ef01..3962ecc996 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -20,6 +20,7 @@ # # import abc +import logging import re import string from enum import Enum @@ -74,6 +75,9 @@ if TYPE_CHECKING: from synapse.storage.databases.main import DataStore, PurgeEventsStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore + +logger = logging.getLogger(__name__) + # Define a state map type from type/state_key to T (usually an event ID or # event) T = TypeVar("T") @@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): represented by a default `stream` attribute and a map of instance name to stream position of any writers that are ahead of the default stream position. + + The values in `instance_map` must be greater than the `stream` attribute. """ stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True) @@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): kw_only=True, ) + def __attrs_post_init__(self) -> None: + # Enforce that all instances have a value greater than the min stream + # position. + for i, v in self.instance_map.items(): + if v <= self.stream: + raise ValueError( + f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}" + ) + @classmethod @abc.abstractmethod async def parse(cls, store: "DataStore", string: str) -> "Self": @@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): for instance in set(self.instance_map).union(other.instance_map) } + # Filter out any redundant entries. + instance_map = {i: s for i, s in instance_map.items() if s > max_stream} + return attr.evolve( self, stream=max_stream, instance_map=immutabledict(instance_map) ) @@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): def bound_stream_token(self, max_stream: int) -> "Self": """Bound the stream positions to a maximum value""" + min_pos = min(self.stream, max_stream) return type(self)( - stream=min(self.stream, max_stream), + stream=min_pos, instance_map=immutabledict( - {k: min(s, max_stream) for k, s in self.instance_map.items()} + { + k: min(s, max_stream) + for k, s in self.instance_map.items() + if min(s, max_stream) > min_pos + } ), ) @@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." ) + super().__attrs_post_init__() + @classmethod async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": try: @@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): instance_map = {} for part in parts[1:]: + if not part: + # Handle tokens of the form `m5~`, which were created by + # a bug + continue + key, value = part.split(".") instance_id = int(key) pos = int(value) @@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): except CancelledError: raise except Exception: - pass + # We log an exception here as even though this *might* be a client + # handing a bad token, its more likely that Synapse returned a bad + # token (and we really want to catch those!). + logger.exception("Failed to parse stream token: %r", string) raise SynapseError(400, "Invalid room stream token %r" % (string,)) @classmethod @@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): return self.instance_map.get(instance_name, self.stream) async def to_string(self, store: "DataStore") -> str: + """See class level docstring for information about the format.""" + if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) elif self.instance_map: @@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): instance_id = await store.get_id_for_instance(name) entries.append(f"{instance_id}.{pos}") - encoded_map = "~".join(entries) - return f"m{self.stream}~{encoded_map}" + if entries: + encoded_map = "~".join(entries) + return f"m{self.stream}~{encoded_map}" + return f"s{self.stream}" else: return "s%d" % (self.stream,) @@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken): instance_map = {} for part in parts[1:]: + if not part: + # Handle tokens of the form `m5~`, which were created by + # a bug + continue + key, value = part.split(".") instance_id = int(key) pos = int(value) @@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken): except CancelledError: raise except Exception: - pass + # We log an exception here as even though this *might* be a client + # handing a bad token, its more likely that Synapse returned a bad + # token (and we really want to catch those!). + logger.exception("Failed to parse stream token: %r", string) raise SynapseError(400, "Invalid stream token %r" % (string,)) async def to_string(self, store: "DataStore") -> str: + """See class level docstring for information about the format.""" + if self.instance_map: entries = [] for name, pos in self.instance_map.items(): @@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken): instance_id = await store.get_id_for_instance(name) entries.append(f"{instance_id}.{pos}") - encoded_map = "~".join(entries) - return f"m{self.stream}~{encoded_map}" + if entries: + encoded_map = "~".join(entries) + return f"m{self.stream}~{encoded_map}" + return str(self.stream) else: return str(self.stream) diff --git a/tests/test_types.py b/tests/test_types.py index 944aa784fc..00adc65a5a 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -19,9 +19,18 @@ # # +from typing import Type +from unittest import skipUnless + +from immutabledict import immutabledict +from parameterized import parameterized_class + from synapse.api.errors import SynapseError from synapse.types import ( + AbstractMultiWriterStreamToken, + MultiWriterStreamToken, RoomAlias, + RoomStreamToken, UserID, get_domain_from_id, get_localpart_from_id, @@ -29,6 +38,7 @@ from synapse.types import ( ) from tests import unittest +from tests.utils import USE_POSTGRES_FOR_TESTS class IsMineIDTests(unittest.HomeserverTestCase): @@ -127,3 +137,64 @@ class MapUsernameTestCase(unittest.TestCase): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast") + + +@parameterized_class( + ("token_type",), + [ + (MultiWriterStreamToken,), + (RoomStreamToken,), + ], + class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}", +) +class MultiWriterTokenTestCase(unittest.HomeserverTestCase): + """Tests for the different types of multi writer tokens.""" + + token_type: Type[AbstractMultiWriterStreamToken] + + def test_basic_token(self) -> None: + """Test that a simple stream token can be serialized and unserialized""" + store = self.hs.get_datastores().main + + token = self.token_type(stream=5) + + string_token = self.get_success(token.to_string(store)) + + if isinstance(token, RoomStreamToken): + self.assertEqual(string_token, "s5") + else: + self.assertEqual(string_token, "5") + + parsed_token = self.get_success(self.token_type.parse(store, string_token)) + self.assertEqual(parsed_token, token) + + @skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres") + def test_instance_map(self) -> None: + """Test for stream token with instance map""" + store = self.hs.get_datastores().main + + token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6})) + + string_token = self.get_success(token.to_string(store)) + self.assertEqual(string_token, "m5~1.6") + + parsed_token = self.get_success(self.token_type.parse(store, string_token)) + self.assertEqual(parsed_token, token) + + def test_instance_map_assertion(self) -> None: + """Test that we assert values in the instance map are greater than the + min stream position""" + + with self.assertRaises(ValueError): + self.token_type(stream=5, instance_map=immutabledict({"foo": 4})) + + with self.assertRaises(ValueError): + self.token_type(stream=5, instance_map=immutabledict({"foo": 5})) + + def test_parse_bad_token(self) -> None: + """Test that we can parse tokens produced by a bug in Synapse of the + form `m5~`""" + store = self.hs.get_datastores().main + + parsed_token = self.get_success(self.token_type.parse(store, "m5~")) + self.assertEqual(parsed_token, self.token_type(stream=5)) -- cgit 1.5.1 From 429ecb7564640b4d04588173e9e54bf53524aadf Mon Sep 17 00:00:00 2001 From: Shay Date: Tue, 16 Jul 2024 03:13:55 -0700 Subject: Handle remote download responses with `UNKNOWN_LENGTH` more gracefully (#17439) Prior to this PR, remote downloads which did not provide a `content-length` were decremented from the remote download ratelimiter at the max allowable size, leading to excessive ratelimiting - see https://github.com/element-hq/synapse/issues/17394. This PR adds a linearizer to limit concurrent remote downloads to 6 per IP address, and decrements remote downloads without a `content-length` from the ratelimiter *after* the download is complete and the response length is known. Also adds logic to ensure that responses with a known length respect the `max_download_size`. --- changelog.d/17439.bugfix | 1 + synapse/http/matrixfederationclient.py | 126 +++++++++++++++++++++------------ tests/media/test_media_storage.py | 49 ++++++++++--- tests/rest/client/test_media.py | 50 +++++++++++-- 4 files changed, 166 insertions(+), 60 deletions(-) create mode 100644 changelog.d/17439.bugfix (limited to 'synapse') diff --git a/changelog.d/17439.bugfix b/changelog.d/17439.bugfix new file mode 100644 index 0000000000..f36c3ec255 --- /dev/null +++ b/changelog.d/17439.bugfix @@ -0,0 +1 @@ +Limit concurrent remote downloads to 6 per IP address, and decrement remote downloads without a content-length from the ratelimiter after the download is complete. \ No newline at end of file diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index 749b01dd0e..6fd75fd381 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -90,7 +90,7 @@ from synapse.logging.context import make_deferred_yieldable, run_in_background from synapse.logging.opentracing import set_tag, start_active_span, tags from synapse.types import JsonDict from synapse.util import json_decoder -from synapse.util.async_helpers import AwakenableSleeper, timeout_deferred +from synapse.util.async_helpers import AwakenableSleeper, Linearizer, timeout_deferred from synapse.util.metrics import Measure from synapse.util.stringutils import parse_and_validate_server_name @@ -475,6 +475,8 @@ class MatrixFederationHttpClient: use_proxy=True, ) + self.remote_download_linearizer = Linearizer("remote_download_linearizer", 6) + def wake_destination(self, destination: str) -> None: """Called when the remote server may have come back online.""" @@ -1486,35 +1488,44 @@ class MatrixFederationHttpClient: ) headers = dict(response.headers.getAllRawHeaders()) - expected_size = response.length - # if we don't get an expected length then use the max length + if expected_size == UNKNOWN_LENGTH: expected_size = max_size - logger.debug( - f"File size unknown, assuming file is max allowable size: {max_size}" - ) + else: + if int(expected_size) > max_size: + msg = "Requested file is too large > %r bytes" % (max_size,) + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE) - read_body, _ = await download_ratelimiter.can_do_action( - requester=None, - key=ip_address, - n_actions=expected_size, - ) - if not read_body: - msg = "Requested file size exceeds ratelimits" - logger.warning( - "{%s} [%s] %s", - request.txn_id, - request.destination, - msg, + read_body, _ = await download_ratelimiter.can_do_action( + requester=None, + key=ip_address, + n_actions=expected_size, ) - raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED) + if not read_body: + msg = "Requested file size exceeds ratelimits" + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError( + HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED + ) try: - # add a byte of headroom to max size as function errs at >= - d = read_body_with_max_size(response, output_stream, expected_size + 1) - d.addTimeout(self.default_timeout_seconds, self.reactor) - length = await make_deferred_yieldable(d) + async with self.remote_download_linearizer.queue(ip_address): + # add a byte of headroom to max size as function errs at >= + d = read_body_with_max_size(response, output_stream, expected_size + 1) + d.addTimeout(self.default_timeout_seconds, self.reactor) + length = await make_deferred_yieldable(d) except BodyExceededMaxSize: msg = "Requested file is too large > %r bytes" % (expected_size,) logger.warning( @@ -1560,6 +1571,13 @@ class MatrixFederationHttpClient: request.method, request.uri.decode("ascii"), ) + + # if we didn't know the length upfront, decrement the actual size from ratelimiter + if response.length == UNKNOWN_LENGTH: + download_ratelimiter.record_action( + requester=None, key=ip_address, n_actions=length + ) + return length, headers async def federation_get_file( @@ -1630,29 +1648,37 @@ class MatrixFederationHttpClient: ) headers = dict(response.headers.getAllRawHeaders()) - expected_size = response.length - # if we don't get an expected length then use the max length + if expected_size == UNKNOWN_LENGTH: expected_size = max_size - logger.debug( - f"File size unknown, assuming file is max allowable size: {max_size}" - ) + else: + if int(expected_size) > max_size: + msg = "Requested file is too large > %r bytes" % (max_size,) + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError(HTTPStatus.BAD_GATEWAY, msg, Codes.TOO_LARGE) - read_body, _ = await download_ratelimiter.can_do_action( - requester=None, - key=ip_address, - n_actions=expected_size, - ) - if not read_body: - msg = "Requested file size exceeds ratelimits" - logger.warning( - "{%s} [%s] %s", - request.txn_id, - request.destination, - msg, + read_body, _ = await download_ratelimiter.can_do_action( + requester=None, + key=ip_address, + n_actions=expected_size, ) - raise SynapseError(HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED) + if not read_body: + msg = "Requested file size exceeds ratelimits" + logger.warning( + "{%s} [%s] %s", + request.txn_id, + request.destination, + msg, + ) + raise SynapseError( + HTTPStatus.TOO_MANY_REQUESTS, msg, Codes.LIMIT_EXCEEDED + ) # this should be a multipart/mixed response with the boundary string in the header try: @@ -1672,11 +1698,12 @@ class MatrixFederationHttpClient: raise SynapseError(HTTPStatus.BAD_GATEWAY, msg) try: - # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >= - deferred = read_multipart_response( - response, output_stream, boundary, expected_size + 1 - ) - deferred.addTimeout(self.default_timeout_seconds, self.reactor) + async with self.remote_download_linearizer.queue(ip_address): + # add a byte of headroom to max size as `_MultipartParserProtocol.dataReceived` errs at >= + deferred = read_multipart_response( + response, output_stream, boundary, expected_size + 1 + ) + deferred.addTimeout(self.default_timeout_seconds, self.reactor) except BodyExceededMaxSize: msg = "Requested file is too large > %r bytes" % (expected_size,) logger.warning( @@ -1743,6 +1770,13 @@ class MatrixFederationHttpClient: request.method, request.uri.decode("ascii"), ) + + # if we didn't know the length upfront, decrement the actual size from ratelimiter + if response.length == UNKNOWN_LENGTH: + download_ratelimiter.record_action( + requester=None, key=ip_address, n_actions=length + ) + return length, headers, multipart_response.json diff --git a/tests/media/test_media_storage.py b/tests/media/test_media_storage.py index 70912e22f8..e55001fb40 100644 --- a/tests/media/test_media_storage.py +++ b/tests/media/test_media_storage.py @@ -1057,13 +1057,15 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): ) assert channel.code == 200 + @override_config({"remote_media_download_burst_count": "87M"}) @patch( "synapse.http.matrixfederationclient.read_body_with_max_size", read_body_with_max_size_30MiB, ) - def test_download_ratelimit_max_size_sub(self) -> None: + def test_download_ratelimit_unknown_length(self) -> None: """ - Test that if no content-length is provided, the default max size is applied instead + Test that if no content-length is provided, ratelimit will still be applied after + download once length is known """ # mock out actually sending the request @@ -1077,19 +1079,48 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): self.client._send_request = _send_request # type: ignore - # ten requests should go through using the max size (500MB/50MB) - for i in range(10): - channel2 = self.make_request( + # 3 requests should go through (note 3rd one would technically violate ratelimit but + # is applied *after* download - the next one will be ratelimited) + for i in range(3): + channel = self.make_request( "GET", f"/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxy{i}", shorthand=False, ) - assert channel2.code == 200 + assert channel.code == 200 - # eleventh will hit ratelimit - channel3 = self.make_request( + # 4th will hit ratelimit + channel2 = self.make_request( "GET", "/_matrix/media/v3/download/remote.org/abcdefghijklmnopqrstuvwxyx", shorthand=False, ) - assert channel3.code == 429 + assert channel2.code == 429 + + @override_config({"max_upload_size": "29M"}) + @patch( + "synapse.http.matrixfederationclient.read_body_with_max_size", + read_body_with_max_size_30MiB, + ) + def test_max_download_respected(self) -> None: + """ + Test that the max download size is enforced - note that max download size is determined + by the max_upload_size + """ + + # mock out actually sending the request + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = 31457280 + resp.headers = Headers({"Content-Type": ["application/octet-stream"]}) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + channel = self.make_request( + "GET", "/_matrix/media/v3/download/remote.org/abcd", shorthand=False + ) + assert channel.code == 502 + assert channel.json_body["errcode"] == "M_TOO_LARGE" diff --git a/tests/rest/client/test_media.py b/tests/rest/client/test_media.py index 7f2caed7d5..466c5a0b70 100644 --- a/tests/rest/client/test_media.py +++ b/tests/rest/client/test_media.py @@ -1809,13 +1809,19 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): ) assert channel.code == 200 + @override_config( + { + "remote_media_download_burst_count": "87M", + } + ) @patch( "synapse.http.matrixfederationclient.read_multipart_response", read_multipart_response_30MiB, ) - def test_download_ratelimit_max_size_sub(self) -> None: + def test_download_ratelimit_unknown_length(self) -> None: """ - Test that if no content-length is provided, the default max size is applied instead + Test that if no content-length is provided, ratelimiting is still applied after + media is downloaded and length is known """ # mock out actually sending the request @@ -1831,8 +1837,9 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): self.client._send_request = _send_request # type: ignore - # ten requests should go through using the max size (500MB/50MB) - for i in range(10): + # first 3 will go through (note that 3rd request technically violates rate limit but + # that since the ratelimiting is applied *after* download it goes through, but next one fails) + for i in range(3): channel2 = self.make_request( "GET", f"/_matrix/client/v1/media/download/remote.org/abc{i}", @@ -1841,7 +1848,7 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): ) assert channel2.code == 200 - # eleventh will hit ratelimit + # 4th will hit ratelimit channel3 = self.make_request( "GET", "/_matrix/client/v1/media/download/remote.org/abcd", @@ -1850,6 +1857,39 @@ class RemoteDownloadLimiterTestCase(unittest.HomeserverTestCase): ) assert channel3.code == 429 + @override_config({"max_upload_size": "29M"}) + @patch( + "synapse.http.matrixfederationclient.read_multipart_response", + read_multipart_response_30MiB, + ) + def test_max_download_respected(self) -> None: + """ + Test that the max download size is enforced - note that max download size is determined + by the max_upload_size + """ + + # mock out actually sending the request, returns a 30MiB response + async def _send_request(*args: Any, **kwargs: Any) -> IResponse: + resp = MagicMock(spec=IResponse) + resp.code = 200 + resp.length = 31457280 + resp.headers = Headers( + {"Content-Type": ["multipart/mixed; boundary=gc0p4Jq0M2Yt08jU534c0p"]} + ) + resp.phrase = b"OK" + return resp + + self.client._send_request = _send_request # type: ignore + + channel = self.make_request( + "GET", + "/_matrix/client/v1/media/download/remote.org/abcd", + shorthand=False, + access_token=self.tok, + ) + assert channel.code == 502 + assert channel.json_body["errcode"] == "M_TOO_LARGE" + def test_file_download(self) -> None: content = io.BytesIO(b"file_to_stream") content_uri = self.get_success( -- cgit 1.5.1 From 30263b43c29620244e277f613b9bdcff9c7ed00b Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 17 Jul 2024 11:20:23 +0100 Subject: Add SlidingSyncStreamToken --- synapse/types/__init__.py | 37 +++++++++++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) (limited to 'synapse') diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 3962ecc996..23ac1842f8 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1137,6 +1137,43 @@ StreamToken.START = StreamToken( ) +@attr.s(slots=True, frozen=True, auto_attribs=True) +class SlidingSyncStreamToken: + """The same as a `StreamToken`, but includes an extra field at the start for + the sliding sync connection token (separated by a '/'). This is used to + store per-connection state. + + This then looks something like: + 5/s2633508_17_338_6732159_1082514_541479_274711_265584_1_379 + """ + + stream_token: StreamToken + connection_token: int + + @staticmethod + @cancellable + async def from_string(store: "DataStore", string: str) -> "SlidingSyncStreamToken": + """Creates a SlidingSyncStreamToken from its textual representation.""" + try: + connection_token_str, stream_token_str = string.split("/", 1) + connection_token = int(connection_token_str) + stream_token = await StreamToken.from_string(store, stream_token_str) + + return SlidingSyncStreamToken( + stream_token=stream_token, + connection_token=connection_token, + ) + except CancelledError: + raise + except Exception: + raise SynapseError(400, "Invalid stream token") + + async def to_string(self, store: "DataStore") -> str: + """Serializes the token to a string""" + stream_token_str = await self.stream_token.to_string(store) + return f"{self.connection_token}/{stream_token_str}" + + @attr.s(slots=True, frozen=True, auto_attribs=True) class PersistedPosition: """Position of a newly persisted row with instance that persisted it.""" -- cgit 1.5.1 From 1ad1cce3f2137f6be2c8d48121de27ea194fef4f Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 16 Jul 2024 12:13:40 +0100 Subject: Pass throught SlidingSyncStreamToken --- synapse/handlers/sliding_sync.py | 24 +++++++++++++++--------- synapse/rest/client/sync.py | 6 ++++-- synapse/types/handlers/__init__.py | 12 +++++++++--- 3 files changed, 28 insertions(+), 14 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 1b5262d667..3aa139ff1c 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -36,6 +36,7 @@ from synapse.types import ( PersistedEventPosition, Requester, RoomStreamToken, + SlidingSyncStreamToken, StateMap, StreamKeyType, StreamToken, @@ -343,7 +344,7 @@ class SlidingSyncHandler: self, requester: Requester, sync_config: SlidingSyncConfig, - from_token: Optional[StreamToken] = None, + from_token: Optional[SlidingSyncStreamToken] = None, timeout_ms: int = 0, ) -> SlidingSyncResult: """ @@ -378,7 +379,7 @@ class SlidingSyncHandler: # this returns false, it means we timed out waiting, and we should # just return an empty response. before_wait_ts = self.clock.time_msec() - if not await self.notifier.wait_for_stream_token(from_token): + if not await self.notifier.wait_for_stream_token(from_token.stream_token): logger.warning( "Timed out waiting for worker to catch up. Returning empty response" ) @@ -416,7 +417,7 @@ class SlidingSyncHandler: sync_config.user.to_string(), timeout_ms, current_sync_callback, - from_token=from_token, + from_token=from_token.stream_token, ) return result @@ -425,7 +426,7 @@ class SlidingSyncHandler: self, sync_config: SlidingSyncConfig, to_token: StreamToken, - from_token: Optional[StreamToken] = None, + from_token: Optional[SlidingSyncStreamToken] = None, ) -> SlidingSyncResult: """ Generates the response body of a Sliding Sync result, represented as a @@ -458,7 +459,7 @@ class SlidingSyncHandler: await self.get_room_membership_for_user_at_to_token( user=sync_config.user, to_token=to_token, - from_token=from_token, + from_token=from_token.stream_token if from_token else None, ) ) @@ -609,8 +610,11 @@ class SlidingSyncHandler: sync_config=sync_config, to_token=to_token ) + # TODO: Update this when we implement per-connection state + connection_token = 0 + return SlidingSyncResult( - next_pos=to_token, + next_pos=SlidingSyncStreamToken(to_token, connection_token), lists=lists, rooms=rooms, extensions=extensions, @@ -1346,7 +1350,7 @@ class SlidingSyncHandler: room_id: str, room_sync_config: RoomSyncConfig, room_membership_for_user_at_to_token: _RoomMembershipForUser, - from_token: Optional[StreamToken], + from_token: Optional[SlidingSyncStreamToken], to_token: StreamToken, ) -> SlidingSyncResult.RoomResult: """ @@ -1410,7 +1414,7 @@ class SlidingSyncHandler: # - TODO: For an incremental sync where we haven't sent it down this # connection before to_bound = ( - from_token.room_key + from_token.stream_token.room_key if from_token is not None and not room_membership_for_user_at_to_token.newly_joined else None @@ -1477,7 +1481,9 @@ class SlidingSyncHandler: instance_name=timeline_event.internal_metadata.instance_name, stream=timeline_event.internal_metadata.stream_ordering, ) - if persisted_position.persisted_after(from_token.room_key): + if persisted_position.persisted_after( + from_token.stream_token.room_key + ): num_live += 1 else: # Since we're iterating over the timeline events in diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 1d8cbfdf00..b3967db18d 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -54,7 +54,7 @@ from synapse.http.servlet import ( from synapse.http.site import SynapseRequest from synapse.logging.opentracing import trace_with_opname from synapse.rest.admin.experimental_features import ExperimentalFeature -from synapse.types import JsonDict, Requester, StreamToken +from synapse.types import JsonDict, Requester, SlidingSyncStreamToken, StreamToken from synapse.types.rest.client import SlidingSyncBody from synapse.util import json_decoder from synapse.util.caches.lrucache import LruCache @@ -889,7 +889,9 @@ class SlidingSyncRestServlet(RestServlet): from_token = None if from_token_string is not None: - from_token = await StreamToken.from_string(self.store, from_token_string) + from_token = await SlidingSyncStreamToken.from_string( + self.store, from_token_string + ) # TODO: We currently don't know whether we're going to use sticky params or # maybe some filters like sync v2 where they are built up once and referenced diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py index 409120470a..642e5483a9 100644 --- a/synapse/types/handlers/__init__.py +++ b/synapse/types/handlers/__init__.py @@ -31,7 +31,13 @@ else: from pydantic import Extra from synapse.events import EventBase -from synapse.types import JsonDict, JsonMapping, StreamToken, UserID +from synapse.types import ( + JsonDict, + JsonMapping, + SlidingSyncStreamToken, + StreamToken, + UserID, +) from synapse.types.rest.client import SlidingSyncBody if TYPE_CHECKING: @@ -287,7 +293,7 @@ class SlidingSyncResult: def __bool__(self) -> bool: return bool(self.to_device) - next_pos: StreamToken + next_pos: SlidingSyncStreamToken lists: Dict[str, SlidingWindowList] rooms: Dict[str, RoomResult] extensions: Extensions @@ -300,7 +306,7 @@ class SlidingSyncResult: return bool(self.lists or self.rooms or self.extensions) @staticmethod - def empty(next_pos: StreamToken) -> "SlidingSyncResult": + def empty(next_pos: SlidingSyncStreamToken) -> "SlidingSyncResult": "Return a new empty result" return SlidingSyncResult( next_pos=next_pos, -- cgit 1.5.1 From d44f7e12b1a206b23e4f70f390834e88aba67561 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Jul 2024 14:12:01 +0100 Subject: WIP/PoC of storing whether we have sent rooms down to clients --- synapse/handlers/sliding_sync.py | 197 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 197 insertions(+) (limited to 'synapse') diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 3aa139ff1c..1c2ed95660 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -18,6 +18,7 @@ # # import logging +from enum import Enum from itertools import chain from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set, Tuple @@ -38,6 +39,7 @@ from synapse.types import ( RoomStreamToken, SlidingSyncStreamToken, StateMap, + StrCollection, StreamKeyType, StreamToken, UserID, @@ -1861,3 +1863,198 @@ class SlidingSyncHandler: next_batch=f"{stream_id}", events=messages, ) + + +class HaveSentRoomFlag(Enum): + """Flag for whether we have sent the room down a sliding sync connection. + + The valid state changes here are: + NEVER -> LIVE + LIVE -> PREVIOUSLY + PREVIOUSLY -> LIVE + """ + + # The room has never been sent down (or we have forgotten we have sent it + # down). + NEVER = 1 + + # We have previously sent the room down, but there are updates that we + # haven't sent down. + PREVIOUSLY = 2 + + # We have sent the room down and the client has received all updates. + LIVE = 3 + + +@attr.s(auto_attribs=True, slots=True, frozen=True) +class HaveSentRoom: + """Whether we have sent the room down a sliding sync connection. + + Attributes: + status: Flag of if we have or haven't sent down the room + last_token: If the flag is `PREVIOUSLY` then this is non-null and + contains the last stream token of the last updates we sent down + the room, i.e. we still need to send everything since then to the + client. + """ + + status: HaveSentRoomFlag + last_token: Optional[RoomStreamToken] + + @staticmethod + def previously(last_token: RoomStreamToken) -> "HaveSentRoom": + """Constructor for `PREVIOUSLY` flag.""" + return HaveSentRoom(HaveSentRoomFlag.PREVIOUSLY, last_token) + + +HAVE_SENT_ROOM_NEVER = HaveSentRoom(HaveSentRoomFlag.NEVER, None) +HAVE_SENT_ROOM_LIVE = HaveSentRoom(HaveSentRoomFlag.LIVE, None) + + +@attr.s(auto_attribs=True) +class SlidingSyncConnectionStore: + """In-memory store of per-connection state, including what rooms we have + previously sent down a sliding sync connection. + + Note: This is NOT safe to run in a worker setup. + + The complication here is that we need to handle requests being resent, i.e. + if we sent down a room in a response that the client received, we must + consider the room *not* sent when we get the request again. + + This is handled by using an integer "token", which is returned to the client + as part of the sync token. For each connection we store a mapping from + tokens to the room states, and create a new entry when we send down new + rooms. + + Note that for any given sliding sync connection we will only store a maximum + of two different tokens: the previous token from the request and a new token + sent in the response. When we receive a request with a given token, we then + clear out all other entries with a different token. + + Attributes: + _connections: Mapping from `(user_id, conn_id)` to mapping of `token` + to mapping of room ID to `HaveSentRoom`. + """ + + # `(user_id, conn_id)` -> `token` -> `room_id` -> `HaveSentRoom` + _connections: Dict[Tuple[str, str], Dict[int, Dict[str, HaveSentRoom]]] = ( + attr.Factory(dict) + ) + + async def have_sent_room( + self, user_id: str, conn_id: str, connection_token: int, room_id: str + ) -> HaveSentRoom: + """Whether for the given user_id/conn_id/token, return whether we have + previously sent the room down + """ + + sync_statuses = self._connections.setdefault((user_id, conn_id), {}) + room_status = sync_statuses.get(connection_token, {}).get( + room_id, HAVE_SENT_ROOM_NEVER + ) + + return room_status + + async def record_rooms( + self, + user_id: str, + conn_id: str, + from_token: Optional[SlidingSyncStreamToken], + *, + sent_room_ids: StrCollection, + unsent_room_ids: StrCollection, + ) -> int: + """Record which rooms we have/haven't sent down in a new response + + Attributes: + user_id + conn_id + from_token: The since token from the request, if any + sent_room_ids: The set of room IDs that we have sent down as + part of this request (only needs to be ones we didn't + previously sent down). + unsent_room_ids: The set of room IDs that have had updates + since the `last_room_token`, but which were not included in + this request + """ + prev_connection_token = 0 + if from_token is not None: + prev_connection_token = from_token.connection_token + + # If there are no changes then this is a noop. + if not sent_room_ids and not unsent_room_ids: + return prev_connection_token + + sync_statuses = self._connections.setdefault((user_id, conn_id), {}) + + # Generate a new token, removing any existing entries in that token + # (which can happen if requests get resent). + new_store_token = prev_connection_token + 1 + sync_statuses.pop(new_store_token, None) + + # Copy over and update the room mappings. + new_room_statuses = dict(sync_statuses.get(prev_connection_token, {})) + + # Whether we have updated the `new_room_statuses`, if we don't by the + # end we can treat this as a noop. + have_updated = False + for room_id in sent_room_ids: + new_room_statuses[room_id] = HAVE_SENT_ROOM_LIVE + have_updated = True + + # Whether we add/update the entries for unsent rooms depends on the + # existing entry: + # - LIVE: We have previously sent down everything up to + # `last_room_token, so we update the entry to be `PREVIOUSLY` with + # `last_room_token`. + # - PREVIOUSLY: We have previously sent down everything up to *a* + # given token, so we don't need to update the entry. + # - NEVER: We have never previously sent down the room, and we haven't + # sent anything down this time either so we leave it as NEVER. + + # Work out the new state for unsent rooms that were `LIVE`. + if from_token: + new_unsent_state = HaveSentRoom.previously(from_token.stream_token.room_key) + else: + new_unsent_state = HAVE_SENT_ROOM_NEVER + + for room_id in unsent_room_ids: + prev_state = new_room_statuses.get(room_id) + if prev_state is not None and prev_state.status == HaveSentRoomFlag.LIVE: + new_room_statuses[room_id] = new_unsent_state + have_updated = True + + if not have_updated: + return prev_connection_token + + sync_statuses[new_store_token] = new_room_statuses + + return new_store_token + + async def mark_token_seen( + self, + user_id: str, + conn_id: str, + from_token: Optional[SlidingSyncStreamToken], + ) -> None: + """We have received a request with the given token, so we can clear out + any other tokens associated with the connection. + + If there is no from token then we have started afresh, and so we delete + all tokens associated with the device. + """ + # Clear out any tokens for the connection that doesn't match the one + # from the request. + + sync_statuses = self._connections.pop((user_id, conn_id), {}) + if from_token is None: + return + + sync_statuses = { + i: room_statuses + for i, room_statuses in sync_statuses.items() + if i == from_token.connection_token + } + if sync_statuses: + self._connections[(user_id, conn_id)] = sync_statuses -- cgit 1.5.1 From 53273db3e8f21c302cec6f0d6e699c49e5ef28a2 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 17 Jul 2024 13:54:00 +0100 Subject: Add conn_id field --- synapse/handlers/sliding_sync.py | 2 +- synapse/rest/client/sync.py | 6 +++--- synapse/types/handlers/__init__.py | 28 +++++++++++++++++++++++++++- synapse/types/rest/client/__init__.py | 5 +++++ 4 files changed, 36 insertions(+), 5 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 1c2ed95660..b108e8a4a2 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -1807,7 +1807,7 @@ class SlidingSyncHandler: """ user_id = sync_config.user.to_string() - device_id = sync_config.device_id + device_id = sync_config.requester.device_id # Check that this request has a valid device ID (not all requests have # to belong to a device, and so device_id is None), and that the diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index b3967db18d..2909052bbf 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -881,7 +881,6 @@ class SlidingSyncRestServlet(RestServlet): ) user = requester.user - device_id = requester.device_id timeout = parse_integer(request, "timeout", default=0) # Position in the stream @@ -902,11 +901,12 @@ class SlidingSyncRestServlet(RestServlet): sync_config = SlidingSyncConfig( user=user, - device_id=device_id, + requester=requester, # FIXME: Currently, we're just manually copying the fields from the - # `SlidingSyncBody` into the config. How can we gurantee into the future + # `SlidingSyncBody` into the config. How can we guarantee into the future # that we don't forget any? I would like something more structured like # `copy_attributes(from=body, to=config)` + conn_id=body.conn_id, lists=body.lists, room_subscriptions=body.room_subscriptions, extensions=body.extensions, diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py index 642e5483a9..0c2ab13c93 100644 --- a/synapse/types/handlers/__init__.py +++ b/synapse/types/handlers/__init__.py @@ -34,6 +34,7 @@ from synapse.events import EventBase from synapse.types import ( JsonDict, JsonMapping, + Requester, SlidingSyncStreamToken, StreamToken, UserID, @@ -108,7 +109,7 @@ class SlidingSyncConfig(SlidingSyncBody): """ user: UserID - device_id: Optional[str] + requester: Requester # Pydantic config class Config: @@ -119,6 +120,31 @@ class SlidingSyncConfig(SlidingSyncBody): # Allow custom types like `UserID` to be used in the model arbitrary_types_allowed = True + def connection_id(self) -> str: + """Return a string identifier for this connection. May clash with + connection IDs from different users. + + This is generally a combination of device ID and conn_id. However, both + these two are optional (e.g. puppet access tokens don't have device + IDs), so this handles those edge cases. + """ + + # `conn_id` can be null, in which case we default to the empty string + # (if conn ID is empty then the client can't have multiple sync loops) + conn_id = self.conn_id or "" + + if self.requester.device_id: + return f"D/{self.requester.device_id}/{conn_id}" + + if self.requester.access_token_id: + # If we don't have a device, then the access token ID should be a + # stable ID. + return f"A/{self.requester.access_token_id}/{conn_id}" + + # If we have neither then its likely an AS or some weird token. Either + # way we can just fail here. + raise Exception("Cannot use sliding sync with access token type") + class OperationType(Enum): """ diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index dbe37bc712..5be8cf5389 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -120,6 +120,9 @@ class SlidingSyncBody(RequestBodyModel): Sliding Sync API request body. Attributes: + conn_id: An optional string to identify this connection to the server. If this + is missing, only 1 sliding sync connection can be made to the server at + any one time. lists: Sliding window API. A map of list key to list information (:class:`SlidingSyncList`). Max lists: 100. The list keys should be arbitrary strings which the client is using to refer to the list. Keep this @@ -315,6 +318,8 @@ class SlidingSyncBody(RequestBodyModel): to_device: Optional[ToDeviceExtension] = None + conn_id: Optional[str] + # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884 if TYPE_CHECKING: lists: Optional[Dict[str, SlidingSyncList]] = None -- cgit 1.5.1 From e2a88e44ef6b33a691411ff7ae93d68c2011297e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 16 Jul 2024 12:32:24 +0100 Subject: Use new room store to track if we've sent a room down --- synapse/handlers/sliding_sync.py | 55 ++++++++++++++++++++++++++++++++-------- 1 file changed, 44 insertions(+), 11 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index b108e8a4a2..618c69d5fe 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set import attr from immutabledict import immutabledict +from typing_extensions import assert_never from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership from synapse.events import EventBase @@ -342,6 +343,8 @@ class SlidingSyncHandler: self.relations_handler = hs.get_relations_handler() self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync + self.connection_store = SlidingSyncConnectionStore() + async def wait_for_sync_for_user( self, requester: Requester, @@ -449,6 +452,12 @@ class SlidingSyncHandler: # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() + await self.connection_store.mark_token_seen( + user_id, + conn_id=sync_config.connection_id(), + from_token=from_token, + ) + # Get all of the room IDs that the user should be able to see in the sync # response has_lists = sync_config.lists is not None and len(sync_config.lists) > 0 @@ -596,7 +605,7 @@ class SlidingSyncHandler: rooms: Dict[str, SlidingSyncResult.RoomResult] = {} for room_id, room_sync_config in relevant_room_map.items(): room_sync_result = await self.get_room_sync_data( - user=sync_config.user, + sync_config=sync_config, room_id=room_id, room_sync_config=room_sync_config, room_membership_for_user_at_to_token=room_membership_for_user_map[ @@ -612,8 +621,18 @@ class SlidingSyncHandler: sync_config=sync_config, to_token=to_token ) - # TODO: Update this when we implement per-connection state - connection_token = 0 + if has_lists or has_room_subscriptions: + connection_token = await self.connection_store.record_rooms( + user_id, + conn_id=sync_config.connection_id(), + from_token=from_token, + sent_room_ids=relevant_room_map.keys(), + unsent_room_ids=[], # TODO: We currently ssume that we have sent down all updates. + ) + elif from_token: + connection_token = from_token.connection_token + else: + connection_token = 0 return SlidingSyncResult( next_pos=SlidingSyncStreamToken(to_token, connection_token), @@ -1348,7 +1367,7 @@ class SlidingSyncHandler: async def get_room_sync_data( self, - user: UserID, + sync_config: SlidingSyncConfig, room_id: str, room_sync_config: RoomSyncConfig, room_membership_for_user_at_to_token: _RoomMembershipForUser, @@ -1370,6 +1389,7 @@ class SlidingSyncHandler: from_token: The point in the stream to sync from. to_token: The point in the stream to sync up to. """ + user = sync_config.user # Assemble the list of timeline events # @@ -1413,14 +1433,27 @@ class SlidingSyncHandler: # screen of information: # - Initial sync (because no `from_token` to limit us anyway) # - When users `newly_joined` - # - TODO: For an incremental sync where we haven't sent it down this + # - For an incremental sync where we haven't sent it down this # connection before - to_bound = ( - from_token.stream_token.room_key - if from_token is not None - and not room_membership_for_user_at_to_token.newly_joined - else None - ) + + if from_token and not room_membership_for_user_at_to_token.newly_joined: + room_status = await self.connection_store.have_sent_room( + user_id=user.to_string(), + conn_id=sync_config.connection_id(), + connection_token=from_token.connection_token, + room_id=room_id, + ) + if room_status.status == HaveSentRoomFlag.LIVE: + to_bound = from_token.stream_token.room_key + elif room_status.status == HaveSentRoomFlag.PREVIOUSLY: + assert room_status.last_token is not None + to_bound = room_status.last_token + elif room_status.status == HaveSentRoomFlag.NEVER: + to_bound = None + else: + assert_never(room_status.status) + else: + to_bound = None timeline_events, new_room_key = await self.store.paginate_room_events( room_id=room_id, -- cgit 1.5.1 From 185831754ead63f1728674497b0dcb91059091ee Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 17 Jul 2024 10:32:31 +0100 Subject: Handle initial flag correctly --- synapse/handlers/sliding_sync.py | 66 +++++++++++++++++++--------------------- 1 file changed, 31 insertions(+), 35 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 618c69d5fe..32686036b4 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -1391,6 +1391,37 @@ class SlidingSyncHandler: """ user = sync_config.user + # Determine whether we should limit the timeline to the token range. + # + # We should return historical messages (before token range) in the + # following cases because we want clients to be able to show a basic + # screen of information: + # - Initial sync (because no `from_token` to limit us anyway) + # - When users `newly_joined` + # - For an incremental sync where we haven't sent it down this + # connection before + to_bound = None + initial = True + if from_token and not room_membership_for_user_at_to_token.newly_joined: + room_status = await self.connection_store.have_sent_room( + user_id=user.to_string(), + conn_id=sync_config.connection_id(), + connection_token=from_token.connection_token, + room_id=room_id, + ) + if room_status.status == HaveSentRoomFlag.LIVE: + to_bound = from_token.stream_token.room_key + initial = False + elif room_status.status == HaveSentRoomFlag.PREVIOUSLY: + assert room_status.last_token is not None + to_bound = room_status.last_token + initial = False + elif room_status.status == HaveSentRoomFlag.NEVER: + to_bound = None + initial = True + else: + assert_never(room_status.status) + # Assemble the list of timeline events # # FIXME: It would be nice to make the `rooms` response more uniform regardless of @@ -1426,35 +1457,6 @@ class SlidingSyncHandler: room_membership_for_user_at_to_token.event_pos.to_room_stream_token() ) - # Determine whether we should limit the timeline to the token range. - # - # We should return historical messages (before token range) in the - # following cases because we want clients to be able to show a basic - # screen of information: - # - Initial sync (because no `from_token` to limit us anyway) - # - When users `newly_joined` - # - For an incremental sync where we haven't sent it down this - # connection before - - if from_token and not room_membership_for_user_at_to_token.newly_joined: - room_status = await self.connection_store.have_sent_room( - user_id=user.to_string(), - conn_id=sync_config.connection_id(), - connection_token=from_token.connection_token, - room_id=room_id, - ) - if room_status.status == HaveSentRoomFlag.LIVE: - to_bound = from_token.stream_token.room_key - elif room_status.status == HaveSentRoomFlag.PREVIOUSLY: - assert room_status.last_token is not None - to_bound = room_status.last_token - elif room_status.status == HaveSentRoomFlag.NEVER: - to_bound = None - else: - assert_never(room_status.status) - else: - to_bound = None - timeline_events, new_room_key = await self.store.paginate_room_events( room_id=room_id, from_key=from_bound, @@ -1575,12 +1577,6 @@ class SlidingSyncHandler: # indicate to the client that a state reset happened. Perhaps we should indicate # this by setting `initial: True` and empty `required_state`. - # TODO: Since we can't determine whether we've already sent a room down this - # Sliding Sync connection before (we plan to add this optimization in the - # future), we're always returning the requested room state instead of - # updates. - initial = True - # Check whether the room has a name set name_state_ids = await self.get_current_state_ids_at( room_id=room_id, -- cgit 1.5.1 From de6e3bdee8715669a76b40eaacdc8728cc02d4eb Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 17 Jul 2024 10:32:50 +0100 Subject: Handle state deltas in non-initial rooms --- synapse/handlers/sliding_sync.py | 14 ++++++++-- synapse/storage/databases/main/state_deltas.py | 38 ++++++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 32686036b4..6d0d9b425c 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -1720,9 +1720,17 @@ class SlidingSyncHandler: to_token=to_token, ) else: - # TODO: Once we can figure out if we've sent a room down this connection before, - # we can return updates instead of the full required state. - raise NotImplementedError() + assert to_bound is not None + + deltas = await self.store.get_current_state_deltas_for_room( + room_id, to_bound, to_token.room_key + ) + # TODO: Filter room state before fetching events + # TODO: Handle state resets where event_id is None + events = await self.store.get_events( + [d.event_id for d in deltas if d.event_id] + ) + room_state = {(s.type, s.state_key): s for s in events.values()} required_room_state: StateMap[EventBase] = {} if required_state_filter != StateFilter.none(): diff --git a/synapse/storage/databases/main/state_deltas.py b/synapse/storage/databases/main/state_deltas.py index 036972ac25..cd6cb2c7a9 100644 --- a/synapse/storage/databases/main/state_deltas.py +++ b/synapse/storage/databases/main/state_deltas.py @@ -26,6 +26,8 @@ import attr from synapse.storage._base import SQLBaseStore from synapse.storage.database import LoggingTransaction +from synapse.storage.databases.main.stream import _filter_results_by_stream +from synapse.types import RoomStreamToken from synapse.util.caches.stream_change_cache import StreamChangeCache logger = logging.getLogger(__name__) @@ -156,3 +158,39 @@ class StateDeltasStore(SQLBaseStore): "get_max_stream_id_in_current_state_deltas", self._get_max_stream_id_in_current_state_deltas_txn, ) + + async def get_current_state_deltas_for_room( + self, room_id: str, from_token: RoomStreamToken, to_token: RoomStreamToken + ) -> List[StateDelta]: + """Get the state deltas between that have happened between two + tokens.""" + + def get_current_state_deltas_for_room_txn( + txn: LoggingTransaction, + ) -> List[StateDelta]: + sql = """ + SELECT instance_name, stream_id, type, state_key, event_id, prev_event_id + FROM current_state_delta_stream + WHERE room_id = ? AND ? < stream_id AND stream_id <= ? + ORDER BY stream_id ASC + """ + txn.execute( + sql, (room_id, from_token.stream, to_token.get_max_stream_pos()) + ) + + return [ + StateDelta( + stream_id=row[1], + room_id=room_id, + event_type=row[2], + state_key=row[3], + event_id=row[4], + prev_event_id=row[5], + ) + for row in txn + if _filter_results_by_stream(from_token, to_token, row[0], row[1]) + ] + + return await self.db_pool.runInteraction( + "get_current_state_deltas_for_room", get_current_state_deltas_for_room_txn + ) -- cgit 1.5.1