summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Eastwood <eric.eastwood@beta.gouv.fr>2024-06-17 12:06:18 -0500
committerGitHub <noreply@github.com>2024-06-17 12:06:18 -0500
commita5485437cf8006b80345f2e0af6e233881e9de21 (patch)
tree823b380e636827f851746513b05a4e4f8e25cc47
parentAdd `stream_ordering` sort to Sliding Sync `/sync` (#17293) (diff)
downloadsynapse-a5485437cf8006b80345f2e0af6e233881e9de21.tar.xz
Add `is_encrypted` filtering to Sliding Sync `/sync` (#17281)
Based on [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575): Sliding Sync
-rw-r--r--changelog.d/17281.feature1
-rw-r--r--synapse/handlers/sliding_sync.py26
-rw-r--r--synapse/handlers/sync.py109
-rw-r--r--synapse/storage/controllers/state.py87
-rw-r--r--tests/handlers/test_sliding_sync.py66
5 files changed, 189 insertions, 100 deletions
diff --git a/changelog.d/17281.feature b/changelog.d/17281.feature
new file mode 100644
index 0000000000..fce512692c
--- /dev/null
+++ b/changelog.d/17281.feature
@@ -0,0 +1 @@
+Add `is_encrypted` filtering to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index b84cf67f7d..16d94925f5 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
 from immutabledict import immutabledict
 
-from synapse.api.constants import AccountDataTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
 from synapse.events import EventBase
 from synapse.storage.roommember import RoomsForUser
 from synapse.types import (
@@ -33,6 +33,7 @@ from synapse.types import (
     UserID,
 )
 from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult
+from synapse.types.state import StateFilter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -85,6 +86,7 @@ class SlidingSyncHandler:
     def __init__(self, hs: "HomeServer"):
         self.clock = hs.get_clock()
         self.store = hs.get_datastores().main
+        self.storage_controllers = hs.get_storage_controllers()
         self.auth_blocking = hs.get_auth_blocking()
         self.notifier = hs.get_notifier()
         self.event_sources = hs.get_event_sources()
@@ -570,8 +572,26 @@ class SlidingSyncHandler:
         if filters.spaces:
             raise NotImplementedError()
 
-        if filters.is_encrypted:
-            raise NotImplementedError()
+        # Filter for encrypted rooms
+        if filters.is_encrypted is not None:
+            # Make a copy so we don't run into an error: `Set changed size during
+            # iteration`, when we filter out and remove items
+            for room_id in list(filtered_room_id_set):
+                state_at_to_token = await self.storage_controllers.state.get_state_at(
+                    room_id,
+                    to_token,
+                    state_filter=StateFilter.from_types(
+                        [(EventTypes.RoomEncryption, "")]
+                    ),
+                )
+                is_encrypted = state_at_to_token.get((EventTypes.RoomEncryption, ""))
+
+                # If we're looking for encrypted rooms, filter out rooms that are not
+                # encrypted and vice versa
+                if (filters.is_encrypted and not is_encrypted) or (
+                    not filters.is_encrypted and is_encrypted
+                ):
+                    filtered_room_id_set.remove(room_id)
 
         if filters.is_invite:
             raise NotImplementedError()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0a40d62c6a..e2563428d2 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -979,91 +979,6 @@ class SyncHandler:
             bundled_aggregations=bundled_aggregations,
         )
 
-    async def get_state_after_event(
-        self,
-        event_id: str,
-        state_filter: Optional[StateFilter] = None,
-        await_full_state: bool = True,
-    ) -> StateMap[str]:
-        """
-        Get the room state after the given event
-
-        Args:
-            event_id: event of interest
-            state_filter: The state filter used to fetch state from the database.
-            await_full_state: if `True`, will block if we do not yet have complete state
-                at the event and `state_filter` is not satisfied by partial state.
-                Defaults to `True`.
-        """
-        state_ids = await self._state_storage_controller.get_state_ids_for_event(
-            event_id,
-            state_filter=state_filter or StateFilter.all(),
-            await_full_state=await_full_state,
-        )
-
-        # using get_metadata_for_events here (instead of get_event) sidesteps an issue
-        # with redactions: if `event_id` is a redaction event, and we don't have the
-        # original (possibly because it got purged), get_event will refuse to return
-        # the redaction event, which isn't terribly helpful here.
-        #
-        # (To be fair, in that case we could assume it's *not* a state event, and
-        # therefore we don't need to worry about it. But still, it seems cleaner just
-        # to pull the metadata.)
-        m = (await self.store.get_metadata_for_events([event_id]))[event_id]
-        if m.state_key is not None and m.rejection_reason is None:
-            state_ids = dict(state_ids)
-            state_ids[(m.event_type, m.state_key)] = event_id
-
-        return state_ids
-
-    async def get_state_at(
-        self,
-        room_id: str,
-        stream_position: StreamToken,
-        state_filter: Optional[StateFilter] = None,
-        await_full_state: bool = True,
-    ) -> StateMap[str]:
-        """Get the room state at a particular stream position
-
-        Args:
-            room_id: room for which to get state
-            stream_position: point at which to get state
-            state_filter: The state filter used to fetch state from the database.
-            await_full_state: if `True`, will block if we do not yet have complete state
-                at the last event in the room before `stream_position` and
-                `state_filter` is not satisfied by partial state. Defaults to `True`.
-        """
-        # FIXME: This gets the state at the latest event before the stream ordering,
-        # which might not be the same as the "current state" of the room at the time
-        # of the stream token if there were multiple forward extremities at the time.
-        last_event_id = (
-            await self.store.get_last_event_id_in_room_before_stream_ordering(
-                room_id,
-                end_token=stream_position.room_key,
-            )
-        )
-
-        if last_event_id:
-            state = await self.get_state_after_event(
-                last_event_id,
-                state_filter=state_filter or StateFilter.all(),
-                await_full_state=await_full_state,
-            )
-
-        else:
-            # no events in this room - so presumably no state
-            state = {}
-
-            # (erikj) This should be rarely hit, but we've had some reports that
-            # we get more state down gappy syncs than we should, so let's add
-            # some logging.
-            logger.info(
-                "Failed to find any events in room %s at %s",
-                room_id,
-                stream_position.room_key,
-            )
-        return state
-
     async def compute_summary(
         self,
         room_id: str,
@@ -1437,7 +1352,7 @@ class SyncHandler:
             await_full_state = True
             lazy_load_members = False
 
-        state_at_timeline_end = await self.get_state_at(
+        state_at_timeline_end = await self._state_storage_controller.get_state_at(
             room_id,
             stream_position=end_token,
             state_filter=state_filter,
@@ -1565,7 +1480,7 @@ class SyncHandler:
         else:
             # We can get here if the user has ignored the senders of all
             # the recent events.
-            state_at_timeline_start = await self.get_state_at(
+            state_at_timeline_start = await self._state_storage_controller.get_state_at(
                 room_id,
                 stream_position=end_token,
                 state_filter=state_filter,
@@ -1587,14 +1502,14 @@ class SyncHandler:
             # about them).
             state_filter = StateFilter.all()
 
-        state_at_previous_sync = await self.get_state_at(
+        state_at_previous_sync = await self._state_storage_controller.get_state_at(
             room_id,
             stream_position=since_token,
             state_filter=state_filter,
             await_full_state=await_full_state,
         )
 
-        state_at_timeline_end = await self.get_state_at(
+        state_at_timeline_end = await self._state_storage_controller.get_state_at(
             room_id,
             stream_position=end_token,
             state_filter=state_filter,
@@ -2593,7 +2508,7 @@ class SyncHandler:
                 continue
 
             if room_id in sync_result_builder.joined_room_ids or has_join:
-                old_state_ids = await self.get_state_at(
+                old_state_ids = await self._state_storage_controller.get_state_at(
                     room_id,
                     since_token,
                     state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
@@ -2623,12 +2538,14 @@ class SyncHandler:
                     newly_left_rooms.append(room_id)
                 else:
                     if not old_state_ids:
-                        old_state_ids = await self.get_state_at(
-                            room_id,
-                            since_token,
-                            state_filter=StateFilter.from_types(
-                                [(EventTypes.Member, user_id)]
-                            ),
+                        old_state_ids = (
+                            await self._state_storage_controller.get_state_at(
+                                room_id,
+                                since_token,
+                                state_filter=StateFilter.from_types(
+                                    [(EventTypes.Member, user_id)]
+                                ),
+                            )
                         )
                         old_mem_ev_id = old_state_ids.get(
                             (EventTypes.Member, user_id), None
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index f9eced23bf..cc9b162ae4 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -45,7 +45,7 @@ from synapse.storage.util.partial_state_events_tracker import (
     PartialStateEventsTracker,
 )
 from synapse.synapse_rust.acl import ServerAclEvaluator
-from synapse.types import MutableStateMap, StateMap, get_domain_from_id
+from synapse.types import MutableStateMap, StateMap, StreamToken, get_domain_from_id
 from synapse.types.state import StateFilter
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
@@ -372,6 +372,91 @@ class StateStorageController:
         )
         return state_map[event_id]
 
+    async def get_state_after_event(
+        self,
+        event_id: str,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
+    ) -> StateMap[str]:
+        """
+        Get the room state after the given event
+
+        Args:
+            event_id: event of interest
+            state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the event and `state_filter` is not satisfied by partial state.
+                Defaults to `True`.
+        """
+        state_ids = await self.get_state_ids_for_event(
+            event_id,
+            state_filter=state_filter or StateFilter.all(),
+            await_full_state=await_full_state,
+        )
+
+        # using get_metadata_for_events here (instead of get_event) sidesteps an issue
+        # with redactions: if `event_id` is a redaction event, and we don't have the
+        # original (possibly because it got purged), get_event will refuse to return
+        # the redaction event, which isn't terribly helpful here.
+        #
+        # (To be fair, in that case we could assume it's *not* a state event, and
+        # therefore we don't need to worry about it. But still, it seems cleaner just
+        # to pull the metadata.)
+        m = (await self.stores.main.get_metadata_for_events([event_id]))[event_id]
+        if m.state_key is not None and m.rejection_reason is None:
+            state_ids = dict(state_ids)
+            state_ids[(m.event_type, m.state_key)] = event_id
+
+        return state_ids
+
+    async def get_state_at(
+        self,
+        room_id: str,
+        stream_position: StreamToken,
+        state_filter: Optional[StateFilter] = None,
+        await_full_state: bool = True,
+    ) -> StateMap[str]:
+        """Get the room state at a particular stream position
+
+        Args:
+            room_id: room for which to get state
+            stream_position: point at which to get state
+            state_filter: The state filter used to fetch state from the database.
+            await_full_state: if `True`, will block if we do not yet have complete state
+                at the last event in the room before `stream_position` and
+                `state_filter` is not satisfied by partial state. Defaults to `True`.
+        """
+        # FIXME: This gets the state at the latest event before the stream ordering,
+        # which might not be the same as the "current state" of the room at the time
+        # of the stream token if there were multiple forward extremities at the time.
+        last_event_id = (
+            await self.stores.main.get_last_event_id_in_room_before_stream_ordering(
+                room_id,
+                end_token=stream_position.room_key,
+            )
+        )
+
+        if last_event_id:
+            state = await self.get_state_after_event(
+                last_event_id,
+                state_filter=state_filter or StateFilter.all(),
+                await_full_state=await_full_state,
+            )
+
+        else:
+            # no events in this room - so presumably no state
+            state = {}
+
+            # (erikj) This should be rarely hit, but we've had some reports that
+            # we get more state down gappy syncs than we should, so let's add
+            # some logging.
+            logger.info(
+                "Failed to find any events in room %s at %s",
+                room_id,
+                stream_position.room_key,
+            )
+        return state
+
     @trace
     @tag_args
     async def get_state_for_groups(
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index af48041f1f..0358239c7f 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -1253,6 +1253,72 @@ class FilterRoomsTestCase(HomeserverTestCase):
 
         self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
 
+    def test_filter_encrypted_rooms(self) -> None:
+        """
+        Test `filter.is_encrypted` for encrypted rooms
+        """
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        # Create a normal room
+        room_id = self.helper.create_room_as(
+            user1_id,
+            is_public=False,
+            tok=user1_tok,
+        )
+
+        # Create an encrypted room
+        encrypted_room_id = self.helper.create_room_as(
+            user1_id,
+            is_public=False,
+            tok=user1_tok,
+        )
+        self.helper.send_state(
+            encrypted_room_id,
+            EventTypes.RoomEncryption,
+            {"algorithm": "m.megolm.v1.aes-sha2"},
+            tok=user1_tok,
+        )
+
+        after_rooms_token = self.event_sources.get_current_token()
+
+        # Get the rooms the user should be syncing with
+        sync_room_map = self.get_success(
+            self.sliding_sync_handler.get_sync_room_ids_for_user(
+                UserID.from_string(user1_id),
+                from_token=None,
+                to_token=after_rooms_token,
+            )
+        )
+
+        # Try with `is_encrypted=True`
+        truthy_filtered_room_map = self.get_success(
+            self.sliding_sync_handler.filter_rooms(
+                UserID.from_string(user1_id),
+                sync_room_map,
+                SlidingSyncConfig.SlidingSyncList.Filters(
+                    is_encrypted=True,
+                ),
+                after_rooms_token,
+            )
+        )
+
+        self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
+
+        # Try with `is_encrypted=False`
+        falsy_filtered_room_map = self.get_success(
+            self.sliding_sync_handler.filter_rooms(
+                UserID.from_string(user1_id),
+                sync_room_map,
+                SlidingSyncConfig.SlidingSyncList.Filters(
+                    is_encrypted=False,
+                ),
+                after_rooms_token,
+            )
+        )
+
+        self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
+
 
 class SortRoomsTestCase(HomeserverTestCase):
     """