summary refs log tree commit diff
path: root/tests/rest/client/test_sync.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_sync.py')
-rw-r--r--tests/rest/client/test_sync.py148
1 files changed, 123 insertions, 25 deletions
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 5195659ec2..bfb26139d3 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -19,7 +19,8 @@
 #
 #
 import json
-from typing import List
+import logging
+from typing import Dict, List
 
 from parameterized import parameterized, parameterized_class
 
@@ -44,6 +45,8 @@ from tests.federation.transport.test_knocking import (
 )
 from tests.server import TimedOutException
 
+logger = logging.getLogger(__name__)
+
 
 class FilterTestCase(unittest.HomeserverTestCase):
     user_id = "@apple:test"
@@ -1234,12 +1237,58 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.event_sources = hs.get_event_sources()
 
+    def _add_new_dm_to_global_account_data(
+        self, source_user_id: str, target_user_id: str, target_room_id: str
+    ) -> None:
+        """
+        Helper to handle inserting a new DM for the source user into global account data
+        (handles all of the list merging).
+
+        Args:
+            source_user_id: The user ID of the DM mapping we're going to update
+            target_user_id: User ID of the person the DM is with
+            target_room_id: Room ID of the DM
+        """
+
+        # Get the current DM map
+        existing_dm_map = self.get_success(
+            self.store.get_global_account_data_by_type_for_user(
+                source_user_id, AccountDataTypes.DIRECT
+            )
+        )
+        # Scrutinize the account data since it has no concrete type. We're just copying
+        # everything into a known type. It should be a mapping from user ID to a list of
+        # room IDs. Ignore anything else.
+        new_dm_map: Dict[str, List[str]] = {}
+        if isinstance(existing_dm_map, dict):
+            for user_id, room_ids in existing_dm_map.items():
+                if isinstance(user_id, str) and isinstance(room_ids, list):
+                    for room_id in room_ids:
+                        if isinstance(room_id, str):
+                            new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
+                                room_id
+                            ]
+
+        # Add the new DM to the map
+        new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
+            target_room_id
+        ]
+        # Save the DM map to global account data
+        self.get_success(
+            self.store.add_account_data_for_user(
+                source_user_id,
+                AccountDataTypes.DIRECT,
+                new_dm_map,
+            )
+        )
+
     def _create_dm_room(
         self,
         inviter_user_id: str,
         inviter_tok: str,
         invitee_user_id: str,
         invitee_tok: str,
+        should_join_room: bool = True,
     ) -> str:
         """
         Helper to create a DM room as the "inviter" and invite the "invitee" user to the
@@ -1260,24 +1309,17 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             tok=inviter_tok,
             extra_data={"is_direct": True},
         )
-        # Person that was invited joins the room
-        self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
+        if should_join_room:
+            # Person that was invited joins the room
+            self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
 
         # Mimic the client setting the room as a direct message in the global account
-        # data
-        self.get_success(
-            self.store.add_account_data_for_user(
-                invitee_user_id,
-                AccountDataTypes.DIRECT,
-                {inviter_user_id: [room_id]},
-            )
+        # data for both users.
+        self._add_new_dm_to_global_account_data(
+            invitee_user_id, inviter_user_id, room_id
         )
-        self.get_success(
-            self.store.add_account_data_for_user(
-                inviter_user_id,
-                AccountDataTypes.DIRECT,
-                {invitee_user_id: [room_id]},
-            )
+        self._add_new_dm_to_global_account_data(
+            inviter_user_id, invitee_user_id, room_id
         )
 
         return room_id
@@ -1397,15 +1439,28 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         user2_tok = self.login(user2_id, "pass")
 
         # Create a DM room
-        dm_room_id = self._create_dm_room(
+        joined_dm_room_id = self._create_dm_room(
             inviter_user_id=user1_id,
             inviter_tok=user1_tok,
             invitee_user_id=user2_id,
             invitee_tok=user2_tok,
+            should_join_room=True,
+        )
+        invited_dm_room_id = self._create_dm_room(
+            inviter_user_id=user1_id,
+            inviter_tok=user1_tok,
+            invitee_user_id=user2_id,
+            invitee_tok=user2_tok,
+            should_join_room=False,
         )
 
         # Create a normal room
-        room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+        room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+        self.helper.join(room_id, user1_id, tok=user1_tok)
+
+        # Create a room that user1 is invited to
+        invite_room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+        self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
 
         # Make the Sliding Sync request
         channel = self.make_request(
@@ -1413,18 +1468,34 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             self.sync_endpoint,
             {
                 "lists": {
+                    # Absense of filters does not imply "False" values
+                    "all": {
+                        "ranges": [[0, 99]],
+                        "required_state": [],
+                        "timeline_limit": 1,
+                        "filters": {},
+                    },
+                    # Test single truthy filter
                     "dms": {
                         "ranges": [[0, 99]],
                         "required_state": [],
                         "timeline_limit": 1,
                         "filters": {"is_dm": True},
                     },
-                    "foo-list": {
+                    # Test single falsy filter
+                    "non-dms": {
                         "ranges": [[0, 99]],
                         "required_state": [],
                         "timeline_limit": 1,
                         "filters": {"is_dm": False},
                     },
+                    # Test how multiple filters should stack (AND'd together)
+                    "room-invites": {
+                        "ranges": [[0, 99]],
+                        "required_state": [],
+                        "timeline_limit": 1,
+                        "filters": {"is_dm": False, "is_invite": True},
+                    },
                 }
             },
             access_token=user1_tok,
@@ -1434,32 +1505,59 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         # Make sure it has the foo-list we requested
         self.assertListEqual(
             list(channel.json_body["lists"].keys()),
-            ["dms", "foo-list"],
+            ["all", "dms", "non-dms", "room-invites"],
             channel.json_body["lists"].keys(),
         )
 
-        # Make sure the list includes the room we are joined to
+        # Make sure the lists have the correct rooms
+        self.assertListEqual(
+            list(channel.json_body["lists"]["all"]["ops"]),
+            [
+                {
+                    "op": "SYNC",
+                    "range": [0, 99],
+                    "room_ids": [
+                        invite_room_id,
+                        room_id,
+                        invited_dm_room_id,
+                        joined_dm_room_id,
+                    ],
+                }
+            ],
+            list(channel.json_body["lists"]["all"]),
+        )
         self.assertListEqual(
             list(channel.json_body["lists"]["dms"]["ops"]),
             [
                 {
                     "op": "SYNC",
                     "range": [0, 99],
-                    "room_ids": [dm_room_id],
+                    "room_ids": [invited_dm_room_id, joined_dm_room_id],
                 }
             ],
             list(channel.json_body["lists"]["dms"]),
         )
         self.assertListEqual(
-            list(channel.json_body["lists"]["foo-list"]["ops"]),
+            list(channel.json_body["lists"]["non-dms"]["ops"]),
             [
                 {
                     "op": "SYNC",
                     "range": [0, 99],
-                    "room_ids": [room_id],
+                    "room_ids": [invite_room_id, room_id],
+                }
+            ],
+            list(channel.json_body["lists"]["non-dms"]),
+        )
+        self.assertListEqual(
+            list(channel.json_body["lists"]["room-invites"]["ops"]),
+            [
+                {
+                    "op": "SYNC",
+                    "range": [0, 99],
+                    "room_ids": [invite_room_id],
                 }
             ],
-            list(channel.json_body["lists"]["foo-list"]),
+            list(channel.json_body["lists"]["room-invites"]),
         )
 
     def test_sort_list(self) -> None: