diff options
Diffstat (limited to 'tests/rest/client/test_sync.py')
-rw-r--r-- | tests/rest/client/test_sync.py | 148 |
1 files changed, 123 insertions, 25 deletions
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py index 5195659ec2..bfb26139d3 100644 --- a/tests/rest/client/test_sync.py +++ b/tests/rest/client/test_sync.py @@ -19,7 +19,8 @@ # # import json -from typing import List +import logging +from typing import Dict, List from parameterized import parameterized, parameterized_class @@ -44,6 +45,8 @@ from tests.federation.transport.test_knocking import ( ) from tests.server import TimedOutException +logger = logging.getLogger(__name__) + class FilterTestCase(unittest.HomeserverTestCase): user_id = "@apple:test" @@ -1234,12 +1237,58 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.store = hs.get_datastores().main self.event_sources = hs.get_event_sources() + def _add_new_dm_to_global_account_data( + self, source_user_id: str, target_user_id: str, target_room_id: str + ) -> None: + """ + Helper to handle inserting a new DM for the source user into global account data + (handles all of the list merging). + + Args: + source_user_id: The user ID of the DM mapping we're going to update + target_user_id: User ID of the person the DM is with + target_room_id: Room ID of the DM + """ + + # Get the current DM map + existing_dm_map = self.get_success( + self.store.get_global_account_data_by_type_for_user( + source_user_id, AccountDataTypes.DIRECT + ) + ) + # Scrutinize the account data since it has no concrete type. We're just copying + # everything into a known type. It should be a mapping from user ID to a list of + # room IDs. Ignore anything else. + new_dm_map: Dict[str, List[str]] = {} + if isinstance(existing_dm_map, dict): + for user_id, room_ids in existing_dm_map.items(): + if isinstance(user_id, str) and isinstance(room_ids, list): + for room_id in room_ids: + if isinstance(room_id, str): + new_dm_map[user_id] = new_dm_map.get(user_id, []) + [ + room_id + ] + + # Add the new DM to the map + new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [ + target_room_id + ] + # Save the DM map to global account data + self.get_success( + self.store.add_account_data_for_user( + source_user_id, + AccountDataTypes.DIRECT, + new_dm_map, + ) + ) + def _create_dm_room( self, inviter_user_id: str, inviter_tok: str, invitee_user_id: str, invitee_tok: str, + should_join_room: bool = True, ) -> str: """ Helper to create a DM room as the "inviter" and invite the "invitee" user to the @@ -1260,24 +1309,17 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): tok=inviter_tok, extra_data={"is_direct": True}, ) - # Person that was invited joins the room - self.helper.join(room_id, invitee_user_id, tok=invitee_tok) + if should_join_room: + # Person that was invited joins the room + self.helper.join(room_id, invitee_user_id, tok=invitee_tok) # Mimic the client setting the room as a direct message in the global account - # data - self.get_success( - self.store.add_account_data_for_user( - invitee_user_id, - AccountDataTypes.DIRECT, - {inviter_user_id: [room_id]}, - ) + # data for both users. + self._add_new_dm_to_global_account_data( + invitee_user_id, inviter_user_id, room_id ) - self.get_success( - self.store.add_account_data_for_user( - inviter_user_id, - AccountDataTypes.DIRECT, - {invitee_user_id: [room_id]}, - ) + self._add_new_dm_to_global_account_data( + inviter_user_id, invitee_user_id, room_id ) return room_id @@ -1397,15 +1439,28 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): user2_tok = self.login(user2_id, "pass") # Create a DM room - dm_room_id = self._create_dm_room( + joined_dm_room_id = self._create_dm_room( inviter_user_id=user1_id, inviter_tok=user1_tok, invitee_user_id=user2_id, invitee_tok=user2_tok, + should_join_room=True, + ) + invited_dm_room_id = self._create_dm_room( + inviter_user_id=user1_id, + inviter_tok=user1_tok, + invitee_user_id=user2_id, + invitee_tok=user2_tok, + should_join_room=False, ) # Create a normal room - room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True) + room_id = self.helper.create_room_as(user1_id, tok=user2_tok) + self.helper.join(room_id, user1_id, tok=user1_tok) + + # Create a room that user1 is invited to + invite_room_id = self.helper.create_room_as(user1_id, tok=user2_tok) + self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok) # Make the Sliding Sync request channel = self.make_request( @@ -1413,18 +1468,34 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): self.sync_endpoint, { "lists": { + # Absense of filters does not imply "False" values + "all": { + "ranges": [[0, 99]], + "required_state": [], + "timeline_limit": 1, + "filters": {}, + }, + # Test single truthy filter "dms": { "ranges": [[0, 99]], "required_state": [], "timeline_limit": 1, "filters": {"is_dm": True}, }, - "foo-list": { + # Test single falsy filter + "non-dms": { "ranges": [[0, 99]], "required_state": [], "timeline_limit": 1, "filters": {"is_dm": False}, }, + # Test how multiple filters should stack (AND'd together) + "room-invites": { + "ranges": [[0, 99]], + "required_state": [], + "timeline_limit": 1, + "filters": {"is_dm": False, "is_invite": True}, + }, } }, access_token=user1_tok, @@ -1434,32 +1505,59 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase): # Make sure it has the foo-list we requested self.assertListEqual( list(channel.json_body["lists"].keys()), - ["dms", "foo-list"], + ["all", "dms", "non-dms", "room-invites"], channel.json_body["lists"].keys(), ) - # Make sure the list includes the room we are joined to + # Make sure the lists have the correct rooms + self.assertListEqual( + list(channel.json_body["lists"]["all"]["ops"]), + [ + { + "op": "SYNC", + "range": [0, 99], + "room_ids": [ + invite_room_id, + room_id, + invited_dm_room_id, + joined_dm_room_id, + ], + } + ], + list(channel.json_body["lists"]["all"]), + ) self.assertListEqual( list(channel.json_body["lists"]["dms"]["ops"]), [ { "op": "SYNC", "range": [0, 99], - "room_ids": [dm_room_id], + "room_ids": [invited_dm_room_id, joined_dm_room_id], } ], list(channel.json_body["lists"]["dms"]), ) self.assertListEqual( - list(channel.json_body["lists"]["foo-list"]["ops"]), + list(channel.json_body["lists"]["non-dms"]["ops"]), [ { "op": "SYNC", "range": [0, 99], - "room_ids": [room_id], + "room_ids": [invite_room_id, room_id], + } + ], + list(channel.json_body["lists"]["non-dms"]), + ) + self.assertListEqual( + list(channel.json_body["lists"]["room-invites"]["ops"]), + [ + { + "op": "SYNC", + "range": [0, 99], + "room_ids": [invite_room_id], } ], - list(channel.json_body["lists"]["foo-list"]), + list(channel.json_body["lists"]["room-invites"]), ) def test_sort_list(self) -> None: |