diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 5195659ec2..bfb26139d3 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -19,7 +19,8 @@
#
#
import json
-from typing import List
+import logging
+from typing import Dict, List
from parameterized import parameterized, parameterized_class
@@ -44,6 +45,8 @@ from tests.federation.transport.test_knocking import (
)
from tests.server import TimedOutException
+logger = logging.getLogger(__name__)
+
class FilterTestCase(unittest.HomeserverTestCase):
user_id = "@apple:test"
@@ -1234,12 +1237,58 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
self.store = hs.get_datastores().main
self.event_sources = hs.get_event_sources()
+ def _add_new_dm_to_global_account_data(
+ self, source_user_id: str, target_user_id: str, target_room_id: str
+ ) -> None:
+ """
+ Helper to handle inserting a new DM for the source user into global account data
+ (handles all of the list merging).
+
+ Args:
+ source_user_id: The user ID of the DM mapping we're going to update
+ target_user_id: User ID of the person the DM is with
+ target_room_id: Room ID of the DM
+ """
+
+ # Get the current DM map
+ existing_dm_map = self.get_success(
+ self.store.get_global_account_data_by_type_for_user(
+ source_user_id, AccountDataTypes.DIRECT
+ )
+ )
+ # Scrutinize the account data since it has no concrete type. We're just copying
+ # everything into a known type. It should be a mapping from user ID to a list of
+ # room IDs. Ignore anything else.
+ new_dm_map: Dict[str, List[str]] = {}
+ if isinstance(existing_dm_map, dict):
+ for user_id, room_ids in existing_dm_map.items():
+ if isinstance(user_id, str) and isinstance(room_ids, list):
+ for room_id in room_ids:
+ if isinstance(room_id, str):
+ new_dm_map[user_id] = new_dm_map.get(user_id, []) + [
+ room_id
+ ]
+
+ # Add the new DM to the map
+ new_dm_map[target_user_id] = new_dm_map.get(target_user_id, []) + [
+ target_room_id
+ ]
+ # Save the DM map to global account data
+ self.get_success(
+ self.store.add_account_data_for_user(
+ source_user_id,
+ AccountDataTypes.DIRECT,
+ new_dm_map,
+ )
+ )
+
def _create_dm_room(
self,
inviter_user_id: str,
inviter_tok: str,
invitee_user_id: str,
invitee_tok: str,
+ should_join_room: bool = True,
) -> str:
"""
Helper to create a DM room as the "inviter" and invite the "invitee" user to the
@@ -1260,24 +1309,17 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
tok=inviter_tok,
extra_data={"is_direct": True},
)
- # Person that was invited joins the room
- self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
+ if should_join_room:
+ # Person that was invited joins the room
+ self.helper.join(room_id, invitee_user_id, tok=invitee_tok)
# Mimic the client setting the room as a direct message in the global account
- # data
- self.get_success(
- self.store.add_account_data_for_user(
- invitee_user_id,
- AccountDataTypes.DIRECT,
- {inviter_user_id: [room_id]},
- )
+ # data for both users.
+ self._add_new_dm_to_global_account_data(
+ invitee_user_id, inviter_user_id, room_id
)
- self.get_success(
- self.store.add_account_data_for_user(
- inviter_user_id,
- AccountDataTypes.DIRECT,
- {invitee_user_id: [room_id]},
- )
+ self._add_new_dm_to_global_account_data(
+ inviter_user_id, invitee_user_id, room_id
)
return room_id
@@ -1397,15 +1439,28 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
user2_tok = self.login(user2_id, "pass")
# Create a DM room
- dm_room_id = self._create_dm_room(
+ joined_dm_room_id = self._create_dm_room(
inviter_user_id=user1_id,
inviter_tok=user1_tok,
invitee_user_id=user2_id,
invitee_tok=user2_tok,
+ should_join_room=True,
+ )
+ invited_dm_room_id = self._create_dm_room(
+ inviter_user_id=user1_id,
+ inviter_tok=user1_tok,
+ invitee_user_id=user2_id,
+ invitee_tok=user2_tok,
+ should_join_room=False,
)
# Create a normal room
- room_id = self.helper.create_room_as(user1_id, tok=user1_tok, is_public=True)
+ room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+ self.helper.join(room_id, user1_id, tok=user1_tok)
+
+ # Create a room that user1 is invited to
+ invite_room_id = self.helper.create_room_as(user1_id, tok=user2_tok)
+ self.helper.invite(invite_room_id, src=user2_id, targ=user1_id, tok=user2_tok)
# Make the Sliding Sync request
channel = self.make_request(
@@ -1413,18 +1468,34 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
self.sync_endpoint,
{
"lists": {
+ # Absense of filters does not imply "False" values
+ "all": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {},
+ },
+ # Test single truthy filter
"dms": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {"is_dm": True},
},
- "foo-list": {
+ # Test single falsy filter
+ "non-dms": {
"ranges": [[0, 99]],
"required_state": [],
"timeline_limit": 1,
"filters": {"is_dm": False},
},
+ # Test how multiple filters should stack (AND'd together)
+ "room-invites": {
+ "ranges": [[0, 99]],
+ "required_state": [],
+ "timeline_limit": 1,
+ "filters": {"is_dm": False, "is_invite": True},
+ },
}
},
access_token=user1_tok,
@@ -1434,32 +1505,59 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
# Make sure it has the foo-list we requested
self.assertListEqual(
list(channel.json_body["lists"].keys()),
- ["dms", "foo-list"],
+ ["all", "dms", "non-dms", "room-invites"],
channel.json_body["lists"].keys(),
)
- # Make sure the list includes the room we are joined to
+ # Make sure the lists have the correct rooms
+ self.assertListEqual(
+ list(channel.json_body["lists"]["all"]["ops"]),
+ [
+ {
+ "op": "SYNC",
+ "range": [0, 99],
+ "room_ids": [
+ invite_room_id,
+ room_id,
+ invited_dm_room_id,
+ joined_dm_room_id,
+ ],
+ }
+ ],
+ list(channel.json_body["lists"]["all"]),
+ )
self.assertListEqual(
list(channel.json_body["lists"]["dms"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
- "room_ids": [dm_room_id],
+ "room_ids": [invited_dm_room_id, joined_dm_room_id],
}
],
list(channel.json_body["lists"]["dms"]),
)
self.assertListEqual(
- list(channel.json_body["lists"]["foo-list"]["ops"]),
+ list(channel.json_body["lists"]["non-dms"]["ops"]),
[
{
"op": "SYNC",
"range": [0, 99],
- "room_ids": [room_id],
+ "room_ids": [invite_room_id, room_id],
+ }
+ ],
+ list(channel.json_body["lists"]["non-dms"]),
+ )
+ self.assertListEqual(
+ list(channel.json_body["lists"]["room-invites"]["ops"]),
+ [
+ {
+ "op": "SYNC",
+ "range": [0, 99],
+ "room_ids": [invite_room_id],
}
],
- list(channel.json_body["lists"]["foo-list"]),
+ list(channel.json_body["lists"]["room-invites"]),
)
def test_sort_list(self) -> None:
|