diff --git a/changelog.d/17281.feature b/changelog.d/17281.feature
new file mode 100644
index 0000000000..fce512692c
--- /dev/null
+++ b/changelog.d/17281.feature
@@ -0,0 +1 @@
+Add `is_encrypted` filtering to experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index b84cf67f7d..16d94925f5 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
from immutabledict import immutabledict
-from synapse.api.constants import AccountDataTypes, Membership
+from synapse.api.constants import AccountDataTypes, EventTypes, Membership
from synapse.events import EventBase
from synapse.storage.roommember import RoomsForUser
from synapse.types import (
@@ -33,6 +33,7 @@ from synapse.types import (
UserID,
)
from synapse.types.handlers import OperationType, SlidingSyncConfig, SlidingSyncResult
+from synapse.types.state import StateFilter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -85,6 +86,7 @@ class SlidingSyncHandler:
def __init__(self, hs: "HomeServer"):
self.clock = hs.get_clock()
self.store = hs.get_datastores().main
+ self.storage_controllers = hs.get_storage_controllers()
self.auth_blocking = hs.get_auth_blocking()
self.notifier = hs.get_notifier()
self.event_sources = hs.get_event_sources()
@@ -570,8 +572,26 @@ class SlidingSyncHandler:
if filters.spaces:
raise NotImplementedError()
- if filters.is_encrypted:
- raise NotImplementedError()
+ # Filter for encrypted rooms
+ if filters.is_encrypted is not None:
+ # Make a copy so we don't run into an error: `Set changed size during
+ # iteration`, when we filter out and remove items
+ for room_id in list(filtered_room_id_set):
+ state_at_to_token = await self.storage_controllers.state.get_state_at(
+ room_id,
+ to_token,
+ state_filter=StateFilter.from_types(
+ [(EventTypes.RoomEncryption, "")]
+ ),
+ )
+ is_encrypted = state_at_to_token.get((EventTypes.RoomEncryption, ""))
+
+ # If we're looking for encrypted rooms, filter out rooms that are not
+ # encrypted and vice versa
+ if (filters.is_encrypted and not is_encrypted) or (
+ not filters.is_encrypted and is_encrypted
+ ):
+ filtered_room_id_set.remove(room_id)
if filters.is_invite:
raise NotImplementedError()
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 0a40d62c6a..e2563428d2 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -979,91 +979,6 @@ class SyncHandler:
bundled_aggregations=bundled_aggregations,
)
- async def get_state_after_event(
- self,
- event_id: str,
- state_filter: Optional[StateFilter] = None,
- await_full_state: bool = True,
- ) -> StateMap[str]:
- """
- Get the room state after the given event
-
- Args:
- event_id: event of interest
- state_filter: The state filter used to fetch state from the database.
- await_full_state: if `True`, will block if we do not yet have complete state
- at the event and `state_filter` is not satisfied by partial state.
- Defaults to `True`.
- """
- state_ids = await self._state_storage_controller.get_state_ids_for_event(
- event_id,
- state_filter=state_filter or StateFilter.all(),
- await_full_state=await_full_state,
- )
-
- # using get_metadata_for_events here (instead of get_event) sidesteps an issue
- # with redactions: if `event_id` is a redaction event, and we don't have the
- # original (possibly because it got purged), get_event will refuse to return
- # the redaction event, which isn't terribly helpful here.
- #
- # (To be fair, in that case we could assume it's *not* a state event, and
- # therefore we don't need to worry about it. But still, it seems cleaner just
- # to pull the metadata.)
- m = (await self.store.get_metadata_for_events([event_id]))[event_id]
- if m.state_key is not None and m.rejection_reason is None:
- state_ids = dict(state_ids)
- state_ids[(m.event_type, m.state_key)] = event_id
-
- return state_ids
-
- async def get_state_at(
- self,
- room_id: str,
- stream_position: StreamToken,
- state_filter: Optional[StateFilter] = None,
- await_full_state: bool = True,
- ) -> StateMap[str]:
- """Get the room state at a particular stream position
-
- Args:
- room_id: room for which to get state
- stream_position: point at which to get state
- state_filter: The state filter used to fetch state from the database.
- await_full_state: if `True`, will block if we do not yet have complete state
- at the last event in the room before `stream_position` and
- `state_filter` is not satisfied by partial state. Defaults to `True`.
- """
- # FIXME: This gets the state at the latest event before the stream ordering,
- # which might not be the same as the "current state" of the room at the time
- # of the stream token if there were multiple forward extremities at the time.
- last_event_id = (
- await self.store.get_last_event_id_in_room_before_stream_ordering(
- room_id,
- end_token=stream_position.room_key,
- )
- )
-
- if last_event_id:
- state = await self.get_state_after_event(
- last_event_id,
- state_filter=state_filter or StateFilter.all(),
- await_full_state=await_full_state,
- )
-
- else:
- # no events in this room - so presumably no state
- state = {}
-
- # (erikj) This should be rarely hit, but we've had some reports that
- # we get more state down gappy syncs than we should, so let's add
- # some logging.
- logger.info(
- "Failed to find any events in room %s at %s",
- room_id,
- stream_position.room_key,
- )
- return state
-
async def compute_summary(
self,
room_id: str,
@@ -1437,7 +1352,7 @@ class SyncHandler:
await_full_state = True
lazy_load_members = False
- state_at_timeline_end = await self.get_state_at(
+ state_at_timeline_end = await self._state_storage_controller.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
@@ -1565,7 +1480,7 @@ class SyncHandler:
else:
# We can get here if the user has ignored the senders of all
# the recent events.
- state_at_timeline_start = await self.get_state_at(
+ state_at_timeline_start = await self._state_storage_controller.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
@@ -1587,14 +1502,14 @@ class SyncHandler:
# about them).
state_filter = StateFilter.all()
- state_at_previous_sync = await self.get_state_at(
+ state_at_previous_sync = await self._state_storage_controller.get_state_at(
room_id,
stream_position=since_token,
state_filter=state_filter,
await_full_state=await_full_state,
)
- state_at_timeline_end = await self.get_state_at(
+ state_at_timeline_end = await self._state_storage_controller.get_state_at(
room_id,
stream_position=end_token,
state_filter=state_filter,
@@ -2593,7 +2508,7 @@ class SyncHandler:
continue
if room_id in sync_result_builder.joined_room_ids or has_join:
- old_state_ids = await self.get_state_at(
+ old_state_ids = await self._state_storage_controller.get_state_at(
room_id,
since_token,
state_filter=StateFilter.from_types([(EventTypes.Member, user_id)]),
@@ -2623,12 +2538,14 @@ class SyncHandler:
newly_left_rooms.append(room_id)
else:
if not old_state_ids:
- old_state_ids = await self.get_state_at(
- room_id,
- since_token,
- state_filter=StateFilter.from_types(
- [(EventTypes.Member, user_id)]
- ),
+ old_state_ids = (
+ await self._state_storage_controller.get_state_at(
+ room_id,
+ since_token,
+ state_filter=StateFilter.from_types(
+ [(EventTypes.Member, user_id)]
+ ),
+ )
)
old_mem_ev_id = old_state_ids.get(
(EventTypes.Member, user_id), None
diff --git a/synapse/storage/controllers/state.py b/synapse/storage/controllers/state.py
index f9eced23bf..cc9b162ae4 100644
--- a/synapse/storage/controllers/state.py
+++ b/synapse/storage/controllers/state.py
@@ -45,7 +45,7 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialStateEventsTracker,
)
from synapse.synapse_rust.acl import ServerAclEvaluator
-from synapse.types import MutableStateMap, StateMap, get_domain_from_id
+from synapse.types import MutableStateMap, StateMap, StreamToken, get_domain_from_id
from synapse.types.state import StateFilter
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
@@ -372,6 +372,91 @@ class StateStorageController:
)
return state_map[event_id]
+ async def get_state_after_event(
+ self,
+ event_id: str,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
+ ) -> StateMap[str]:
+ """
+ Get the room state after the given event
+
+ Args:
+ event_id: event of interest
+ state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the event and `state_filter` is not satisfied by partial state.
+ Defaults to `True`.
+ """
+ state_ids = await self.get_state_ids_for_event(
+ event_id,
+ state_filter=state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
+ )
+
+ # using get_metadata_for_events here (instead of get_event) sidesteps an issue
+ # with redactions: if `event_id` is a redaction event, and we don't have the
+ # original (possibly because it got purged), get_event will refuse to return
+ # the redaction event, which isn't terribly helpful here.
+ #
+ # (To be fair, in that case we could assume it's *not* a state event, and
+ # therefore we don't need to worry about it. But still, it seems cleaner just
+ # to pull the metadata.)
+ m = (await self.stores.main.get_metadata_for_events([event_id]))[event_id]
+ if m.state_key is not None and m.rejection_reason is None:
+ state_ids = dict(state_ids)
+ state_ids[(m.event_type, m.state_key)] = event_id
+
+ return state_ids
+
+ async def get_state_at(
+ self,
+ room_id: str,
+ stream_position: StreamToken,
+ state_filter: Optional[StateFilter] = None,
+ await_full_state: bool = True,
+ ) -> StateMap[str]:
+ """Get the room state at a particular stream position
+
+ Args:
+ room_id: room for which to get state
+ stream_position: point at which to get state
+ state_filter: The state filter used to fetch state from the database.
+ await_full_state: if `True`, will block if we do not yet have complete state
+ at the last event in the room before `stream_position` and
+ `state_filter` is not satisfied by partial state. Defaults to `True`.
+ """
+ # FIXME: This gets the state at the latest event before the stream ordering,
+ # which might not be the same as the "current state" of the room at the time
+ # of the stream token if there were multiple forward extremities at the time.
+ last_event_id = (
+ await self.stores.main.get_last_event_id_in_room_before_stream_ordering(
+ room_id,
+ end_token=stream_position.room_key,
+ )
+ )
+
+ if last_event_id:
+ state = await self.get_state_after_event(
+ last_event_id,
+ state_filter=state_filter or StateFilter.all(),
+ await_full_state=await_full_state,
+ )
+
+ else:
+ # no events in this room - so presumably no state
+ state = {}
+
+ # (erikj) This should be rarely hit, but we've had some reports that
+ # we get more state down gappy syncs than we should, so let's add
+ # some logging.
+ logger.info(
+ "Failed to find any events in room %s at %s",
+ room_id,
+ stream_position.room_key,
+ )
+ return state
+
@trace
@tag_args
async def get_state_for_groups(
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index af48041f1f..0358239c7f 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -1253,6 +1253,72 @@ class FilterRoomsTestCase(HomeserverTestCase):
self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
+ def test_filter_encrypted_rooms(self) -> None:
+ """
+ Test `filter.is_encrypted` for encrypted rooms
+ """
+ user1_id = self.register_user("user1", "pass")
+ user1_tok = self.login(user1_id, "pass")
+
+ # Create a normal room
+ room_id = self.helper.create_room_as(
+ user1_id,
+ is_public=False,
+ tok=user1_tok,
+ )
+
+ # Create an encrypted room
+ encrypted_room_id = self.helper.create_room_as(
+ user1_id,
+ is_public=False,
+ tok=user1_tok,
+ )
+ self.helper.send_state(
+ encrypted_room_id,
+ EventTypes.RoomEncryption,
+ {"algorithm": "m.megolm.v1.aes-sha2"},
+ tok=user1_tok,
+ )
+
+ after_rooms_token = self.event_sources.get_current_token()
+
+ # Get the rooms the user should be syncing with
+ sync_room_map = self.get_success(
+ self.sliding_sync_handler.get_sync_room_ids_for_user(
+ UserID.from_string(user1_id),
+ from_token=None,
+ to_token=after_rooms_token,
+ )
+ )
+
+ # Try with `is_encrypted=True`
+ truthy_filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(
+ is_encrypted=True,
+ ),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(truthy_filtered_room_map.keys(), {encrypted_room_id})
+
+ # Try with `is_encrypted=False`
+ falsy_filtered_room_map = self.get_success(
+ self.sliding_sync_handler.filter_rooms(
+ UserID.from_string(user1_id),
+ sync_room_map,
+ SlidingSyncConfig.SlidingSyncList.Filters(
+ is_encrypted=False,
+ ),
+ after_rooms_token,
+ )
+ )
+
+ self.assertEqual(falsy_filtered_room_map.keys(), {room_id})
+
class SortRoomsTestCase(HomeserverTestCase):
"""
|