diff options
-rw-r--r-- | changelog.d/17335.feature | 1 | ||||
-rw-r--r-- | synapse/handlers/sliding_sync.py | 19 | ||||
-rw-r--r-- | tests/handlers/test_sliding_sync.py | 74 | ||||
-rw-r--r-- | tests/rest/client/test_sync.py | 148 |
4 files changed, 199 insertions, 43 deletions
diff --git a/changelog.d/17335.feature b/changelog.d/17335.feature new file mode 100644 index 0000000000..c6beed42ed --- /dev/null +++ b/changelog.d/17335.feature @@ -0,0 +1 @@ +Add `is_invite` filtering to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint. diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py index 16d94925f5..847a638bba 100644 --- a/synapse/handlers/sliding_sync.py +++ b/synapse/handlers/sliding_sync.py @@ -554,7 +554,7 @@ class SlidingSyncHandler: # Flatten out the map dm_room_id_set = set() - if dm_map: + if isinstance(dm_map, dict): for room_ids in dm_map.values(): # Account data should be a list of room IDs. Ignore anything else if isinstance(room_ids, list): @@ -593,8 +593,21 @@ class SlidingSyncHandler: ): filtered_room_id_set.remove(room_id) - if filters.is_invite: - raise NotImplementedError() + # Filter for rooms that the user has been invited to + if filters.is_invite is not None: + # Make a copy so we don't run into an error: `Set changed size during + # iteration`, when we filter out and remove items + for room_id in list(filtered_room_id_set): + room_for_user = sync_room_map[room_id] + # If we're looking for invite rooms, filter out rooms that the user is + # not invited to and vice versa + if ( + filters.is_invite and room_for_user.membership != Membership.INVITE + ) or ( + not filters.is_invite + and room_for_user.membership == Membership.INVITE + ): + filtered_room_id_set.remove(room_id) if filters.room_types: raise NotImplementedError() diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py index 0358239c7f..8dd4521b18 100644 --- a/tests/handlers/test_sliding_sync.py +++ b/tests/handlers/test_sliding_sync.py @@ -1200,11 +1200,7 @@ class FilterRoomsTestCase(HomeserverTestCase): user2_tok = self.login(user2_id, "pass") # Create a normal room - room_id = self.helper.create_room_as( - user1_id, - is_public=False, - tok=user1_tok, - ) + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) # Create a DM room dm_room_id = self._create_dm_room( @@ -1261,18 +1257,10 @@ class FilterRoomsTestCase(HomeserverTestCase): user1_tok = self.login(user1_id, "pass") # Create a normal room - room_id = self.helper.create_room_as( - user1_id, - is_public=False, - tok=user1_tok, - ) + room_id = self.helper.create_room_as(user1_id, tok=user1_tok) # Create an encrypted room - encrypted_room_id = self.helper.create_room_as( - user1_id, - is_public=False, - tok=user1_tok, - ) + encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok) self.helper.send_state( encrypted_room_id, EventTypes.RoomEncryption, @@ -1319,6 +1307,62 @@ class FilterRoomsTestCase(HomeserverTestCase): self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) + def test_filter_invite_rooms(self) -> None: + """ + Test `filter.is_invite` for rooms that the user has been invited to + """ + user1_id = self.register_user("user1", "pass") + user1_tok = self.login(user1_id, "pass") + user2_id = self.register_user("user2", "pass") + user2_tok = self.login(user2_id, "pass") + + # Create a normal room + room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.join(room_id, user1_id, tok=user1_tok) + + # Create a room that user1 is invited to + invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok) + self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok) + + after_rooms_token = self.event_sources.get_current_token() + + # Get the rooms the user should be syncing with + sync_room_map = self.get_success( + self.sliding_sync_handler.get_sync_room_ids_for_user( + UserID.from_string(user1_id), + from_token=None, + to_token=after_rooms_token, + ) + ) + + # Try with `is_invite=True` + truthy_filtered_room_map = self.get_success( + self.sliding_sync_handler.filter_rooms( + UserID.from_string(user1_id), + sync_room_map, + SlidingSyncConfig.SlidingSyncList.Filters( + is_invite=True, + ), + after_rooms_token, + ) + ) + + self.assertEqual(truthy_filtered_room_map.keys(), {invite_room_id}) + + # Try with `is_invite=False` + falsy_filtered_room_map = self.get_success( + self.sliding_sync_handler.filter_rooms( + UserID.from_string(user1_id), + sync_room_map, + SlidingSyncConfig.SlidingSyncList.Filters( + is_invite=False, + ), + after_rooms_token, + ) + ) + + self.assertEqual(falsy_filtered_room_map.keys(), {room_id}) + class SortRoomsTestCase(HomeserverTestCase): """ diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 5195659ec2..bfb26139d3 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -19,7 +19,8 @@ # # import json -from typing import List +import logging +from typing import Dict, List from parameterized import parameterized, parameterized_class @@ -44,6 +45,8 @@ from tests.federation.transport.test_knocking import ( ) from tests.server import TimedOutException +logger = logging.getLogger(__name__) + class FilterTestCase(unittest.HomeserverTestCase): user_id = "@apple:test" @@ -1234,12 +1237,58 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.event_sources = hs.get_event_sources() + def _add_new_dm_to_global_account_data( + self, source_user_id: str, target_user_id: str, target_room_id: str + ) -> None: + """ + Helper to handle inserting a new DM for the source user into global account data + (handles all of the list merging). + + Args: + source_user_id: The user ID of the DM mapping we're going to update + target_user_id: User ID of the person the DM is with + target_room_id: Room ID of the DM + """ + + # Get the current DM map + existing_dm_map = self.get_success( + self.store.get_global_account_data_by_type_for_user( + source_user_id, AccountDataTypes.DIRECT + ) + ) + # Scrutinize the account data since it has no concrete type. We're just copying + # everything into a known type. It should be a mapping from user ID to a list of + # room IDs. Ignore anything else. + new_dm_map: Dict[str, List[str]] = {} + if isinstance(existing_dm_map, dict): + for user_id, room_ids in existing_dm_map.items(): + if isinstance(user_id, str) and isinstance(room_ids, list): + for room_id in room_ids: + if isinstance(room_id, str): + new_dm_map[user_id] = new_dm_map.get(user_id, []) + [ + room_id + ] + + # Add the new DM to the map + new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [ + target_room_id + ] + # Save the DM map to global account data + self.get_success( + self.store.add_account_data_for_user( + source_user_id, + AccountDataTypes.DIRECT, + new_dm_map, + ) + ) + def _create_dm_room( self, inviter_user_id: str, inviter_tok: str, invitee_user_id: str, invitee_tok: str, + should_join_room: bool = True, ) -> str: """ Helper to create a DM room as the "inviter" and invite the "invitee" user to the @@ -1260,24 +1309,17 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): tok=inviter_tok, extra_data={"is_direct": True}, ) - # Person that was invited joins the room - self.helper.join(room_id, invitee_user_id, tok=invitee_tok) + if should_join_room: + # Person that was invited joins the room + self.helper.join(room_id, invitee_user_id, tok=invitee_tok) # Mimic the client setting the room as a direct message in the global account - # data - self.get_success( - self.store.add_account_data_for_user( - invitee_user_id, - AccountDataTypes.DIRECT, - {inviter_user_id: [room_id]}, - ) + # data for both users. + self._add_new_dm_to_global_account_data( + invitee_user_id, inviter_user_id, room_id ) - self.get_success( - self.store.add_account_data_for_user( - inviter_user_id, - AccountDataTypes.DIRECT, - {invitee_user_id: [room_id]}, - ) + self._add_new_dm_to_global_account_data( + inviter_user_id, invitee_user_id, room_id ) return room_id @@ -1397,15 +1439,28 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): user2_tok = self.login(user2_id, "pass") # Create a DM room - dm_room_id = self._create_dm_room( + joined_dm_room_id = self._create_dm_room( inviter_user_id=user1_id, inviter_tok=user1_tok, invitee_user_id=user2_id, invitee_tok=user2_tok, + should_join_room=True, + ) + invited_dm_room_id = self._create_dm_room( + inviter_user_id=user1_id, + inviter_tok=user1_tok, + invitee_user_id=user2_id, + invitee_tok=user2_tok, + should_join_room=False, ) # Create a normal room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) + room_id = self.helper.create_room_as(user1_id, tok=user2_tok) + self.helper.join(room_id, user1_id, tok=user1_tok) + + # Create a room that user1 is invited to + invite_room_id = self.helper.create_room_as(user1_id, tok=user2_tok) + self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok) # Make the Sliding Sync request channel = self.make_request( @@ -1413,18 +1468,34 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.sync_endpoint, { "lists": { + # Absense of filters does not imply "False" values + "all": { + "ranges": [[0, 99]], + "required_state": [], + "timeline_limit": 1, + "filters": {}, + }, + # Test single truthy filter "dms": { "ranges": [[0, 99]], "required_state": [], "timeline_limit": 1, "filters": {"is_dm": True}, }, - "foo-list": { + # Test single falsy filter + "non-dms": { "ranges": [[0, 99]], "required_state": [], "timeline_limit": 1, "filters": {"is_dm": False}, }, + # Test how multiple filters should stack (AND'd together) + "room-invites": { + "ranges": [[0, 99]], + "required_state": [], + "timeline_limit": 1, + "filters": {"is_dm": False, "is_invite": True}, + }, } }, access_token=user1_tok, @@ -1434,32 +1505,59 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # Make sure it has the foo-list we requested self.assertListEqual( list(channel.json_body["lists"].keys()), - ["dms", "foo-list"], + ["all", "dms", "non-dms", "room-invites"], channel.json_body["lists"].keys(), ) - # Make sure the list includes the room we are joined to + # Make sure the lists have the correct rooms + self.assertListEqual( + list(channel.json_body["lists"]["all"]["ops"]), + [ + { + "op": "SYNC", + "range": [0, 99], + "room_ids": [ + invite_room_id, + room_id, + invited_dm_room_id, + joined_dm_room_id, + ], + } + ], + list(channel.json_body["lists"]["all"]), + ) self.assertListEqual( list(channel.json_body["lists"]["dms"]["ops"]), [ { "op": "SYNC", "range": [0, 99], - "room_ids": [dm_room_id], + "room_ids": [invited_dm_room_id, joined_dm_room_id], } ], list(channel.json_body["lists"]["dms"]), ) self.assertListEqual( - list(channel.json_body["lists"]["foo-list"]["ops"]), + list(channel.json_body["lists"]["non-dms"]["ops"]), [ { "op": "SYNC", "range": [0, 99], - "room_ids": [room_id], + "room_ids": [invite_room_id, room_id], + } + ], + list(channel.json_body["lists"]["non-dms"]), + ) + self.assertListEqual( + list(channel.json_body["lists"]["room-invites"]["ops"]), + [ + { + "op": "SYNC", + "range": [0, 99], + "room_ids": [invite_room_id], } ], - list(channel.json_body["lists"]["foo-list"]), + list(channel.json_body["lists"]["room-invites"]), ) def test_sort_list(self) -> None: |