summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17335.feature1
-rw-r--r--synapse/handlers/sliding_sync.py19
-rw-r--r--tests/handlers/test_sliding_sync.py74
-rw-r--r--tests/rest/client/test_sync.py148
4 files changed, 199 insertions, 43 deletions
diff --git a/changelog.d/17335.feature b/changelog.d/17335.feature
new file mode 100644
index 0000000000..c6beed42ed
--- /dev/null
+++ b/changelog.d/17335.feature
@@ -0,0 +1 @@
+Add `is_invite` filtering to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 16d94925f5..847a638bba 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -554,7 +554,7 @@ class SlidingSyncHandler:
 
             # Flatten out the map
             dm_room_id_set = set()
-            if dm_map:
+            if isinstance(dm_map, dict):
                 for room_ids in dm_map.values():
                     # Account data should be a list of room IDs. Ignore anything else
                     if isinstance(room_ids, list):
@@ -593,8 +593,21 @@ class SlidingSyncHandler:
                 ):
                     filtered_room_id_set.remove(room_id)
 
-        if filters.is_invite:
-            raise NotImplementedError()
+        # Filter for rooms that the user has been invited to
+        if filters.is_invite is not None:
+            # Make a copy so we don't run into an error: `Set changed size during
+            # iteration`, when we filter out and remove items
+            for room_id in list(filtered_room_id_set):
+                room_for_user = sync_room_map[room_id]
+                # If we're looking for invite rooms, filter out rooms that the user is
+                # not invited to and vice versa
+                if (
+                    filters.is_invite and room_for_user.membership != Membership.INVITE
+                ) or (
+                    not filters.is_invite
+                    and room_for_user.membership == Membership.INVITE
+                ):
+                    filtered_room_id_set.remove(room_id)
 
         if filters.room_types:
             raise NotImplementedError()
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index 0358239c7f..8dd4521b18 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -1200,11 +1200,7 @@ class FilterRoomsTestCase(HomeserverTestCase):
         user2_tok = self.login(user2_id, "pass")
 
         # Create a normal room
-        room_id = self.helper.create_room_as(
-            user1_id,
-            is_public=False,
-            tok=user1_tok,
-        )
+        room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
 
         # Create a DM room
         dm_room_id = self._create_dm_room(
@@ -1261,18 +1257,10 @@ class FilterRoomsTestCase(HomeserverTestCase):
         user1_tok = self.login(user1_id, "pass")
 
         # Create a normal room
-        room_id = self.helper.create_room_as(
-            user1_id,
-            is_public=False,
-            tok=user1_tok,
-        )
+        room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
 
         # Create an encrypted room
-        encrypted_room_id = self.helper.create_room_as(
-            user1_id,
-            is_public=False,
-            tok=user1_tok,
-        )
+        encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
         self.helper.send_state(
             encrypted_room_id,
             EventTypes.RoomEncryption,
@@ -1319,6 +1307,62 @@ class FilterRoomsTestCase(HomeserverTestCase):
 
         self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
 
+    def test_filter_invite_rooms(self) -> None:
+        """
+        Test `filter.is_invite` for rooms that the user has been invited to
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+        user2_id = self.register_user("user2", "pass")
+        user2_tok = self.login(user2_id, "pass")
+
+        # Create a normal room
+        room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+        self.helper.join(room_id, user1_id, tok=user1_tok)
+
+        # Create a room that user1 is invited to
+        invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+        self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+
+        after_rooms_token = self.event_sources.get_current_token()
+
+        # Get the rooms the user should be syncing with
+        sync_room_map = self.get_success(
+            self.sliding_sync_handler.get_sync_room_ids_for_user(
+                UserID.from_string(user1_id),
+                from_token=None,
+                to_token=after_rooms_token,
+            )
+        )
+
+        # Try with `is_invite=True`
+        truthy_filtered_room_map = self.get_success(
+            self.sliding_sync_handler.filter_rooms(
+                UserID.from_string(user1_id),
+                sync_room_map,
+                SlidingSyncConfig.SlidingSyncList.Filters(
+                    is_invite=True,
+                ),
+                after_rooms_token,
+            )
+        )
+
+        self.assertEqual(truthy_filtered_room_map.keys(), {invite_room_id})
+
+        # Try with `is_invite=False`
+        falsy_filtered_room_map = self.get_success(
+            self.sliding_sync_handler.filter_rooms(
+                UserID.from_string(user1_id),
+                sync_room_map,
+                SlidingSyncConfig.SlidingSyncList.Filters(
+                    is_invite=False,
+                ),
+                after_rooms_token,
+            )
+        )
+
+        self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
+
 
 class SortRoomsTestCase(HomeserverTestCase):
     """
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 5195659ec2..bfb26139d3 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -19,7 +19,8 @@
 #
 #
 import json
-from typing import List
+import logging
+from typing import Dict, List
 
 from parameterized import parameterized, parameterized_class
 
@@ -44,6 +45,8 @@ from tests.federation.transport.test_knocking import (
 )
 from tests.server import TimedOutException
 
+logger = logging.getLogger(__name__)
+
 
 class FilterTestCase(unittest.HomeserverTestCase):
     user_id = "@apple:test"
@@ -1234,12 +1237,58 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         self.store = hs.get_datastores().main
         self.event_sources = hs.get_event_sources()
 
+    def _add_new_dm_to_global_account_data(
+        self, source_user_id: str, target_user_id: str, target_room_id: str
+    ) -> None:
+        """
+        Helper to handle inserting a new DM for the source user into global account data
+        (handles all of the list merging).
+
+        Args:
+            source_user_id: The user ID of the DM mapping we're going to update
+            target_user_id: User ID of the person the DM is with
+            target_room_id: Room ID of the DM
+        """
+
+        # Get the current DM map
+        existing_dm_map = self.get_success(
+            self.store.get_global_account_data_by_type_for_user(
+                source_user_id, AccountDataTypes.DIRECT
+            )
+        )
+        # Scrutinize the account data since it has no concrete type. We're just copying
+        # everything into a known type. It should be a mapping from user ID to a list of
+        # room IDs. Ignore anything else.
+        new_dm_map: Dict[str, List[str]] = {}
+        if isinstance(existing_dm_map, dict):
+            for user_id, room_ids in existing_dm_map.items():
+                if isinstance(user_id, str) and isinstance(room_ids, list):
+                    for room_id in room_ids:
+                        if isinstance(room_id, str):
+                            new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
+                                room_id
+                            ]
+
+        # Add the new DM to the map
+        new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
+            target_room_id
+        ]
+        # Save the DM map to global account data
+        self.get_success(
+            self.store.add_account_data_for_user(
+                source_user_id,
+                AccountDataTypes.DIRECT,
+                new_dm_map,
+            )
+        )
+
     def _create_dm_room(
         self,
         inviter_user_id: str,
         inviter_tok: str,
         invitee_user_id: str,
         invitee_tok: str,
+        should_join_room: bool = True,
     ) -> str:
         """
         Helper to create a DM room as the "inviter" and invite the "invitee" user to the
@@ -1260,24 +1309,17 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             tok=inviter_tok,
             extra_data={"is_direct": True},
         )
-        # Person that was invited joins the room
-        self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
+        if should_join_room:
+            # Person that was invited joins the room
+            self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
 
         # Mimic the client setting the room as a direct message in the global account
-        # data
-        self.get_success(
-            self.store.add_account_data_for_user(
-                invitee_user_id,
-                AccountDataTypes.DIRECT,
-                {inviter_user_id: [room_id]},
-            )
+        # data for both users.
+        self._add_new_dm_to_global_account_data(
+            invitee_user_id, inviter_user_id, room_id
         )
-        self.get_success(
-            self.store.add_account_data_for_user(
-                inviter_user_id,
-                AccountDataTypes.DIRECT,
-                {invitee_user_id: [room_id]},
-            )
+        self._add_new_dm_to_global_account_data(
+            inviter_user_id, invitee_user_id, room_id
         )
 
         return room_id
@@ -1397,15 +1439,28 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         user2_tok = self.login(user2_id, "pass")
 
         # Create a DM room
-        dm_room_id = self._create_dm_room(
+        joined_dm_room_id = self._create_dm_room(
             inviter_user_id=user1_id,
             inviter_tok=user1_tok,
             invitee_user_id=user2_id,
             invitee_tok=user2_tok,
+            should_join_room=True,
+        )
+        invited_dm_room_id = self._create_dm_room(
+            inviter_user_id=user1_id,
+            inviter_tok=user1_tok,
+            invitee_user_id=user2_id,
+            invitee_tok=user2_tok,
+            should_join_room=False,
         )
 
         # Create a normal room
-        room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+        room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+        self.helper.join(room_id, user1_id, tok=user1_tok)
+
+        # Create a room that user1 is invited to
+        invite_room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+        self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
 
         # Make the Sliding Sync request
         channel = self.make_request(
@@ -1413,18 +1468,34 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             self.sync_endpoint,
             {
                 "lists": {
+                    # Absense of filters does not imply "False" values
+                    "all": {
+                        "ranges": [[0, 99]],
+                        "required_state": [],
+                        "timeline_limit": 1,
+                        "filters": {},
+                    },
+                    # Test single truthy filter
                     "dms": {
                         "ranges": [[0, 99]],
                         "required_state": [],
                         "timeline_limit": 1,
                         "filters": {"is_dm": True},
                     },
-                    "foo-list": {
+                    # Test single falsy filter
+                    "non-dms": {
                         "ranges": [[0, 99]],
                         "required_state": [],
                         "timeline_limit": 1,
                         "filters": {"is_dm": False},
                     },
+                    # Test how multiple filters should stack (AND'd together)
+                    "room-invites": {
+                        "ranges": [[0, 99]],
+                        "required_state": [],
+                        "timeline_limit": 1,
+                        "filters": {"is_dm": False, "is_invite": True},
+                    },
                 }
             },
             access_token=user1_tok,
@@ -1434,32 +1505,59 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
         # Make sure it has the foo-list we requested
         self.assertListEqual(
             list(channel.json_body["lists"].keys()),
-            ["dms", "foo-list"],
+            ["all", "dms", "non-dms", "room-invites"],
             channel.json_body["lists"].keys(),
         )
 
-        # Make sure the list includes the room we are joined to
+        # Make sure the lists have the correct rooms
+        self.assertListEqual(
+            list(channel.json_body["lists"]["all"]["ops"]),
+            [
+                {
+                    "op": "SYNC",
+                    "range": [0, 99],
+                    "room_ids": [
+                        invite_room_id,
+                        room_id,
+                        invited_dm_room_id,
+                        joined_dm_room_id,
+                    ],
+                }
+            ],
+            list(channel.json_body["lists"]["all"]),
+        )
         self.assertListEqual(
             list(channel.json_body["lists"]["dms"]["ops"]),
             [
                 {
                     "op": "SYNC",
                     "range": [0, 99],
-                    "room_ids": [dm_room_id],
+                    "room_ids": [invited_dm_room_id, joined_dm_room_id],
                 }
             ],
             list(channel.json_body["lists"]["dms"]),
         )
         self.assertListEqual(
-            list(channel.json_body["lists"]["foo-list"]["ops"]),
+            list(channel.json_body["lists"]["non-dms"]["ops"]),
             [
                 {
                     "op": "SYNC",
                     "range": [0, 99],
-                    "room_ids": [room_id],
+                    "room_ids": [invite_room_id, room_id],
+                }
+            ],
+            list(channel.json_body["lists"]["non-dms"]),
+        )
+        self.assertListEqual(
+            list(channel.json_body["lists"]["room-invites"]["ops"]),
+            [
+                {
+                    "op": "SYNC",
+                    "range": [0, 99],
+                    "room_ids": [invite_room_id],
                 }
             ],
-            list(channel.json_body["lists"]["foo-list"]),
+            list(channel.json_body["lists"]["room-invites"]),
         )
 
     def test_sort_list(self) -> None: