diff --git a/changelog.d/17335.feature b/changelog.d/17335.feature
new file mode 100644
index 0000000000..c6beed42ed
--- /dev/null
+++ b/changelog.d/17335.feature
@@ -0,0 +1 @@
+Add `is_invite` filtering to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 16d94925f5..847a638bba 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -554,7 +554,7 @@ class SlidingSyncHandler:
# Flatten out the map
dm_room_id_set = set()
- if dm_map:
+ if isinstance(dm_map, dict):
for room_ids in dm_map.values():
# Account data should be a list of room IDs. Ignore anything else
if isinstance(room_ids, list):
@@ -593,8 +593,21 @@ class SlidingSyncHandler:
):
filtered_room_id_set.remove(room_id)
- if filters.is_invite:
- raise NotImplementedError()
+ # Filter for rooms that the user has been invited to
+ if filters.is_invite is not None:
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for room_id in list(filtered_room_id_set):
+ room_for_user = sync_room_map[room_id]
+ # If we're looking for invite rooms, filter out rooms that the user is
+ # not invited to and vice versa
+ if (
+ filters.is_invite and room_for_user.membership != Membership.INVITE
+ ) or (
+ not filters.is_invite
+ and room_for_user.membership == Membership.INVITE
+ ):
+ filtered_room_id_set.remove(room_id)
if filters.room_types:
raise NotImplementedError()
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index 0358239c7f..8dd4521b18 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -1200,11 +1200,7 @@ class FilterRoomsTestCase(HomeserverTestCase):
user2_tok = self.login(user2_id, "pass")
# Create a normal room
- room_id = self.helper.create_room_as(
- user1_id,
- is_public=False,
- tok=user1_tok,
- )
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create a DM room
dm_room_id = self._create_dm_room(
@@ -1261,18 +1257,10 @@ class FilterRoomsTestCase(HomeserverTestCase):
user1_tok = self.login(user1_id, "pass")
# Create a normal room
- room_id = self.helper.create_room_as(
- user1_id,
- is_public=False,
- tok=user1_tok,
- )
+ room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
# Create an encrypted room
- encrypted_room_id = self.helper.create_room_as(
- user1_id,
- is_public=False,
- tok=user1_tok,
- )
+ encrypted_room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
self.helper.send_state(
encrypted_room_id,
EventTypes.RoomEncryption,
@@ -1319,6 +1307,62 @@ class FilterRoomsTestCase(HomeserverTestCase):
self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
+ def test_filter_invite_rooms(self) -> None:
+ """
+ Test `filter.is_invite` for rooms that the user has been invited to
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+ user2_id = self.register_user("user2", "pass")
+ user2_tok = self.login(user2_id, "pass")
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that user1 is invited to
+ invite_room_id = self.helper.create_room_as(user2_id, tok=user2_tok)
+ self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
+
+ after_rooms_token = self.event_sources.get_current_token()
+
+ # Get the rooms the user should be syncing with
+ sync_room_map = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
+ )
+ )
+
+ # Try with `is_invite=True`
+ truthy_filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(
+ is_invite=True,
+ ),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(truthy_filtered_room_map.keys(), {invite_room_id})
+
+ # Try with `is_invite=False`
+ falsy_filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(
+ is_invite=False,
+ ),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
+
class SortRoomsTestCase(HomeserverTestCase):
"""
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 5195659ec2..bfb26139d3 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -19,7 +19,8 @@
#
#
import json
-from typing import List
+import logging
+from typing import Dict, List
from parameterized import parameterized, parameterized_class
@@ -44,6 +45,8 @@ from tests.federation.transport.test_knocking import (
)
from tests.server import TimedOutException
+logger = logging.getLogger(__name__)
+
class FilterTestCase(unittest.HomeserverTestCase):
user_id = "@apple:test"
@@ -1234,12 +1237,58 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.event_sources = hs.get_event_sources()
+ def _add_new_dm_to_global_account_data(
+ self, source_user_id: str, target_user_id: str, target_room_id: str
+ ) -> None:
+ """
+ Helper to handle inserting a new DM for the source user into global account data
+ (handles all of the list merging).
+
+ Args:
+ source_user_id: The user ID of the DM mapping we're going to update
+ target_user_id: User ID of the person the DM is with
+ target_room_id: Room ID of the DM
+ """
+
+ # Get the current DM map
+ existing_dm_map = self.get_success(
+ self.store.get_global_account_data_by_type_for_user(
+ source_user_id, AccountDataTypes.DIRECT
+ )
+ )
+ # Scrutinize the account data since it has no concrete type. We're just copying
+ # everything into a known type. It should be a mapping from user ID to a list of
+ # room IDs. Ignore anything else.
+ new_dm_map: Dict[str, List[str]] = {}
+ if isinstance(existing_dm_map, dict):
+ for user_id, room_ids in existing_dm_map.items():
+ if isinstance(user_id, str) and isinstance(room_ids, list):
+ for room_id in room_ids:
+ if isinstance(room_id, str):
+ new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
+ room_id
+ ]
+
+ # Add the new DM to the map
+ new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
+ target_room_id
+ ]
+ # Save the DM map to global account data
+ self.get_success(
+ self.store.add_account_data_for_user(
+ source_user_id,
+ AccountDataTypes.DIRECT,
+ new_dm_map,
+ )
+ )
+
def _create_dm_room(
self,
inviter_user_id: str,
inviter_tok: str,
invitee_user_id: str,
invitee_tok: str,
+ should_join_room: bool = True,
) -> str:
"""
Helper to create a DM room as the "inviter" and invite the "invitee" user to the
@@ -1260,24 +1309,17 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
tok=inviter_tok,
extra_data={"is_direct": True},
)
- # Person that was invited joins the room
- self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
+ if should_join_room:
+ # Person that was invited joins the room
+ self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
# Mimic the client setting the room as a direct message in the global account
- # data
- self.get_success(
- self.store.add_account_data_for_user(
- invitee_user_id,
- AccountDataTypes.DIRECT,
- {inviter_user_id: [room_id]},
- )
+ # data for both users.
+ self._add_new_dm_to_global_account_data(
+ invitee_user_id, inviter_user_id, room_id
)
- self.get_success(
- self.store.add_account_data_for_user(
- inviter_user_id,
- AccountDataTypes.DIRECT,
- {invitee_user_id: [room_id]},
- )
+ self._add_new_dm_to_global_account_data(
+ inviter_user_id, invitee_user_id, room_id
)
return room_id
@@ -1397,15 +1439,28 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
user2_tok = self.login(user2_id, "pass")
# Create a DM room
- dm_room_id = self._create_dm_room(
+ joined_dm_room_id = self._create_dm_room(
inviter_user_id=user1_id,
inviter_tok=user1_tok,
invitee_user_id=user2_id,
invitee_tok=user2_tok,
+ should_join_room=True,
+ )
+ invited_dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ should_join_room=False,
)
# Create a normal room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that user1 is invited to
+ invite_room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+ self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
# Make the Sliding Sync request
channel = self.make_request(
@@ -1413,18 +1468,34 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
self.sync_endpoint,
{
"lists": {
+ # Absense of filters does not imply "False" values
+ "all": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {},
+ },
+ # Test single truthy filter
"dms": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {"is_dm": True},
},
- "foo-list": {
+ # Test single falsy filter
+ "non-dms": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {"is_dm": False},
},
+ # Test how multiple filters should stack (AND'd together)
+ "room-invites": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {"is_dm": False, "is_invite": True},
+ },
}
},
access_token=user1_tok,
@@ -1434,32 +1505,59 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
# Make sure it has the foo-list we requested
self.assertListEqual(
list(channel.json_body["lists"].keys()),
- ["dms", "foo-list"],
+ ["all", "dms", "non-dms", "room-invites"],
channel.json_body["lists"].keys(),
)
- # Make sure the list includes the room we are joined to
+ # Make sure the lists have the correct rooms
+ self.assertListEqual(
+ list(channel.json_body["lists"]["all"]["ops"]),
+ [
+ {
+ "op": "SYNC",
+ "range": [0, 99],
+ "room_ids": [
+ invite_room_id,
+ room_id,
+ invited_dm_room_id,
+ joined_dm_room_id,
+ ],
+ }
+ ],
+ list(channel.json_body["lists"]["all"]),
+ )
self.assertListEqual(
list(channel.json_body["lists"]["dms"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
- "room_ids": [dm_room_id],
+ "room_ids": [invited_dm_room_id, joined_dm_room_id],
}
],
list(channel.json_body["lists"]["dms"]),
)
self.assertListEqual(
- list(channel.json_body["lists"]["foo-list"]["ops"]),
+ list(channel.json_body["lists"]["non-dms"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
- "room_ids": [room_id],
+ "room_ids": [invite_room_id, room_id],
+ }
+ ],
+ list(channel.json_body["lists"]["non-dms"]),
+ )
+ self.assertListEqual(
+ list(channel.json_body["lists"]["room-invites"]["ops"]),
+ [
+ {
+ "op": "SYNC",
+ "range": [0, 99],
+ "room_ids": [invite_room_id],
}
],
- list(channel.json_body["lists"]["foo-list"]),
+ list(channel.json_body["lists"]["room-invites"]),
)
def test_sort_list(self) -> None:
|