From 1cf3ff6b40a9f0e72c39e471e921a46f56e4511f Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Tue, 9 Jul 2024 12:26:45 -0500 Subject: Add `rooms` `name` and `avatar` to Sliding Sync `/sync` (#17418) Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync --- synapse/handlers/sliding_sync.py | 151 +++++++++++++++++++++++++-------------- 1 file changed, 96 insertions(+), 55 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 8e2f751c02..bb81ca9d97 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -18,6 +18,7 @@ # # import logging +from itertools import chain from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Set, Tuple import attr @@ -464,6 +465,7 @@ class SlidingSyncHandler: membership_state_keys = room_sync_config.required_state_map.get( EventTypes.Member ) + # Also see `StateFilter.must_await_full_state(...)` for comparison lazy_loading = ( membership_state_keys is not None and len(membership_state_keys) == 1 @@ -1202,7 +1204,7 @@ class SlidingSyncHandler: # Figure out any stripped state events for invite/knocks. This allows the # potential joiner to identify the room. - stripped_state: List[JsonDict] = [] + stripped_state: Optional[List[JsonDict]] = None if room_membership_for_user_at_to_token.membership in ( Membership.INVITE, Membership.KNOCK, @@ -1239,7 +1241,7 @@ class SlidingSyncHandler: # updates. initial = True - # Fetch the required state for the room + # Fetch the `required_state` for the room # # No `required_state` for invite/knock rooms (just `stripped_state`) # @@ -1247,13 +1249,15 @@ class SlidingSyncHandler: # of membership. Currently, we have to make this optional because # `invite`/`knock` rooms only have `stripped_state`. See # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 + # + # Calculate the `StateFilter` based on the `required_state` for the room room_state: Optional[StateMap[EventBase]] = None + required_room_state: Optional[StateMap[EventBase]] = None if room_membership_for_user_at_to_token.membership not in ( Membership.INVITE, Membership.KNOCK, ): - # Calculate the `StateFilter` based on the `required_state` for the room - state_filter: Optional[StateFilter] = StateFilter.none() + required_state_filter = StateFilter.none() # If we have a double wildcard ("*", "*") in the `required_state`, we need # to fetch all state for the room # @@ -1276,7 +1280,7 @@ class SlidingSyncHandler: if StateValues.WILDCARD in room_sync_config.required_state_map.get( StateValues.WILDCARD, set() ): - state_filter = StateFilter.all() + required_state_filter = StateFilter.all() # TODO: `StateFilter` currently doesn't support wildcard event types. We're # currently working around this by returning all state to the client but it # would be nice to fetch less from the database and return just what the @@ -1285,7 +1289,7 @@ class SlidingSyncHandler: room_sync_config.required_state_map.get(StateValues.WILDCARD) is not None ): - state_filter = StateFilter.all() + required_state_filter = StateFilter.all() else: required_state_types: List[Tuple[str, Optional[str]]] = [] for ( @@ -1317,51 +1321,88 @@ class SlidingSyncHandler: else: required_state_types.append((state_type, state_key)) - state_filter = StateFilter.from_types(required_state_types) - - # We can skip fetching state if we don't need any - if state_filter != StateFilter.none(): - # We can return all of the state that was requested if we're doing an - # initial sync - if initial: - # People shouldn't see past their leave/ban event - if room_membership_for_user_at_to_token.membership in ( - Membership.LEAVE, - Membership.BAN, - ): - room_state = await self.storage_controllers.state.get_state_at( - room_id, - stream_position=to_token.copy_and_replace( - StreamKeyType.ROOM, - room_membership_for_user_at_to_token.event_pos.to_room_stream_token(), - ), - state_filter=state_filter, - # Partially-stated rooms should have all state events except for - # the membership events and since we've already excluded - # partially-stated rooms unless `required_state` only has - # `["m.room.member", "$LAZY"]` for membership, we should be able - # to retrieve everything requested. Plus we don't want to block - # the whole sync waiting for this one room. - await_full_state=False, - ) - # Otherwise, we can get the latest current state in the room - else: - room_state = await self.storage_controllers.state.get_current_state( - room_id, - state_filter, - # Partially-stated rooms should have all state events except for - # the membership events and since we've already excluded - # partially-stated rooms unless `required_state` only has - # `["m.room.member", "$LAZY"]` for membership, we should be able - # to retrieve everything requested. Plus we don't want to block - # the whole sync waiting for this one room. - await_full_state=False, - ) - # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` + required_state_filter = StateFilter.from_types(required_state_types) + + # We need this base set of info for the response so let's just fetch it along + # with the `required_state` for the room + META_ROOM_STATE = [(EventTypes.Name, ""), (EventTypes.RoomAvatar, "")] + state_filter = StateFilter( + types=StateFilter.from_types( + chain(META_ROOM_STATE, required_state_filter.to_types()) + ).types, + include_others=required_state_filter.include_others, + ) + + # We can return all of the state that was requested if this was the first + # time we've sent the room down this connection. + if initial: + # People shouldn't see past their leave/ban event + if room_membership_for_user_at_to_token.membership in ( + Membership.LEAVE, + Membership.BAN, + ): + room_state = await self.storage_controllers.state.get_state_at( + room_id, + stream_position=to_token.copy_and_replace( + StreamKeyType.ROOM, + room_membership_for_user_at_to_token.event_pos.to_room_stream_token(), + ), + state_filter=state_filter, + # Partially-stated rooms should have all state events except for + # remote membership events. Since we've already excluded + # partially-stated rooms unless `required_state` only has + # `["m.room.member", "$LAZY"]` for membership, we should be able to + # retrieve everything requested. When we're lazy-loading, if there + # are some remote senders in the timeline, we should also have their + # membership event because we had to auth that timeline event. Plus + # we don't want to block the whole sync waiting for this one room. + await_full_state=False, + ) + # Otherwise, we can get the latest current state in the room else: - # TODO: Once we can figure out if we've sent a room down this connection before, - # we can return updates instead of the full required state. - raise NotImplementedError() + room_state = await self.storage_controllers.state.get_current_state( + room_id, + state_filter, + # Partially-stated rooms should have all state events except for + # remote membership events. Since we've already excluded + # partially-stated rooms unless `required_state` only has + # `["m.room.member", "$LAZY"]` for membership, we should be able to + # retrieve everything requested. When we're lazy-loading, if there + # are some remote senders in the timeline, we should also have their + # membership event because we had to auth that timeline event. Plus + # we don't want to block the whole sync waiting for this one room. + await_full_state=False, + ) + # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` + else: + # TODO: Once we can figure out if we've sent a room down this connection before, + # we can return updates instead of the full required state. + raise NotImplementedError() + + if required_state_filter != StateFilter.none(): + required_room_state = required_state_filter.filter_state(room_state) + + # Find the room name and avatar from the state + room_name: Optional[str] = None + room_avatar: Optional[str] = None + if room_state is not None: + name_event = room_state.get((EventTypes.Name, "")) + if name_event is not None: + room_name = name_event.content.get("name") + + avatar_event = room_state.get((EventTypes.RoomAvatar, "")) + if avatar_event is not None: + room_avatar = avatar_event.content.get("url") + elif stripped_state is not None: + for event in stripped_state: + if event["type"] == EventTypes.Name: + room_name = event.get("content", {}).get("name") + elif event["type"] == EventTypes.RoomAvatar: + room_avatar = event.get("content", {}).get("url") + + # Found everything so we can stop looking + if room_name is not None and room_avatar is not None: + break # Figure out the last bump event in the room last_bump_event_result = ( @@ -1378,16 +1419,16 @@ class SlidingSyncHandler: bump_stamp = bump_event_pos.stream return SlidingSyncResult.RoomResult( - # TODO: Dummy value - name=None, - # TODO: Dummy value - avatar=None, + name=room_name, + avatar=room_avatar, # TODO: Dummy value heroes=None, # TODO: Dummy value is_dm=False, initial=initial, - required_state=list(room_state.values()) if room_state else None, + required_state=( + list(required_room_state.values()) if required_room_state else None + ), timeline_events=timeline_events, bundled_aggregations=bundled_aggregations, stripped_state=stripped_state, -- cgit 1.5.1 From 4ca13ce0dd6d1dc931cfde7e06191200ca0ec066 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 10 Jul 2024 11:58:42 +0100 Subject: Handle to-device extensions to Sliding Sync (#17416) Implements MSC3885 --------- Co-authored-by: Eric Eastwood --- changelog.d/17416.feature | 1 + synapse/handlers/sliding_sync.py | 103 ++++++++++++++++- synapse/rest/client/sync.py | 17 ++- synapse/types/handlers/__init__.py | 35 +++++- synapse/types/rest/client/__init__.py | 48 +++++++- tests/rest/client/test_sync.py | 200 +++++++++++++++++++++++++++++++++- 6 files changed, 392 insertions(+), 12 deletions(-) create mode 100644 changelog.d/17416.feature (limited to 'synapse/handlers') diff --git a/changelog.d/17416.feature b/changelog.d/17416.feature new file mode 100644 index 0000000000..1d119cf48f --- /dev/null +++ b/changelog.d/17416.feature @@ -0,0 +1 @@ +Add to-device extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index bb81ca9d97..818b13621c 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -542,11 +542,15 @@ class SlidingSyncHandler: rooms[room_id] = room_sync_result + extensions = await self.get_extensions_response( + sync_config=sync_config, to_token=to_token + ) + return SlidingSyncResult( next_pos=to_token, lists=lists, rooms=rooms, - extensions={}, + extensions=extensions, ) async def get_sync_room_ids_for_user( @@ -1445,3 +1449,100 @@ class SlidingSyncHandler: notification_count=0, highlight_count=0, ) + + async def get_extensions_response( + self, + sync_config: SlidingSyncConfig, + to_token: StreamToken, + ) -> SlidingSyncResult.Extensions: + """Handle extension requests. + + Args: + sync_config: Sync configuration + to_token: The point in the stream to sync up to. + """ + + if sync_config.extensions is None: + return SlidingSyncResult.Extensions() + + to_device_response = None + if sync_config.extensions.to_device: + to_device_response = await self.get_to_device_extensions_response( + sync_config=sync_config, + to_device_request=sync_config.extensions.to_device, + to_token=to_token, + ) + + return SlidingSyncResult.Extensions(to_device=to_device_response) + + async def get_to_device_extensions_response( + self, + sync_config: SlidingSyncConfig, + to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension, + to_token: StreamToken, + ) -> SlidingSyncResult.Extensions.ToDeviceExtension: + """Handle to-device extension (MSC3885) + + Args: + sync_config: Sync configuration + to_device_request: The to-device extension from the request + to_token: The point in the stream to sync up to. + """ + + user_id = sync_config.user.to_string() + device_id = sync_config.device_id + + # Check that this request has a valid device ID (not all requests have + # to belong to a device, and so device_id is None), and that the + # extension is enabled. + if device_id is None or not to_device_request.enabled: + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=f"{to_token.to_device_key}", + events=[], + ) + + since_stream_id = 0 + if to_device_request.since is not None: + # We've already validated this is an int. + since_stream_id = int(to_device_request.since) + + if to_token.to_device_key < since_stream_id: + # The since token is ahead of our current token, so we return an + # empty response. + logger.warning( + "Got to-device.since from the future. since token: %r is ahead of our current to_device stream position: %r", + since_stream_id, + to_token.to_device_key, + ) + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=to_device_request.since, + events=[], + ) + + # Delete everything before the given since token, as we know the + # device must have received them. + deleted = await self.store.delete_messages_for_device( + user_id=user_id, + device_id=device_id, + up_to_stream_id=since_stream_id, + ) + + logger.debug( + "Deleted %d to-device messages up to %d for %s", + deleted, + since_stream_id, + user_id, + ) + + messages, stream_id = await self.store.get_messages_for_device( + user_id=user_id, + device_id=device_id, + from_stream_id=since_stream_id, + to_stream_id=to_token.to_device_key, + limit=min(to_device_request.limit, 100), # Limit to at most 100 events + ) + + return SlidingSyncResult.Extensions.ToDeviceExtension( + next_batch=f"{stream_id}", + events=messages, + ) diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 13aed1dc85..94d5faf9f7 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -942,7 +942,9 @@ class SlidingSyncRestServlet(RestServlet): response["rooms"] = await self.encode_rooms( requester, sliding_sync_result.rooms ) - response["extensions"] = {} # TODO: sliding_sync_result.extensions + response["extensions"] = await self.encode_extensions( + requester, sliding_sync_result.extensions + ) return response @@ -1054,6 +1056,19 @@ class SlidingSyncRestServlet(RestServlet): return serialized_rooms + async def encode_extensions( + self, requester: Requester, extensions: SlidingSyncResult.Extensions + ) -> JsonDict: + result = {} + + if extensions.to_device is not None: + result["to_device"] = { + "next_batch": extensions.to_device.next_batch, + "events": extensions.to_device.events, + } + + return result + def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: SyncRestServlet(hs).register(http_server) diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py index 43dcdf20dd..a8a3a8f242 100644 --- a/synapse/types/handlers/__init__.py +++ b/synapse/types/handlers/__init__.py @@ -18,7 +18,7 @@ # # from enum import Enum -from typing import TYPE_CHECKING, Dict, Final, List, Optional, Tuple +from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple import attr from typing_extensions import TypedDict @@ -252,10 +252,39 @@ class SlidingSyncResult: count: int ops: List[Operation] + @attr.s(slots=True, frozen=True, auto_attribs=True) + class Extensions: + """Responses for extensions + + Attributes: + to_device: The to-device extension (MSC3885) + """ + + @attr.s(slots=True, frozen=True, auto_attribs=True) + class ToDeviceExtension: + """The to-device extension (MSC3885) + + Attributes: + next_batch: The to-device stream token the client should use + to get more results + events: A list of to-device messages for the client + """ + + next_batch: str + events: Sequence[JsonMapping] + + def __bool__(self) -> bool: + return bool(self.events) + + to_device: Optional[ToDeviceExtension] = None + + def __bool__(self) -> bool: + return bool(self.to_device) + next_pos: StreamToken lists: Dict[str, SlidingWindowList] rooms: Dict[str, RoomResult] - extensions: JsonMapping + extensions: Extensions def __bool__(self) -> bool: """Make the result appear empty if there are no updates. This is used @@ -271,5 +300,5 @@ class SlidingSyncResult: next_pos=next_pos, lists={}, rooms={}, - extensions={}, + extensions=SlidingSyncResult.Extensions(), ) diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index 55f6b44053..1e8fe76c99 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -276,10 +276,48 @@ class SlidingSyncBody(RequestBodyModel): class RoomSubscription(CommonRoomParameters): pass - class Extension(RequestBodyModel): - enabled: Optional[StrictBool] = False - lists: Optional[List[StrictStr]] = None - rooms: Optional[List[StrictStr]] = None + class Extensions(RequestBodyModel): + """The extensions section of the request. + + Extensions MUST have an `enabled` flag which defaults to `false`. If a client + sends an unknown extension name, the server MUST ignore it (or else backwards + compatibility between clients and servers is broken when a newer client tries to + communicate with an older server). + """ + + class ToDeviceExtension(RequestBodyModel): + """The to-device extension (MSC3885) + + Attributes: + enabled + limit: Maximum number of to-device messages to return + since: The `next_batch` from the previous sync response + """ + + enabled: Optional[StrictBool] = False + limit: StrictInt = 100 + since: Optional[StrictStr] = None + + @validator("since") + def since_token_check( + cls, value: Optional[StrictStr] + ) -> Optional[StrictStr]: + # `since` comes in as an opaque string token but we know that it's just + # an integer representing the position in the device inbox stream. We + # want to pre-validate it to make sure it works fine in downstream code. + if value is None: + return value + + try: + int(value) + except ValueError: + raise ValueError( + "'extensions.to_device.since' is invalid (should look like an int)" + ) + + return value + + to_device: Optional[ToDeviceExtension] = None # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884 if TYPE_CHECKING: @@ -287,7 +325,7 @@ class SlidingSyncBody(RequestBodyModel): else: lists: Optional[Dict[constr(max_length=64, strict=True), SlidingSyncList]] = None # type: ignore[valid-type] room_subscriptions: Optional[Dict[StrictStr, RoomSubscription]] = None - extensions: Optional[Dict[StrictStr, Extension]] = None + extensions: Optional[Extensions] = None @validator("lists") def lists_length_check( diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index f7852562b1..304c0d4d3d 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -38,7 +38,16 @@ from synapse.api.constants import ( ) from synapse.events import EventBase from synapse.handlers.sliding_sync import StateValues -from synapse.rest.client import devices, knock, login, read_marker, receipts, room, sync +from synapse.rest.client import ( + devices, + knock, + login, + read_marker, + receipts, + room, + sendtodevice, + sync, +) from synapse.server import HomeServer from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken, UserID from synapse.util import Clock @@ -47,7 +56,7 @@ from tests import unittest from tests.federation.transport.test_knocking import ( KnockingStrippedStateEventHelperMixin, ) -from tests.server import TimedOutException +from tests.server import FakeChannel, TimedOutException from tests.test_utils.event_injection import mark_event_as_partial_state logger = logging.getLogger(__name__) @@ -3696,3 +3705,190 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): ], channel.json_body["lists"]["foo-list"], ) + + +class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase): + """Tests for the to-device sliding sync extension""" + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + sync.register_servlets, + sendtodevice.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable sliding sync + config["experimental_features"] = {"msc3575_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.sync_endpoint = ( + "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync" + ) + + def _assert_to_device_response( + self, channel: FakeChannel, expected_messages: List[JsonDict] + ) -> str: + """Assert the sliding sync response was successful and has the expected + to-device messages. + + Returns the next_batch token from the to-device section. + """ + self.assertEqual(channel.code, 200, channel.json_body) + extensions = channel.json_body["extensions"] + to_device = extensions["to_device"] + self.assertIsInstance(to_device["next_batch"], str) + self.assertEqual(to_device["events"], expected_messages) + + return to_device["next_batch"] + + def test_no_data(self) -> None: + """Test that enabling to-device extension works, even if there is + no-data + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + + # We expect no to-device messages + self._assert_to_device_response(channel, []) + + def test_data_initial_sync(self) -> None: + """Test that we get to-device messages when we don't specify a since + token""" + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass", "d1") + user2_id = self.register_user("u2", "pass") + user2_tok = self.login(user2_id, "pass", "d2") + + # Send the to-device message + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user1_id: {"d1": test_msg}}}, + access_token=user2_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + self._assert_to_device_response( + channel, + [{"content": test_msg, "sender": user2_id, "type": "m.test"}], + ) + + def test_data_incremental_sync(self) -> None: + """Test that we get to-device messages over incremental syncs""" + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass", "d1") + user2_id = self.register_user("u2", "pass") + user2_tok = self.login(user2_id, "pass", "d2") + + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + # No to-device messages yet. + next_batch = self._assert_to_device_response(channel, []) + + test_msg = {"foo": "bar"} + chan = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user1_id: {"d1": test_msg}}}, + access_token=user2_tok, + ) + self.assertEqual(chan.code, 200, chan.result) + + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + "since": next_batch, + } + }, + }, + access_token=user1_tok, + ) + next_batch = self._assert_to_device_response( + channel, + [{"content": test_msg, "sender": user2_id, "type": "m.test"}], + ) + + # The next sliding sync request should not include the to-device + # message. + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + "since": next_batch, + } + }, + }, + access_token=user1_tok, + ) + self._assert_to_device_response(channel, []) + + # An initial sliding sync request should not include the to-device + # message, as it should have been deleted + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + self._assert_to_device_response(channel, []) -- cgit 1.5.1 From 606da398fc4c693f2e75b9520264e0fc51d03581 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Thu, 11 Jul 2024 16:00:44 +0100 Subject: Fix filtering room types on remote rooms (#17434) We can only fetch room types for rooms the server is in, so we need to only filter rooms that we're joined to. Also includes a perf fix to bulk fetch room types. --- changelog.d/17434.bugfix | 1 + synapse/handlers/sliding_sync.py | 22 +++++------ synapse/storage/databases/main/state.py | 52 ++++++++++++++++++++++++- tests/handlers/test_sliding_sync.py | 68 +++++++++++++++++++++++++++++++++ 4 files changed, 130 insertions(+), 13 deletions(-) create mode 100644 changelog.d/17434.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/17434.bugfix b/changelog.d/17434.bugfix new file mode 100644 index 0000000000..c7cce52397 --- /dev/null +++ b/changelog.d/17434.bugfix @@ -0,0 +1 @@ +Fix bug in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint when using room type filters and the user has one or more remote invites. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 818b13621c..8e6c2fb860 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -24,13 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Set, Tuple import attr from immutabledict import immutabledict -from synapse.api.constants import ( - AccountDataTypes, - Direction, - EventContentFields, - EventTypes, - Membership, -) +from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership from synapse.events import EventBase from synapse.events.utils import strip_event from synapse.handlers.relations import BundledAggregations @@ -959,11 +953,15 @@ class SlidingSyncHandler: # provided in the list. `None` is a valid type for rooms which do not have a # room type. if filters.room_types is not None or filters.not_room_types is not None: - # Make a copy so we don't run into an error: `Set changed size during - # iteration`, when we filter out and remove items - for room_id in filtered_room_id_set.copy(): - create_event = await self.store.get_create_event_for_room(room_id) - room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + room_to_type = await self.store.bulk_get_room_type( + { + room_id + for room_id in filtered_room_id_set + # We only know the room types for joined rooms + if sync_room_map[room_id].membership == Membership.JOIN + } + ) + for room_id, room_type in room_to_type.items(): if ( filters.room_types is not None and room_type not in filters.room_types diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py index b2a67aff89..5188b2f7a4 100644 --- a/synapse/storage/databases/main/state.py +++ b/synapse/storage/databases/main/state.py @@ -41,7 +41,7 @@ from typing import ( import attr -from synapse.api.constants import EventTypes, Membership +from synapse.api.constants import EventContentFields, EventTypes, Membership from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion from synapse.events import EventBase @@ -298,6 +298,56 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore): create_event = await self.get_event(create_id) return create_event + @cached(max_entries=10000) + async def get_room_type(self, room_id: str) -> Optional[str]: + """Get the room type for a given room. The server must be joined to the + given room. + """ + + row = await self.db_pool.simple_select_one( + table="room_stats_state", + keyvalues={"room_id": room_id}, + retcols=("room_type",), + allow_none=True, + desc="get_room_type", + ) + + if row is not None: + return row[0] + + # If we haven't updated `room_stats_state` with the room yet, query the + # create event directly. + create_event = await self.get_create_event_for_room(room_id) + room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + return room_type + + @cachedList(cached_method_name="get_room_type", list_name="room_ids") + async def bulk_get_room_type( + self, room_ids: Set[str] + ) -> Mapping[str, Optional[str]]: + """Bulk fetch room types for the given rooms, the server must be in all + the rooms given. + """ + + rows = await self.db_pool.simple_select_many_batch( + table="room_stats_state", + column="room_id", + iterable=room_ids, + retcols=("room_id", "room_type"), + desc="bulk_get_room_type", + ) + + # If we haven't updated `room_stats_state` with the room yet, query the + # create events directly. This should happen only rarely so we don't + # mind if we do this in a loop. + results = dict(rows) + for room_id in room_ids - results.keys(): + create_event = await self.get_create_event_for_room(room_id) + room_type = create_event.content.get(EventContentFields.ROOM_TYPE) + results[room_id] = room_type + + return results + @cached(max_entries=100000, iterable=True) async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]: """Get the current state event ids for a room based on the diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index 9dd2363adc..eb4b0a05c7 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -35,6 +35,8 @@ from synapse.api.constants import ( RoomTypes, ) from synapse.api.room_versions import RoomVersions +from synapse.events import make_event_from_dict +from synapse.events.snapshot import EventContext from synapse.handlers.sliding_sync import RoomSyncConfig, StateValues from synapse.rest import admin from synapse.rest.client import knock, login, room @@ -2791,6 +2793,72 @@ class FilterRoomsTestCase(HomeserverTestCase): self.assertEqual(filtered_room_map.keys(), {space_room_id}) + def test_filter_room_types_with_invite_remote_room(self) -> None: + """Test that we can apply a room type filter, even if we have an invite + for a remote room. + + This is a regression test. + """ + + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Create a fake remote invite and persist it. + invite_room_id = "!some:room" + invite_event = make_event_from_dict( + { + "room_id": invite_room_id, + "sender": "@user:test.serv", + "state_key": user1_id, + "depth": 1, + "origin_server_ts": 1, + "type": EventTypes.Member, + "content": {"membership": Membership.INVITE}, + "auth_events": [], + "prev_events": [], + }, + room_version=RoomVersions.V10, + ) + invite_event.internal_metadata.outlier = True + invite_event.internal_metadata.out_of_band_membership = True + + self.get_success( + self.store.maybe_store_room_on_outlier_membership( + room_id=invite_room_id, room_version=invite_event.room_version + ) + ) + context = EventContext.for_outlier(self.hs.get_storage_controllers()) + persist_controller = self.hs.get_storage_controllers().persistence + assert persist_controller is not None + self.get_success(persist_controller.persist_event(invite_event, context)) + + # Create a normal room (no room type) + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) + + after_rooms_token = self.event_sources.get_current_token() + + # Get the rooms the user should be syncing with + sync_room_map = self.get_success( + self.sliding_sync_handler.get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, + ) + ) + + filtered_room_map = self.get_success( + self.sliding_sync_handler.filter_rooms( + UserID.from_string(user1_id), + sync_room_map, + SlidingSyncConfig.SlidingSyncList.Filters( + room_types=[None, RoomTypes.SPACE], + ), + after_rooms_token, + ) + ) + + self.assertEqual(filtered_room_map.keys(), {room_id, invite_room_id}) + class SortRoomsTestCase(HomeserverTestCase): """ -- cgit 1.5.1 From 5a97bbd8958548f16461cfa3fde201f3e032d6b8 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 11 Jul 2024 14:05:38 -0500 Subject: Add `heroes` and room summary fields to Sliding Sync `/sync` (#17419) Additional room summary fields: `joined_count`, `invited_count` Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync --- changelog.d/17419.feature | 1 + synapse/handlers/sliding_sync.py | 280 +++++++++++++++++++++---------- synapse/rest/client/sync.py | 32 +++- synapse/types/handlers/__init__.py | 18 +- synapse/types/rest/client/__init__.py | 4 - tests/rest/client/test_sync.py | 304 ++++++++++++++++++++++++++++++++-- 6 files changed, 529 insertions(+), 110 deletions(-) create mode 100644 changelog.d/17419.feature (limited to 'synapse/handlers') diff --git a/changelog.d/17419.feature b/changelog.d/17419.feature new file mode 100644 index 0000000000..186a27c470 --- /dev/null +++ b/changelog.d/17419.feature @@ -0,0 +1 @@ +Populate `heroes` and room summary fields (`joined_count`, `invited_count`) in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 8e6c2fb860..e3230d28e6 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -19,7 +19,7 @@ # import logging from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set, Tuple import attr from immutabledict import immutabledict @@ -28,7 +28,9 @@ from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membe from synapse.events import EventBase from synapse.events.utils import strip_event from synapse.handlers.relations import BundledAggregations +from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.databases.main.stream import CurrentStateDeltaMembership +from synapse.storage.roommember import MemberSummary from synapse.types import ( JsonDict, PersistedEventPosition, @@ -1043,6 +1045,103 @@ class SlidingSyncHandler: reverse=True, ) + async def get_current_state_ids_at( + self, + room_id: str, + room_membership_for_user_at_to_token: _RoomMembershipForUser, + state_filter: StateFilter, + to_token: StreamToken, + ) -> StateMap[str]: + """ + Get current state IDs for the user in the room according to their membership. This + will be the current state at the time of their LEAVE/BAN, otherwise will be the + current state <= to_token. + + Args: + room_id: The room ID to fetch data for + room_membership_for_user_at_token: Membership information for the user + in the room at the time of `to_token`. + to_token: The point in the stream to sync up to. + """ + + room_state_ids: StateMap[str] + # People shouldn't see past their leave/ban event + if room_membership_for_user_at_to_token.membership in ( + Membership.LEAVE, + Membership.BAN, + ): + # TODO: `get_state_ids_at(...)` doesn't take into account the "current state" + room_state_ids = await self.storage_controllers.state.get_state_ids_at( + room_id, + stream_position=to_token.copy_and_replace( + StreamKeyType.ROOM, + room_membership_for_user_at_to_token.event_pos.to_room_stream_token(), + ), + state_filter=state_filter, + # Partially-stated rooms should have all state events except for + # remote membership events. Since we've already excluded + # partially-stated rooms unless `required_state` only has + # `["m.room.member", "$LAZY"]` for membership, we should be able to + # retrieve everything requested. When we're lazy-loading, if there + # are some remote senders in the timeline, we should also have their + # membership event because we had to auth that timeline event. Plus + # we don't want to block the whole sync waiting for this one room. + await_full_state=False, + ) + # Otherwise, we can get the latest current state in the room + else: + room_state_ids = await self.storage_controllers.state.get_current_state_ids( + room_id, + state_filter, + # Partially-stated rooms should have all state events except for + # remote membership events. Since we've already excluded + # partially-stated rooms unless `required_state` only has + # `["m.room.member", "$LAZY"]` for membership, we should be able to + # retrieve everything requested. When we're lazy-loading, if there + # are some remote senders in the timeline, we should also have their + # membership event because we had to auth that timeline event. Plus + # we don't want to block the whole sync waiting for this one room. + await_full_state=False, + ) + # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` + + return room_state_ids + + async def get_current_state_at( + self, + room_id: str, + room_membership_for_user_at_to_token: _RoomMembershipForUser, + state_filter: StateFilter, + to_token: StreamToken, + ) -> StateMap[EventBase]: + """ + Get current state for the user in the room according to their membership. This + will be the current state at the time of their LEAVE/BAN, otherwise will be the + current state <= to_token. + + Args: + room_id: The room ID to fetch data for + room_membership_for_user_at_token: Membership information for the user + in the room at the time of `to_token`. + to_token: The point in the stream to sync up to. + """ + room_state_ids = await self.get_current_state_ids_at( + room_id=room_id, + room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, + state_filter=state_filter, + to_token=to_token, + ) + + event_map = await self.store.get_events(list(room_state_ids.values())) + + state_map = {} + for key, event_id in room_state_ids.items(): + event = event_map.get(event_id) + if event: + state_map[key] = event + + return state_map + async def get_room_sync_data( self, user: UserID, @@ -1074,7 +1173,7 @@ class SlidingSyncHandler: # membership. Currently, we have to make all of these optional because # `invite`/`knock` rooms only have `stripped_state`. See # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 - timeline_events: Optional[List[EventBase]] = None + timeline_events: List[EventBase] = [] bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None limited: Optional[bool] = None prev_batch_token: Optional[StreamToken] = None @@ -1206,7 +1305,7 @@ class SlidingSyncHandler: # Figure out any stripped state events for invite/knocks. This allows the # potential joiner to identify the room. - stripped_state: Optional[List[JsonDict]] = None + stripped_state: List[JsonDict] = [] if room_membership_for_user_at_to_token.membership in ( Membership.INVITE, Membership.KNOCK, @@ -1243,6 +1342,44 @@ class SlidingSyncHandler: # updates. initial = True + # Check whether the room has a name set + name_state_ids = await self.get_current_state_ids_at( + room_id=room_id, + room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, + state_filter=StateFilter.from_types([(EventTypes.Name, "")]), + to_token=to_token, + ) + name_event_id = name_state_ids.get((EventTypes.Name, "")) + + room_membership_summary: Mapping[str, MemberSummary] + empty_membership_summary = MemberSummary([], 0) + if room_membership_for_user_at_to_token.membership in ( + Membership.LEAVE, + Membership.BAN, + ): + # TODO: Figure out how to get the membership summary for left/banned rooms + room_membership_summary = {} + else: + room_membership_summary = await self.store.get_room_summary(room_id) + # TODO: Reverse/rewind back to the `to_token` + + # `heroes` are required if the room name is not set. + # + # Note: When you're the first one on your server to be invited to a new room + # over federation, we only have access to some stripped state in + # `event.unsigned.invite_room_state` which currently doesn't include `heroes`, + # see https://github.com/matrix-org/matrix-spec/issues/380. This means that + # clients won't be able to calculate the room name when necessary and just a + # pitfall we have to deal with until that spec issue is resolved. + hero_user_ids: List[str] = [] + # TODO: Should we also check for `EventTypes.CanonicalAlias` + # (`m.room.canonical_alias`) as a fallback for the room name? see + # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153 + if name_event_id is None: + hero_user_ids = extract_heroes_from_room_summary( + room_membership_summary, me=user.to_string() + ) + # Fetch the `required_state` for the room # # No `required_state` for invite/knock rooms (just `stripped_state`) @@ -1253,13 +1390,11 @@ class SlidingSyncHandler: # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 # # Calculate the `StateFilter` based on the `required_state` for the room - room_state: Optional[StateMap[EventBase]] = None - required_room_state: Optional[StateMap[EventBase]] = None + required_state_filter = StateFilter.none() if room_membership_for_user_at_to_token.membership not in ( Membership.INVITE, Membership.KNOCK, ): - required_state_filter = StateFilter.none() # If we have a double wildcard ("*", "*") in the `required_state`, we need # to fetch all state for the room # @@ -1325,86 +1460,65 @@ class SlidingSyncHandler: required_state_filter = StateFilter.from_types(required_state_types) - # We need this base set of info for the response so let's just fetch it along - # with the `required_state` for the room - META_ROOM_STATE = [(EventTypes.Name, ""), (EventTypes.RoomAvatar, "")] + # We need this base set of info for the response so let's just fetch it along + # with the `required_state` for the room + meta_room_state = [(EventTypes.Name, ""), (EventTypes.RoomAvatar, "")] + [ + (EventTypes.Member, hero_user_id) for hero_user_id in hero_user_ids + ] + state_filter = StateFilter.all() + if required_state_filter != StateFilter.all(): state_filter = StateFilter( types=StateFilter.from_types( - chain(META_ROOM_STATE, required_state_filter.to_types()) + chain(meta_room_state, required_state_filter.to_types()) ).types, include_others=required_state_filter.include_others, ) - # We can return all of the state that was requested if this was the first - # time we've sent the room down this connection. - if initial: - # People shouldn't see past their leave/ban event - if room_membership_for_user_at_to_token.membership in ( - Membership.LEAVE, - Membership.BAN, - ): - room_state = await self.storage_controllers.state.get_state_at( - room_id, - stream_position=to_token.copy_and_replace( - StreamKeyType.ROOM, - room_membership_for_user_at_to_token.event_pos.to_room_stream_token(), - ), - state_filter=state_filter, - # Partially-stated rooms should have all state events except for - # remote membership events. Since we've already excluded - # partially-stated rooms unless `required_state` only has - # `["m.room.member", "$LAZY"]` for membership, we should be able to - # retrieve everything requested. When we're lazy-loading, if there - # are some remote senders in the timeline, we should also have their - # membership event because we had to auth that timeline event. Plus - # we don't want to block the whole sync waiting for this one room. - await_full_state=False, - ) - # Otherwise, we can get the latest current state in the room - else: - room_state = await self.storage_controllers.state.get_current_state( - room_id, - state_filter, - # Partially-stated rooms should have all state events except for - # remote membership events. Since we've already excluded - # partially-stated rooms unless `required_state` only has - # `["m.room.member", "$LAZY"]` for membership, we should be able to - # retrieve everything requested. When we're lazy-loading, if there - # are some remote senders in the timeline, we should also have their - # membership event because we had to auth that timeline event. Plus - # we don't want to block the whole sync waiting for this one room. - await_full_state=False, - ) - # TODO: Query `current_state_delta_stream` and reverse/rewind back to the `to_token` - else: - # TODO: Once we can figure out if we've sent a room down this connection before, - # we can return updates instead of the full required state. - raise NotImplementedError() + # We can return all of the state that was requested if this was the first + # time we've sent the room down this connection. + room_state: StateMap[EventBase] = {} + if initial: + room_state = await self.get_current_state_at( + room_id=room_id, + room_membership_for_user_at_to_token=room_membership_for_user_at_to_token, + state_filter=state_filter, + to_token=to_token, + ) + else: + # TODO: Once we can figure out if we've sent a room down this connection before, + # we can return updates instead of the full required state. + raise NotImplementedError() - if required_state_filter != StateFilter.none(): - required_room_state = required_state_filter.filter_state(room_state) + required_room_state: StateMap[EventBase] = {} + if required_state_filter != StateFilter.none(): + required_room_state = required_state_filter.filter_state(room_state) # Find the room name and avatar from the state room_name: Optional[str] = None + # TODO: Should we also check for `EventTypes.CanonicalAlias` + # (`m.room.canonical_alias`) as a fallback for the room name? see + # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1671260153 + name_event = room_state.get((EventTypes.Name, "")) + if name_event is not None: + room_name = name_event.content.get("name") + room_avatar: Optional[str] = None - if room_state is not None: - name_event = room_state.get((EventTypes.Name, "")) - if name_event is not None: - room_name = name_event.content.get("name") - - avatar_event = room_state.get((EventTypes.RoomAvatar, "")) - if avatar_event is not None: - room_avatar = avatar_event.content.get("url") - elif stripped_state is not None: - for event in stripped_state: - if event["type"] == EventTypes.Name: - room_name = event.get("content", {}).get("name") - elif event["type"] == EventTypes.RoomAvatar: - room_avatar = event.get("content", {}).get("url") - - # Found everything so we can stop looking - if room_name is not None and room_avatar is not None: - break + avatar_event = room_state.get((EventTypes.RoomAvatar, "")) + if avatar_event is not None: + room_avatar = avatar_event.content.get("url") + + # Assemble heroes: extract the info from the state we just fetched + heroes: List[SlidingSyncResult.RoomResult.StrippedHero] = [] + for hero_user_id in hero_user_ids: + member_event = room_state.get((EventTypes.Member, hero_user_id)) + if member_event is not None: + heroes.append( + SlidingSyncResult.RoomResult.StrippedHero( + user_id=hero_user_id, + display_name=member_event.content.get("displayname"), + avatar_url=member_event.content.get("avatar_url"), + ) + ) # Figure out the last bump event in the room last_bump_event_result = ( @@ -1423,14 +1537,11 @@ class SlidingSyncHandler: return SlidingSyncResult.RoomResult( name=room_name, avatar=room_avatar, - # TODO: Dummy value - heroes=None, + heroes=heroes, # TODO: Dummy value is_dm=False, initial=initial, - required_state=( - list(required_room_state.values()) if required_room_state else None - ), + required_state=list(required_room_state.values()), timeline_events=timeline_events, bundled_aggregations=bundled_aggregations, stripped_state=stripped_state, @@ -1438,9 +1549,12 @@ class SlidingSyncHandler: limited=limited, num_live=num_live, bump_stamp=bump_stamp, - # TODO: Dummy values - joined_count=0, - invited_count=0, + joined_count=room_membership_summary.get( + Membership.JOIN, empty_membership_summary + ).count, + invited_count=room_membership_summary.get( + Membership.INVITE, empty_membership_summary + ).count, # TODO: These are just dummy values. We could potentially just remove these # since notifications can only really be done correctly on the client anyway # (encrypted rooms). diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 94d5faf9f7..1d8cbfdf00 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -997,8 +997,21 @@ class SlidingSyncRestServlet(RestServlet): if room_result.avatar: serialized_rooms[room_id]["avatar"] = room_result.avatar - if room_result.heroes: - serialized_rooms[room_id]["heroes"] = room_result.heroes + if room_result.heroes is not None and len(room_result.heroes) > 0: + serialized_heroes = [] + for hero in room_result.heroes: + serialized_hero = { + "user_id": hero.user_id, + } + if hero.display_name is not None: + # Not a typo, just how "displayname" is spelled in the spec + serialized_hero["displayname"] = hero.display_name + + if hero.avatar_url is not None: + serialized_hero["avatar_url"] = hero.avatar_url + + serialized_heroes.append(serialized_hero) + serialized_rooms[room_id]["heroes"] = serialized_heroes # We should only include the `initial` key if it's `True` to save bandwidth. # The absense of this flag means `False`. @@ -1006,7 +1019,10 @@ class SlidingSyncRestServlet(RestServlet): serialized_rooms[room_id]["initial"] = room_result.initial # This will be omitted for invite/knock rooms with `stripped_state` - if room_result.required_state is not None: + if ( + room_result.required_state is not None + and len(room_result.required_state) > 0 + ): serialized_required_state = ( await self.event_serializer.serialize_events( room_result.required_state, @@ -1017,7 +1033,10 @@ class SlidingSyncRestServlet(RestServlet): serialized_rooms[room_id]["required_state"] = serialized_required_state # This will be omitted for invite/knock rooms with `stripped_state` - if room_result.timeline_events is not None: + if ( + room_result.timeline_events is not None + and len(room_result.timeline_events) > 0 + ): serialized_timeline = await self.event_serializer.serialize_events( room_result.timeline_events, time_now, @@ -1045,7 +1064,10 @@ class SlidingSyncRestServlet(RestServlet): serialized_rooms[room_id]["is_dm"] = room_result.is_dm # Stripped state only applies to invite/knock rooms - if room_result.stripped_state is not None: + if ( + room_result.stripped_state is not None + and len(room_result.stripped_state) > 0 + ): # TODO: `knocked_state` but that isn't specced yet. # # TODO: Instead of adding `knocked_state`, it would be good to rename diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py index a8a3a8f242..409120470a 100644 --- a/synapse/types/handlers/__init__.py +++ b/synapse/types/handlers/__init__.py @@ -200,18 +200,24 @@ class SlidingSyncResult: flag set. (same as sync v2) """ + @attr.s(slots=True, frozen=True, auto_attribs=True) + class StrippedHero: + user_id: str + display_name: Optional[str] + avatar_url: Optional[str] + name: Optional[str] avatar: Optional[str] - heroes: Optional[List[EventBase]] + heroes: Optional[List[StrippedHero]] is_dm: bool initial: bool - # Only optional because it won't be included for invite/knock rooms with `stripped_state` - required_state: Optional[List[EventBase]] - # Only optional because it won't be included for invite/knock rooms with `stripped_state` - timeline_events: Optional[List[EventBase]] + # Should be empty for invite/knock rooms with `stripped_state` + required_state: List[EventBase] + # Should be empty for invite/knock rooms with `stripped_state` + timeline_events: List[EventBase] bundled_aggregations: Optional[Dict[str, "BundledAggregations"]] # Optional because it's only relevant to invite/knock rooms - stripped_state: Optional[List[JsonDict]] + stripped_state: List[JsonDict] # Only optional because it won't be included for invite/knock rooms with `stripped_state` prev_batch: Optional[StreamToken] # Only optional because it won't be included for invite/knock rooms with `stripped_state` diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index 1e8fe76c99..dbe37bc712 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -200,9 +200,6 @@ class SlidingSyncBody(RequestBodyModel): } timeline_limit: The maximum number of timeline events to return per response. - include_heroes: Return a stripped variant of membership events (containing - `user_id` and optionally `avatar_url` and `displayname`) for the users used - to calculate the room name. filters: Filters to apply to the list before sorting. """ @@ -270,7 +267,6 @@ class SlidingSyncBody(RequestBodyModel): else: ranges: Optional[List[Tuple[conint(ge=0, strict=True), conint(ge=0, strict=True)]]] = None # type: ignore[valid-type] slow_get_all_rooms: Optional[StrictBool] = False - include_heroes: Optional[StrictBool] = False filters: Optional[Filters] = None class RoomSubscription(CommonRoomParameters): diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 304c0d4d3d..0d0bea538b 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -1813,8 +1813,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): def test_rooms_meta_when_joined(self) -> None: """ - Test that the `rooms` `name` and `avatar` (soon to test `heroes`) are included - in the response when the user is joined to the room. + Test that the `rooms` `name` and `avatar` are included in the response and + reflect the current state of the room when the user is joined to the room. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1866,11 +1866,19 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): "mxc://DUMMY_MEDIA_ID", channel.json_body["rooms"][room_id1], ) + self.assertEqual( + channel.json_body["rooms"][room_id1]["joined_count"], + 2, + ) + self.assertEqual( + channel.json_body["rooms"][room_id1]["invited_count"], + 0, + ) def test_rooms_meta_when_invited(self) -> None: """ - Test that the `rooms` `name` and `avatar` (soon to test `heroes`) are included - in the response when the user is invited to the room. + Test that the `rooms` `name` and `avatar` are included in the response and + reflect the current state of the room when the user is invited to the room. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1892,7 +1900,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): tok=user2_tok, ) - self.helper.join(room_id1, user1_id, tok=user1_tok) + # User1 is invited to the room + self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) # Update the room name after user1 has left self.helper.send_state( @@ -1938,11 +1947,19 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): "mxc://UPDATED_DUMMY_MEDIA_ID", channel.json_body["rooms"][room_id1], ) + self.assertEqual( + channel.json_body["rooms"][room_id1]["joined_count"], + 1, + ) + self.assertEqual( + channel.json_body["rooms"][room_id1]["invited_count"], + 1, + ) def test_rooms_meta_when_banned(self) -> None: """ - Test that the `rooms` `name` and `avatar` (soon to test `heroes`) reflect the - state of the room when the user was banned (do not leak current state). + Test that the `rooms` `name` and `avatar` reflect the state of the room when the + user was banned (do not leak current state). """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -2010,6 +2027,273 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): "mxc://DUMMY_MEDIA_ID", channel.json_body["rooms"][room_id1], ) + self.assertEqual( + channel.json_body["rooms"][room_id1]["joined_count"], + # FIXME: The actual number should be "1" (user2) but we currently don't + # support this for rooms where the user has left/been banned. + 0, + ) + self.assertEqual( + channel.json_body["rooms"][room_id1]["invited_count"], + 0, + ) + + def test_rooms_meta_heroes(self) -> None: + """ + Test that the `rooms` `heroes` are included in the response when the room + doesn't have a room name set. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + user3_id = self.register_user("user3", "pass") + _user3_tok = self.login(user3_id, "pass") + + room_id1 = self.helper.create_room_as( + user2_id, + tok=user2_tok, + extra_content={ + "name": "my super room", + }, + ) + self.helper.join(room_id1, user1_id, tok=user1_tok) + # User3 is invited + self.helper.invite(room_id1, src=user2_id, targ=user3_id, tok=user2_tok) + + room_id2 = self.helper.create_room_as( + user2_id, + tok=user2_tok, + extra_content={ + # No room name set so that `heroes` is populated + # + # "name": "my super room2", + }, + ) + self.helper.join(room_id2, user1_id, tok=user1_tok) + # User3 is invited + self.helper.invite(room_id2, src=user2_id, targ=user3_id, tok=user2_tok) + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Room1 has a name so we shouldn't see any `heroes` which the client would use + # the calculate the room name themselves. + self.assertEqual( + channel.json_body["rooms"][room_id1]["name"], + "my super room", + channel.json_body["rooms"][room_id1], + ) + self.assertIsNone(channel.json_body["rooms"][room_id1].get("heroes")) + self.assertEqual( + channel.json_body["rooms"][room_id1]["joined_count"], + 2, + ) + self.assertEqual( + channel.json_body["rooms"][room_id1]["invited_count"], + 1, + ) + + # Room2 doesn't have a name so we should see `heroes` populated + self.assertIsNone(channel.json_body["rooms"][room_id2].get("name")) + self.assertCountEqual( + [ + hero["user_id"] + for hero in channel.json_body["rooms"][room_id2].get("heroes", []) + ], + # Heroes shouldn't include the user themselves (we shouldn't see user1) + [user2_id, user3_id], + ) + self.assertEqual( + channel.json_body["rooms"][room_id2]["joined_count"], + 2, + ) + self.assertEqual( + channel.json_body["rooms"][room_id2]["invited_count"], + 1, + ) + + # We didn't request any state so we shouldn't see any `required_state` + self.assertIsNone(channel.json_body["rooms"][room_id1].get("required_state")) + self.assertIsNone(channel.json_body["rooms"][room_id2].get("required_state")) + + def test_rooms_meta_heroes_max(self) -> None: + """ + Test that the `rooms` `heroes` only includes the first 5 users (not including + yourself). + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + user3_id = self.register_user("user3", "pass") + user3_tok = self.login(user3_id, "pass") + user4_id = self.register_user("user4", "pass") + user4_tok = self.login(user4_id, "pass") + user5_id = self.register_user("user5", "pass") + user5_tok = self.login(user5_id, "pass") + user6_id = self.register_user("user6", "pass") + user6_tok = self.login(user6_id, "pass") + user7_id = self.register_user("user7", "pass") + user7_tok = self.login(user7_id, "pass") + + room_id1 = self.helper.create_room_as( + user2_id, + tok=user2_tok, + extra_content={ + # No room name set so that `heroes` is populated + # + # "name": "my super room", + }, + ) + self.helper.join(room_id1, user1_id, tok=user1_tok) + self.helper.join(room_id1, user3_id, tok=user3_tok) + self.helper.join(room_id1, user4_id, tok=user4_tok) + self.helper.join(room_id1, user5_id, tok=user5_tok) + self.helper.join(room_id1, user6_id, tok=user6_tok) + self.helper.join(room_id1, user7_id, tok=user7_tok) + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Room2 doesn't have a name so we should see `heroes` populated + self.assertIsNone(channel.json_body["rooms"][room_id1].get("name")) + # FIXME: Remove this basic assertion and uncomment the better assertion below + # after https://github.com/element-hq/synapse/pull/17435 merges + self.assertEqual(len(channel.json_body["rooms"][room_id1].get("heroes", [])), 5) + # self.assertCountEqual( + # [ + # hero["user_id"] + # for hero in channel.json_body["rooms"][room_id1].get("heroes", []) + # ], + # # Heroes should be the first 5 users in the room (excluding the user + # # themselves, we shouldn't see `user1`) + # [user2_id, user3_id, user4_id, user5_id, user6_id], + # ) + self.assertEqual( + channel.json_body["rooms"][room_id1]["joined_count"], + 7, + ) + self.assertEqual( + channel.json_body["rooms"][room_id1]["invited_count"], + 0, + ) + + # We didn't request any state so we shouldn't see any `required_state` + self.assertIsNone(channel.json_body["rooms"][room_id1].get("required_state")) + + def test_rooms_meta_heroes_when_banned(self) -> None: + """ + Test that the `rooms` `heroes` are included in the response when the room + doesn't have a room name set but doesn't leak information past their ban. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + user3_id = self.register_user("user3", "pass") + _user3_tok = self.login(user3_id, "pass") + user4_id = self.register_user("user4", "pass") + user4_tok = self.login(user4_id, "pass") + user5_id = self.register_user("user5", "pass") + _user5_tok = self.login(user5_id, "pass") + + room_id1 = self.helper.create_room_as( + user2_id, + tok=user2_tok, + extra_content={ + # No room name set so that `heroes` is populated + # + # "name": "my super room", + }, + ) + # User1 joins the room + self.helper.join(room_id1, user1_id, tok=user1_tok) + # User3 is invited + self.helper.invite(room_id1, src=user2_id, targ=user3_id, tok=user2_tok) + + # User1 is banned from the room + self.helper.ban(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) + + # User4 joins the room after user1 is banned + self.helper.join(room_id1, user4_id, tok=user4_tok) + # User5 is invited after user1 is banned + self.helper.invite(room_id1, src=user2_id, targ=user5_id, tok=user2_tok) + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [], + "timeline_limit": 0, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Room2 doesn't have a name so we should see `heroes` populated + self.assertIsNone(channel.json_body["rooms"][room_id1].get("name")) + self.assertCountEqual( + [ + hero["user_id"] + for hero in channel.json_body["rooms"][room_id1].get("heroes", []) + ], + # Heroes shouldn't include the user themselves (we shouldn't see user1). We + # also shouldn't see user4 since they joined after user1 was banned. + # + # FIXME: The actual result should be `[user2_id, user3_id]` but we currently + # don't support this for rooms where the user has left/been banned. + [], + ) + + self.assertEqual( + channel.json_body["rooms"][room_id1]["joined_count"], + # FIXME: The actual number should be "1" (user2) but we currently don't + # support this for rooms where the user has left/been banned. + 0, + ) + self.assertEqual( + channel.json_body["rooms"][room_id1]["invited_count"], + # We shouldn't see user5 since they were invited after user1 was banned. + # + # FIXME: The actual number should be "1" (user3) but we currently don't + # support this for rooms where the user has left/been banned. + 0, + ) def test_rooms_limited_initial_sync(self) -> None: """ @@ -3081,11 +3365,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.assertEqual(channel.code, 200, channel.json_body) # Nothing to see for this banned user in the room in the token range - self.assertEqual( - channel.json_body["rooms"][room_id1]["timeline"], - [], - channel.json_body["rooms"][room_id1]["timeline"], - ) + self.assertIsNone(channel.json_body["rooms"][room_id1].get("timeline")) # No events returned in the timeline so nothing is "live" self.assertEqual( channel.json_body["rooms"][room_id1]["num_live"], -- cgit 1.5.1 From fb66e938b26e96384af5a72c71ed7d9dec12f1a2 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 11 Jul 2024 18:19:26 -0500 Subject: Add `is_dm` room field to Sliding Sync `/sync` (#17429) Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync --- changelog.d/17429.feature | 1 + synapse/handlers/sliding_sync.py | 75 ++++++++++++++++++++++++---------------- tests/rest/client/test_sync.py | 23 ++++++++++++ 3 files changed, 70 insertions(+), 29 deletions(-) create mode 100644 changelog.d/17429.feature (limited to 'synapse/handlers') diff --git a/changelog.d/17429.feature b/changelog.d/17429.feature new file mode 100644 index 0000000000..608b75d632 --- /dev/null +++ b/changelog.d/17429.feature @@ -0,0 +1 @@ +Populate `is_dm` room field in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index e3230d28e6..904787ced3 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -291,6 +291,7 @@ class _RoomMembershipForUser: sender: The person who sent the membership event newly_joined: Whether the user newly joined the room during the given token range + is_dm: Whether this user considers this room as a direct-message (DM) room """ room_id: str @@ -299,6 +300,7 @@ class _RoomMembershipForUser: membership: str sender: Optional[str] newly_joined: bool + is_dm: bool def copy_and_replace(self, **kwds: Any) -> "_RoomMembershipForUser": return attr.evolve(self, **kwds) @@ -613,6 +615,7 @@ class SlidingSyncHandler: membership=room_for_user.membership, sender=room_for_user.sender, newly_joined=False, + is_dm=False, ) for room_for_user in room_for_user_list } @@ -652,6 +655,7 @@ class SlidingSyncHandler: # - 1c) Update room membership events to the point in time of the `to_token` # - 2) Add back newly_left rooms (> `from_token` and <= `to_token`) # - 3) Figure out which rooms are `newly_joined` + # - 4) Figure out which rooms are DM's # 1) ----------------------------------------------------- @@ -714,6 +718,7 @@ class SlidingSyncHandler: membership=first_membership_change_after_to_token.prev_membership, sender=first_membership_change_after_to_token.prev_sender, newly_joined=False, + is_dm=False, ) else: # If we can't find the previous membership event, we shouldn't @@ -809,6 +814,7 @@ class SlidingSyncHandler: membership=last_membership_change_in_from_to_range.membership, sender=last_membership_change_in_from_to_range.sender, newly_joined=False, + is_dm=False, ) # 3) Figure out `newly_joined` @@ -846,6 +852,35 @@ class SlidingSyncHandler: room_id ].copy_and_replace(newly_joined=True) + # 4) Figure out which rooms the user considers to be direct-message (DM) rooms + # + # We're using global account data (`m.direct`) instead of checking for + # `is_direct` on membership events because that property only appears for + # the invitee membership event (doesn't show up for the inviter). + # + # We're unable to take `to_token` into account for global account data since + # we only keep track of the latest account data for the user. + dm_map = await self.store.get_global_account_data_by_type_for_user( + user_id, AccountDataTypes.DIRECT + ) + + # Flatten out the map. Account data is set by the client so it needs to be + # scrutinized. + dm_room_id_set = set() + if isinstance(dm_map, dict): + for room_ids in dm_map.values(): + # Account data should be a list of room IDs. Ignore anything else + if isinstance(room_ids, list): + for room_id in room_ids: + if isinstance(room_id, str): + dm_room_id_set.add(room_id) + + # 4) Fixup + for room_id in filtered_sync_room_id_set: + filtered_sync_room_id_set[room_id] = filtered_sync_room_id_set[ + room_id + ].copy_and_replace(is_dm=room_id in dm_room_id_set) + return filtered_sync_room_id_set async def filter_rooms( @@ -869,41 +904,24 @@ class SlidingSyncHandler: A filtered dictionary of room IDs along with membership information in the room at the time of `to_token`. """ - user_id = user.to_string() - - # TODO: Apply filters - filtered_room_id_set = set(sync_room_map.keys()) # Filter for Direct-Message (DM) rooms if filters.is_dm is not None: - # We're using global account data (`m.direct`) instead of checking for - # `is_direct` on membership events because that property only appears for - # the invitee membership event (doesn't show up for the inviter). Account - # data is set by the client so it needs to be scrutinized. - # - # We're unable to take `to_token` into account for global account data since - # we only keep track of the latest account data for the user. - dm_map = await self.store.get_global_account_data_by_type_for_user( - user_id, AccountDataTypes.DIRECT - ) - - # Flatten out the map - dm_room_id_set = set() - if isinstance(dm_map, dict): - for room_ids in dm_map.values(): - # Account data should be a list of room IDs. Ignore anything else - if isinstance(room_ids, list): - for room_id in room_ids: - if isinstance(room_id, str): - dm_room_id_set.add(room_id) - if filters.is_dm: # Only DM rooms please - filtered_room_id_set = filtered_room_id_set.intersection(dm_room_id_set) + filtered_room_id_set = { + room_id + for room_id in filtered_room_id_set + if sync_room_map[room_id].is_dm + } else: # Only non-DM rooms please - filtered_room_id_set = filtered_room_id_set.difference(dm_room_id_set) + filtered_room_id_set = { + room_id + for room_id in filtered_room_id_set + if not sync_room_map[room_id].is_dm + } if filters.spaces: raise NotImplementedError() @@ -1538,8 +1556,7 @@ class SlidingSyncHandler: name=room_name, avatar=room_avatar, heroes=heroes, - # TODO: Dummy value - is_dm=False, + is_dm=room_membership_for_user_at_to_token.is_dm, initial=initial, required_state=list(required_room_state.values()), timeline_events=timeline_events, diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 0d0bea538b..4236812db5 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -1662,6 +1662,20 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): list(channel.json_body["lists"]["room-invites"]), ) + # Ensure DM's are correctly marked + self.assertDictEqual( + { + room_id: room.get("is_dm") + for room_id, room in channel.json_body["rooms"].items() + }, + { + invite_room_id: None, + room_id: None, + invited_dm_room_id: True, + joined_dm_room_id: True, + }, + ) + def test_sort_list(self) -> None: """ Test that the `lists` are sorted by `stream_ordering` @@ -1874,6 +1888,9 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): channel.json_body["rooms"][room_id1]["invited_count"], 0, ) + self.assertIsNone( + channel.json_body["rooms"][room_id1].get("is_dm"), + ) def test_rooms_meta_when_invited(self) -> None: """ @@ -1955,6 +1972,9 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): channel.json_body["rooms"][room_id1]["invited_count"], 1, ) + self.assertIsNone( + channel.json_body["rooms"][room_id1].get("is_dm"), + ) def test_rooms_meta_when_banned(self) -> None: """ @@ -2037,6 +2057,9 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): channel.json_body["rooms"][room_id1]["invited_count"], 0, ) + self.assertIsNone( + channel.json_body["rooms"][room_id1].get("is_dm"), + ) def test_rooms_meta_heroes(self) -> None: """ -- cgit 1.5.1 From ab62aa09da4a3c4444d80a9d3a899c685f1bb798 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 15 Jul 2024 04:37:10 -0500 Subject: Add room subscriptions to Sliding Sync `/sync` (#17432) Add room subscriptions to Sliding Sync `/sync` Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync Currently, you can only subscribe to rooms you have had *any* membership in before. In the future, we will allow `world_readable` rooms to be subscribed to without joining. --- changelog.d/17432.feature | 1 + synapse/handlers/sliding_sync.py | 403 ++++++++++---- tests/handlers/test_sliding_sync.py | 1045 +++++++++++++++++++++++++++++------ tests/rest/client/test_sync.py | 347 +++++++++--- tests/unittest.py | 51 ++ 5 files changed, 1489 insertions(+), 358 deletions(-) create mode 100644 changelog.d/17432.feature (limited to 'synapse/handlers') diff --git a/changelog.d/17432.feature b/changelog.d/17432.feature new file mode 100644 index 0000000000..c86f04c118 --- /dev/null +++ b/changelog.d/17432.feature @@ -0,0 +1 @@ +Add room subscriptions to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 904787ced3..be98b379eb 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -62,32 +62,79 @@ DEFAULT_BUMP_EVENT_TYPES = { } +@attr.s(slots=True, frozen=True, auto_attribs=True) +class _RoomMembershipForUser: + """ + Attributes: + room_id: The room ID of the membership event + event_id: The event ID of the membership event + event_pos: The stream position of the membership event + membership: The membership state of the user in the room + sender: The person who sent the membership event + newly_joined: Whether the user newly joined the room during the given token + range and is still joined to the room at the end of this range. + newly_left: Whether the user newly left (or kicked) the room during the given + token range and is still "leave" at the end of this range. + is_dm: Whether this user considers this room as a direct-message (DM) room + """ + + room_id: str + # Optional because state resets can affect room membership without a corresponding event. + event_id: Optional[str] + # Even during a state reset which removes the user from the room, we expect this to + # be set because `current_state_delta_stream` will note the position that the reset + # happened. + event_pos: PersistedEventPosition + # Even during a state reset which removes the user from the room, we expect this to + # be set to `LEAVE` because we can make that assumption based on the situaton (see + # `get_current_state_delta_membership_changes_for_user(...)`) + membership: str + # Optional because state resets can affect room membership without a corresponding event. + sender: Optional[str] + newly_joined: bool + newly_left: bool + is_dm: bool + + def copy_and_replace(self, **kwds: Any) -> "_RoomMembershipForUser": + return attr.evolve(self, **kwds) + + def filter_membership_for_sync( - *, membership: str, user_id: str, sender: Optional[str] + *, user_id: str, room_membership_for_user: _RoomMembershipForUser ) -> bool: """ Returns True if the membership event should be included in the sync response, otherwise False. Attributes: - membership: The membership state of the user in the room. user_id: The user ID that the membership applies to - sender: The person who sent the membership event + room_membership_for_user: Membership information for the user in the room """ - # Everything except `Membership.LEAVE` because we want everything that's *still* - # relevant to the user. There are few more things to include in the sync response - # (newly_left) but those are handled separately. + membership = room_membership_for_user.membership + sender = room_membership_for_user.sender + newly_left = room_membership_for_user.newly_left + + # We want to allow everything except rooms the user has left unless `newly_left` + # because we want everything that's *still* relevant to the user. We include + # `newly_left` rooms because the last event that the user should see is their own + # leave event. # - # This logic includes kicks (leave events where the sender is not the same user) and - # can be read as "anything that isn't a leave or a leave with a different sender". + # A leave != kick. This logic includes kicks (leave events where the sender is not + # the same user). # - # When `sender=None` and `membership=Membership.LEAVE`, it means that a state reset - # happened that removed the user from the room, or the user was the last person - # locally to leave the room which caused the server to leave the room. In both - # cases, we can just remove the rooms since they are no longer relevant to the user. - # They could still be added back later if they are `newly_left`. - return membership != Membership.LEAVE or sender not in (user_id, None) + # When `sender=None`, it means that a state reset happened that removed the user + # from the room without a corresponding leave event. We can just remove the rooms + # since they are no longer relevant to the user but will still appear if they are + # `newly_left`. + return ( + # Anything except leave events + membership != Membership.LEAVE + # Unless... + or newly_left + # Allow kicks + or (membership == Membership.LEAVE and sender not in (user_id, None)) + ) # We can't freeze this class because we want to update it in place with the @@ -281,31 +328,6 @@ class StateValues: LAZY: Final = "$LAZY" -@attr.s(slots=True, frozen=True, auto_attribs=True) -class _RoomMembershipForUser: - """ - Attributes: - event_id: The event ID of the membership event - event_pos: The stream position of the membership event - membership: The membership state of the user in the room - sender: The person who sent the membership event - newly_joined: Whether the user newly joined the room during the given token - range - is_dm: Whether this user considers this room as a direct-message (DM) room - """ - - room_id: str - event_id: Optional[str] - event_pos: PersistedEventPosition - membership: str - sender: Optional[str] - newly_joined: bool - is_dm: bool - - def copy_and_replace(self, **kwds: Any) -> "_RoomMembershipForUser": - return attr.evolve(self, **kwds) - - class SlidingSyncHandler: def __init__(self, hs: "HomeServer"): self.clock = hs.get_clock() @@ -424,18 +446,31 @@ class SlidingSyncHandler: # See https://github.com/matrix-org/matrix-doc/issues/1144 raise NotImplementedError() + # Get all of the room IDs that the user should be able to see in the sync + # response + has_lists = sync_config.lists is not None and len(sync_config.lists) > 0 + has_room_subscriptions = ( + sync_config.room_subscriptions is not None + and len(sync_config.room_subscriptions) > 0 + ) + if has_lists or has_room_subscriptions: + room_membership_for_user_map = ( + await self.get_room_membership_for_user_at_to_token( + user=sync_config.user, + to_token=to_token, + from_token=from_token, + ) + ) + # Assemble sliding window lists lists: Dict[str, SlidingSyncResult.SlidingWindowList] = {} # Keep track of the rooms that we're going to display and need to fetch more # info about relevant_room_map: Dict[str, RoomSyncConfig] = {} - if sync_config.lists: - # Get all of the room IDs that the user should be able to see in the sync - # response - sync_room_map = await self.get_sync_room_ids_for_user( - sync_config.user, - from_token=from_token, - to_token=to_token, + if has_lists and sync_config.lists is not None: + sync_room_map = await self.filter_rooms_relevant_for_sync( + user=sync_config.user, + room_membership_for_user_map=room_membership_for_user_map, ) for list_key, list_config in sync_config.lists.items(): @@ -524,7 +559,35 @@ class SlidingSyncHandler: ops=ops, ) - # TODO: if (sync_config.room_subscriptions): + # Handle room subscriptions + if has_room_subscriptions and sync_config.room_subscriptions is not None: + for room_id, room_subscription in sync_config.room_subscriptions.items(): + room_membership_for_user_at_to_token = ( + await self.check_room_subscription_allowed_for_user( + room_id=room_id, + room_membership_for_user_map=room_membership_for_user_map, + to_token=to_token, + ) + ) + + # Skip this room if the user isn't allowed to see it + if not room_membership_for_user_at_to_token: + continue + + room_membership_for_user_map[room_id] = ( + room_membership_for_user_at_to_token + ) + + # Take the superset of the `RoomSyncConfig` for each room. + # + # Update our `relevant_room_map` with the room we're going to display + # and need to fetch more info about. + room_sync_config = RoomSyncConfig.from_room_config(room_subscription) + existing_room_sync_config = relevant_room_map.get(room_id) + if existing_room_sync_config is not None: + existing_room_sync_config.combine_room_sync_config(room_sync_config) + else: + relevant_room_map[room_id] = room_sync_config # Fetch room data rooms: Dict[str, SlidingSyncResult.RoomResult] = {} @@ -533,7 +596,9 @@ class SlidingSyncHandler: user=sync_config.user, room_id=room_id, room_sync_config=room_sync_config, - room_membership_for_user_at_to_token=sync_room_map[room_id], + room_membership_for_user_at_to_token=room_membership_for_user_map[ + room_id + ], from_token=from_token, to_token=to_token, ) @@ -551,28 +616,23 @@ class SlidingSyncHandler: extensions=extensions, ) - async def get_sync_room_ids_for_user( + async def get_room_membership_for_user_at_to_token( self, user: UserID, to_token: StreamToken, - from_token: Optional[StreamToken] = None, + from_token: Optional[StreamToken], ) -> Dict[str, _RoomMembershipForUser]: """ - Fetch room IDs that should be listed for this user in the sync response (the - full room list that will be filtered, sorted, and sliced). + Fetch room IDs that the user has had membership in (the full room list including + long-lost left rooms that will be filtered, sorted, and sliced). - We're looking for rooms where the user has the following state in the token - range (> `from_token` and <= `to_token`): + We're looking for rooms where the user has had any sort of membership in the + token range (> `from_token` and <= `to_token`) - - `invite`, `join`, `knock`, `ban` membership events - - Kicks (`leave` membership events where `sender` is different from the - `user_id`/`state_key`) - - `newly_left` (rooms that were left during the given token range) - - In order for bans/kicks to not show up in sync, you need to `/forget` those - rooms. This doesn't modify the event itself though and only adds the - `forgotten` flag to the `room_memberships` table in Synapse. There isn't a way - to tell when a room was forgotten at the moment so we can't factor it into the - from/to range. + In order for bans/kicks to not show up, you need to `/forget` those rooms. This + doesn't modify the event itself though and only adds the `forgotten` flag to the + `room_memberships` table in Synapse. There isn't a way to tell when a room was + forgotten at the moment so we can't factor it into the token range. Args: user: User to fetch rooms for @@ -580,8 +640,8 @@ class SlidingSyncHandler: from_token: The point in the stream to sync from. Returns: - A dictionary of room IDs that should be listed in the sync response along - with membership information in that room at the time of `to_token`. + A dictionary of room IDs that the user has had membership in along with + membership information in that room at the time of `to_token`. """ user_id = user.to_string() @@ -592,9 +652,6 @@ class SlidingSyncHandler: # We want to fetch any kind of membership (joined and left rooms) in order # to get the `event_pos` of the latest room membership event for the # user. - # - # We will filter out the rooms that don't belong below (see - # `filter_membership_for_sync`) membership_list=Membership.LIST, excluded_rooms=self.rooms_to_exclude_globally, ) @@ -614,7 +671,9 @@ class SlidingSyncHandler: event_pos=room_for_user.event_pos, membership=room_for_user.membership, sender=room_for_user.sender, + # We will update these fields below to be accurate newly_joined=False, + newly_left=False, is_dm=False, ) for room_for_user in room_for_user_list @@ -653,12 +712,10 @@ class SlidingSyncHandler: # - 1a) Remove rooms that the user joined after the `to_token` # - 1b) Add back rooms that the user left after the `to_token` # - 1c) Update room membership events to the point in time of the `to_token` - # - 2) Add back newly_left rooms (> `from_token` and <= `to_token`) - # - 3) Figure out which rooms are `newly_joined` + # - 2) Figure out which rooms are `newly_left` rooms (> `from_token` and <= `to_token`) + # - 3) Figure out which rooms are `newly_joined` (> `from_token` and <= `to_token`) # - 4) Figure out which rooms are DM's - # 1) ----------------------------------------------------- - # 1) Fetch membership changes that fall in the range from `to_token` up to # `membership_snapshot_token` # @@ -717,7 +774,9 @@ class SlidingSyncHandler: event_pos=first_membership_change_after_to_token.prev_event_pos, membership=first_membership_change_after_to_token.prev_membership, sender=first_membership_change_after_to_token.prev_sender, + # We will update these fields below to be accurate newly_joined=False, + newly_left=False, is_dm=False, ) else: @@ -726,22 +785,6 @@ class SlidingSyncHandler: # exact membership state and shouldn't rely on the current snapshot. sync_room_id_set.pop(room_id, None) - # Filter the rooms that that we have updated room membership events to the point - # in time of the `to_token` (from the "1)" fixups) - filtered_sync_room_id_set = { - room_id: room_membership_for_user - for room_id, room_membership_for_user in sync_room_id_set.items() - if filter_membership_for_sync( - membership=room_membership_for_user.membership, - user_id=user_id, - sender=room_membership_for_user.sender, - ) - } - - # 2) ----------------------------------------------------- - # We fix-up newly_left rooms after the first fixup because it may have removed - # some left rooms that we can figure out are newly_left in the following code - # 2) Fetch membership changes that fall in the range from `from_token` up to `to_token` current_state_delta_membership_changes_in_from_to_range = [] if from_token: @@ -803,19 +846,40 @@ class SlidingSyncHandler: if last_membership_change_in_from_to_range.membership == Membership.JOIN: possibly_newly_joined_room_ids.add(room_id) - # 2) Add back newly_left rooms (> `from_token` and <= `to_token`). We - # include newly_left rooms because the last event that the user should see - # is their own leave event + # 2) Figure out newly_left rooms (> `from_token` and <= `to_token`). if last_membership_change_in_from_to_range.membership == Membership.LEAVE: - filtered_sync_room_id_set[room_id] = _RoomMembershipForUser( - room_id=room_id, - event_id=last_membership_change_in_from_to_range.event_id, - event_pos=last_membership_change_in_from_to_range.event_pos, - membership=last_membership_change_in_from_to_range.membership, - sender=last_membership_change_in_from_to_range.sender, - newly_joined=False, - is_dm=False, - ) + # 2) Mark this room as `newly_left` + + # If we're seeing a membership change here, we should expect to already + # have it in our snapshot but if a state reset happens, it wouldn't have + # shown up in our snapshot but appear as a change here. + existing_sync_entry = sync_room_id_set.get(room_id) + if existing_sync_entry is not None: + # Normal expected case + sync_room_id_set[room_id] = existing_sync_entry.copy_and_replace( + newly_left=True + ) + else: + # State reset! + logger.warn( + "State reset detected for room_id %s with %s who is no longer in the room", + room_id, + user_id, + ) + # Even though a state reset happened which removed the person from + # the room, we still add it the list so the user knows they left the + # room. Downstream code can check for a state reset by looking for + # `event_id=None and membership is not None`. + sync_room_id_set[room_id] = _RoomMembershipForUser( + room_id=room_id, + event_id=last_membership_change_in_from_to_range.event_id, + event_pos=last_membership_change_in_from_to_range.event_pos, + membership=last_membership_change_in_from_to_range.membership, + sender=last_membership_change_in_from_to_range.sender, + newly_joined=False, + newly_left=True, + is_dm=False, + ) # 3) Figure out `newly_joined` for room_id in possibly_newly_joined_room_ids: @@ -826,9 +890,9 @@ class SlidingSyncHandler: # also some non-join in the range, we know they `newly_joined`. if has_non_join_in_from_to_range: # We found a `newly_joined` room (we left and joined within the token range) - filtered_sync_room_id_set[room_id] = filtered_sync_room_id_set[ - room_id - ].copy_and_replace(newly_joined=True) + sync_room_id_set[room_id] = sync_room_id_set[room_id].copy_and_replace( + newly_joined=True + ) else: prev_event_id = first_membership_change_by_room_id_in_from_to_range[ room_id @@ -840,7 +904,7 @@ class SlidingSyncHandler: if prev_event_id is None: # We found a `newly_joined` room (we are joining the room for the # first time within the token range) - filtered_sync_room_id_set[room_id] = filtered_sync_room_id_set[ + sync_room_id_set[room_id] = sync_room_id_set[ room_id ].copy_and_replace(newly_joined=True) # Last resort, we need to step back to the previous membership event @@ -848,7 +912,7 @@ class SlidingSyncHandler: elif prev_membership != Membership.JOIN: # We found a `newly_joined` room (we left before the token range # and joined within the token range) - filtered_sync_room_id_set[room_id] = filtered_sync_room_id_set[ + sync_room_id_set[room_id] = sync_room_id_set[ room_id ].copy_and_replace(newly_joined=True) @@ -876,12 +940,122 @@ class SlidingSyncHandler: dm_room_id_set.add(room_id) # 4) Fixup - for room_id in filtered_sync_room_id_set: - filtered_sync_room_id_set[room_id] = filtered_sync_room_id_set[ - room_id - ].copy_and_replace(is_dm=room_id in dm_room_id_set) + for room_id in sync_room_id_set: + sync_room_id_set[room_id] = sync_room_id_set[room_id].copy_and_replace( + is_dm=room_id in dm_room_id_set + ) + + return sync_room_id_set + + async def filter_rooms_relevant_for_sync( + self, + user: UserID, + room_membership_for_user_map: Dict[str, _RoomMembershipForUser], + ) -> Dict[str, _RoomMembershipForUser]: + """ + Filter room IDs that should/can be listed for this user in the sync response (the + full room list that will be further filtered, sorted, and sliced). + + We're looking for rooms where the user has the following state in the token + range (> `from_token` and <= `to_token`): + + - `invite`, `join`, `knock`, `ban` membership events + - Kicks (`leave` membership events where `sender` is different from the + `user_id`/`state_key`) + - `newly_left` (rooms that were left during the given token range) + - In order for bans/kicks to not show up in sync, you need to `/forget` those + rooms. This doesn't modify the event itself though and only adds the + `forgotten` flag to the `room_memberships` table in Synapse. There isn't a way + to tell when a room was forgotten at the moment so we can't factor it into the + from/to range. + + Args: + user: User that is syncing + room_membership_for_user_map: Room membership for the user + + Returns: + A dictionary of room IDs that should be listed in the sync response along + with membership information in that room at the time of `to_token`. + """ + user_id = user.to_string() + + # Filter rooms to only what we're interested to sync with + filtered_sync_room_map = { + room_id: room_membership_for_user + for room_id, room_membership_for_user in room_membership_for_user_map.items() + if filter_membership_for_sync( + user_id=user_id, + room_membership_for_user=room_membership_for_user, + ) + } + + return filtered_sync_room_map + + async def check_room_subscription_allowed_for_user( + self, + room_id: str, + room_membership_for_user_map: Dict[str, _RoomMembershipForUser], + to_token: StreamToken, + ) -> Optional[_RoomMembershipForUser]: + """ + Check whether the user is allowed to see the room based on whether they have + ever had membership in the room or if the room is `world_readable`. - return filtered_sync_room_id_set + Similar to `check_user_in_room_or_world_readable(...)` + + Args: + room_id: Room to check + room_membership_for_user_map: Room membership for the user at the time of + the `to_token` (<= `to_token`). + to_token: The token to fetch rooms up to. + + Returns: + The room membership for the user if they are allowed to subscribe to the + room else `None`. + """ + + # We can first check if they are already allowed to see the room based + # on our previous work to assemble the `room_membership_for_user_map`. + # + # If they have had any membership in the room over time (up to the `to_token`), + # let them subscribe and see what they can. + existing_membership_for_user = room_membership_for_user_map.get(room_id) + if existing_membership_for_user is not None: + return existing_membership_for_user + + # TODO: Handle `world_readable` rooms + return None + + # If the room is `world_readable`, it doesn't matter whether they can join, + # everyone can see the room. + # not_in_room_membership_for_user = _RoomMembershipForUser( + # room_id=room_id, + # event_id=None, + # event_pos=None, + # membership=None, + # sender=None, + # newly_joined=False, + # newly_left=False, + # is_dm=False, + # ) + # room_state = await self.get_current_state_at( + # room_id=room_id, + # room_membership_for_user_at_to_token=not_in_room_membership_for_user, + # state_filter=StateFilter.from_types( + # [(EventTypes.RoomHistoryVisibility, "")] + # ), + # to_token=to_token, + # ) + + # visibility_event = room_state.get((EventTypes.RoomHistoryVisibility, "")) + # if ( + # visibility_event is not None + # and visibility_event.content.get("history_visibility") + # == HistoryVisibility.WORLD_READABLE + # ): + # return not_in_room_membership_for_user + + # return None async def filter_rooms( self, @@ -1081,7 +1255,6 @@ class SlidingSyncHandler: in the room at the time of `to_token`. to_token: The point in the stream to sync up to. """ - room_state_ids: StateMap[str] # People shouldn't see past their leave/ban event if room_membership_for_user_at_to_token.membership in ( @@ -1349,10 +1522,10 @@ class SlidingSyncHandler: stripped_state.append(strip_event(invite_or_knock_event)) # TODO: Handle state resets. For example, if we see - # `room_membership_for_user_at_to_token.membership = Membership.LEAVE` but - # `required_state` doesn't include it, we should indicate to the client that a - # state reset happened. Perhaps we should indicate this by setting `initial: - # True` and empty `required_state`. + # `room_membership_for_user_at_to_token.event_id=None and + # room_membership_for_user_at_to_token.membership is not None`, we should + # indicate to the client that a state reset happened. Perhaps we should indicate + # this by setting `initial: True` and empty `required_state`. # TODO: Since we can't determine whether we've already sent a room down this # Sliding Sync connection before (we plan to add this optimization in the diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index eb4b0a05c7..a7aa9bb8af 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -19,7 +19,7 @@ # import logging from copy import deepcopy -from typing import Optional +from typing import Dict, Optional from unittest.mock import patch from parameterized import parameterized @@ -37,12 +37,16 @@ from synapse.api.constants import ( from synapse.api.room_versions import RoomVersions from synapse.events import make_event_from_dict from synapse.events.snapshot import EventContext -from synapse.handlers.sliding_sync import RoomSyncConfig, StateValues +from synapse.handlers.sliding_sync import ( + RoomSyncConfig, + StateValues, + _RoomMembershipForUser, +) from synapse.rest import admin from synapse.rest.client import knock, login, room from synapse.server import HomeServer from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.types import JsonDict, UserID +from synapse.types import JsonDict, StreamToken, UserID from synapse.types.handlers import SlidingSyncConfig from synapse.util import Clock @@ -581,9 +585,9 @@ class RoomSyncConfigTestCase(TestCase): self._assert_room_config_equal(room_sync_config_b, expected, "A into B") -class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): +class GetRoomMembershipForUserAtToTokenTestCase(HomeserverTestCase): """ - Tests Sliding Sync handler `get_sync_room_ids_for_user()` to make sure it returns + Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it returns the correct list of rooms IDs. """ @@ -616,7 +620,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): now_token = self.event_sources.get_current_token() room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=now_token, to_token=now_token, @@ -643,7 +647,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): after_room_token = self.event_sources.get_current_token() room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room_token, to_token=after_room_token, @@ -657,9 +661,11 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): room_id_results[room_id].event_id, join_response["event_id"], ) + self.assertEqual(room_id_results[room_id].membership, Membership.JOIN) # We should be considered `newly_joined` because we joined during the token # range self.assertEqual(room_id_results[room_id].newly_joined, True) + self.assertEqual(room_id_results[room_id].newly_left, False) def test_get_already_joined_room(self) -> None: """ @@ -676,7 +682,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): after_room_token = self.event_sources.get_current_token() room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_room_token, to_token=after_room_token, @@ -690,8 +696,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): room_id_results[room_id].event_id, join_response["event_id"], ) + self.assertEqual(room_id_results[room_id].membership, Membership.JOIN) # We should *NOT* be `newly_joined` because we joined before the token range self.assertEqual(room_id_results[room_id].newly_joined, False) + self.assertEqual(room_id_results[room_id].newly_left, False) def test_get_invited_banned_knocked_room(self) -> None: """ @@ -748,7 +756,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): after_room_token = self.event_sources.get_current_token() room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room_token, to_token=after_room_token, @@ -770,19 +778,25 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): room_id_results[invited_room_id].event_id, invite_response["event_id"], ) + self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE) + self.assertEqual(room_id_results[invited_room_id].newly_joined, False) + self.assertEqual(room_id_results[invited_room_id].newly_left, False) + self.assertEqual( room_id_results[ban_room_id].event_id, ban_response["event_id"], ) + self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN) + self.assertEqual(room_id_results[ban_room_id].newly_joined, False) + self.assertEqual(room_id_results[ban_room_id].newly_left, False) + self.assertEqual( room_id_results[knock_room_id].event_id, knock_room_membership_state_event.event_id, ) - # We should *NOT* be `newly_joined` because we were not joined at the the time - # of the `to_token`. - self.assertEqual(room_id_results[invited_room_id].newly_joined, False) - self.assertEqual(room_id_results[ban_room_id].newly_joined, False) + self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK) self.assertEqual(room_id_results[knock_room_id].newly_joined, False) + self.assertEqual(room_id_results[knock_room_id].newly_left, False) def test_get_kicked_room(self) -> None: """ @@ -814,7 +828,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): after_kick_token = self.event_sources.get_current_token() room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_kick_token, to_token=after_kick_token, @@ -828,9 +842,12 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): room_id_results[kick_room_id].event_id, kick_response["event_id"], ) + self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE) + self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id) # We should *NOT* be `newly_joined` because we were not joined at the the time # of the `to_token`. self.assertEqual(room_id_results[kick_room_id].newly_joined, False) + self.assertEqual(room_id_results[kick_room_id].newly_left, False) def test_forgotten_rooms(self) -> None: """ @@ -904,7 +921,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): self.assertEqual(channel.code, 200, channel.result) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room_forgets, to_token=before_room_forgets, @@ -914,52 +931,58 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): # We shouldn't see the room because it was forgotten self.assertEqual(room_id_results.keys(), set()) - def test_only_newly_left_rooms_show_up(self) -> None: + def test_newly_left_rooms(self) -> None: """ - Test that newly_left rooms still show up in the sync response but rooms that - were left before the `from_token` don't show up. See condition "2)" comments in - the `get_sync_room_ids_for_user` method. + Test that newly_left are marked properly """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") # Leave before we calculate the `from_token` room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) - self.helper.leave(room_id1, user1_id, tok=user1_tok) + leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) after_room1_token = self.event_sources.get_current_token() # Leave during the from_token/to_token range (newly_left) room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) - _leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok) + leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok) after_room2_token = self.event_sources.get_current_token() room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_room1_token, to_token=after_room2_token, ) ) - # Only the newly_left room should show up - self.assertEqual(room_id_results.keys(), {room_id2}) - # It should be pointing to the latest membership event in the from/to range but - # the `event_id` is `None` because we left the room causing the server to leave - # the room because no other local users are in it (quirk of the - # `current_state_delta_stream` table that we source things from) + self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) + + self.assertEqual( + room_id_results[room_id1].event_id, + leave_response1["event_id"], + ) + self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + # We should *NOT* be `newly_joined` or `newly_left` because that happened before + # the from/to range + self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) + self.assertEqual( room_id_results[room_id2].event_id, - None, # _leave_response2["event_id"], + leave_response2["event_id"], ) + self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) # We should *NOT* be `newly_joined` because we are instead `newly_left` self.assertEqual(room_id_results[room_id2].newly_joined, False) + self.assertEqual(room_id_results[room_id2].newly_left, True) def test_no_joins_after_to_token(self) -> None: """ Rooms we join after the `to_token` should *not* show up. See condition "1b)" - comments in the `get_sync_room_ids_for_user()` method. + comments in the `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -978,7 +1001,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): self.helper.join(room_id2, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room1_token, to_token=after_room1_token, @@ -991,14 +1014,16 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): room_id_results[room_id1].event_id, join_response1["event_id"], ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should be `newly_joined` because we joined during the token range self.assertEqual(room_id_results[room_id1].newly_joined, True) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_join_during_range_and_left_room_after_to_token(self) -> None: """ Room still shows up if we left the room but were joined during the from_token/to_token. See condition "1a)" comments in the - `get_sync_room_ids_for_user()` method. + `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1016,7 +1041,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room1_token, to_token=after_room1_token, @@ -1038,14 +1063,16 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should be `newly_joined` because we joined during the token range self.assertEqual(room_id_results[room_id1].newly_joined, True) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_join_before_range_and_left_room_after_to_token(self) -> None: """ Room still shows up if we left the room but were joined before the `from_token` so it should show up. See condition "1a)" comments in the - `get_sync_room_ids_for_user()` method. + `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1061,7 +1088,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_room1_token, to_token=after_room1_token, @@ -1082,14 +1109,16 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should *NOT* be `newly_joined` because we joined before the token range self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_kicked_before_range_and_left_after_to_token(self) -> None: """ Room still shows up if we left the room but were kicked before the `from_token` so it should show up. See condition "1a)" comments in the - `get_sync_room_ids_for_user()` method. + `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1123,7 +1152,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): leave_response = self.helper.leave(kick_room_id, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_kick_token, to_token=after_kick_token, @@ -1146,14 +1175,17 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE) + self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id) # We should *NOT* be `newly_joined` because we were kicked self.assertEqual(room_id_results[kick_room_id].newly_joined, False) + self.assertEqual(room_id_results[kick_room_id].newly_left, False) def test_newly_left_during_range_and_join_leave_after_to_token(self) -> None: """ Newly left room should show up. But we're also testing that joining and leaving after the `to_token` doesn't mess with the results. See condition "2)" and "1a)" - comments in the `get_sync_room_ids_for_user()` method. + comments in the `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1176,7 +1208,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): leave_response2 = self.helper.leave(room_id1, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room1_token, to_token=after_room1_token, @@ -1199,14 +1231,17 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) - # We should *NOT* be `newly_joined` because we left during the token range + self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + # We should *NOT* be `newly_joined` because we are actually `newly_left` during + # the token range self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, True) def test_newly_left_during_range_and_join_after_to_token(self) -> None: """ Newly left room should show up. But we're also testing that joining after the `to_token` doesn't mess with the results. See condition "2)" and "1b)" comments - in the `get_sync_room_ids_for_user()` method. + in the `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1228,7 +1263,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room1_token, to_token=after_room1_token, @@ -1250,16 +1285,19 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) - # We should *NOT* be `newly_joined` because we left during the token range + self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + # We should *NOT* be `newly_joined` because we are actually `newly_left` during + # the token range self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, True) def test_no_from_token(self) -> None: """ - Test that if we don't provide a `from_token`, we get all the rooms that we we're - joined up to the `to_token`. + Test that if we don't provide a `from_token`, we get all the rooms that we had + membership in up to the `to_token`. - Providing `from_token` only really has the effect that it adds `newly_left` - rooms to the response. + Providing `from_token` only really has the effect that it marks rooms as + `newly_left` in the response. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1276,7 +1314,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): # Join and leave the room2 before the `to_token` self.helper.join(room_id2, user1_id, tok=user1_tok) - self.helper.leave(room_id2, user1_id, tok=user1_tok) + leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok) after_room1_token = self.event_sources.get_current_token() @@ -1284,7 +1322,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): self.helper.join(room_id2, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=None, to_token=after_room1_token, @@ -1292,15 +1330,31 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) # Only rooms we were joined to before the `to_token` should show up - self.assertEqual(room_id_results.keys(), {room_id1}) + self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) + + # Room1 # It should be pointing to the latest membership event in the from/to range self.assertEqual( room_id_results[room_id1].event_id, join_response1["event_id"], ) - # We should *NOT* be `newly_joined` because there is no `from_token` to - # define a "live" range to compare against + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + # We should *NOT* be `newly_joined`/`newly_left` because there is no + # `from_token` to define a "live" range to compare against self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) + + # Room2 + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[room_id2].event_id, + leave_response2["event_id"], + ) + self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) + # We should *NOT* be `newly_joined`/`newly_left` because there is no + # `from_token` to define a "live" range to compare against + self.assertEqual(room_id_results[room_id2].newly_joined, False) + self.assertEqual(room_id_results[room_id2].newly_left, False) def test_from_token_ahead_of_to_token(self) -> None: """ @@ -1319,28 +1373,28 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) room_id4 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - # Join room1 before `before_room_token` - join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) + # Join room1 before `to_token` + join_room1_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) - # Join and leave the room2 before `before_room_token` - self.helper.join(room_id2, user1_id, tok=user1_tok) - self.helper.leave(room_id2, user1_id, tok=user1_tok) + # Join and leave the room2 before `to_token` + _join_room2_response1 = self.helper.join(room_id2, user1_id, tok=user1_tok) + leave_room2_response1 = self.helper.leave(room_id2, user1_id, tok=user1_tok) # Note: These are purposely swapped. The `from_token` should come after # the `to_token` in this test to_token = self.event_sources.get_current_token() - # Join room2 after `before_room_token` - self.helper.join(room_id2, user1_id, tok=user1_tok) + # Join room2 after `to_token` + _join_room2_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok) # -------- - # Join room3 after `before_room_token` - self.helper.join(room_id3, user1_id, tok=user1_tok) + # Join room3 after `to_token` + _join_room3_response1 = self.helper.join(room_id3, user1_id, tok=user1_tok) - # Join and leave the room4 after `before_room_token` - self.helper.join(room_id4, user1_id, tok=user1_tok) - self.helper.leave(room_id4, user1_id, tok=user1_tok) + # Join and leave the room4 after `to_token` + _join_room4_response1 = self.helper.join(room_id4, user1_id, tok=user1_tok) + _leave_room4_response1 = self.helper.leave(room_id4, user1_id, tok=user1_tok) # Note: These are purposely swapped. The `from_token` should come after the # `to_token` in this test @@ -1350,31 +1404,59 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): self.helper.join(room_id4, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=from_token, to_token=to_token, ) ) - # Only rooms we were joined to before the `to_token` should show up - # - # There won't be any newly_left rooms because the `from_token` is ahead of the - # `to_token` and that range will give no membership changes to check. - self.assertEqual(room_id_results.keys(), {room_id1}) + # In the "current" state snapshot, we're joined to all of the rooms but in the + # from/to token range... + self.assertIncludes( + room_id_results.keys(), + { + # Included because we were joined before both tokens + room_id1, + # Included because we had membership before the to_token + room_id2, + # Excluded because we joined after the `to_token` + # room_id3, + # Excluded because we joined after the `to_token` + # room_id4, + }, + exact=True, + ) + + # Room1 # It should be pointing to the latest membership event in the from/to range self.assertEqual( room_id_results[room_id1].event_id, - join_response1["event_id"], + join_room1_response1["event_id"], ) - # We should *NOT* be `newly_joined` because we joined `room1` before either of the tokens + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + # We should *NOT* be `newly_joined`/`newly_left` because we joined `room1` + # before either of the tokens self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) + + # Room2 + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[room_id2].event_id, + leave_room2_response1["event_id"], + ) + self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) + # We should *NOT* be `newly_joined`/`newly_left` because we joined and left + # `room1` before either of the tokens + self.assertEqual(room_id_results[room_id2].newly_joined, False) + self.assertEqual(room_id_results[room_id2].newly_left, False) def test_leave_before_range_and_join_leave_after_to_token(self) -> None: """ - Old left room shouldn't show up. But we're also testing that joining and leaving - after the `to_token` doesn't mess with the results. See condition "1a)" comments - in the `get_sync_room_ids_for_user()` method. + Test old left rooms. But we're also testing that joining and leaving after the + `to_token` doesn't mess with the results. See condition "1a)" comments in the + `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1386,7 +1468,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) # Join and leave the room before the from/to range self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.leave(room_id1, user1_id, tok=user1_tok) + leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) after_room1_token = self.event_sources.get_current_token() @@ -1395,21 +1477,30 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): self.helper.leave(room_id1, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_room1_token, to_token=after_room1_token, ) ) - # Room shouldn't show up because it was left before the `from_token` - self.assertEqual(room_id_results.keys(), set()) + self.assertEqual(room_id_results.keys(), {room_id1}) + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[room_id1].event_id, + leave_response["event_id"], + ) + self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + # We should *NOT* be `newly_joined`/`newly_left` because we joined and left + # `room1` before either of the tokens + self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_leave_before_range_and_join_after_to_token(self) -> None: """ - Old left room shouldn't show up. But we're also testing that joining after the - `to_token` doesn't mess with the results. See condition "1b)" comments in the - `get_sync_room_ids_for_user()` method. + Test old left room. But we're also testing that joining after the `to_token` + doesn't mess with the results. See condition "1b)" comments in the + `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1421,7 +1512,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) # Join and leave the room before the from/to range self.helper.join(room_id1, user1_id, tok=user1_tok) - self.helper.leave(room_id1, user1_id, tok=user1_tok) + leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) after_room1_token = self.event_sources.get_current_token() @@ -1429,24 +1520,32 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): self.helper.join(room_id1, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_room1_token, to_token=after_room1_token, ) ) - # Room shouldn't show up because it was left before the `from_token` - self.assertEqual(room_id_results.keys(), set()) + self.assertEqual(room_id_results.keys(), {room_id1}) + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[room_id1].event_id, + leave_response["event_id"], + ) + self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + # We should *NOT* be `newly_joined`/`newly_left` because we joined and left + # `room1` before either of the tokens + self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_join_leave_multiple_times_during_range_and_after_to_token( self, ) -> None: """ Join and leave multiple times shouldn't affect rooms from showing up. It just - matters that we were joined or newly_left in the from/to range. But we're also - testing that joining and leaving after the `to_token` doesn't mess with the - results. + matters that we had membership in the from/to range. But we're also testing that + joining and leaving after the `to_token` doesn't mess with the results. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1458,7 +1557,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): # We create the room with user2 so the room isn't left with no members when we # leave and can still re-join. room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) - # Join, leave, join back to the room before the from/to range + # Join, leave, join back to the room during the from/to range join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) leave_response1 = self.helper.leave(room_id1, user1_id, tok=user1_tok) join_response2 = self.helper.join(room_id1, user1_id, tok=user1_tok) @@ -1471,7 +1570,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room1_token, to_token=after_room1_token, @@ -1496,15 +1595,19 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should be `newly_joined` because we joined during the token range self.assertEqual(room_id_results[room_id1].newly_joined, True) + # We should *NOT* be `newly_left` because we joined during the token range and + # was still joined at the end of the range + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_join_leave_multiple_times_before_range_and_after_to_token( self, ) -> None: """ Join and leave multiple times before the from/to range shouldn't affect rooms - from showing up. It just matters that we were joined or newly_left in the + from showing up. It just matters that we had membership in the from/to range. But we're also testing that joining and leaving after the `to_token` doesn't mess with the results. """ @@ -1529,7 +1632,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): leave_response3 = self.helper.leave(room_id1, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_room1_token, to_token=after_room1_token, @@ -1554,8 +1657,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should *NOT* be `newly_joined` because we joined before the token range self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_invite_before_range_and_join_leave_after_to_token( self, @@ -1563,7 +1668,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): """ Make it look like we joined after the token range but we were invited before the from/to range so the room should still show up. See condition "1a)" comments in - the `get_sync_room_ids_for_user()` method. + the `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1586,7 +1691,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_room1_token, to_token=after_room1_token, @@ -1608,9 +1713,11 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[room_id1].membership, Membership.INVITE) # We should *NOT* be `newly_joined` because we were only invited before the # token range self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_join_and_display_name_changes_in_token_range( self, @@ -1658,7 +1765,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room1_token, to_token=after_room1_token, @@ -1684,8 +1791,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should be `newly_joined` because we joined during the token range self.assertEqual(room_id_results[room_id1].newly_joined, True) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_display_name_changes_in_token_range( self, @@ -1721,7 +1830,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): after_change1_token = self.event_sources.get_current_token() room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_room1_token, to_token=after_change1_token, @@ -1744,8 +1853,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should *NOT* be `newly_joined` because we joined before the token range self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_display_name_changes_before_and_after_token_range( self, @@ -1791,7 +1902,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_room1_token, to_token=after_room1_token, @@ -1817,8 +1928,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should *NOT* be `newly_joined` because we joined before the token range self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_display_name_changes_leave_after_token_range( self, @@ -1828,7 +1941,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): if there are multiple `join` membership events in a row indicating `displayname`/`avatar_url` updates and we leave after the `to_token`. - See condition "1a)" comments in the `get_sync_room_ids_for_user()` method. + See condition "1a)" comments in the `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1871,7 +1984,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): self.helper.leave(room_id1, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room1_token, to_token=after_room1_token, @@ -1897,8 +2010,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should be `newly_joined` because we joined during the token range self.assertEqual(room_id_results[room_id1].newly_joined, True) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_display_name_changes_join_after_token_range( self, @@ -1908,7 +2023,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): indicating `displayname`/`avatar_url` updates doesn't affect the results (we joined after the token range so it shouldn't show up) - See condition "1b)" comments in the `get_sync_room_ids_for_user()` method. + See condition "1b)" comments in the `get_room_membership_for_user_at_to_token()` method. """ user1_id = self.register_user("user1", "pass") user1_tok = self.login(user1_id, "pass") @@ -1937,7 +2052,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): ) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room1_token, to_token=after_room1_token, @@ -1973,7 +2088,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): after_more_changes_token = self.event_sources.get_current_token() room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=after_room1_token, to_token=after_more_changes_token, @@ -1987,9 +2102,11 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): room_id_results[room_id1].event_id, join_response2["event_id"], ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should be considered `newly_joined` because there is some non-join event in # between our latest join event. self.assertEqual(room_id_results[room_id1].newly_joined, True) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_newly_joined_only_joins_during_token_range( self, @@ -2036,7 +2153,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): after_room1_token = self.event_sources.get_current_token() room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room1_token, to_token=after_room1_token, @@ -2062,8 +2179,10 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): } ), ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) # We should be `newly_joined` because we first joined during the token range self.assertEqual(room_id_results[room_id1].newly_joined, True) + self.assertEqual(room_id_results[room_id1].newly_left, False) def test_multiple_rooms_are_not_confused( self, @@ -2086,16 +2205,18 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): # Invited and left the room before the token self.helper.invite(room_id1, src=user2_id, targ=user1_id, tok=user2_tok) - self.helper.leave(room_id1, user1_id, tok=user1_tok) + leave_room1_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) # Invited to room2 - self.helper.invite(room_id2, src=user2_id, targ=user1_id, tok=user2_tok) + invite_room2_response = self.helper.invite( + room_id2, src=user2_id, targ=user1_id, tok=user2_tok + ) before_room3_token = self.event_sources.get_current_token() # Invited and left room3 during the from/to range room_id3 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=True) self.helper.invite(room_id3, src=user2_id, targ=user1_id, tok=user2_tok) - self.helper.leave(room_id3, user1_id, tok=user1_tok) + leave_room3_response = self.helper.leave(room_id3, user1_id, tok=user1_tok) after_room3_token = self.event_sources.get_current_token() @@ -2108,7 +2229,7 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): self.helper.leave(room_id3, user1_id, tok=user1_tok) room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_room3_token, to_token=after_room3_token, @@ -2118,19 +2239,158 @@ class GetSyncRoomIdsForUserTestCase(HomeserverTestCase): self.assertEqual( room_id_results.keys(), { - # `room_id1` shouldn't show up because we left before the from/to range - # - # Room should show up because we were invited before the from/to range + # Left before the from/to range + room_id1, + # Invited before the from/to range room_id2, - # Room should show up because it was newly_left during the from/to range + # `newly_left` during the from/to range room_id3, }, ) + # Room1 + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[room_id1].event_id, + leave_room1_response["event_id"], + ) + self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + # We should *NOT* be `newly_joined`/`newly_left` because we were invited and left + # before the token range + self.assertEqual(room_id_results[room_id1].newly_joined, False) + self.assertEqual(room_id_results[room_id1].newly_left, False) + + # Room2 + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[room_id2].event_id, + invite_room2_response["event_id"], + ) + self.assertEqual(room_id_results[room_id2].membership, Membership.INVITE) + # We should *NOT* be `newly_joined`/`newly_left` because we were invited before + # the token range + self.assertEqual(room_id_results[room_id2].newly_joined, False) + self.assertEqual(room_id_results[room_id2].newly_left, False) + + # Room3 + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[room_id3].event_id, + leave_room3_response["event_id"], + ) + self.assertEqual(room_id_results[room_id3].membership, Membership.LEAVE) + # We should be `newly_left` because we were invited and left during + # the token range + self.assertEqual(room_id_results[room_id3].newly_joined, False) + self.assertEqual(room_id_results[room_id3].newly_left, True) + + def test_state_reset(self) -> None: + """ + Test a state reset scenario where the user gets removed from the room (when + there is no corresponding leave event) + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # The room where the state reset will happen + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) + + # Join another room so we don't hit the short-circuit and return early if they + # have no room membership + room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id2, user1_id, tok=user1_tok) + + before_reset_token = self.event_sources.get_current_token() + + # Send another state event to make a position for the state reset to happen at + dummy_state_response = self.helper.send_state( + room_id1, + event_type="foobarbaz", + state_key="", + body={"foo": "bar"}, + tok=user2_tok, + ) + dummy_state_pos = self.get_success( + self.store.get_position_for_event(dummy_state_response["event_id"]) + ) + + # Mock a state reset removing the membership for user1 in the current state + self.get_success( + self.store.db_pool.simple_delete( + table="current_state_events", + keyvalues={ + "room_id": room_id1, + "type": EventTypes.Member, + "state_key": user1_id, + }, + desc="state reset user in current_state_events", + ) + ) + self.get_success( + self.store.db_pool.simple_delete( + table="local_current_membership", + keyvalues={ + "room_id": room_id1, + "user_id": user1_id, + }, + desc="state reset user in local_current_membership", + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + table="current_state_delta_stream", + values={ + "stream_id": dummy_state_pos.stream, + "room_id": room_id1, + "type": EventTypes.Member, + "state_key": user1_id, + "event_id": None, + "prev_event_id": join_response1["event_id"], + "instance_name": dummy_state_pos.instance_name, + }, + desc="state reset user in current_state_delta_stream", + ) + ) + + # Manually bust the cache since we we're just manually messing with the database + # and not causing an actual state reset. + self.store._membership_stream_cache.entity_has_changed( + user1_id, dummy_state_pos.stream + ) + + after_reset_token = self.event_sources.get_current_token() + + # The function under test + room_id_results = self.get_success( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( + UserID.from_string(user1_id), + from_token=before_reset_token, + to_token=after_reset_token, + ) + ) -class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase): + # Room1 should show up because it was `newly_left` via state reset during the from/to range + self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) + # It should be pointing to no event because we were removed from the room + # without a corresponding leave event + self.assertEqual( + room_id_results[room_id1].event_id, + None, + ) + # State reset caused us to leave the room and there is no corresponding leave event + self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + # We should *NOT* be `newly_joined` because we joined before the token range + self.assertEqual(room_id_results[room_id1].newly_joined, False) + # We should be `newly_left` because we were removed via state reset during the from/to range + self.assertEqual(room_id_results[room_id1].newly_left, True) + + +class GetRoomMembershipForUserAtToTokenShardTestCase(BaseMultiWorkerStreamTestCase): """ - Tests Sliding Sync handler `get_sync_room_ids_for_user()` to make sure it works with + Tests Sliding Sync handler `get_room_membership_for_user_at_to_token()` to make sure it works with sharded event stream_writers enabled """ @@ -2189,7 +2449,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase): We then send some events to advance the stream positions of worker1 and worker3 but worker2 is lagging behind because it's stuck. We are specifically testing - that `get_sync_room_ids_for_user(from_token=xxx, to_token=xxx)` should work + that `get_room_membership_for_user_at_to_token(from_token=xxx, to_token=xxx)` should work correctly in these adverse conditions. """ user1_id = self.register_user("user1", "pass") @@ -2228,7 +2488,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase): join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) join_response2 = self.helper.join(room_id2, user1_id, tok=user1_tok) # Leave room2 - self.helper.leave(room_id2, user1_id, tok=user1_tok) + leave_room2_response = self.helper.leave(room_id2, user1_id, tok=user1_tok) join_response3 = self.helper.join(room_id3, user1_id, tok=user1_tok) # Leave room3 self.helper.leave(room_id3, user1_id, tok=user1_tok) @@ -2265,7 +2525,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase): # For room_id1/worker1: leave and join the room to advance the stream position # and generate membership changes. self.helper.leave(room_id1, user1_id, tok=user1_tok) - self.helper.join(room_id1, user1_id, tok=user1_tok) + join_room1_response = self.helper.join(room_id1, user1_id, tok=user1_tok) # For room_id2/worker2: which is currently stuck, join the room. join_on_worker2_response = self.helper.join(room_id2, user1_id, tok=user1_tok) # For room_id3/worker3: leave and join the room to advance the stream position @@ -2319,7 +2579,7 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase): # The function under test room_id_results = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( UserID.from_string(user1_id), from_token=before_stuck_activity_token, to_token=stuck_activity_token, @@ -2330,18 +2590,411 @@ class GetSyncRoomIdsForUserEventShardTestCase(BaseMultiWorkerStreamTestCase): room_id_results.keys(), { room_id1, - # room_id2 shouldn't show up because we left before the from/to range - # and the join event during the range happened while worker2 was stuck. - # This means that from the perspective of the master, where the - # `stuck_activity_token` is generated, the stream position for worker2 - # wasn't advanced to the join yet. Looking at the `instance_map`, the - # join technically comes after `stuck_activity_token``. - # - # room_id2, + room_id2, room_id3, }, ) + # Room1 + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[room_id1].event_id, + join_room1_response["event_id"], + ) + self.assertEqual(room_id_results[room_id1].membership, Membership.JOIN) + # We should be `newly_joined` because we joined during the token range + self.assertEqual(room_id_results[room_id1].newly_joined, True) + self.assertEqual(room_id_results[room_id1].newly_left, False) + + # Room2 + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[room_id2].event_id, + leave_room2_response["event_id"], + ) + self.assertEqual(room_id_results[room_id2].membership, Membership.LEAVE) + # room_id2 should *NOT* be considered `newly_left` because we left before the + # from/to range and the join event during the range happened while worker2 was + # stuck. This means that from the perspective of the master, where the + # `stuck_activity_token` is generated, the stream position for worker2 wasn't + # advanced to the join yet. Looking at the `instance_map`, the join technically + # comes after `stuck_activity_token`. + self.assertEqual(room_id_results[room_id2].newly_joined, False) + self.assertEqual(room_id_results[room_id2].newly_left, False) + + # Room3 + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[room_id3].event_id, + join_on_worker3_response["event_id"], + ) + self.assertEqual(room_id_results[room_id3].membership, Membership.JOIN) + # We should be `newly_joined` because we joined during the token range + self.assertEqual(room_id_results[room_id3].newly_joined, True) + self.assertEqual(room_id_results[room_id3].newly_left, False) + + +class FilterRoomsRelevantForSyncTestCase(HomeserverTestCase): + """ + Tests Sliding Sync handler `filter_rooms_relevant_for_sync()` to make sure it returns + the correct list of rooms IDs. + """ + + servlets = [ + admin.register_servlets, + knock.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable sliding sync + config["experimental_features"] = {"msc3575_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.sliding_sync_handler = self.hs.get_sliding_sync_handler() + self.store = self.hs.get_datastores().main + self.event_sources = hs.get_event_sources() + self.storage_controllers = hs.get_storage_controllers() + + def _get_sync_room_ids_for_user( + self, + user: UserID, + to_token: StreamToken, + from_token: Optional[StreamToken], + ) -> Dict[str, _RoomMembershipForUser]: + """ + Get the rooms the user should be syncing with + """ + room_membership_for_user_map = self.get_success( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( + user=user, + from_token=from_token, + to_token=to_token, + ) + ) + filtered_sync_room_map = self.get_success( + self.sliding_sync_handler.filter_rooms_relevant_for_sync( + user=user, + room_membership_for_user_map=room_membership_for_user_map, + ) + ) + + return filtered_sync_room_map + + def test_no_rooms(self) -> None: + """ + Test when the user has never joined any rooms before + """ + user1_id = self.register_user("user1", "pass") + # user1_tok = self.login(user1_id, "pass") + + now_token = self.event_sources.get_current_token() + + room_id_results = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=now_token, + to_token=now_token, + ) + + self.assertEqual(room_id_results.keys(), set()) + + def test_basic_rooms(self) -> None: + """ + Test that rooms that the user is joined to, invited to, banned from, and knocked + on show up. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + before_room_token = self.event_sources.get_current_token() + + join_room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + join_response = self.helper.join(join_room_id, user1_id, tok=user1_tok) + + # Setup the invited room (user2 invites user1 to the room) + invited_room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + invite_response = self.helper.invite( + invited_room_id, targ=user1_id, tok=user2_tok + ) + + # Setup the ban room (user2 bans user1 from the room) + ban_room_id = self.helper.create_room_as( + user2_id, tok=user2_tok, is_public=True + ) + self.helper.join(ban_room_id, user1_id, tok=user1_tok) + ban_response = self.helper.ban( + ban_room_id, src=user2_id, targ=user1_id, tok=user2_tok + ) + + # Setup the knock room (user1 knocks on the room) + knock_room_id = self.helper.create_room_as( + user2_id, tok=user2_tok, room_version=RoomVersions.V7.identifier + ) + self.helper.send_state( + knock_room_id, + EventTypes.JoinRules, + {"join_rule": JoinRules.KNOCK}, + tok=user2_tok, + ) + # User1 knocks on the room + knock_channel = self.make_request( + "POST", + "/_matrix/client/r0/knock/%s" % (knock_room_id,), + b"{}", + user1_tok, + ) + self.assertEqual(knock_channel.code, 200, knock_channel.result) + knock_room_membership_state_event = self.get_success( + self.storage_controllers.state.get_current_state_event( + knock_room_id, EventTypes.Member, user1_id + ) + ) + assert knock_room_membership_state_event is not None + + after_room_token = self.event_sources.get_current_token() + + room_id_results = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=before_room_token, + to_token=after_room_token, + ) + + # Ensure that the invited, ban, and knock rooms show up + self.assertEqual( + room_id_results.keys(), + { + join_room_id, + invited_room_id, + ban_room_id, + knock_room_id, + }, + ) + # It should be pointing to the the respective membership event (latest + # membership event in the from/to range) + self.assertEqual( + room_id_results[join_room_id].event_id, + join_response["event_id"], + ) + self.assertEqual(room_id_results[join_room_id].membership, Membership.JOIN) + self.assertEqual(room_id_results[join_room_id].newly_joined, True) + self.assertEqual(room_id_results[join_room_id].newly_left, False) + + self.assertEqual( + room_id_results[invited_room_id].event_id, + invite_response["event_id"], + ) + self.assertEqual(room_id_results[invited_room_id].membership, Membership.INVITE) + self.assertEqual(room_id_results[invited_room_id].newly_joined, False) + self.assertEqual(room_id_results[invited_room_id].newly_left, False) + + self.assertEqual( + room_id_results[ban_room_id].event_id, + ban_response["event_id"], + ) + self.assertEqual(room_id_results[ban_room_id].membership, Membership.BAN) + self.assertEqual(room_id_results[ban_room_id].newly_joined, False) + self.assertEqual(room_id_results[ban_room_id].newly_left, False) + + self.assertEqual( + room_id_results[knock_room_id].event_id, + knock_room_membership_state_event.event_id, + ) + self.assertEqual(room_id_results[knock_room_id].membership, Membership.KNOCK) + self.assertEqual(room_id_results[knock_room_id].newly_joined, False) + self.assertEqual(room_id_results[knock_room_id].newly_left, False) + + def test_only_newly_left_rooms_show_up(self) -> None: + """ + Test that `newly_left` rooms still show up in the sync response but rooms that + were left before the `from_token` don't show up. See condition "2)" comments in + the `get_room_membership_for_user_at_to_token()` method. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Leave before we calculate the `from_token` + room_id1 = self.helper.create_room_as(user1_id, tok=user1_tok) + self.helper.leave(room_id1, user1_id, tok=user1_tok) + + after_room1_token = self.event_sources.get_current_token() + + # Leave during the from_token/to_token range (newly_left) + room_id2 = self.helper.create_room_as(user1_id, tok=user1_tok) + _leave_response2 = self.helper.leave(room_id2, user1_id, tok=user1_tok) + + after_room2_token = self.event_sources.get_current_token() + + room_id_results = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=after_room1_token, + to_token=after_room2_token, + ) + + # Only the `newly_left` room should show up + self.assertEqual(room_id_results.keys(), {room_id2}) + self.assertEqual( + room_id_results[room_id2].event_id, + _leave_response2["event_id"], + ) + # We should *NOT* be `newly_joined` because we are instead `newly_left` + self.assertEqual(room_id_results[room_id2].newly_joined, False) + self.assertEqual(room_id_results[room_id2].newly_left, True) + + def test_get_kicked_room(self) -> None: + """ + Test that a room that the user was kicked from still shows up. When the user + comes back to their client, they should see that they were kicked. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Setup the kick room (user2 kicks user1 from the room) + kick_room_id = self.helper.create_room_as( + user2_id, tok=user2_tok, is_public=True + ) + self.helper.join(kick_room_id, user1_id, tok=user1_tok) + # Kick user1 from the room + kick_response = self.helper.change_membership( + room=kick_room_id, + src=user2_id, + targ=user1_id, + tok=user2_tok, + membership=Membership.LEAVE, + extra_data={ + "reason": "Bad manners", + }, + ) + + after_kick_token = self.event_sources.get_current_token() + + room_id_results = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=after_kick_token, + to_token=after_kick_token, + ) + + # The kicked room should show up + self.assertEqual(room_id_results.keys(), {kick_room_id}) + # It should be pointing to the latest membership event in the from/to range + self.assertEqual( + room_id_results[kick_room_id].event_id, + kick_response["event_id"], + ) + self.assertEqual(room_id_results[kick_room_id].membership, Membership.LEAVE) + self.assertNotEqual(room_id_results[kick_room_id].sender, user1_id) + # We should *NOT* be `newly_joined` because we were not joined at the the time + # of the `to_token`. + self.assertEqual(room_id_results[kick_room_id].newly_joined, False) + self.assertEqual(room_id_results[kick_room_id].newly_left, False) + + def test_state_reset(self) -> None: + """ + Test a state reset scenario where the user gets removed from the room (when + there is no corresponding leave event) + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # The room where the state reset will happen + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + join_response1 = self.helper.join(room_id1, user1_id, tok=user1_tok) + + # Join another room so we don't hit the short-circuit and return early if they + # have no room membership + room_id2 = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id2, user1_id, tok=user1_tok) + + before_reset_token = self.event_sources.get_current_token() + + # Send another state event to make a position for the state reset to happen at + dummy_state_response = self.helper.send_state( + room_id1, + event_type="foobarbaz", + state_key="", + body={"foo": "bar"}, + tok=user2_tok, + ) + dummy_state_pos = self.get_success( + self.store.get_position_for_event(dummy_state_response["event_id"]) + ) + + # Mock a state reset removing the membership for user1 in the current state + self.get_success( + self.store.db_pool.simple_delete( + table="current_state_events", + keyvalues={ + "room_id": room_id1, + "type": EventTypes.Member, + "state_key": user1_id, + }, + desc="state reset user in current_state_events", + ) + ) + self.get_success( + self.store.db_pool.simple_delete( + table="local_current_membership", + keyvalues={ + "room_id": room_id1, + "user_id": user1_id, + }, + desc="state reset user in local_current_membership", + ) + ) + self.get_success( + self.store.db_pool.simple_insert( + table="current_state_delta_stream", + values={ + "stream_id": dummy_state_pos.stream, + "room_id": room_id1, + "type": EventTypes.Member, + "state_key": user1_id, + "event_id": None, + "prev_event_id": join_response1["event_id"], + "instance_name": dummy_state_pos.instance_name, + }, + desc="state reset user in current_state_delta_stream", + ) + ) + + # Manually bust the cache since we we're just manually messing with the database + # and not causing an actual state reset. + self.store._membership_stream_cache.entity_has_changed( + user1_id, dummy_state_pos.stream + ) + + after_reset_token = self.event_sources.get_current_token() + + # The function under test + room_id_results = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=before_reset_token, + to_token=after_reset_token, + ) + + # Room1 should show up because it was `newly_left` via state reset during the from/to range + self.assertEqual(room_id_results.keys(), {room_id1, room_id2}) + # It should be pointing to no event because we were removed from the room + # without a corresponding leave event + self.assertEqual( + room_id_results[room_id1].event_id, + None, + ) + # State reset caused us to leave the room and there is no corresponding leave event + self.assertEqual(room_id_results[room_id1].membership, Membership.LEAVE) + # We should *NOT* be `newly_joined` because we joined before the token range + self.assertEqual(room_id_results[room_id1].newly_joined, False) + # We should be `newly_left` because we were removed via state reset during the from/to range + self.assertEqual(room_id_results[room_id1].newly_left, True) + class FilterRoomsTestCase(HomeserverTestCase): """ @@ -2367,6 +3020,31 @@ class FilterRoomsTestCase(HomeserverTestCase): self.store = self.hs.get_datastores().main self.event_sources = hs.get_event_sources() + def _get_sync_room_ids_for_user( + self, + user: UserID, + to_token: StreamToken, + from_token: Optional[StreamToken], + ) -> Dict[str, _RoomMembershipForUser]: + """ + Get the rooms the user should be syncing with + """ + room_membership_for_user_map = self.get_success( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( + user=user, + from_token=from_token, + to_token=to_token, + ) + ) + filtered_sync_room_map = self.get_success( + self.sliding_sync_handler.filter_rooms_relevant_for_sync( + user=user, + room_membership_for_user_map=room_membership_for_user_map, + ) + ) + + return filtered_sync_room_map + def _create_dm_room( self, inviter_user_id: str, @@ -2438,12 +3116,10 @@ class FilterRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() # Get the rooms the user should be syncing with - sync_room_map = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) + sync_room_map = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, ) # Try with `is_dm=True` @@ -2496,12 +3172,10 @@ class FilterRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() # Get the rooms the user should be syncing with - sync_room_map = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) + sync_room_map = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, ) # Try with `is_encrypted=True` @@ -2552,12 +3226,10 @@ class FilterRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() # Get the rooms the user should be syncing with - sync_room_map = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) + sync_room_map = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, ) # Try with `is_invite=True` @@ -2621,12 +3293,10 @@ class FilterRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() # Get the rooms the user should be syncing with - sync_room_map = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) + sync_room_map = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, ) # Try finding only normal rooms @@ -2714,12 +3384,10 @@ class FilterRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() # Get the rooms the user should be syncing with - sync_room_map = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) + sync_room_map = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, ) # Try finding *NOT* normal rooms @@ -2838,12 +3506,10 @@ class FilterRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() # Get the rooms the user should be syncing with - sync_room_map = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) + sync_room_map = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, ) filtered_room_map = self.get_success( @@ -2884,6 +3550,31 @@ class SortRoomsTestCase(HomeserverTestCase): self.store = self.hs.get_datastores().main self.event_sources = hs.get_event_sources() + def _get_sync_room_ids_for_user( + self, + user: UserID, + to_token: StreamToken, + from_token: Optional[StreamToken], + ) -> Dict[str, _RoomMembershipForUser]: + """ + Get the rooms the user should be syncing with + """ + room_membership_for_user_map = self.get_success( + self.sliding_sync_handler.get_room_membership_for_user_at_to_token( + user=user, + from_token=from_token, + to_token=to_token, + ) + ) + filtered_sync_room_map = self.get_success( + self.sliding_sync_handler.filter_rooms_relevant_for_sync( + user=user, + room_membership_for_user_map=room_membership_for_user_map, + ) + ) + + return filtered_sync_room_map + def test_sort_activity_basic(self) -> None: """ Rooms with newer activity are sorted first. @@ -2903,12 +3594,10 @@ class SortRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() # Get the rooms the user should be syncing with - sync_room_map = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) + sync_room_map = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, ) # Sort the rooms (what we're testing) @@ -2986,12 +3675,10 @@ class SortRoomsTestCase(HomeserverTestCase): self.helper.send(room_id3, "activity in room3", tok=user2_tok) # Get the rooms the user should be syncing with - sync_room_map = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=before_rooms_token, - to_token=after_rooms_token, - ) + sync_room_map = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=before_rooms_token, + to_token=after_rooms_token, ) # Sort the rooms (what we're testing) @@ -3052,12 +3739,10 @@ class SortRoomsTestCase(HomeserverTestCase): after_rooms_token = self.event_sources.get_current_token() # Get the rooms the user should be syncing with - sync_room_map = self.get_success( - self.sliding_sync_handler.get_sync_room_ids_for_user( - UserID.from_string(user1_id), - from_token=None, - to_token=after_rooms_token, - ) + sync_room_map = self._get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, ) # Sort the rooms (what we're testing) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 4236812db5..f5d57e689c 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -20,7 +20,8 @@ # import json import logging -from typing import AbstractSet, Any, Dict, Iterable, List, Optional +from http import HTTPStatus +from typing import Any, Dict, Iterable, List from parameterized import parameterized, parameterized_class @@ -1259,7 +1260,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): exact: bool = False, ) -> None: """ - Wrapper around `_assertIncludes` to give slightly better looking diff error + Wrapper around `assertIncludes` to give slightly better looking diff error messages that include some context "$event_id (type, state_key)". Args: @@ -1275,7 +1276,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): for event in actual_required_state: assert isinstance(event, dict) - self._assertIncludes( + self.assertIncludes( { f'{event["event_id"]} ("{event["type"]}", "{event["state_key"]}")' for event in actual_required_state @@ -1289,56 +1290,6 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): message=str(actual_required_state), ) - def _assertIncludes( - self, - actual_items: AbstractSet[str], - expected_items: AbstractSet[str], - exact: bool = False, - message: Optional[str] = None, - ) -> None: - """ - Assert that all of the `expected_items` are included in the `actual_items`. - - This assert could also be called `assertContains`, `assertItemsInSet` - - Args: - actual_items: The container - expected_items: The items to check for in the container - exact: Whether the actual state should be exactly equal to the expected - state (no extras). - message: Optional message to include in the failure message. - """ - # Check that each set has the same items - if exact and actual_items == expected_items: - return - # Check for a superset - elif not exact and actual_items >= expected_items: - return - - expected_lines: List[str] = [] - for expected_item in expected_items: - is_expected_in_actual = expected_item in actual_items - expected_lines.append( - "{} {}".format(" " if is_expected_in_actual else "?", expected_item) - ) - - actual_lines: List[str] = [] - for actual_item in actual_items: - is_actual_in_expected = actual_item in expected_items - actual_lines.append( - "{} {}".format("+" if is_actual_in_expected else " ", actual_item) - ) - - newline = "\n" - expected_string = f"Expected items to be in actual ('?' = missing expected items):\n {{\n{newline.join(expected_lines)}\n }}" - actual_string = f"Actual ('+' = found expected items):\n {{\n{newline.join(actual_lines)}\n }}" - first_message = ( - "Items must match exactly" if exact else "Some expected items are missing." - ) - diff_message = f"{first_message}\n{expected_string}\n{actual_string}" - - self.fail(f"{diff_message}\n{message}") - def _add_new_dm_to_global_account_data( self, source_user_id: str, target_user_id: str, target_room_id: str ) -> None: @@ -3868,6 +3819,13 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): body={"foo": "bar"}, tok=user2_tok, ) + self.helper.send_state( + room_id1, + event_type="org.matrix.bar_state", + state_key="", + body={"bar": "qux"}, + tok=user2_tok, + ) # Make the Sliding Sync request with wildcards for the `state_key` channel = self.make_request( @@ -3891,16 +3849,13 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): ], "timeline_limit": 0, }, - } - # TODO: Room subscription should also combine with the `required_state` - # "room_subscriptions": { - # room_id1: { - # "required_state": [ - # ["org.matrix.bar_state", ""] - # ], - # "timeline_limit": 0, - # } - # } + }, + "room_subscriptions": { + room_id1: { + "required_state": [["org.matrix.bar_state", ""]], + "timeline_limit": 0, + } + }, }, access_token=user1_tok, ) @@ -3917,6 +3872,7 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): state_map[(EventTypes.Member, user1_id)], state_map[(EventTypes.Member, user2_id)], state_map[("org.matrix.foo_state", "")], + state_map[("org.matrix.bar_state", "")], }, exact=True, ) @@ -4009,6 +3965,271 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): channel.json_body["lists"]["foo-list"], ) + def test_room_subscriptions_with_join_membership(self) -> None: + """ + Test `room_subscriptions` with a joined room should give us timeline and current + state events. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) + + # Make the Sliding Sync request with just the room subscription + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "room_subscriptions": { + room_id1: { + "required_state": [ + [EventTypes.Create, ""], + ], + "timeline_limit": 1, + } + }, + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + + # We should see some state + self._assertRequiredStateIncludes( + channel.json_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Create, "")], + }, + exact=True, + ) + self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state")) + + # We should see some events + self.assertEqual( + [ + event["event_id"] + for event in channel.json_body["rooms"][room_id1]["timeline"] + ], + [ + join_response["event_id"], + ], + channel.json_body["rooms"][room_id1]["timeline"], + ) + # No "live" events in an initial sync (no `from_token` to define the "live" + # range) + self.assertEqual( + channel.json_body["rooms"][room_id1]["num_live"], + 0, + channel.json_body["rooms"][room_id1], + ) + # There are more events to paginate to + self.assertEqual( + channel.json_body["rooms"][room_id1]["limited"], + True, + channel.json_body["rooms"][room_id1], + ) + + def test_room_subscriptions_with_leave_membership(self) -> None: + """ + Test `room_subscriptions` with a leave room should give us timeline and state + events up to the leave event. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.send_state( + room_id1, + event_type="org.matrix.foo_state", + state_key="", + body={"foo": "bar"}, + tok=user2_tok, + ) + + join_response = self.helper.join(room_id1, user1_id, tok=user1_tok) + leave_response = self.helper.leave(room_id1, user1_id, tok=user1_tok) + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + + # Send some events after user1 leaves + self.helper.send(room_id1, "activity after leave", tok=user2_tok) + # Update state after user1 leaves + self.helper.send_state( + room_id1, + event_type="org.matrix.foo_state", + state_key="", + body={"foo": "qux"}, + tok=user2_tok, + ) + + # Make the Sliding Sync request with just the room subscription + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "room_subscriptions": { + room_id1: { + "required_state": [ + ["org.matrix.foo_state", ""], + ], + "timeline_limit": 2, + } + }, + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # We should see the state at the time of the leave + self._assertRequiredStateIncludes( + channel.json_body["rooms"][room_id1]["required_state"], + { + state_map[("org.matrix.foo_state", "")], + }, + exact=True, + ) + self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state")) + + # We should see some before we left (nothing after) + self.assertEqual( + [ + event["event_id"] + for event in channel.json_body["rooms"][room_id1]["timeline"] + ], + [ + join_response["event_id"], + leave_response["event_id"], + ], + channel.json_body["rooms"][room_id1]["timeline"], + ) + # No "live" events in an initial sync (no `from_token` to define the "live" + # range) + self.assertEqual( + channel.json_body["rooms"][room_id1]["num_live"], + 0, + channel.json_body["rooms"][room_id1], + ) + # There are more events to paginate to + self.assertEqual( + channel.json_body["rooms"][room_id1]["limited"], + True, + channel.json_body["rooms"][room_id1], + ) + + def test_room_subscriptions_no_leak_private_room(self) -> None: + """ + Test `room_subscriptions` with a private room we have never been in should not + leak any data to the user. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok, is_public=False) + + # We should not be able to join the private room + self.helper.join( + room_id1, user1_id, tok=user1_tok, expect_code=HTTPStatus.FORBIDDEN + ) + + # Make the Sliding Sync request with just the room subscription + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "room_subscriptions": { + room_id1: { + "required_state": [ + [EventTypes.Create, ""], + ], + "timeline_limit": 1, + } + }, + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # We should not see the room at all (we're not in it) + self.assertIsNone( + channel.json_body["rooms"].get(room_id1), channel.json_body["rooms"] + ) + + def test_room_subscriptions_world_readable(self) -> None: + """ + Test `room_subscriptions` with a room that has `world_readable` history visibility + + FIXME: We should be able to see the room timeline and state + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create a room with `world_readable` history visibility + room_id1 = self.helper.create_room_as( + user2_id, + tok=user2_tok, + extra_content={ + "preset": "public_chat", + "initial_state": [ + { + "content": { + "history_visibility": HistoryVisibility.WORLD_READABLE + }, + "state_key": "", + "type": EventTypes.RoomHistoryVisibility, + } + ], + }, + ) + # Ensure we're testing with a room with `world_readable` history visibility + # which means events are visible to anyone even without membership. + history_visibility_response = self.helper.get_state( + room_id1, EventTypes.RoomHistoryVisibility, tok=user2_tok + ) + self.assertEqual( + history_visibility_response.get("history_visibility"), + HistoryVisibility.WORLD_READABLE, + ) + + # Note: We never join the room + + # Make the Sliding Sync request with just the room subscription + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "room_subscriptions": { + room_id1: { + "required_state": [ + [EventTypes.Create, ""], + ], + "timeline_limit": 1, + } + }, + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # FIXME: In the future, we should be able to see the room because it's + # `world_readable` but currently we don't support this. + self.assertIsNone( + channel.json_body["rooms"].get(room_id1), channel.json_body["rooms"] + ) + class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase): """Tests for the to-device sliding sync extension""" diff --git a/tests/unittest.py b/tests/unittest.py index a7c20556a0..4aa7f56106 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -28,6 +28,7 @@ import logging import secrets import time from typing import ( + AbstractSet, Any, Awaitable, Callable, @@ -269,6 +270,56 @@ class TestCase(unittest.TestCase): required[key], actual[key], msg="%s mismatch. %s" % (key, actual) ) + def assertIncludes( + self, + actual_items: AbstractSet[str], + expected_items: AbstractSet[str], + exact: bool = False, + message: Optional[str] = None, + ) -> None: + """ + Assert that all of the `expected_items` are included in the `actual_items`. + + This assert could also be called `assertContains`, `assertItemsInSet` + + Args: + actual_items: The container + expected_items: The items to check for in the container + exact: Whether the actual state should be exactly equal to the expected + state (no extras). + message: Optional message to include in the failure message. + """ + # Check that each set has the same items + if exact and actual_items == expected_items: + return + # Check for a superset + elif not exact and actual_items >= expected_items: + return + + expected_lines: List[str] = [] + for expected_item in expected_items: + is_expected_in_actual = expected_item in actual_items + expected_lines.append( + "{} {}".format(" " if is_expected_in_actual else "?", expected_item) + ) + + actual_lines: List[str] = [] + for actual_item in actual_items: + is_actual_in_expected = actual_item in expected_items + actual_lines.append( + "{} {}".format("+" if is_actual_in_expected else " ", actual_item) + ) + + newline = "\n" + expected_string = f"Expected items to be in actual ('?' = missing expected items):\n {{\n{newline.join(expected_lines)}\n }}" + actual_string = f"Actual ('+' = found expected items):\n {{\n{newline.join(actual_lines)}\n }}" + first_message = ( + "Items must match exactly" if exact else "Some expected items are missing." + ) + diff_message = f"{first_message}\n{expected_string}\n{actual_string}" + + self.fail(f"{diff_message}\n{message}") + def DEBUG(target: TV) -> TV: """A decorator to set the .loglevel attribute to logging.DEBUG. -- cgit 1.5.1 From df11af14dbd2faad916924cab96e75bd7c95a66a Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 15 Jul 2024 16:13:04 +0100 Subject: Fix bug where sync could get stuck when using workers (#17438) This is because we serialized the token wrong if the instance map contained entries from before the minimum token. --- changelog.d/17438.bugfix | 1 + synapse/handlers/sliding_sync.py | 11 +++++-- synapse/types/__init__.py | 65 +++++++++++++++++++++++++++++++----- tests/test_types.py | 71 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 138 insertions(+), 10 deletions(-) create mode 100644 changelog.d/17438.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/17438.bugfix b/changelog.d/17438.bugfix new file mode 100644 index 0000000000..cff6eecd48 --- /dev/null +++ b/changelog.d/17438.bugfix @@ -0,0 +1 @@ +Fix rare bug where `/sync` would break for a user when using workers with multiple stream writers. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index be98b379eb..1b5262d667 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -699,10 +699,17 @@ class SlidingSyncHandler: instance_to_max_stream_ordering_map[instance_name] = stream_ordering # Then assemble the `RoomStreamToken` + min_stream_pos = min(instance_to_max_stream_ordering_map.values()) membership_snapshot_token = RoomStreamToken( # Minimum position in the `instance_map` - stream=min(instance_to_max_stream_ordering_map.values()), - instance_map=immutabledict(instance_to_max_stream_ordering_map), + stream=min_stream_pos, + instance_map=immutabledict( + { + instance_name: stream_pos + for instance_name, stream_pos in instance_to_max_stream_ordering_map.items() + if stream_pos > min_stream_pos + } + ), ) # Since we fetched the users room list at some point in time after the from/to diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index b22a13ef01..3962ecc996 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -20,6 +20,7 @@ # # import abc +import logging import re import string from enum import Enum @@ -74,6 +75,9 @@ if TYPE_CHECKING: from synapse.storage.databases.main import DataStore, PurgeEventsStore from synapse.storage.databases.main.appservice import ApplicationServiceWorkerStore + +logger = logging.getLogger(__name__) + # Define a state map type from type/state_key to T (usually an event ID or # event) T = TypeVar("T") @@ -454,6 +458,8 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): represented by a default `stream` attribute and a map of instance name to stream position of any writers that are ahead of the default stream position. + + The values in `instance_map` must be greater than the `stream` attribute. """ stream: int = attr.ib(validator=attr.validators.instance_of(int), kw_only=True) @@ -468,6 +474,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): kw_only=True, ) + def __attrs_post_init__(self) -> None: + # Enforce that all instances have a value greater than the min stream + # position. + for i, v in self.instance_map.items(): + if v <= self.stream: + raise ValueError( + f"'instance_map' includes a stream position before the main 'stream' attribute. Instance: {i}" + ) + @classmethod @abc.abstractmethod async def parse(cls, store: "DataStore", string: str) -> "Self": @@ -494,6 +509,9 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): for instance in set(self.instance_map).union(other.instance_map) } + # Filter out any redundant entries. + instance_map = {i: s for i, s in instance_map.items() if s > max_stream} + return attr.evolve( self, stream=max_stream, instance_map=immutabledict(instance_map) ) @@ -539,10 +557,15 @@ class AbstractMultiWriterStreamToken(metaclass=abc.ABCMeta): def bound_stream_token(self, max_stream: int) -> "Self": """Bound the stream positions to a maximum value""" + min_pos = min(self.stream, max_stream) return type(self)( - stream=min(self.stream, max_stream), + stream=min_pos, instance_map=immutabledict( - {k: min(s, max_stream) for k, s in self.instance_map.items()} + { + k: min(s, max_stream) + for k, s in self.instance_map.items() + if min(s, max_stream) > min_pos + } ), ) @@ -637,6 +660,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): "Cannot set both 'topological' and 'instance_map' on 'RoomStreamToken'." ) + super().__attrs_post_init__() + @classmethod async def parse(cls, store: "PurgeEventsStore", string: str) -> "RoomStreamToken": try: @@ -651,6 +676,11 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): instance_map = {} for part in parts[1:]: + if not part: + # Handle tokens of the form `m5~`, which were created by + # a bug + continue + key, value = part.split(".") instance_id = int(key) pos = int(value) @@ -666,7 +696,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): except CancelledError: raise except Exception: - pass + # We log an exception here as even though this *might* be a client + # handing a bad token, its more likely that Synapse returned a bad + # token (and we really want to catch those!). + logger.exception("Failed to parse stream token: %r", string) raise SynapseError(400, "Invalid room stream token %r" % (string,)) @classmethod @@ -713,6 +746,8 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): return self.instance_map.get(instance_name, self.stream) async def to_string(self, store: "DataStore") -> str: + """See class level docstring for information about the format.""" + if self.topological is not None: return "t%d-%d" % (self.topological, self.stream) elif self.instance_map: @@ -727,8 +762,10 @@ class RoomStreamToken(AbstractMultiWriterStreamToken): instance_id = await store.get_id_for_instance(name) entries.append(f"{instance_id}.{pos}") - encoded_map = "~".join(entries) - return f"m{self.stream}~{encoded_map}" + if entries: + encoded_map = "~".join(entries) + return f"m{self.stream}~{encoded_map}" + return f"s{self.stream}" else: return "s%d" % (self.stream,) @@ -756,6 +793,11 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken): instance_map = {} for part in parts[1:]: + if not part: + # Handle tokens of the form `m5~`, which were created by + # a bug + continue + key, value = part.split(".") instance_id = int(key) pos = int(value) @@ -770,10 +812,15 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken): except CancelledError: raise except Exception: - pass + # We log an exception here as even though this *might* be a client + # handing a bad token, its more likely that Synapse returned a bad + # token (and we really want to catch those!). + logger.exception("Failed to parse stream token: %r", string) raise SynapseError(400, "Invalid stream token %r" % (string,)) async def to_string(self, store: "DataStore") -> str: + """See class level docstring for information about the format.""" + if self.instance_map: entries = [] for name, pos in self.instance_map.items(): @@ -786,8 +833,10 @@ class MultiWriterStreamToken(AbstractMultiWriterStreamToken): instance_id = await store.get_id_for_instance(name) entries.append(f"{instance_id}.{pos}") - encoded_map = "~".join(entries) - return f"m{self.stream}~{encoded_map}" + if entries: + encoded_map = "~".join(entries) + return f"m{self.stream}~{encoded_map}" + return str(self.stream) else: return str(self.stream) diff --git a/tests/test_types.py b/tests/test_types.py index 944aa784fc..00adc65a5a 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -19,9 +19,18 @@ # # +from typing import Type +from unittest import skipUnless + +from immutabledict import immutabledict +from parameterized import parameterized_class + from synapse.api.errors import SynapseError from synapse.types import ( + AbstractMultiWriterStreamToken, + MultiWriterStreamToken, RoomAlias, + RoomStreamToken, UserID, get_domain_from_id, get_localpart_from_id, @@ -29,6 +38,7 @@ from synapse.types import ( ) from tests import unittest +from tests.utils import USE_POSTGRES_FOR_TESTS class IsMineIDTests(unittest.HomeserverTestCase): @@ -127,3 +137,64 @@ class MapUsernameTestCase(unittest.TestCase): # this should work with either a unicode or a bytes self.assertEqual(map_username_to_mxid_localpart("têst"), "t=c3=aast") self.assertEqual(map_username_to_mxid_localpart("têst".encode()), "t=c3=aast") + + +@parameterized_class( + ("token_type",), + [ + (MultiWriterStreamToken,), + (RoomStreamToken,), + ], + class_name_func=lambda cls, num, params_dict: f"{cls.__name__}_{params_dict['token_type'].__name__}", +) +class MultiWriterTokenTestCase(unittest.HomeserverTestCase): + """Tests for the different types of multi writer tokens.""" + + token_type: Type[AbstractMultiWriterStreamToken] + + def test_basic_token(self) -> None: + """Test that a simple stream token can be serialized and unserialized""" + store = self.hs.get_datastores().main + + token = self.token_type(stream=5) + + string_token = self.get_success(token.to_string(store)) + + if isinstance(token, RoomStreamToken): + self.assertEqual(string_token, "s5") + else: + self.assertEqual(string_token, "5") + + parsed_token = self.get_success(self.token_type.parse(store, string_token)) + self.assertEqual(parsed_token, token) + + @skipUnless(USE_POSTGRES_FOR_TESTS, "Requires Postgres") + def test_instance_map(self) -> None: + """Test for stream token with instance map""" + store = self.hs.get_datastores().main + + token = self.token_type(stream=5, instance_map=immutabledict({"foo": 6})) + + string_token = self.get_success(token.to_string(store)) + self.assertEqual(string_token, "m5~1.6") + + parsed_token = self.get_success(self.token_type.parse(store, string_token)) + self.assertEqual(parsed_token, token) + + def test_instance_map_assertion(self) -> None: + """Test that we assert values in the instance map are greater than the + min stream position""" + + with self.assertRaises(ValueError): + self.token_type(stream=5, instance_map=immutabledict({"foo": 4})) + + with self.assertRaises(ValueError): + self.token_type(stream=5, instance_map=immutabledict({"foo": 5})) + + def test_parse_bad_token(self) -> None: + """Test that we can parse tokens produced by a bug in Synapse of the + form `m5~`""" + store = self.hs.get_datastores().main + + parsed_token = self.get_success(self.token_type.parse(store, "m5~")) + self.assertEqual(parsed_token, self.token_type(stream=5)) -- cgit 1.5.1 From a574de006286352e150b397f34fecc0d83e68148 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Thu, 18 Jul 2024 06:49:53 -0500 Subject: Add `m.room.create` to default bump event types (#17453) Add `m.room.create` to default bump event types This probably helps when no messages have been sent in the room and it was just created. --- changelog.d/17453.misc | 1 + synapse/handlers/sliding_sync.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 changelog.d/17453.misc (limited to 'synapse/handlers') diff --git a/changelog.d/17453.misc b/changelog.d/17453.misc new file mode 100644 index 0000000000..2978a52477 --- /dev/null +++ b/changelog.d/17453.misc @@ -0,0 +1 @@ +Update experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint to bump room when it is created. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 1b5262d667..a23a6b9dd9 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -53,6 +53,7 @@ logger = logging.getLogger(__name__) # The event types that clients should consider as new activity. DEFAULT_BUMP_EVENT_TYPES = { + EventTypes.Create, EventTypes.Message, EventTypes.Encrypted, EventTypes.Sticker, -- cgit 1.5.1 From 6a01af59e14b67960e290c26131249fb8c833293 Mon Sep 17 00:00:00 2001 From: Ben Banfield-Zanin Date: Thu, 18 Jul 2024 13:32:32 +0100 Subject: Improve default_power_level_content_override documentation (#17451) Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/17451.doc | 1 + docs/usage/configuration/config_documentation.md | 32 ++++++++++++++++++++++++ synapse/handlers/room.py | 2 ++ 3 files changed, 35 insertions(+) create mode 100644 changelog.d/17451.doc (limited to 'synapse/handlers') diff --git a/changelog.d/17451.doc b/changelog.d/17451.doc new file mode 100644 index 0000000000..357ac2c906 --- /dev/null +++ b/changelog.d/17451.doc @@ -0,0 +1 @@ +Improve documentation for the [`default_power_level_content_override`](https://element-hq.github.io/synapse/latest/usage/configuration/config_documentation.html#default_power_level_content_override) config option. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index 65b03ad0f8..38b24b5044 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -4134,6 +4134,38 @@ default_power_level_content_override: trusted_private_chat: null public_chat: null ``` + +The default power levels for each preset are: +```yaml +"m.room.name": 50 +"m.room.power_levels": 100 +"m.room.history_visibility": 100 +"m.room.canonical_alias": 50 +"m.room.avatar": 50 +"m.room.tombstone": 100 +"m.room.server_acl": 100 +"m.room.encryption": 100 +``` + +So a complete example where the default power-levels for a preset are maintained +but the power level for a new key is set is: +```yaml +default_power_level_content_override: + private_chat: + events: + "com.example.foo": 0 + "m.room.name": 50 + "m.room.power_levels": 100 + "m.room.history_visibility": 100 + "m.room.canonical_alias": 50 + "m.room.avatar": 50 + "m.room.tombstone": 100 + "m.room.server_acl": 100 + "m.room.encryption": 100 + trusted_private_chat: null + public_chat: null +``` + --- ### `forget_rooms_on_leave` diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py index 2302d283a7..262d9f4044 100644 --- a/synapse/handlers/room.py +++ b/synapse/handlers/room.py @@ -1188,6 +1188,8 @@ class RoomCreationHandler: ) events_to_send.append((power_event, power_context)) else: + # Please update the docs for `default_power_level_content_override` when + # updating the `events` dict below power_level_content: JsonDict = { "users": {creator_id: 100}, "users_default": 0, -- cgit 1.5.1 From 43c865f7c98a1e31d6eb8d3979d1e199fadcb950 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Fri, 19 Jul 2024 12:09:39 +0100 Subject: Generate room sync data concurrently (#17458) This is also what we do for standard `/sync`. --- changelog.d/17458.misc | 1 + synapse/handlers/sliding_sync.py | 12 ++++++++++-- 2 files changed, 11 insertions(+), 2 deletions(-) create mode 100644 changelog.d/17458.misc (limited to 'synapse/handlers') diff --git a/changelog.d/17458.misc b/changelog.d/17458.misc new file mode 100644 index 0000000000..09cce15d0d --- /dev/null +++ b/changelog.d/17458.misc @@ -0,0 +1 @@ +Speed up generating sliding sync responses. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index a23a6b9dd9..423f0329d6 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -28,6 +28,7 @@ from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membe from synapse.events import EventBase from synapse.events.utils import strip_event from synapse.handlers.relations import BundledAggregations +from synapse.logging.opentracing import start_active_span, tag_args, trace from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.databases.main.stream import CurrentStateDeltaMembership from synapse.storage.roommember import MemberSummary @@ -43,6 +44,7 @@ from synapse.types import ( ) from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult from synapse.types.state import StateFilter +from synapse.util.async_helpers import concurrently_execute from synapse.visibility import filter_events_for_client if TYPE_CHECKING: @@ -592,11 +594,14 @@ class SlidingSyncHandler: # Fetch room data rooms: Dict[str, SlidingSyncResult.RoomResult] = {} - for room_id, room_sync_config in relevant_room_map.items(): + + @trace + @tag_args + async def handle_room(room_id: str) -> None: room_sync_result = await self.get_room_sync_data( user=sync_config.user, room_id=room_id, - room_sync_config=room_sync_config, + room_sync_config=relevant_room_map[room_id], room_membership_for_user_at_to_token=room_membership_for_user_map[ room_id ], @@ -606,6 +611,9 @@ class SlidingSyncHandler: rooms[room_id] = room_sync_result + with start_active_span("sliding_sync.generate_room_entries"): + await concurrently_execute(handle_room, relevant_room_map, 10) + extensions = await self.get_extensions_response( sync_config=sync_config, to_token=to_token ) -- cgit 1.5.1 From ed0face8ada72731eba70f9cc56a9b482b7bef1e Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 22 Jul 2024 14:51:17 +0100 Subject: Speed up room keys query by using read/write lock (#17461) Linaerizing all access slows things down when devices try and fetch lots of keys on login --- changelog.d/17461.misc | 1 + synapse/handlers/e2e_room_keys.py | 18 +++++++++--------- 2 files changed, 10 insertions(+), 9 deletions(-) create mode 100644 changelog.d/17461.misc (limited to 'synapse/handlers') diff --git a/changelog.d/17461.misc b/changelog.d/17461.misc new file mode 100644 index 0000000000..80f7144baa --- /dev/null +++ b/changelog.d/17461.misc @@ -0,0 +1 @@ +Speed up fetching room keys from backup. diff --git a/synapse/handlers/e2e_room_keys.py b/synapse/handlers/e2e_room_keys.py index 99f9f6e64a..f397911f28 100644 --- a/synapse/handlers/e2e_room_keys.py +++ b/synapse/handlers/e2e_room_keys.py @@ -34,7 +34,7 @@ from synapse.api.errors import ( from synapse.logging.opentracing import log_kv, trace from synapse.storage.databases.main.e2e_room_keys import RoomKey from synapse.types import JsonDict -from synapse.util.async_helpers import Linearizer +from synapse.util.async_helpers import ReadWriteLock if TYPE_CHECKING: from synapse.server import HomeServer @@ -58,7 +58,7 @@ class E2eRoomKeysHandler: # clients belonging to a user will receive and try to upload a new session at # roughly the same time. Also used to lock out uploads when the key is being # changed. - self._upload_linearizer = Linearizer("upload_room_keys_lock") + self._upload_lock = ReadWriteLock() @trace async def get_room_keys( @@ -89,7 +89,7 @@ class E2eRoomKeysHandler: # we deliberately take the lock to get keys so that changing the version # works atomically - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.read(user_id): # make sure the backup version exists try: await self.store.get_e2e_room_keys_version_info(user_id, version) @@ -132,7 +132,7 @@ class E2eRoomKeysHandler: """ # lock for consistency with uploading - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.write(user_id): # make sure the backup version exists try: version_info = await self.store.get_e2e_room_keys_version_info( @@ -193,7 +193,7 @@ class E2eRoomKeysHandler: # TODO: Validate the JSON to make sure it has the right keys. # XXX: perhaps we should use a finer grained lock here? - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.write(user_id): # Check that the version we're trying to upload is the current version try: version_info = await self.store.get_e2e_room_keys_version_info(user_id) @@ -355,7 +355,7 @@ class E2eRoomKeysHandler: # TODO: Validate the JSON to make sure it has the right keys. # lock everyone out until we've switched version - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.write(user_id): new_version = await self.store.create_e2e_room_keys_version( user_id, version_info ) @@ -382,7 +382,7 @@ class E2eRoomKeysHandler: } """ - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.read(user_id): try: res = await self.store.get_e2e_room_keys_version_info(user_id, version) except StoreError as e: @@ -407,7 +407,7 @@ class E2eRoomKeysHandler: NotFoundError: if this backup version doesn't exist """ - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.write(user_id): try: await self.store.delete_e2e_room_keys_version(user_id, version) except StoreError as e: @@ -437,7 +437,7 @@ class E2eRoomKeysHandler: raise SynapseError( 400, "Version in body does not match", Codes.INVALID_PARAM ) - async with self._upload_linearizer.queue(user_id): + async with self._upload_lock.write(user_id): try: old_info = await self.store.get_e2e_room_keys_version_info( user_id, version -- cgit 1.5.1 From d221512498f7e3267916a289dd2ef4f3e00728e8 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Mon, 22 Jul 2024 17:48:09 +0100 Subject: SS: Implement `$ME` support (#17469) `$ME` can be used as a substitute for the requester's user ID. --- changelog.d/17469.misc | 1 + synapse/handlers/sliding_sync.py | 6 +++- tests/rest/client/test_sync.py | 74 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 1 deletion(-) create mode 100644 changelog.d/17469.misc (limited to 'synapse/handlers') diff --git a/changelog.d/17469.misc b/changelog.d/17469.misc new file mode 100644 index 0000000000..ba0419355b --- /dev/null +++ b/changelog.d/17469.misc @@ -0,0 +1 @@ +Implement handling of `$ME` as a state key in sliding sync. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 423f0329d6..c362afa6e2 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -329,6 +329,9 @@ class StateValues: # `sender` in the timeline). We only give special meaning to this value when it's a # `state_key`. LAZY: Final = "$LAZY" + # Subsitute with the requester's user ID. Typically used by clients to get + # the user's membership. + ME: Final = "$ME" class SlidingSyncHandler: @@ -504,7 +507,6 @@ class SlidingSyncHandler: # Also see `StateFilter.must_await_full_state(...)` for comparison lazy_loading = ( membership_state_keys is not None - and len(membership_state_keys) == 1 and StateValues.LAZY in membership_state_keys ) @@ -1662,6 +1664,8 @@ class SlidingSyncHandler: # FIXME: We probably also care about invite, ban, kick, targets, etc # but the spec only mentions "senders". + elif state_key == StateValues.ME: + required_state_types.append((state_type, user.to_string())) else: required_state_types.append((state_type, state_key)) diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index a008ee465b..a88bdb5c14 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -3714,6 +3714,80 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): ) self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state")) + def test_rooms_required_state_me(self) -> None: + """ + Test `rooms.required_state` correctly handles $ME. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id1 = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id1, user1_id, tok=user1_tok) + + self.helper.send(room_id1, "1", tok=user2_tok) + + # Also send normal state events with state keys of the users, first + # change the power levels to allow this. + self.helper.send_state( + room_id1, + event_type=EventTypes.PowerLevels, + body={"users": {user1_id: 50, user2_id: 100}}, + tok=user2_tok, + ) + self.helper.send_state( + room_id1, + event_type="org.matrix.foo", + state_key=user1_id, + body={}, + tok=user1_tok, + ) + self.helper.send_state( + room_id1, + event_type="org.matrix.foo", + state_key=user2_id, + body={}, + tok=user2_tok, + ) + + # Make the Sliding Sync request with a request for '$ME'. + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": { + "foo-list": { + "ranges": [[0, 1]], + "required_state": [ + [EventTypes.Create, ""], + [EventTypes.Member, StateValues.ME], + ["org.matrix.foo", StateValues.ME], + ], + "timeline_limit": 3, + } + } + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + state_map = self.get_success( + self.storage_controllers.state.get_current_state(room_id1) + ) + + # Only user2 and user3 sent events in the 3 events we see in the `timeline` + self._assertRequiredStateIncludes( + channel.json_body["rooms"][room_id1]["required_state"], + { + state_map[(EventTypes.Create, "")], + state_map[(EventTypes.Member, user1_id)], + state_map[("org.matrix.foo", user1_id)], + }, + exact=True, + ) + self.assertIsNone(channel.json_body["rooms"][room_id1].get("invite_state")) + @parameterized.expand([(Membership.LEAVE,), (Membership.BAN,)]) def test_rooms_required_state_leave_ban(self, stop_membership: str) -> None: """ -- cgit 1.5.1 From de05a642460fa04f6e279fa166f032e9ff94b4b0 Mon Sep 17 00:00:00 2001 From: Eric Eastwood Date: Mon, 22 Jul 2024 15:40:06 -0500 Subject: Sliding Sync: Add E2EE extension (MSC3884) (#17454) Spec: [MSC3884](https://github.com/matrix-org/matrix-spec-proposals/pull/3884) Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync --- changelog.d/17454.feature | 1 + synapse/handlers/device.py | 17 +- synapse/handlers/sliding_sync.py | 107 ++++- synapse/rest/client/keys.py | 10 +- synapse/rest/client/sync.py | 32 +- synapse/types/__init__.py | 7 +- synapse/types/handlers/__init__.py | 48 +- synapse/types/rest/client/__init__.py | 10 + tests/rest/client/test_sync.py | 825 +++++++++++++++++++++++++++++++++- 9 files changed, 1023 insertions(+), 34 deletions(-) create mode 100644 changelog.d/17454.feature (limited to 'synapse/handlers') diff --git a/changelog.d/17454.feature b/changelog.d/17454.feature new file mode 100644 index 0000000000..bb088371bf --- /dev/null +++ b/changelog.d/17454.feature @@ -0,0 +1 @@ +Add E2EE extension support to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint. diff --git a/synapse/handlers/device.py b/synapse/handlers/device.py index 0432d97109..4fc6fcd7ae 100644 --- a/synapse/handlers/device.py +++ b/synapse/handlers/device.py @@ -39,6 +39,7 @@ from synapse.metrics.background_process_metrics import ( ) from synapse.storage.databases.main.client_ips import DeviceLastConnectionInfo from synapse.types import ( + DeviceListUpdates, JsonDict, JsonMapping, ScheduledTask, @@ -214,7 +215,7 @@ class DeviceWorkerHandler: @cancellable async def get_user_ids_changed( self, user_id: str, from_token: StreamToken - ) -> JsonDict: + ) -> DeviceListUpdates: """Get list of users that have had the devices updated, or have newly joined a room, that `user_id` may be interested in. """ @@ -341,11 +342,19 @@ class DeviceWorkerHandler: possibly_joined = set() possibly_left = set() - result = {"changed": list(possibly_joined), "left": list(possibly_left)} + device_list_updates = DeviceListUpdates( + changed=possibly_joined, + left=possibly_left, + ) - log_kv(result) + log_kv( + { + "changed": device_list_updates.changed, + "left": device_list_updates.left, + } + ) - return result + return device_list_updates async def on_federation_query_user_devices(self, user_id: str) -> JsonDict: if not self.hs.is_mine(UserID.from_string(user_id)): diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index c362afa6e2..886d7c7159 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -19,7 +19,18 @@ # import logging from itertools import chain -from typing import TYPE_CHECKING, Any, Dict, Final, List, Mapping, Optional, Set, Tuple +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Final, + List, + Mapping, + Optional, + Sequence, + Set, + Tuple, +) import attr from immutabledict import immutabledict @@ -33,6 +44,7 @@ from synapse.storage.databases.main.roommember import extract_heroes_from_room_s from synapse.storage.databases.main.stream import CurrentStateDeltaMembership from synapse.storage.roommember import MemberSummary from synapse.types import ( + DeviceListUpdates, JsonDict, PersistedEventPosition, Requester, @@ -343,6 +355,7 @@ class SlidingSyncHandler: self.notifier = hs.get_notifier() self.event_sources = hs.get_event_sources() self.relations_handler = hs.get_relations_handler() + self.device_handler = hs.get_device_handler() self.rooms_to_exclude_globally = hs.config.server.rooms_to_exclude_from_sync async def wait_for_sync_for_user( @@ -371,10 +384,6 @@ class SlidingSyncHandler: # auth_blocking will occur) await self.auth_blocking.check_auth_blocking(requester=requester) - # TODO: If the To-Device extension is enabled and we have a `from_token`, delete - # any to-device messages before that token (since we now know that the device - # has received them). (see sync v2 for how to do this) - # If we're working with a user-provided token, we need to make sure to wait for # this worker to catch up with the token so we don't skip past any incoming # events or future events if the user is nefariously, manually modifying the @@ -617,7 +626,9 @@ class SlidingSyncHandler: await concurrently_execute(handle_room, relevant_room_map, 10) extensions = await self.get_extensions_response( - sync_config=sync_config, to_token=to_token + sync_config=sync_config, + from_token=from_token, + to_token=to_token, ) return SlidingSyncResult( @@ -1776,33 +1787,47 @@ class SlidingSyncHandler: self, sync_config: SlidingSyncConfig, to_token: StreamToken, + from_token: Optional[StreamToken], ) -> SlidingSyncResult.Extensions: """Handle extension requests. Args: sync_config: Sync configuration to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. """ if sync_config.extensions is None: return SlidingSyncResult.Extensions() to_device_response = None - if sync_config.extensions.to_device: - to_device_response = await self.get_to_device_extensions_response( + if sync_config.extensions.to_device is not None: + to_device_response = await self.get_to_device_extension_response( sync_config=sync_config, to_device_request=sync_config.extensions.to_device, to_token=to_token, ) - return SlidingSyncResult.Extensions(to_device=to_device_response) + e2ee_response = None + if sync_config.extensions.e2ee is not None: + e2ee_response = await self.get_e2ee_extension_response( + sync_config=sync_config, + e2ee_request=sync_config.extensions.e2ee, + to_token=to_token, + from_token=from_token, + ) - async def get_to_device_extensions_response( + return SlidingSyncResult.Extensions( + to_device=to_device_response, + e2ee=e2ee_response, + ) + + async def get_to_device_extension_response( self, sync_config: SlidingSyncConfig, to_device_request: SlidingSyncConfig.Extensions.ToDeviceExtension, to_token: StreamToken, - ) -> SlidingSyncResult.Extensions.ToDeviceExtension: + ) -> Optional[SlidingSyncResult.Extensions.ToDeviceExtension]: """Handle to-device extension (MSC3885) Args: @@ -1810,14 +1835,16 @@ class SlidingSyncHandler: to_device_request: The to-device extension from the request to_token: The point in the stream to sync up to. """ - user_id = sync_config.user.to_string() device_id = sync_config.device_id + # Skip if the extension is not enabled + if not to_device_request.enabled: + return None + # Check that this request has a valid device ID (not all requests have - # to belong to a device, and so device_id is None), and that the - # extension is enabled. - if device_id is None or not to_device_request.enabled: + # to belong to a device, and so device_id is None) + if device_id is None: return SlidingSyncResult.Extensions.ToDeviceExtension( next_batch=f"{to_token.to_device_key}", events=[], @@ -1868,3 +1895,53 @@ class SlidingSyncHandler: next_batch=f"{stream_id}", events=messages, ) + + async def get_e2ee_extension_response( + self, + sync_config: SlidingSyncConfig, + e2ee_request: SlidingSyncConfig.Extensions.E2eeExtension, + to_token: StreamToken, + from_token: Optional[StreamToken], + ) -> Optional[SlidingSyncResult.Extensions.E2eeExtension]: + """Handle E2EE device extension (MSC3884) + + Args: + sync_config: Sync configuration + e2ee_request: The e2ee extension from the request + to_token: The point in the stream to sync up to. + from_token: The point in the stream to sync from. + """ + user_id = sync_config.user.to_string() + device_id = sync_config.device_id + + # Skip if the extension is not enabled + if not e2ee_request.enabled: + return None + + device_list_updates: Optional[DeviceListUpdates] = None + if from_token is not None: + # TODO: This should take into account the `from_token` and `to_token` + device_list_updates = await self.device_handler.get_user_ids_changed( + user_id=user_id, + from_token=from_token, + ) + + device_one_time_keys_count: Mapping[str, int] = {} + device_unused_fallback_key_types: Sequence[str] = [] + if device_id: + # TODO: We should have a way to let clients differentiate between the states of: + # * no change in OTK count since the provided since token + # * the server has zero OTKs left for this device + # Spec issue: https://github.com/matrix-org/matrix-doc/issues/3298 + device_one_time_keys_count = await self.store.count_e2e_one_time_keys( + user_id, device_id + ) + device_unused_fallback_key_types = ( + await self.store.get_e2e_unused_fallback_key_types(user_id, device_id) + ) + + return SlidingSyncResult.Extensions.E2eeExtension( + device_list_updates=device_list_updates, + device_one_time_keys_count=device_one_time_keys_count, + device_unused_fallback_key_types=device_unused_fallback_key_types, + ) diff --git a/synapse/rest/client/keys.py b/synapse/rest/client/keys.py index 67de634eab..eddad7d5b8 100644 --- a/synapse/rest/client/keys.py +++ b/synapse/rest/client/keys.py @@ -256,9 +256,15 @@ class KeyChangesServlet(RestServlet): user_id = requester.user.to_string() - results = await self.device_handler.get_user_ids_changed(user_id, from_token) + device_list_updates = await self.device_handler.get_user_ids_changed( + user_id, from_token + ) + + response: JsonDict = {} + response["changed"] = list(device_list_updates.changed) + response["left"] = list(device_list_updates.left) - return 200, results + return 200, response class OneTimeKeyServlet(RestServlet): diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py index 1d8cbfdf00..93fe1d439e 100644 --- a/synapse/rest/client/sync.py +++ b/synapse/rest/client/sync.py @@ -1081,15 +1081,41 @@ class SlidingSyncRestServlet(RestServlet): async def encode_extensions( self, requester: Requester, extensions: SlidingSyncResult.Extensions ) -> JsonDict: - result = {} + serialized_extensions: JsonDict = {} if extensions.to_device is not None: - result["to_device"] = { + serialized_extensions["to_device"] = { "next_batch": extensions.to_device.next_batch, "events": extensions.to_device.events, } - return result + if extensions.e2ee is not None: + serialized_extensions["e2ee"] = { + # We always include this because + # https://github.com/vector-im/element-android/issues/3725. The spec + # isn't terribly clear on when this can be omitted and how a client + # would tell the difference between "no keys present" and "nothing + # changed" in terms of whole field absent / individual key type entry + # absent Corresponding synapse issue: + # https://github.com/matrix-org/synapse/issues/10456 + "device_one_time_keys_count": extensions.e2ee.device_one_time_keys_count, + # https://github.com/matrix-org/matrix-doc/blob/54255851f642f84a4f1aaf7bc063eebe3d76752b/proposals/2732-olm-fallback-keys.md + # states that this field should always be included, as long as the + # server supports the feature. + "device_unused_fallback_key_types": extensions.e2ee.device_unused_fallback_key_types, + } + + if extensions.e2ee.device_list_updates is not None: + serialized_extensions["e2ee"]["device_lists"] = {} + + serialized_extensions["e2ee"]["device_lists"]["changed"] = list( + extensions.e2ee.device_list_updates.changed + ) + serialized_extensions["e2ee"]["device_lists"]["left"] = list( + extensions.e2ee.device_list_updates.left + ) + + return serialized_extensions def register_servlets(hs: "HomeServer", http_server: HttpServer) -> None: diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 3962ecc996..046cdc29cd 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1219,11 +1219,12 @@ class ReadReceipt: @attr.s(slots=True, frozen=True, auto_attribs=True) class DeviceListUpdates: """ - An object containing a diff of information regarding other users' device lists, intended for - a recipient to carry out device list tracking. + An object containing a diff of information regarding other users' device lists, + intended for a recipient to carry out device list tracking. Attributes: - changed: A set of users whose device lists have changed recently. + changed: A set of users who have updated their device identity or + cross-signing keys, or who now share an encrypted room with. left: A set of users who the recipient no longer needs to track the device lists of. Typically when those users no longer share any end-to-end encryption enabled rooms. """ diff --git a/synapse/types/handlers/__init__.py b/synapse/types/handlers/__init__.py index 409120470a..4c6c42db04 100644 --- a/synapse/types/handlers/__init__.py +++ b/synapse/types/handlers/__init__.py @@ -18,7 +18,7 @@ # # from enum import Enum -from typing import TYPE_CHECKING, Dict, Final, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Dict, Final, List, Mapping, Optional, Sequence, Tuple import attr from typing_extensions import TypedDict @@ -31,7 +31,7 @@ else: from pydantic import Extra from synapse.events import EventBase -from synapse.types import JsonDict, JsonMapping, StreamToken, UserID +from synapse.types import DeviceListUpdates, JsonDict, JsonMapping, StreamToken, UserID from synapse.types.rest.client import SlidingSyncBody if TYPE_CHECKING: @@ -264,6 +264,7 @@ class SlidingSyncResult: Attributes: to_device: The to-device extension (MSC3885) + e2ee: The E2EE device extension (MSC3884) """ @attr.s(slots=True, frozen=True, auto_attribs=True) @@ -282,10 +283,51 @@ class SlidingSyncResult: def __bool__(self) -> bool: return bool(self.events) + @attr.s(slots=True, frozen=True, auto_attribs=True) + class E2eeExtension: + """The E2EE device extension (MSC3884) + + Attributes: + device_list_updates: List of user_ids whose devices have changed or left (only + present on incremental syncs). + device_one_time_keys_count: Map from key algorithm to the number of + unclaimed one-time keys currently held on the server for this device. If + an algorithm is unlisted, the count for that algorithm is assumed to be + zero. If this entire parameter is missing, the count for all algorithms + is assumed to be zero. + device_unused_fallback_key_types: List of unused fallback key algorithms + for this device. + """ + + # Only present on incremental syncs + device_list_updates: Optional[DeviceListUpdates] + device_one_time_keys_count: Mapping[str, int] + device_unused_fallback_key_types: Sequence[str] + + def __bool__(self) -> bool: + # Note that "signed_curve25519" is always returned in key count responses + # regardless of whether we uploaded any keys for it. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. + # + # Also related: + # https://github.com/element-hq/element-android/issues/3725 and + # https://github.com/matrix-org/synapse/issues/10456 + default_otk = self.device_one_time_keys_count.get("signed_curve25519") + more_than_default_otk = len(self.device_one_time_keys_count) > 1 or ( + default_otk is not None and default_otk > 0 + ) + + return bool( + more_than_default_otk + or self.device_list_updates + or self.device_unused_fallback_key_types + ) + to_device: Optional[ToDeviceExtension] = None + e2ee: Optional[E2eeExtension] = None def __bool__(self) -> bool: - return bool(self.to_device) + return bool(self.to_device or self.e2ee) next_pos: StreamToken lists: Dict[str, SlidingWindowList] diff --git a/synapse/types/rest/client/__init__.py b/synapse/types/rest/client/__init__.py index dbe37bc712..f3c45a0d6a 100644 --- a/synapse/types/rest/client/__init__.py +++ b/synapse/types/rest/client/__init__.py @@ -313,7 +313,17 @@ class SlidingSyncBody(RequestBodyModel): return value + class E2eeExtension(RequestBodyModel): + """The E2EE device extension (MSC3884) + + Attributes: + enabled + """ + + enabled: Optional[StrictBool] = False + to_device: Optional[ToDeviceExtension] = None + e2ee: Optional[E2eeExtension] = None # mypy workaround via https://github.com/pydantic/pydantic/issues/156#issuecomment-1130883884 if TYPE_CHECKING: diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index a88bdb5c14..2628869de6 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -59,6 +59,7 @@ from tests.federation.transport.test_knocking import ( ) from tests.server import FakeChannel, TimedOutException from tests.test_utils.event_injection import mark_event_as_partial_state +from tests.unittest import skip_unless logger = logging.getLogger(__name__) @@ -1113,12 +1114,11 @@ class DeviceUnusedFallbackKeySyncTestCase(unittest.HomeserverTestCase): self.assertEqual(res, []) # Upload a fallback key for the user/device - fallback_key = {"alg1:k1": "fallback_key1"} self.get_success( self.e2e_keys_handler.upload_keys_for_user( alice_user_id, test_device_id, - {"fallback_keys": fallback_key}, + {"fallback_keys": {"alg1:k1": "fallback_key1"}}, ) ) # We should now have an unused alg1 key @@ -1252,6 +1252,8 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.event_sources = hs.get_event_sources() self.storage_controllers = hs.get_storage_controllers() + self.account_data_handler = hs.get_account_data_handler() + self.notifier = hs.get_notifier() def _assertRequiredStateIncludes( self, @@ -1377,6 +1379,52 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): return room_id + def _bump_notifier_wait_for_events(self, user_id: str) -> None: + """ + Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding + Sync results. + """ + # We're expecting some new activity from this point onwards + from_token = self.event_sources.get_current_token() + + triggered_notifier_wait_for_events = False + + async def _on_new_acivity( + before_token: StreamToken, after_token: StreamToken + ) -> bool: + nonlocal triggered_notifier_wait_for_events + triggered_notifier_wait_for_events = True + return True + + # Listen for some new activity for the user. We're just trying to confirm that + # our bump below actually does what we think it does (triggers new activity for + # the user). + result_awaitable = self.notifier.wait_for_events( + user_id, + 1000, + _on_new_acivity, + from_token=from_token, + ) + + # Update the account data so that `notifier.wait_for_events(...)` wakes up. + # We're bumping account data because it won't show up in the Sliding Sync + # response so it won't affect whether we have results. + self.get_success( + self.account_data_handler.add_account_data_for_user( + user_id, + "org.matrix.foobarbaz", + {"foo": "bar"}, + ) + ) + + # Wait for our notifier result + self.get_success(result_awaitable) + + if not triggered_notifier_wait_for_events: + raise AssertionError( + "Expected `notifier.wait_for_events(...)` to be triggered" + ) + def test_sync_list(self) -> None: """ Test that room IDs show up in the Sliding Sync `lists` @@ -1482,6 +1530,124 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # with because we weren't able to find anything new yet. self.assertEqual(channel.json_body["pos"], future_position_token_serialized) + def test_wait_for_new_data(self) -> None: + """ + Test to make sure that the Sliding Sync request waits for new data to arrive. + + (Only applies to incremental syncs with a `timeout` specified) + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id, user1_id, tok=user1_tok) + + from_token = self.event_sources.get_current_token() + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint + + "?timeout=10000" + + f"&pos={self.get_success(from_token.to_string(self.store))}", + { + "lists": { + "foo-list": { + "ranges": [[0, 0]], + "required_state": [], + "timeline_limit": 1, + } + } + }, + access_token=user1_tok, + await_result=False, + ) + # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=5000) + # Bump the room with new events to trigger new results + event_response1 = self.helper.send( + room_id, "new activity in room", tok=user1_tok + ) + # Should respond before the 10 second timeout + channel.await_result(timeout_ms=3000) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check to make sure the new event is returned + self.assertEqual( + [ + event["event_id"] + for event in channel.json_body["rooms"][room_id]["timeline"] + ], + [ + event_response1["event_id"], + ], + channel.json_body["rooms"][room_id]["timeline"], + ) + + # TODO: Once we remove `ops`, we should be able to add a `RoomResult.__bool__` to + # check if there are any updates since the `from_token`. + @skip_unless( + False, + "Once we remove ops from the Sliding Sync response, this test should pass", + ) + def test_wait_for_new_data_timeout(self) -> None: + """ + Test to make sure that the Sliding Sync request waits for new data to arrive but + no data ever arrives so we timeout. We're also making sure that the default data + doesn't trigger a false-positive for new data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id, user1_id, tok=user1_tok) + + from_token = self.event_sources.get_current_token() + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint + + "?timeout=10000" + + f"&pos={self.get_success(from_token.to_string(self.store))}", + { + "lists": { + "foo-list": { + "ranges": [[0, 0]], + "required_state": [], + "timeline_limit": 1, + } + } + }, + access_token=user1_tok, + await_result=False, + ) + # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=5000) + # Wake-up `notifier.wait_for_events(...)` that will cause us test + # `SlidingSyncResult.__bool__` for new results. + self._bump_notifier_wait_for_events(user1_id) + # Block for a little bit more to ensure we don't see any new results. + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=4000) + # Wait for the sync to complete (wait for the rest of the 10 second timeout, + # 5000 + 4000 + 1200 > 10000) + channel.await_result(timeout_ms=1200) + self.assertEqual(channel.code, 200, channel.json_body) + + # We still see rooms because that's how Sliding Sync lists work but we reached + # the timeout before seeing them + self.assertEqual( + [event["event_id"] for event in channel.json_body["rooms"].keys()], + [room_id], + ) + def test_filter_list(self) -> None: """ Test that filters apply to `lists` @@ -1508,11 +1674,11 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): ) # Create a normal room - room_id = self.helper.create_room_as(user1_id, tok=user2_tok) + room_id = self.helper.create_room_as(user2_id, tok=user2_tok) self.helper.join(room_id, user1_id, tok=user1_tok) # Create a room that user1 is invited to - invite_room_id = self.helper.create_room_as(user1_id, tok=user2_tok) + invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok) self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok) # Make the Sliding Sync request @@ -4320,10 +4486,59 @@ class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase): def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: self.store = hs.get_datastores().main + self.event_sources = hs.get_event_sources() + self.account_data_handler = hs.get_account_data_handler() + self.notifier = hs.get_notifier() self.sync_endpoint = ( "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync" ) + def _bump_notifier_wait_for_events(self, user_id: str) -> None: + """ + Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding + Sync results. + """ + # We're expecting some new activity from this point onwards + from_token = self.event_sources.get_current_token() + + triggered_notifier_wait_for_events = False + + async def _on_new_acivity( + before_token: StreamToken, after_token: StreamToken + ) -> bool: + nonlocal triggered_notifier_wait_for_events + triggered_notifier_wait_for_events = True + return True + + # Listen for some new activity for the user. We're just trying to confirm that + # our bump below actually does what we think it does (triggers new activity for + # the user). + result_awaitable = self.notifier.wait_for_events( + user_id, + 1000, + _on_new_acivity, + from_token=from_token, + ) + + # Update the account data so that `notifier.wait_for_events(...)` wakes up. + # We're bumping account data because it won't show up in the Sliding Sync + # response so it won't affect whether we have results. + self.get_success( + self.account_data_handler.add_account_data_for_user( + user_id, + "org.matrix.foobarbaz", + {"foo": "bar"}, + ) + ) + + # Wait for our notifier result + self.get_success(result_awaitable) + + if not triggered_notifier_wait_for_events: + raise AssertionError( + "Expected `notifier.wait_for_events(...)` to be triggered" + ) + def _assert_to_device_response( self, channel: FakeChannel, expected_messages: List[JsonDict] ) -> str: @@ -4487,3 +4702,605 @@ class SlidingSyncToDeviceExtensionTestCase(unittest.HomeserverTestCase): access_token=user1_tok, ) self._assert_to_device_response(channel, []) + + def test_wait_for_new_data(self) -> None: + """ + Test to make sure that the Sliding Sync request waits for new data to arrive. + + (Only applies to incremental syncs with a `timeout` specified) + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass", "d1") + user2_id = self.register_user("u2", "pass") + user2_tok = self.login(user2_id, "pass", "d2") + + from_token = self.event_sources.get_current_token() + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint + + "?timeout=10000" + + f"&pos={self.get_success(from_token.to_string(self.store))}", + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + await_result=False, + ) + # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=5000) + # Bump the to-device messages to trigger new results + test_msg = {"foo": "bar"} + send_to_device_channel = self.make_request( + "PUT", + "/_matrix/client/r0/sendToDevice/m.test/1234", + content={"messages": {user1_id: {"d1": test_msg}}}, + access_token=user2_tok, + ) + self.assertEqual( + send_to_device_channel.code, 200, send_to_device_channel.result + ) + # Should respond before the 10 second timeout + channel.await_result(timeout_ms=3000) + self.assertEqual(channel.code, 200, channel.json_body) + + self._assert_to_device_response( + channel, + [{"content": test_msg, "sender": user2_id, "type": "m.test"}], + ) + + def test_wait_for_new_data_timeout(self) -> None: + """ + Test to make sure that the Sliding Sync request waits for new data to arrive but + no data ever arrives so we timeout. We're also making sure that the default data + from the To-Device extension doesn't trigger a false-positive for new data. + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + from_token = self.event_sources.get_current_token() + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint + + "?timeout=10000" + + f"&pos={self.get_success(from_token.to_string(self.store))}", + { + "lists": {}, + "extensions": { + "to_device": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + await_result=False, + ) + # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=5000) + # Wake-up `notifier.wait_for_events(...)` that will cause us test + # `SlidingSyncResult.__bool__` for new results. + self._bump_notifier_wait_for_events(user1_id) + # Block for a little bit more to ensure we don't see any new results. + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=4000) + # Wait for the sync to complete (wait for the rest of the 10 second timeout, + # 5000 + 4000 + 1200 > 10000) + channel.await_result(timeout_ms=1200) + self.assertEqual(channel.code, 200, channel.json_body) + + self._assert_to_device_response(channel, []) + + +class SlidingSyncE2eeExtensionTestCase(unittest.HomeserverTestCase): + """Tests for the e2ee sliding sync extension""" + + servlets = [ + synapse.rest.admin.register_servlets, + login.register_servlets, + room.register_servlets, + sync.register_servlets, + devices.register_servlets, + ] + + def default_config(self) -> JsonDict: + config = super().default_config() + # Enable sliding sync + config["experimental_features"] = {"msc3575_enabled": True} + return config + + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: + self.store = hs.get_datastores().main + self.event_sources = hs.get_event_sources() + self.e2e_keys_handler = hs.get_e2e_keys_handler() + self.account_data_handler = hs.get_account_data_handler() + self.notifier = hs.get_notifier() + self.sync_endpoint = ( + "/_matrix/client/unstable/org.matrix.simplified_msc3575/sync" + ) + + def _bump_notifier_wait_for_events(self, user_id: str) -> None: + """ + Wake-up a `notifier.wait_for_events(user_id)` call without affecting the Sliding + Sync results. + """ + # We're expecting some new activity from this point onwards + from_token = self.event_sources.get_current_token() + + triggered_notifier_wait_for_events = False + + async def _on_new_acivity( + before_token: StreamToken, after_token: StreamToken + ) -> bool: + nonlocal triggered_notifier_wait_for_events + triggered_notifier_wait_for_events = True + return True + + # Listen for some new activity for the user. We're just trying to confirm that + # our bump below actually does what we think it does (triggers new activity for + # the user). + result_awaitable = self.notifier.wait_for_events( + user_id, + 1000, + _on_new_acivity, + from_token=from_token, + ) + + # Update the account data so that `notifier.wait_for_events(...)` wakes up. + # We're bumping account data because it won't show up in the Sliding Sync + # response so it won't affect whether we have results. + self.get_success( + self.account_data_handler.add_account_data_for_user( + user_id, + "org.matrix.foobarbaz", + {"foo": "bar"}, + ) + ) + + # Wait for our notifier result + self.get_success(result_awaitable) + + if not triggered_notifier_wait_for_events: + raise AssertionError( + "Expected `notifier.wait_for_events(...)` to be triggered" + ) + + def test_no_data_initial_sync(self) -> None: + """ + Test that enabling e2ee extension works during an intitial sync, even if there + is no-data + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + # Make an initial Sliding Sync request with the e2ee extension enabled + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "e2ee": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Device list updates are only present for incremental syncs + self.assertIsNone(channel.json_body["extensions"]["e2ee"].get("device_lists")) + + # Both of these should be present even when empty + self.assertEqual( + channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"], + { + # This is always present because of + # https://github.com/element-hq/element-android/issues/3725 and + # https://github.com/matrix-org/synapse/issues/10456 + "signed_curve25519": 0 + }, + ) + self.assertEqual( + channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"], + [], + ) + + def test_no_data_incremental_sync(self) -> None: + """ + Test that enabling e2ee extension works during an incremental sync, even if + there is no-data + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + from_token = self.event_sources.get_current_token() + + # Make an incremental Sliding Sync request with the e2ee extension enabled + channel = self.make_request( + "POST", + self.sync_endpoint + + f"?pos={self.get_success(from_token.to_string(self.store))}", + { + "lists": {}, + "extensions": { + "e2ee": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Device list shows up for incremental syncs + self.assertEqual( + channel.json_body["extensions"]["e2ee"] + .get("device_lists", {}) + .get("changed"), + [], + ) + self.assertEqual( + channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"), + [], + ) + + # Both of these should be present even when empty + self.assertEqual( + channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"], + { + # Note that "signed_curve25519" is always returned in key count responses + # regardless of whether we uploaded any keys for it. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. + # + # Also related: + # https://github.com/element-hq/element-android/issues/3725 and + # https://github.com/matrix-org/synapse/issues/10456 + "signed_curve25519": 0 + }, + ) + self.assertEqual( + channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"], + [], + ) + + def test_wait_for_new_data(self) -> None: + """ + Test to make sure that the Sliding Sync request waits for new data to arrive. + + (Only applies to incremental syncs with a `timeout` specified) + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + test_device_id = "TESTDEVICE" + user3_id = self.register_user("user3", "pass") + user3_tok = self.login(user3_id, "pass", device_id=test_device_id) + + room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id, user1_id, tok=user1_tok) + self.helper.join(room_id, user3_id, tok=user3_tok) + + from_token = self.event_sources.get_current_token() + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint + + "?timeout=10000" + + f"&pos={self.get_success(from_token.to_string(self.store))}", + { + "lists": {}, + "extensions": { + "e2ee": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + await_result=False, + ) + # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=5000) + # Bump the device lists to trigger new results + # Have user3 update their device list + device_update_channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=user3_tok, + ) + self.assertEqual( + device_update_channel.code, 200, device_update_channel.json_body + ) + # Should respond before the 10 second timeout + channel.await_result(timeout_ms=3000) + self.assertEqual(channel.code, 200, channel.json_body) + + # We should see the device list update + self.assertEqual( + channel.json_body["extensions"]["e2ee"] + .get("device_lists", {}) + .get("changed"), + [user3_id], + ) + self.assertEqual( + channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"), + [], + ) + + def test_wait_for_new_data_timeout(self) -> None: + """ + Test to make sure that the Sliding Sync request waits for new data to arrive but + no data ever arrives so we timeout. We're also making sure that the default data + from the E2EE extension doesn't trigger a false-positive for new data (see + `device_one_time_keys_count.signed_curve25519`). + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + + from_token = self.event_sources.get_current_token() + + # Make the Sliding Sync request + channel = self.make_request( + "POST", + self.sync_endpoint + + "?timeout=10000" + + f"&pos={self.get_success(from_token.to_string(self.store))}", + { + "lists": {}, + "extensions": { + "e2ee": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + await_result=False, + ) + # Block for 5 seconds to make sure we are `notifier.wait_for_events(...)` + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=5000) + # Wake-up `notifier.wait_for_events(...)` that will cause us test + # `SlidingSyncResult.__bool__` for new results. + self._bump_notifier_wait_for_events(user1_id) + # Block for a little bit more to ensure we don't see any new results. + with self.assertRaises(TimedOutException): + channel.await_result(timeout_ms=4000) + # Wait for the sync to complete (wait for the rest of the 10 second timeout, + # 5000 + 4000 + 1200 > 10000) + channel.await_result(timeout_ms=1200) + self.assertEqual(channel.code, 200, channel.json_body) + + # Device lists are present for incremental syncs but empty because no device changes + self.assertEqual( + channel.json_body["extensions"]["e2ee"] + .get("device_lists", {}) + .get("changed"), + [], + ) + self.assertEqual( + channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"), + [], + ) + + # Both of these should be present even when empty + self.assertEqual( + channel.json_body["extensions"]["e2ee"]["device_one_time_keys_count"], + { + # Note that "signed_curve25519" is always returned in key count responses + # regardless of whether we uploaded any keys for it. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. + # + # Also related: + # https://github.com/element-hq/element-android/issues/3725 and + # https://github.com/matrix-org/synapse/issues/10456 + "signed_curve25519": 0 + }, + ) + self.assertEqual( + channel.json_body["extensions"]["e2ee"]["device_unused_fallback_key_types"], + [], + ) + + def test_device_lists(self) -> None: + """ + Test that device list updates are included in the response + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + test_device_id = "TESTDEVICE" + user3_id = self.register_user("user3", "pass") + user3_tok = self.login(user3_id, "pass", device_id=test_device_id) + + user4_id = self.register_user("user4", "pass") + user4_tok = self.login(user4_id, "pass") + + room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id, user1_id, tok=user1_tok) + self.helper.join(room_id, user3_id, tok=user3_tok) + self.helper.join(room_id, user4_id, tok=user4_tok) + + from_token = self.event_sources.get_current_token() + + # Have user3 update their device list + channel = self.make_request( + "PUT", + f"/devices/{test_device_id}", + { + "display_name": "New Device Name", + }, + access_token=user3_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # User4 leaves the room + self.helper.leave(room_id, user4_id, tok=user4_tok) + + # Make an incremental Sliding Sync request with the e2ee extension enabled + channel = self.make_request( + "POST", + self.sync_endpoint + + f"?pos={self.get_success(from_token.to_string(self.store))}", + { + "lists": {}, + "extensions": { + "e2ee": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Device list updates show up + self.assertEqual( + channel.json_body["extensions"]["e2ee"] + .get("device_lists", {}) + .get("changed"), + [user3_id], + ) + self.assertEqual( + channel.json_body["extensions"]["e2ee"].get("device_lists", {}).get("left"), + [user4_id], + ) + + def test_device_one_time_keys_count(self) -> None: + """ + Test that `device_one_time_keys_count` are included in the response + """ + test_device_id = "TESTDEVICE" + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass", device_id=test_device_id) + + # Upload one time keys for the user/device + keys: JsonDict = { + "alg1:k1": "key1", + "alg2:k2": {"key": "key2", "signatures": {"k1": "sig1"}}, + "alg2:k3": {"key": "key3"}, + } + upload_keys_response = self.get_success( + self.e2e_keys_handler.upload_keys_for_user( + user1_id, test_device_id, {"one_time_keys": keys} + ) + ) + self.assertDictEqual( + upload_keys_response, + { + "one_time_key_counts": { + "alg1": 1, + "alg2": 2, + # Note that "signed_curve25519" is always returned in key count responses + # regardless of whether we uploaded any keys for it. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. + # + # Also related: + # https://github.com/element-hq/element-android/issues/3725 and + # https://github.com/matrix-org/synapse/issues/10456 + "signed_curve25519": 0, + } + }, + ) + + # Make a Sliding Sync request with the e2ee extension enabled + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "e2ee": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check for those one time key counts + self.assertEqual( + channel.json_body["extensions"]["e2ee"].get("device_one_time_keys_count"), + { + "alg1": 1, + "alg2": 2, + # Note that "signed_curve25519" is always returned in key count responses + # regardless of whether we uploaded any keys for it. This is necessary until + # https://github.com/matrix-org/matrix-doc/issues/3298 is fixed. + # + # Also related: + # https://github.com/element-hq/element-android/issues/3725 and + # https://github.com/matrix-org/synapse/issues/10456 + "signed_curve25519": 0, + }, + ) + + def test_device_unused_fallback_key_types(self) -> None: + """ + Test that `device_unused_fallback_key_types` are included in the response + """ + test_device_id = "TESTDEVICE" + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass", device_id=test_device_id) + + # We shouldn't have any unused fallback keys yet + res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id) + ) + self.assertEqual(res, []) + + # Upload a fallback key for the user/device + self.get_success( + self.e2e_keys_handler.upload_keys_for_user( + user1_id, + test_device_id, + {"fallback_keys": {"alg1:k1": "fallback_key1"}}, + ) + ) + # We should now have an unused alg1 key + fallback_res = self.get_success( + self.store.get_e2e_unused_fallback_key_types(user1_id, test_device_id) + ) + self.assertEqual(fallback_res, ["alg1"], fallback_res) + + # Make a Sliding Sync request with the e2ee extension enabled + channel = self.make_request( + "POST", + self.sync_endpoint, + { + "lists": {}, + "extensions": { + "e2ee": { + "enabled": True, + } + }, + }, + access_token=user1_tok, + ) + self.assertEqual(channel.code, 200, channel.json_body) + + # Check for the unused fallback key types + self.assertListEqual( + channel.json_body["extensions"]["e2ee"].get( + "device_unused_fallback_key_types" + ), + ["alg1"], + ) -- cgit 1.5.1 From a9ee832e4899fefcfd27fba475091e1ffaa069b8 Mon Sep 17 00:00:00 2001 From: Michael Hollister Date: Tue, 23 Jul 2024 04:59:24 -0500 Subject: Fixed presence results not returning offline users on initial sync (#17231) This is to address an issue in which `m.presence` results on initial sync are not returning entries of users who are currently offline. The original behaviour was from https://github.com/element-hq/synapse/issues/1535 This change is useful for applications that use the presence system for tracking user profile information/updates (e.g. https://github.com/element-hq/synapse/pull/16992 or for profile status messages). This is gated behind a new configuration option to avoid performance impact for applications that don't need this, as a pragmatic solution for now. --- changelog.d/17231.bugfix | 1 + docs/usage/configuration/config_documentation.md | 5 +++++ synapse/config/server.py | 5 +++++ synapse/handlers/sync.py | 6 +++++- 4 files changed, 16 insertions(+), 1 deletion(-) create mode 100644 changelog.d/17231.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/17231.bugfix b/changelog.d/17231.bugfix new file mode 100644 index 0000000000..d09b455654 --- /dev/null +++ b/changelog.d/17231.bugfix @@ -0,0 +1 @@ +Added configurable option to always include offline users in presence sync results. Contributed by @Michael-Hollister. diff --git a/docs/usage/configuration/config_documentation.md b/docs/usage/configuration/config_documentation.md index e8bc2df798..649f4f71c7 100644 --- a/docs/usage/configuration/config_documentation.md +++ b/docs/usage/configuration/config_documentation.md @@ -246,6 +246,7 @@ Example configuration: ```yaml presence: enabled: false + include_offline_users_on_sync: false ``` `enabled` can also be set to a special value of "untracked" which ignores updates @@ -254,6 +255,10 @@ received via clients and federation, while still accepting updates from the *The "untracked" option was added in Synapse 1.96.0.* +When clients perform an initial or `full_state` sync, presence results for offline users are +not included by default. Setting `include_offline_users_on_sync` to `true` will always include +offline users in the results. Defaults to false. + --- ### `require_auth_for_profile_requests` diff --git a/synapse/config/server.py b/synapse/config/server.py index 8bb97df175..fd52c0475c 100644 --- a/synapse/config/server.py +++ b/synapse/config/server.py @@ -384,6 +384,11 @@ class ServerConfig(Config): # Whether to internally track presence, requires that presence is enabled, self.track_presence = self.presence_enabled and presence_enabled != "untracked" + # Determines if presence results for offline users are included on initial/full sync + self.presence_include_offline_users_on_sync = presence_config.get( + "include_offline_users_on_sync", False + ) + # Custom presence router module # This is the legacy way of configuring it (the config should now be put in the modules section) self.presence_router_module_class = None diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py index de227faec3..ede014180c 100644 --- a/synapse/handlers/sync.py +++ b/synapse/handlers/sync.py @@ -2270,7 +2270,11 @@ class SyncHandler: user=user, from_key=presence_key, is_guest=sync_config.is_guest, - include_offline=include_offline, + include_offline=( + True + if self.hs_config.server.presence_include_offline_users_on_sync + else include_offline + ), ) assert presence_key sync_result_builder.now_token = now_token.copy_and_replace( -- cgit 1.5.1 From d225b6b3ebea419bdf0e6c0f1476544053f2dcbf Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Tue, 23 Jul 2024 14:03:14 +0100 Subject: Speed up SS room sorting (#17468) We do this by bulk fetching the latest stream ordering. --------- Co-authored-by: Andrew Morgan <1342360+anoadragon453@users.noreply.github.com> --- changelog.d/17468.misc | 1 + synapse/handlers/sliding_sync.py | 43 ++++--- synapse/storage/databases/main/event_federation.py | 5 + synapse/storage/databases/main/stream.py | 123 ++++++++++++++++++++- synapse/util/caches/stream_change_cache.py | 12 +- tests/util/test_stream_change_cache.py | 4 +- 6 files changed, 159 insertions(+), 29 deletions(-) create mode 100644 changelog.d/17468.misc (limited to 'synapse/handlers') diff --git a/changelog.d/17468.misc b/changelog.d/17468.misc new file mode 100644 index 0000000000..d908776204 --- /dev/null +++ b/changelog.d/17468.misc @@ -0,0 +1 @@ +Speed up sorting of the room list in sliding sync. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 886d7c7159..554ab59bf3 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -1230,34 +1230,33 @@ class SlidingSyncHandler: # Assemble a map of room ID to the `stream_ordering` of the last activity that the # user should see in the room (<= `to_token`) last_activity_in_room_map: Dict[str, int] = {} - for room_id, room_for_user in sync_room_map.items(): - # If they are fully-joined to the room, let's find the latest activity - # at/before the `to_token`. - if room_for_user.membership == Membership.JOIN: - last_event_result = ( - await self.store.get_last_event_pos_in_room_before_stream_ordering( - room_id, to_token.room_key - ) - ) - - # If the room has no events at/before the `to_token`, this is probably a - # mistake in the code that generates the `sync_room_map` since that should - # only give us rooms that the user had membership in during the token range. - assert last_event_result is not None - _, event_pos = last_event_result - - last_activity_in_room_map[room_id] = event_pos.stream - else: - # Otherwise, if the user has left/been invited/knocked/been banned from - # a room, they shouldn't see anything past that point. + for room_id, room_for_user in sync_room_map.items(): + if room_for_user.membership != Membership.JOIN: + # If the user has left/been invited/knocked/been banned from a + # room, they shouldn't see anything past that point. # - # FIXME: It's possible that people should see beyond this point in - # invited/knocked cases if for example the room has + # FIXME: It's possible that people should see beyond this point + # in invited/knocked cases if for example the room has # `invite`/`world_readable` history visibility, see # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932 last_activity_in_room_map[room_id] = room_for_user.event_pos.stream + # For fully-joined rooms, we find the latest activity at/before the + # `to_token`. + joined_room_positions = ( + await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering( + [ + room_id + for room_id, room_for_user in sync_room_map.items() + if room_for_user.membership == Membership.JOIN + ], + to_token.room_key, + ) + ) + + last_activity_in_room_map.update(joined_room_positions) + return sorted( sync_room_map.values(), # Sort by the last activity (stream_ordering) in the room diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 24abab4a23..715846865b 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1313,6 +1313,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # We want to make the cache more effective, so we clamp to the last # change before the given ordering. last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined] + if last_change is None: + # If the room isn't in the cache we know that the last change was + # somewhere before the earliest known position of the cache, so we + # can clamp to that. + last_change = self._events_stream_cache.get_earliest_known_position() # type: ignore[attr-defined] # We don't always have a full stream_to_exterm_id table, e.g. after # the upgrade that introduced it, so we make sure we never ask for a diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py index e74e0d2e91..b034361aec 100644 --- a/synapse/storage/databases/main/stream.py +++ b/synapse/storage/databases/main/stream.py @@ -78,10 +78,11 @@ from synapse.storage.database import ( from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine from synapse.storage.util.id_generators import MultiWriterIdGenerator -from synapse.types import PersistedEventPosition, RoomStreamToken +from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection from synapse.util.caches.descriptors import cached from synapse.util.caches.stream_change_cache import StreamChangeCache from synapse.util.cancellation import cancellable +from synapse.util.iterutils import batch_iter if TYPE_CHECKING: from synapse.server import HomeServer @@ -1293,6 +1294,126 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore): get_last_event_pos_in_room_before_stream_ordering_txn, ) + async def bulk_get_last_event_pos_in_room_before_stream_ordering( + self, + room_ids: StrCollection, + end_token: RoomStreamToken, + ) -> Dict[str, int]: + """Bulk fetch the stream position of the latest events in the given + rooms + """ + + min_token = end_token.stream + max_token = end_token.get_max_stream_pos() + results: Dict[str, int] = {} + + # First, we check for the rooms in the stream change cache to see if we + # can just use the latest position from it. + missing_room_ids: Set[str] = set() + for room_id in room_ids: + stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id) + if stream_pos and stream_pos <= min_token: + results[room_id] = stream_pos + else: + missing_room_ids.add(room_id) + + # Next, we query the stream position from the DB. At first we fetch all + # positions less than the *max* stream pos in the token, then filter + # them down. We do this as a) this is a cheaper query, and b) the vast + # majority of rooms will have a latest token from before the min stream + # pos. + + def bulk_get_last_event_pos_txn( + txn: LoggingTransaction, batch_room_ids: StrCollection + ) -> Dict[str, int]: + # This query fetches the latest stream position in the rooms before + # the given max position. + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", batch_room_ids + ) + sql = f""" + SELECT room_id, ( + SELECT stream_ordering FROM events AS e + LEFT JOIN rejections USING (event_id) + WHERE e.room_id = r.room_id + AND stream_ordering <= ? + AND NOT outlier + AND rejection_reason IS NULL + ORDER BY stream_ordering DESC + LIMIT 1 + ) + FROM rooms AS r + WHERE {clause} + """ + txn.execute(sql, [max_token] + args) + return {row[0]: row[1] for row in txn} + + recheck_rooms: Set[str] = set() + for batched in batch_iter(missing_room_ids, 1000): + result = await self.db_pool.runInteraction( + "bulk_get_last_event_pos_in_room_before_stream_ordering", + bulk_get_last_event_pos_txn, + batched, + ) + + # Check that the stream position for the rooms are from before the + # minimum position of the token. If not then we need to fetch more + # rows. + for room_id, stream in result.items(): + if stream <= min_token: + results[room_id] = stream + else: + recheck_rooms.add(room_id) + + if not recheck_rooms: + return results + + # For the remaining rooms we need to fetch all rows between the min and + # max stream positions in the end token, and filter out the rows that + # are after the end token. + # + # This query should be fast as the range between the min and max should + # be small. + + def bulk_get_last_event_pos_recheck_txn( + txn: LoggingTransaction, batch_room_ids: StrCollection + ) -> Dict[str, int]: + clause, args = make_in_list_sql_clause( + self.database_engine, "room_id", batch_room_ids + ) + sql = f""" + SELECT room_id, instance_name, stream_ordering + FROM events + WHERE ? < stream_ordering AND stream_ordering <= ? + AND NOT outlier + AND rejection_reason IS NULL + AND {clause} + ORDER BY stream_ordering ASC + """ + txn.execute(sql, [min_token, max_token] + args) + + # We take the max stream ordering that is less than the token. Since + # we ordered by stream ordering we just need to iterate through and + # take the last matching stream ordering. + txn_results: Dict[str, int] = {} + for row in txn: + room_id = row[0] + event_pos = PersistedEventPosition(row[1], row[2]) + if not event_pos.persisted_after(end_token): + txn_results[room_id] = event_pos.stream + + return txn_results + + for batched in batch_iter(recheck_rooms, 1000): + recheck_result = await self.db_pool.runInteraction( + "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck", + bulk_get_last_event_pos_recheck_txn, + batched, + ) + results.update(recheck_result) + + return results + async def get_current_room_stream_token_for_room_id( self, room_id: str ) -> RoomStreamToken: diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py index 91c335f85b..16fcb00206 100644 --- a/synapse/util/caches/stream_change_cache.py +++ b/synapse/util/caches/stream_change_cache.py @@ -327,7 +327,7 @@ class StreamChangeCache: for entity in r: self._entity_to_key.pop(entity, None) - def get_max_pos_of_last_change(self, entity: EntityType) -> int: + def get_max_pos_of_last_change(self, entity: EntityType) -> Optional[int]: """Returns an upper bound of the stream id of the last change to an entity. @@ -335,7 +335,11 @@ class StreamChangeCache: entity: The entity to check. Return: - The stream position of the latest change for the given entity or - the earliest known stream position if the entitiy is unknown. + The stream position of the latest change for the given entity, if + known """ - return self._entity_to_key.get(entity, self._earliest_known_stream_pos) + return self._entity_to_key.get(entity) + + def get_earliest_known_position(self) -> int: + """Returns the earliest position in the cache.""" + return self._earliest_known_stream_pos diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py index 5d38718a50..af1199ef8a 100644 --- a/tests/util/test_stream_change_cache.py +++ b/tests/util/test_stream_change_cache.py @@ -249,5 +249,5 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase): self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3) self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4) - # Unknown entities will return the stream start position. - self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1) + # Unknown entities will return None + self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), None) -- cgit 1.5.1