From 1cf3ff6b40a9f0e72c39e471e921a46f56e4511f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 9 Jul 2024 12:26:45 -0500 Subject: Add `rooms` `name` and `avatar` to Sliding Sync `/sync` (#17418) Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync --- synapse/handlers/sliding_sync.py | 151 +++++++++++++++++++++++++-------------- 1 file changed, 96 insertions(+), 55 deletions(-) (limited to 'synapse') diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 8e2f751c02..bb81ca9d97 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -18,6 +18,7 @@ # # import logging +from itertools import chain from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Set, Tuple import attr @@ -464,6 +465,7 @@ class SlidingSyncHandler: membership_state_keys = room_sync_config.required_state_map.get( EventTypes.Member ) + # Also see `StateFilter.must_await_full_state(...)` for comparison lazy_loading = ( membership_state_keys is not None and len(membership_state_keys) == 1 @@ -1202,7 +1204,7 @@ class SlidingSyncHandler: # Figure out any stripped state events for invite/knocks. This allows the # potential joiner to identify the room. - stripped_state: List[JsonDict] = [] + stripped_state: Optional[List[JsonDict]] = None if room_membership_for_user_at_to_token.membership in ( Membership.INVITE, Membership.KNOCK, @@ -1239,7 +1241,7 @@ class SlidingSyncHandler: # updates. initial = True - # Fetch the required state for the room + # Fetch the `required_state` for the room # # No `required_state` for invite/knock rooms (just `stripped_state`) # @@ -1247,13 +1249,15 @@ class SlidingSyncHandler: # of membership. Currently, we have to make this optional because # `invite`/`knock` rooms only have `stripped_state`. See # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 + # + # Calculate the `StateFilter` based on the `required_state` for the room room_state: Optional[StateMap[EventBase]] = None + required_room_state: Optional[StateMap[EventBase]] = None if room_membership_for_user_at_to_token.membership not in ( Membership.INVITE, Membership.KNOCK, ): - # Calculate the `StateFilter` based on the `required_state` for the room - state_filter: Optional[StateFilter] = StateFilter.none() + required_state_filter = StateFilter.none() # If we have a double wildcard ("*", "*") in the `required_state`, we need # to fetch all state for the room # @@ -1276,7 +1280,7 @@ class SlidingSyncHandler: if StateValues.WILDCARD in room_sync_config.required_state_map.get( StateValues.WILDCARD, set() ): - state_filter = StateFilter.all() + required_state_filter = StateFilter.all() # TODO: `StateFilter` currently doesn't support wildcard event types. We're # currently working around this by returning all state to the client but it # would be nice to fetch less from the database and return just what the @@ -1285,7 +1289,7 @@ class SlidingSyncHandler: room_sync_config.required_state_map.get(StateValues.WILDCARD) is not None ): - state_filter = StateFilter.all() + required_state_filter = StateFilter.all() else: required_state_types: List[Tuple[str, Optional[str]]] = [] for ( @@ -1317,51 +1321,88 @@ class SlidingSyncHandler: else: required_state_types.append((state_type, state_key)) - state_filter = StateFilter.from_types(required_state_types) - - # We can skip fetching state if we don't need any - if state_filter != StateFilter.none(): - # We can return all of the state that was requested if we're doing an - # initial sync - if initial: - # People shouldn't see past their leave/ban event - if room_membership_for_user_at_to_token.membership in ( - Membership.LEAVE, - Membership.BAN, - ): - room_state = await self.storage_controllers.state.get_state_at( - room_id, - stream_position=to_token.copy_and_replace( - StreamKeyType.ROOM, - room_membership_for_user_at_to_token.event_pos.to_room_stream_token(), - ), - state_filter=state_filter, - # Partially-stated rooms should have all state events except for - # the membership events and since we've already excluded - # partially-stated rooms unless `required_state` only has - # `["m.room.member", "$LAZY"]` for membership, we should be able - # to retrieve everything requested. Plus we don't want to block - # the whole sync waiting for this one room. - await_full_state=False, - ) - # Otherwise, we can get the latest current state in the room - else: - room_state = await self.storage_controllers.state.get_current_state( - room_id, - state_filter, - # Partially-stated rooms should have all state events except for - # the membership events and since we've already excluded - # partially-stated rooms unless `required_state` only has - # `["m.room.member", "$LAZY"]` for membership, we should be able - # to retrieve everything requested. Plus we don't want to block - # the whole sync waiting for this one room. - await_full_state=False, - ) - # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` + required_state_filter = StateFilter.from_types(required_state_types) + + # We need this base set of info for the response so let's just fetch it along + # with the `required_state` for the room + META_ROOM_STATE = [(EventTypes.Name, ""), (EventTypes.RoomAvatar, "")] + state_filter = StateFilter( + types=StateFilter.from_types( + chain(META_ROOM_STATE, required_state_filter.to_types()) + ).types, + include_others=required_state_filter.include_others, + ) + + # We can return all of the state that was requested if this was the first + # time we've sent the room down this connection. + if initial: + # People shouldn't see past their leave/ban event + if room_membership_for_user_at_to_token.membership in ( + Membership.LEAVE, + Membership.BAN, + ): + room_state = await self.storage_controllers.state.get_state_at( + room_id, + stream_position=to_token.copy_and_replace( + StreamKeyType.ROOM, + room_membership_for_user_at_to_token.event_pos.to_room_stream_token(), + ), + state_filter=state_filter, + # Partially-stated rooms should have all state events except for + # remote membership events. Since we've already excluded + # partially-stated rooms unless `required_state` only has + # `["m.room.member", "$LAZY"]` for membership, we should be able to + # retrieve everything requested. When we're lazy-loading, if there + # are some remote senders in the timeline, we should also have their + # membership event because we had to auth that timeline event. Plus + # we don't want to block the whole sync waiting for this one room. + await_full_state=False, + ) + # Otherwise, we can get the latest current state in the room else: - # TODO: Once we can figure out if we've sent a room down this connection before, - # we can return updates instead of the full required state. - raise NotImplementedError() + room_state = await self.storage_controllers.state.get_current_state( + room_id, + state_filter, + # Partially-stated rooms should have all state events except for + # remote membership events. Since we've already excluded + # partially-stated rooms unless `required_state` only has + # `["m.room.member", "$LAZY"]` for membership, we should be able to + # retrieve everything requested. When we're lazy-loading, if there + # are some remote senders in the timeline, we should also have their + # membership event because we had to auth that timeline event. Plus + # we don't want to block the whole sync waiting for this one room. + await_full_state=False, + ) + # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` + else: + # TODO: Once we can figure out if we've sent a room down this connection before, + # we can return updates instead of the full required state. + raise NotImplementedError() + + if required_state_filter != StateFilter.none(): + required_room_state = required_state_filter.filter_state(room_state) + + # Find the room name and avatar from the state + room_name: Optional[str] = None + room_avatar: Optional[str] = None + if room_state is not None: + name_event = room_state.get((EventTypes.Name, "")) + if name_event is not None: + room_name = name_event.content.get("name") + + avatar_event = room_state.get((EventTypes.RoomAvatar, "")) + if avatar_event is not None: + room_avatar = avatar_event.content.get("url") + elif stripped_state is not None: + for event in stripped_state: + if event["type"] == EventTypes.Name: + room_name = event.get("content", {}).get("name") + elif event["type"] == EventTypes.RoomAvatar: + room_avatar = event.get("content", {}).get("url") + + # Found everything so we can stop looking + if room_name is not None and room_avatar is not None: + break # Figure out the last bump event in the room last_bump_event_result = ( @@ -1378,16 +1419,16 @@ class SlidingSyncHandler: bump_stamp = bump_event_pos.stream return SlidingSyncResult.RoomResult( - # TODO: Dummy value - name=None, - # TODO: Dummy value - avatar=None, + name=room_name, + avatar=room_avatar, # TODO: Dummy value heroes=None, # TODO: Dummy value is_dm=False, initial=initial, - required_state=list(room_state.values()) if room_state else None, + required_state=( + list(required_room_state.values()) if required_room_state else None + ), timeline_events=timeline_events, bundled_aggregations=bundled_aggregations, stripped_state=stripped_state, -- cgit 1.5.1 From 4ca13ce0dd6d1dc931cfde7e06191200ca0ec066 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Jul 2024 11:58:42 +0100 Subject: Handle to-device extensions to Sliding Sync (#17416) Implements MSC3885 --------- Co-authored-by: Eric Eastwood --- changelog.d/17416.feature | 1 + synapse/handlers/sliding_sync.py | 103 ++++++++++++++++- synapse/rest/client/sync.py | 17 ++- synapse/types/handlers/__init__.py | 35 +++++- synapse/types/rest/client/__init__.py | 48 +++++++- tests/rest/client/test_sync.py | 200 +++++++++++++++++++++++++++++++++- 6 files changed, 392 insertions(+), 12 deletions(-) create mode 100644 changelog.d/17416.feature (limited to 'synapse') diff --git a/changelog.d/17416.feature b/changelog.d/17416.feature new file mode 100644 index 0000000000..1d119cf48f --- /dev/null +++ b/changelog.d/17416.feature @@ -0,0 +1 @@ +Add to-device extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index bb81ca9d97..818b13621c 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -542,11 +542,15 @@ class SlidingSyncHandler: rooms[room_id] = room_sync_result + extensions = await self.get_extensions_response( + sync_config=sync_config, to_token=to_token + ) + return SlidingSyncResult( next_pos=to_token, lists=lists, rooms=rooms, - extensions={}, + extensions=extensions, ) async def get_sync_room_ids_for_user( @@ -1445,3 +1449,100 @@ class SlidingSyncHandler: notification_count=0, highlight_count=0, ) + + async def get_extensions_response( + self, + sync_config: SlidingSyncConfig, + to_token: StreamToken, + ) -> SlidingSyncResult.Extensions: + """Handle extension requests. + + Args: + sync_config: Sync configuration + to_token: The point in the stream to sync up to. + """ + + if sync_config.extensions is None: + return SlidingSyncResult.Extensions() + + to_device_response = None + if sync_config.extensions.to_device: + to_device_response = await self.get_to_device_extensions_response( + sync_config=sync_config, + to_device_request=sync_config.extensions.to_device, + to_token=to_token, + ) + + return SlidingSyncResult.Extensions(to_device=to_device_response) + + async def get_to_device_extensions_response( + self, + sync_config: SlidingSyncConfig, + to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension, + to_token: StreamToken, + ) -> SlidingSyncResult.Extensions.ToDeviceExtension: + """Handle to-device extension (MSC3885) + + Args: + sync_config: Sync configuration + to_device_request: The to-device extension from the request + to_token: The point in the stream to sync up to. + """ + + user_id = sync_config.user.to_string() + device_id = sync_config.device_id + + # Check that this request has a valid device ID (not all requests have + # to belong to a device, and so device_id is None), and that the + # extension is enabled. + if device_id is None or not to_device_request.enabled: + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=f"{to_token.to_device_key}", + events=[], + ) + + since_stream_id = 0 + if to_device_request.since is not None: + # We've already validated this is an int. + since_stream_id = int(to_device_request.since) + + if to_token.to_device_key < since_stream_id: + # The since token is ahead of our current token, so we return an + # empty response. + logger.warning( + "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r", + since_stream_id, + to_token.to_device_key, + ) + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=to_device_request.since, + events=[], + ) + + # Delete everything before the given since token, as we know the + # device must have received them. + deleted = await self.store.delete_messages_for_device( + user_id=user_id, + device_id=device_id, + up_to_stream_id=since_stream_id, + ) + + logger.debug( + "Deleted %d to-device messages up to %d for %s", + deleted, + since_stream_id, + user_id, + ) + + messages, stream_id = await self.store.get_messages_for_device( + user_id=user_id, + device_id=device_id, + from_stream_id=since_stream_id, + to_stream_id=to_token.to_device_key, + limit=min(to_device_request.limit, 100), # Limit to at most 100 events + ) + + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=f"{stream_id}", + events=messages, + ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 13aed1dc85..94d5faf9f7 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet): response["rooms"] = await self.encode_rooms( requester, sliding_sync_result.rooms ) - response["extensions"] = {} # TODO: sliding_sync_result.extensions + response["extensions"] = await self.encode_extensions( + requester, sliding_sync_result.extensions + ) return response @@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet): return serialized_rooms + async def encode_extensions( + self, requester: Requester, extensions: SlidingSyncResult.Extensions + ) -> JsonDict: + result = {} + + if extensions.to_device is not None: + result["to_device"] = { + "next_batch": extensions.to_device.next_batch, + "events": extensions.to_device.events, + } + + return result + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py index 43dcdf20dd..a8a3a8f242 100644 --- a/synapse/types/handlers/__init__.py +++ b/synapse/types/handlers/__init__.py @@ -18,7 +18,7 @@ # # from enum import Enum -from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple import attr from typing_extensions import TypedDict @@ -252,10 +252,39 @@ class SlidingSyncResult: count: int ops: List[Operation] + @attr.s(slots=True, frozen=True, auto_attribs=True) + class Extensions: + """Responses for extensions + + Attributes: + to_device: The to-device extension (MSC3885) + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ToDeviceExtension: + """The to-device extension (MSC3885) + + Attributes: + next_batch: The to-device stream token the client should use + to get more results + events: A list of to-device messages for the client + """ + + next_batch: str + events: Sequence[JsonMapping] + + def __bool__(self) -> bool: + return bool(self.events) + + to_device: Optional[ToDeviceExtension] = None + + def __bool__(self) -> bool: + return bool(self.to_device) + next_pos: StreamToken lists: Dict[str, SlidingWindowList] rooms: Dict[str, RoomResult] - extensions: JsonMapping + extensions: Extensions def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -271,5 +300,5 @@ class SlidingSyncResult: next_pos=next_pos, lists={}, rooms={}, - extensions={}, + extensions=SlidingSyncResult.Extensions(), ) diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index 55f6b44053..1e8fe76c99 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel): class RoomSubscription(CommonRoomParameters): pass - class Extension(RequestBodyModel): - enabled: Optional[StrictBool] = False - lists: Optional[List[StrictStr]] = None - rooms: Optional[List[StrictStr]] = None + class Extensions(RequestBodyModel): + """The extensions section of the request. + + Extensions MUST have an `enabled` flag which defaults to `false`. If a client + sends an unknown extension name, the server MUST ignore it (or else backwards + compatibility between clients and servers is broken when a newer client tries to + communicate with an older server). + """ + + class ToDeviceExtension(RequestBodyModel): + """The to-device extension (MSC3885) + + Attributes: + enabled + limit: Maximum number of to-device messages to return + since: The `next_batch` from the previous sync response + """ + + enabled: Optional[StrictBool] = False + limit: StrictInt = 100 + since: Optional[StrictStr] = None + + @validator("since") + def since_token_check( + cls, value: Optional[StrictStr] + ) -> Optional[StrictStr]: + # `since` comes in as an opaque string token but we know that it's just + # an integer representing the position in the device inbox stream. We + # want to pre-validate it to make sure it works fine in downstream code. + if value is None: + return value + + try: + int(value) + except ValueError: + raise ValueError( + "'extensions.to_device.since' is invalid (should look like an int)" + ) + + return value + + to_device: Optional[ToDeviceExtension] = None # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884 if TYPE_CHECKING: @@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel): else: lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type] room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None - extensions: Optional[Dict[StrictStr, Extension]] = None + extensions: Optional[Extensions] = None @validator("lists") def lists_length_check( diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index f7852562b1..304c0d4d3d 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -38,7 +38,16 @@ from synapse.api.constants import ( ) from synapse.events import EventBase from synapse.handlers.sliding_sync import StateValues -from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync +from synapse.rest.client import ( + devices, + knock, + login, + read_marker, + receipts, + room, + sendtodevice, + sync, +) from synapse.server import HomeServer from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID from synapse.util import Clock @@ -47,7 +56,7 @@ from tests import unittest from tests.federation.transport.test_knocking import ( KnockingStrippedStateEventHelperMixin, ) -from tests.server import TimedOutException +from tests.server import FakeChannel, TimedOutException from tests.test_utils.event_injection import mark_event_as_partial_state logger = logging.getLogger(__name__) @@ -3696,3 +3705,190 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): ], channel.json_body["lists"]["foo-list"], ) + + +class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase): + """Tests for the to-device sliding sync extension""" + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + sendtodevice.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable sliding sync + config["experimental_features"] = {"msc3575_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.sync_endpoint = ( + "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync" + ) + + def _assert_to_device_response( + self, channel: FakeChannel, expected_messages: List[JsonDict] + ) -> str: + """Assert the sliding sync response was successful and has the expected + to-device messages. + + Returns the next_batch token from the to-device section. + """ + self.assertEqual(channel.code, 200, channel.json_body) + extensions = channel.json_body["extensions"] + to_device = extensions["to_device"] + self.assertIsInstance(to_device["next_batch"], str) + self.assertEqual(to_device["events"], expected_messages) + + return to_device["next_batch"] + + def test_no_data(self) -> None: + """Test that enabling to-device extension works, even if there is + no-data + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + + # We expect no to-device messages + self._assert_to_device_response(channel, []) + + def test_data_initial_sync(self) -> None: + """Test that we get to-device messages when we don't specify a since + token""" + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass", "d1") + user2_id = self.register_user("u2", "pass") + user2_tok = self.login(user2_id, "pass", "d2") + + # Send the to-device message + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user1_id: {"d1": test_msg}}}, + access_token=user2_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + self._assert_to_device_response( + channel, + [{"content": test_msg, "sender": user2_id, "type": "m.test"}], + ) + + def test_data_incremental_sync(self) -> None: + """Test that we get to-device messages over incremental syncs""" + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass", "d1") + user2_id = self.register_user("u2", "pass") + user2_tok = self.login(user2_id, "pass", "d2") + + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + # No to-device messages yet. + next_batch = self._assert_to_device_response(channel, []) + + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user1_id: {"d1": test_msg}}}, + access_token=user2_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + "since": next_batch, + } + }, + }, + access_token=user1_tok, + ) + next_batch = self._assert_to_device_response( + channel, + [{"content": test_msg, "sender": user2_id, "type": "m.test"}], + ) + + # The next sliding sync request should not include the to-device + # message. + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + "since": next_batch, + } + }, + }, + access_token=user1_tok, + ) + self._assert_to_device_response(channel, []) + + # An initial sliding sync request should not include the to-device + # message, as it should have been deleted + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + self._assert_to_device_response(channel, []) -- cgit 1.5.1 From 677142b6a965baf36a3cc59c4e997bbd67891d42 Mon Sep 17 00:00:00 2001 From: Travis Ralston Date: Thu, 11 Jul 2024 07:03:13 -0600 Subject: Fix docs on `record_action` to clarify the actions are applied (#17426) This looks like a copy/paste error: the function doesn't reject anything, but instead allows the action count to go through regardless. The remainder of the function's documentation appears correct. --- changelog.d/17426.misc | 1 + synapse/api/ratelimiting.py | 5 ++--- 2 files changed, 3 insertions(+), 3 deletions(-) create mode 100644 changelog.d/17426.misc (limited to 'synapse') diff --git a/changelog.d/17426.misc b/changelog.d/17426.misc new file mode 100644 index 0000000000..886e5d4389 --- /dev/null +++ b/changelog.d/17426.misc @@ -0,0 +1 @@ +Fix documentation on `RateLimiter#record_action`. \ No newline at end of file diff --git a/synapse/api/ratelimiting.py b/synapse/api/ratelimiting.py index 26b8711851..b80630c5d3 100644 --- a/synapse/api/ratelimiting.py +++ b/synapse/api/ratelimiting.py @@ -236,9 +236,8 @@ class Ratelimiter: requester: The requester that is doing the action, if any. key: An arbitrary key used to classify an action. Defaults to the requester's user ID. - n_actions: The number of times the user wants to do this action. If the user - cannot do all of the actions, the user's action count is not incremented - at all. + n_actions: The number of times the user performed the action. May be negative + to "refund" the rate limit. _time_now_s: The current time. Optional, defaults to the current time according to self.clock. Only used by tests. """ -- cgit 1.5.1 From 606da398fc4c693f2e75b9520264e0fc51d03581 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 11 Jul 2024 16:00:44 +0100 Subject: Fix filtering room types on remote rooms (#17434) We can only fetch room types for rooms the server is in, so we need to only filter rooms that we're joined to. Also includes a perf fix to bulk fetch room types. --- changelog.d/17434.bugfix | 1 + synapse/handlers/sliding_sync.py | 22 +++++------ synapse/storage/databases/main/state.py | 52 ++++++++++++++++++++++++- tests/handlers/test_sliding_sync.py | 68 +++++++++++++++++++++++++++++++++ 4 files changed, 130 insertions(+), 13 deletions(-) create mode 100644 changelog.d/17434.bugfix (limited to 'synapse') diff --git a/changelog.d/17434.bugfix b/changelog.d/17434.bugfix new file mode 100644 index 0000000000..c7cce52397 --- /dev/null +++ b/changelog.d/17434.bugfix @@ -0,0 +1 @@ +Fix bug in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint when using room type filters and the user has one or more remote invites. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 818b13621c..8e6c2fb860 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -24,13 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Set, Tuple import attr from immutabledict import immutabledict -from synapse.api.constants import ( - AccountDataTypes, - Direction, - EventContentFields, - EventTypes, - Membership, -) +from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership from synapse.events import EventBase from synapse.events.utils import strip_event from synapse.handlers.relations import BundledAggregations @@ -959,11 +953,15 @@ class SlidingSyncHandler: # provided in the list. `None` is a valid type for rooms which do not have a # room type. if filters.room_types is not None or filters.not_room_types is not None: - # Make a copy so we don't run into an error: `Set changed size during - # iteration`, when we filter out and remove items - for room_id in filtered_room_id_set.copy(): - create_event = await self.store.get_create_event_for_room(room_id) - room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + room_to_type = await self.store.bulk_get_room_type( + { + room_id + for room_id in filtered_room_id_set + # We only know the room types for joined rooms + if sync_room_map[room_id].membership == Membership.JOIN + } + ) + for room_id, room_type in room_to_type.items(): if ( filters.room_types is not None and room_type not in filters.room_types diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index b2a67aff89..5188b2f7a4 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -41,7 +41,7 @@ from typing import ( import attr -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase @@ -298,6 +298,56 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = await self.get_event(create_id) return create_event + @cached(max_entries=10000) + async def get_room_type(self, room_id: str) -> Optional[str]: + """Get the room type for a given room. The server must be joined to the + given room. + """ + + row = await self.db_pool.simple_select_one( + table="room_stats_state", + keyvalues={"room_id": room_id}, + retcols=("room_type",), + allow_none=True, + desc="get_room_type", + ) + + if row is not None: + return row[0] + + # If we haven't updated `room_stats_state` with the room yet, query the + # create event directly. + create_event = await self.get_create_event_for_room(room_id) + room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + return room_type + + @cachedList(cached_method_name="get_room_type", list_name="room_ids") + async def bulk_get_room_type( + self, room_ids: Set[str] + ) -> Mapping[str, Optional[str]]: + """Bulk fetch room types for the given rooms, the server must be in all + the rooms given. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="room_stats_state", + column="room_id", + iterable=room_ids, + retcols=("room_id", "room_type"), + desc="bulk_get_room_type", + ) + + # If we haven't updated `room_stats_state` with the room yet, query the + # create events directly. This should happen only rarely so we don't + # mind if we do this in a loop. + results = dict(rows) + for room_id in room_ids - results.keys(): + create_event = await self.get_create_event_for_room(room_id) + room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + results[room_id] = room_type + + return results + @cached(max_entries=100000, iterable=True) async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index 9dd2363adc..eb4b0a05c7 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -35,6 +35,8 @@ from synapse.api.constants import ( RoomTypes, ) from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.handlers.sliding_sync import RoomSyncConfig, StateValues from synapse.rest import admin from synapse.rest.client import knock, login, room @@ -2791,6 +2793,72 @@ class FilterRoomsTestCase(HomeserverTestCase): self.assertEqual(filtered_room_map.keys(), {space_room_id}) + def test_filter_room_types_with_invite_remote_room(self) -> None: + """Test that we can apply a room type filter, even if we have an invite + for a remote room. + + This is a regression test. + """ + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create a fake remote invite and persist it. + invite_room_id = "!some:room" + invite_event = make_event_from_dict( + { + "room_id": invite_room_id, + "sender": "@user:test.serv", + "state_key": user1_id, + "depth": 1, + "origin_server_ts": 1, + "type": EventTypes.Member, + "content": {"membership": Membership.INVITE}, + "auth_events": [], + "prev_events": [], + }, + room_version=RoomVersions.V10, + ) + invite_event.internal_metadata.outlier = True + invite_event.internal_metadata.out_of_band_membership = True + + self.get_success( + self.store.maybe_store_room_on_outlier_membership( + room_id=invite_room_id, room_version=invite_event.room_version + ) + ) + context = EventContext.for_outlier(self.hs.get_storage_controllers()) + persist_controller = self.hs.get_storage_controllers().persistence + assert persist_controller is not None + self.get_success(persist_controller.persist_event(invite_event, context)) + + # Create a normal room (no room type) + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + after_rooms_token = self.event_sources.get_current_token() + + # Get the rooms the user should be syncing with + sync_room_map = self.get_success( + self.sliding_sync_handler.get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, + ) + ) + + filtered_room_map = self.get_success( + self.sliding_sync_handler.filter_rooms( + UserID.from_string(user1_id), + sync_room_map, + SlidingSyncConfig.SlidingSyncList.Filters( + room_types=[None, RoomTypes.SPACE], + ), + after_rooms_token, + ) + ) + + self.assertEqual(filtered_room_map.keys(), {room_id, invite_room_id}) + class SortRoomsTestCase(HomeserverTestCase): """ -- cgit 1.5.1