summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17429.feature1
-rw-r--r--synapse/handlers/sliding_sync.py75
-rw-r--r--tests/rest/client/test_sync.py23
3 files changed, 70 insertions, 29 deletions
diff --git a/changelog.d/17429.feature b/changelog.d/17429.feature
new file mode 100644
index 0000000000..608b75d632
--- /dev/null
+++ b/changelog.d/17429.feature
@@ -0,0 +1 @@
+Populate `is_dm` room field in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index e3230d28e6..904787ced3 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -291,6 +291,7 @@ class _RoomMembershipForUser:
         sender: The person who sent the membership event
         newly_joined: Whether the user newly joined the room during the given token
             range
+        is_dm: Whether this user considers this room as a direct-message (DM) room
     """
 
     room_id: str
@@ -299,6 +300,7 @@ class _RoomMembershipForUser:
     membership: str
     sender: Optional[str]
     newly_joined: bool
+    is_dm: bool
 
     def copy_and_replace(self, **kwds: Any) -> "_RoomMembershipForUser":
         return attr.evolve(self, **kwds)
@@ -613,6 +615,7 @@ class SlidingSyncHandler:
                 membership=room_for_user.membership,
                 sender=room_for_user.sender,
                 newly_joined=False,
+                is_dm=False,
             )
             for room_for_user in room_for_user_list
         }
@@ -652,6 +655,7 @@ class SlidingSyncHandler:
         # - 1c) Update room membership events to the point in time of the `to_token`
         # - 2) Add back newly_left rooms (> `from_token` and <= `to_token`)
         # - 3) Figure out which rooms are `newly_joined`
+        # - 4) Figure out which rooms are DM's
 
         # 1) -----------------------------------------------------
 
@@ -714,6 +718,7 @@ class SlidingSyncHandler:
                         membership=first_membership_change_after_to_token.prev_membership,
                         sender=first_membership_change_after_to_token.prev_sender,
                         newly_joined=False,
+                        is_dm=False,
                     )
                 else:
                     # If we can't find the previous membership event, we shouldn't
@@ -809,6 +814,7 @@ class SlidingSyncHandler:
                     membership=last_membership_change_in_from_to_range.membership,
                     sender=last_membership_change_in_from_to_range.sender,
                     newly_joined=False,
+                    is_dm=False,
                 )
 
         # 3) Figure out `newly_joined`
@@ -846,6 +852,35 @@ class SlidingSyncHandler:
                         room_id
                     ].copy_and_replace(newly_joined=True)
 
+        # 4) Figure out which rooms the user considers to be direct-message (DM) rooms
+        #
+        # We're using global account data (`m.direct`) instead of checking for
+        # `is_direct` on membership events because that property only appears for
+        # the invitee membership event (doesn't show up for the inviter).
+        #
+        # We're unable to take `to_token` into account for global account data since
+        # we only keep track of the latest account data for the user.
+        dm_map = await self.store.get_global_account_data_by_type_for_user(
+            user_id, AccountDataTypes.DIRECT
+        )
+
+        # Flatten out the map. Account data is set by the client so it needs to be
+        # scrutinized.
+        dm_room_id_set = set()
+        if isinstance(dm_map, dict):
+            for room_ids in dm_map.values():
+                # Account data should be a list of room IDs. Ignore anything else
+                if isinstance(room_ids, list):
+                    for room_id in room_ids:
+                        if isinstance(room_id, str):
+                            dm_room_id_set.add(room_id)
+
+        # 4) Fixup
+        for room_id in filtered_sync_room_id_set:
+            filtered_sync_room_id_set[room_id] = filtered_sync_room_id_set[
+                room_id
+            ].copy_and_replace(is_dm=room_id in dm_room_id_set)
+
         return filtered_sync_room_id_set
 
     async def filter_rooms(
@@ -869,41 +904,24 @@ class SlidingSyncHandler:
             A filtered dictionary of room IDs along with membership information in the
             room at the time of `to_token`.
         """
-        user_id = user.to_string()
-
-        # TODO: Apply filters
-
         filtered_room_id_set = set(sync_room_map.keys())
 
         # Filter for Direct-Message (DM) rooms
         if filters.is_dm is not None:
-            # We're using global account data (`m.direct`) instead of checking for
-            # `is_direct` on membership events because that property only appears for
-            # the invitee membership event (doesn't show up for the inviter). Account
-            # data is set by the client so it needs to be scrutinized.
-            #
-            # We're unable to take `to_token` into account for global account data since
-            # we only keep track of the latest account data for the user.
-            dm_map = await self.store.get_global_account_data_by_type_for_user(
-                user_id, AccountDataTypes.DIRECT
-            )
-
-            # Flatten out the map
-            dm_room_id_set = set()
-            if isinstance(dm_map, dict):
-                for room_ids in dm_map.values():
-                    # Account data should be a list of room IDs. Ignore anything else
-                    if isinstance(room_ids, list):
-                        for room_id in room_ids:
-                            if isinstance(room_id, str):
-                                dm_room_id_set.add(room_id)
-
             if filters.is_dm:
                 # Only DM rooms please
-                filtered_room_id_set = filtered_room_id_set.intersection(dm_room_id_set)
+                filtered_room_id_set = {
+                    room_id
+                    for room_id in filtered_room_id_set
+                    if sync_room_map[room_id].is_dm
+                }
             else:
                 # Only non-DM rooms please
-                filtered_room_id_set = filtered_room_id_set.difference(dm_room_id_set)
+                filtered_room_id_set = {
+                    room_id
+                    for room_id in filtered_room_id_set
+                    if not sync_room_map[room_id].is_dm
+                }
 
         if filters.spaces:
             raise NotImplementedError()
@@ -1538,8 +1556,7 @@ class SlidingSyncHandler:
             name=room_name,
             avatar=room_avatar,
             heroes=heroes,
-            # TODO: Dummy value
-            is_dm=False,
+            is_dm=room_membership_for_user_at_to_token.is_dm,
             initial=initial,
             required_state=list(required_room_state.values()),
             timeline_events=timeline_events,
diff --git a/tests/rest/client/test_sync.py b/tests/rest/client/test_sync.py
index 0d0bea538b..4236812db5 100644
--- a/tests/rest/client/test_sync.py
+++ b/tests/rest/client/test_sync.py
@@ -1662,6 +1662,20 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             list(channel.json_body["lists"]["room-invites"]),
         )
 
+        # Ensure DM's are correctly marked
+        self.assertDictEqual(
+            {
+                room_id: room.get("is_dm")
+                for room_id, room in channel.json_body["rooms"].items()
+            },
+            {
+                invite_room_id: None,
+                room_id: None,
+                invited_dm_room_id: True,
+                joined_dm_room_id: True,
+            },
+        )
+
     def test_sort_list(self) -> None:
         """
         Test that the `lists` are sorted by `stream_ordering`
@@ -1874,6 +1888,9 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             channel.json_body["rooms"][room_id1]["invited_count"],
             0,
         )
+        self.assertIsNone(
+            channel.json_body["rooms"][room_id1].get("is_dm"),
+        )
 
     def test_rooms_meta_when_invited(self) -> None:
         """
@@ -1955,6 +1972,9 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             channel.json_body["rooms"][room_id1]["invited_count"],
             1,
         )
+        self.assertIsNone(
+            channel.json_body["rooms"][room_id1].get("is_dm"),
+        )
 
     def test_rooms_meta_when_banned(self) -> None:
         """
@@ -2037,6 +2057,9 @@ class SlidingSyncTestCase(unittest.HomeserverTestCase):
             channel.json_body["rooms"][room_id1]["invited_count"],
             0,
         )
+        self.assertIsNone(
+            channel.json_body["rooms"][room_id1].get("is_dm"),
+        )
 
     def test_rooms_meta_heroes(self) -> None:
         """