summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17434.bugfix1
-rw-r--r--synapse/handlers/sliding_sync.py22
-rw-r--r--synapse/storage/databases/main/state.py52
-rw-r--r--tests/handlers/test_sliding_sync.py68
4 files changed, 130 insertions, 13 deletions
diff --git a/changelog.d/17434.bugfix b/changelog.d/17434.bugfix
new file mode 100644
index 0000000000..c7cce52397
--- /dev/null
+++ b/changelog.d/17434.bugfix
@@ -0,0 +1 @@
+Fix bug in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint when using room type filters and the user has one or more remote invites.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 818b13621c..8e6c2fb860 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -24,13 +24,7 @@ from typing import TYPE_CHECKING, Any, Dict, Final, List, Optional, Set, Tuple
 import attr
 from immutabledict import immutabledict
 
-from synapse.api.constants import (
-    AccountDataTypes,
-    Direction,
-    EventContentFields,
-    EventTypes,
-    Membership,
-)
+from synapse.api.constants import AccountDataTypes, Direction, EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.utils import strip_event
 from synapse.handlers.relations import BundledAggregations
@@ -959,11 +953,15 @@ class SlidingSyncHandler:
         # provided in the list. `None` is a valid type for rooms which do not have a
         # room type.
         if filters.room_types is not None or filters.not_room_types is not None:
-            # Make a copy so we don't run into an error: `Set changed size during
-            # iteration`, when we filter out and remove items
-            for room_id in filtered_room_id_set.copy():
-                create_event = await self.store.get_create_event_for_room(room_id)
-                room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+            room_to_type = await self.store.bulk_get_room_type(
+                {
+                    room_id
+                    for room_id in filtered_room_id_set
+                    # We only know the room types for joined rooms
+                    if sync_room_map[room_id].membership == Membership.JOIN
+                }
+            )
+            for room_id, room_type in room_to_type.items():
                 if (
                     filters.room_types is not None
                     and room_type not in filters.room_types
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index b2a67aff89..5188b2f7a4 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -41,7 +41,7 @@ from typing import (
 
 import attr
 
-from synapse.api.constants import EventTypes, Membership
+from synapse.api.constants import EventContentFields, EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, RoomVersion
 from synapse.events import EventBase
@@ -298,6 +298,56 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         create_event = await self.get_event(create_id)
         return create_event
 
+    @cached(max_entries=10000)
+    async def get_room_type(self, room_id: str) -> Optional[str]:
+        """Get the room type for a given room. The server must be joined to the
+        given room.
+        """
+
+        row = await self.db_pool.simple_select_one(
+            table="room_stats_state",
+            keyvalues={"room_id": room_id},
+            retcols=("room_type",),
+            allow_none=True,
+            desc="get_room_type",
+        )
+
+        if row is not None:
+            return row[0]
+
+        # If we haven't updated `room_stats_state` with the room yet, query the
+        # create event directly.
+        create_event = await self.get_create_event_for_room(room_id)
+        room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+        return room_type
+
+    @cachedList(cached_method_name="get_room_type", list_name="room_ids")
+    async def bulk_get_room_type(
+        self, room_ids: Set[str]
+    ) -> Mapping[str, Optional[str]]:
+        """Bulk fetch room types for the given rooms, the server must be in all
+        the rooms given.
+        """
+
+        rows = await self.db_pool.simple_select_many_batch(
+            table="room_stats_state",
+            column="room_id",
+            iterable=room_ids,
+            retcols=("room_id", "room_type"),
+            desc="bulk_get_room_type",
+        )
+
+        # If we haven't updated `room_stats_state` with the room yet, query the
+        # create events directly. This should happen only rarely so we don't
+        # mind if we do this in a loop.
+        results = dict(rows)
+        for room_id in room_ids - results.keys():
+            create_event = await self.get_create_event_for_room(room_id)
+            room_type = create_event.content.get(EventContentFields.ROOM_TYPE)
+            results[room_id] = room_type
+
+        return results
+
     @cached(max_entries=100000, iterable=True)
     async def get_partial_current_state_ids(self, room_id: str) -> StateMap[str]:
         """Get the current state event ids for a room based on the
diff --git a/tests/handlers/test_sliding_sync.py b/tests/handlers/test_sliding_sync.py
index 9dd2363adc..eb4b0a05c7 100644
--- a/tests/handlers/test_sliding_sync.py
+++ b/tests/handlers/test_sliding_sync.py
@@ -35,6 +35,8 @@ from synapse.api.constants import (
     RoomTypes,
 )
 from synapse.api.room_versions import RoomVersions
+from synapse.events import make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.handlers.sliding_sync import RoomSyncConfig, StateValues
 from synapse.rest import admin
 from synapse.rest.client import knock, login, room
@@ -2791,6 +2793,72 @@ class FilterRoomsTestCase(HomeserverTestCase):
 
         self.assertEqual(filtered_room_map.keys(), {space_room_id})
 
+    def test_filter_room_types_with_invite_remote_room(self) -> None:
+        """Test that we can apply a room type filter, even if we have an invite
+        for a remote room.
+
+        This is a regression test.
+        """
+
+        user1_id = self.register_user("user1", "pass")
+        user1_tok = self.login(user1_id, "pass")
+
+        # Create a fake remote invite and persist it.
+        invite_room_id = "!some:room"
+        invite_event = make_event_from_dict(
+            {
+                "room_id": invite_room_id,
+                "sender": "@user:test.serv",
+                "state_key": user1_id,
+                "depth": 1,
+                "origin_server_ts": 1,
+                "type": EventTypes.Member,
+                "content": {"membership": Membership.INVITE},
+                "auth_events": [],
+                "prev_events": [],
+            },
+            room_version=RoomVersions.V10,
+        )
+        invite_event.internal_metadata.outlier = True
+        invite_event.internal_metadata.out_of_band_membership = True
+
+        self.get_success(
+            self.store.maybe_store_room_on_outlier_membership(
+                room_id=invite_room_id, room_version=invite_event.room_version
+            )
+        )
+        context = EventContext.for_outlier(self.hs.get_storage_controllers())
+        persist_controller = self.hs.get_storage_controllers().persistence
+        assert persist_controller is not None
+        self.get_success(persist_controller.persist_event(invite_event, context))
+
+        # Create a normal room (no room type)
+        room_id = self.helper.create_room_as(user1_id, tok=user1_tok)
+
+        after_rooms_token = self.event_sources.get_current_token()
+
+        # Get the rooms the user should be syncing with
+        sync_room_map = self.get_success(
+            self.sliding_sync_handler.get_sync_room_ids_for_user(
+                UserID.from_string(user1_id),
+                from_token=None,
+                to_token=after_rooms_token,
+            )
+        )
+
+        filtered_room_map = self.get_success(
+            self.sliding_sync_handler.filter_rooms(
+                UserID.from_string(user1_id),
+                sync_room_map,
+                SlidingSyncConfig.SlidingSyncList.Filters(
+                    room_types=[None, RoomTypes.SPACE],
+                ),
+                after_rooms_token,
+            )
+        )
+
+        self.assertEqual(filtered_room_map.keys(), {room_id, invite_room_id})
+
 
 class SortRoomsTestCase(HomeserverTestCase):
     """