summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13031.feature1
-rw-r--r--synapse/api/constants.py10
-rw-r--r--synapse/config/experimental.py3
-rw-r--r--synapse/handlers/room_list.py23
-rw-r--r--synapse/handlers/stats.py3
-rw-r--r--synapse/rest/client/versions.py2
-rw-r--r--synapse/storage/databases/main/room.py126
-rw-r--r--synapse/storage/databases/main/stats.py10
-rw-r--r--synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql19
-rw-r--r--tests/rest/client/test_rooms.py92
-rw-r--r--tests/storage/databases/main/test_room.py69
11 files changed, 345 insertions, 13 deletions
diff --git a/changelog.d/13031.feature b/changelog.d/13031.feature
new file mode 100644
index 0000000000..fee8e9d1ff
--- /dev/null
+++ b/changelog.d/13031.feature
@@ -0,0 +1 @@
+Implement [MSC3827](https://github.com/matrix-org/matrix-spec-proposals/pull/3827): Filtering of /publicRooms by room type.
diff --git a/synapse/api/constants.py b/synapse/api/constants.py
index e1d31cabed..2653764119 100644
--- a/synapse/api/constants.py
+++ b/synapse/api/constants.py
@@ -259,3 +259,13 @@ class ReceiptTypes:
     READ: Final = "m.read"
     READ_PRIVATE: Final = "org.matrix.msc2285.read.private"
     FULLY_READ: Final = "m.fully_read"
+
+
+class PublicRoomsFilterFields:
+    """Fields in the search filter for `/publicRooms` that we understand.
+
+    As defined in https://spec.matrix.org/v1.3/client-server-api/#post_matrixclientv3publicrooms
+    """
+
+    GENERIC_SEARCH_TERM: Final = "generic_search_term"
+    ROOM_TYPES: Final = "org.matrix.msc3827.room_types"
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 0a285dba31..ee443cea00 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -87,3 +87,6 @@ class ExperimentalConfig(Config):
 
         # MSC3715: dir param on /relations.
         self.msc3715_enabled: bool = experimental.get("msc3715_enabled", False)
+
+        # MSC3827: Filtering of /publicRooms by room type
+        self.msc3827_enabled: bool = experimental.get("msc3827_enabled", False)
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 183d4ae3c4..29868eb743 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -25,6 +25,7 @@ from synapse.api.constants import (
     GuestAccess,
     HistoryVisibility,
     JoinRules,
+    PublicRoomsFilterFields,
 )
 from synapse.api.errors import (
     Codes,
@@ -181,6 +182,7 @@ class RoomListHandler:
                 == HistoryVisibility.WORLD_READABLE,
                 "guest_can_join": room["guest_access"] == "can_join",
                 "join_rule": room["join_rules"],
+                "org.matrix.msc3827.room_type": room["room_type"],
             }
 
             # Filter out Nones – rather omit the field altogether
@@ -239,7 +241,9 @@ class RoomListHandler:
         response["chunk"] = results
 
         response["total_room_count_estimate"] = await self.store.count_public_rooms(
-            network_tuple, ignore_non_federatable=from_federation
+            network_tuple,
+            ignore_non_federatable=from_federation,
+            search_filter=search_filter,
         )
 
         return response
@@ -508,8 +512,21 @@ class RoomListNextBatch:
 
 
 def _matches_room_entry(room_entry: JsonDict, search_filter: dict) -> bool:
-    if search_filter and search_filter.get("generic_search_term", None):
-        generic_search_term = search_filter["generic_search_term"].upper()
+    """Determines whether the given search filter matches a room entry returned over
+    federation.
+
+    Only used if the remote server does not support MSC2197 remote-filtered search, and
+    hence does not support MSC3827 filtering of `/publicRooms` by room type either.
+
+    In this case, we cannot apply the `room_type` filter since no `room_type` field is
+    returned.
+    """
+    if search_filter and search_filter.get(
+        PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None
+    ):
+        generic_search_term = search_filter[
+            PublicRoomsFilterFields.GENERIC_SEARCH_TERM
+        ].upper()
         if generic_search_term in room_entry.get("name", "").upper():
             return True
         elif generic_search_term in room_entry.get("topic", "").upper():
diff --git a/synapse/handlers/stats.py b/synapse/handlers/stats.py
index f45e06eb0e..5c01482acf 100644
--- a/synapse/handlers/stats.py
+++ b/synapse/handlers/stats.py
@@ -271,6 +271,9 @@ class StatsHandler:
                 room_state["is_federatable"] = (
                     event_content.get(EventContentFields.FEDERATE, True) is True
                 )
+                room_type = event_content.get(EventContentFields.ROOM_TYPE)
+                if isinstance(room_type, str):
+                    room_state["room_type"] = room_type
             elif typ == EventTypes.JoinRules:
                 room_state["join_rules"] = event_content.get("join_rule")
             elif typ == EventTypes.RoomHistoryVisibility:
diff --git a/synapse/rest/client/versions.py b/synapse/rest/client/versions.py
index c1bd775fec..f4f06563dd 100644
--- a/synapse/rest/client/versions.py
+++ b/synapse/rest/client/versions.py
@@ -95,6 +95,8 @@ class VersionsRestServlet(RestServlet):
                     "org.matrix.msc3026.busy_presence": self.config.experimental.msc3026_enabled,
                     # Supports receiving private read receipts as per MSC2285
                     "org.matrix.msc2285": self.config.experimental.msc2285_enabled,
+                    # Supports filtering of /publicRooms by room type MSC3827
+                    "org.matrix.msc3827": self.config.experimental.msc3827_enabled,
                     # Adds support for importing historical messages as per MSC2716
                     "org.matrix.msc2716": self.config.experimental.msc2716_enabled,
                     # Adds support for jump to date endpoints (/timestamp_to_event) as per MSC3030
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 5760d3428e..d8026e3fac 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -32,12 +32,17 @@ from typing import (
 
 import attr
 
-from synapse.api.constants import EventContentFields, EventTypes, JoinRules
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    JoinRules,
+    PublicRoomsFilterFields,
+)
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.config.homeserver import HomeServerConfig
 from synapse.events import EventBase
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -199,10 +204,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             desc="get_public_room_ids",
         )
 
+    def _construct_room_type_where_clause(
+        self, room_types: Union[List[Union[str, None]], None]
+    ) -> Tuple[Union[str, None], List[str]]:
+        if not room_types or not self.config.experimental.msc3827_enabled:
+            return None, []
+        else:
+            # We use None when we want get rooms without a type
+            is_null_clause = ""
+            if None in room_types:
+                is_null_clause = "OR room_type IS NULL"
+                room_types = [value for value in room_types if value is not None]
+
+            list_clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_type", room_types
+            )
+
+            return f"({list_clause} {is_null_clause})", args
+
     async def count_public_rooms(
         self,
         network_tuple: Optional[ThirdPartyInstanceID],
         ignore_non_federatable: bool,
+        search_filter: Optional[dict],
     ) -> int:
         """Counts the number of public rooms as tracked in the room_stats_current
         and room_stats_state table.
@@ -210,11 +234,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         Args:
             network_tuple
             ignore_non_federatable: If true filters out non-federatable rooms
+            search_filter
         """
 
         def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
             query_args = []
 
+            room_type_clause, args = self._construct_room_type_where_clause(
+                search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+                if search_filter
+                else None
+            )
+            room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+            query_args += args
+
             if network_tuple:
                 if network_tuple.appservice_id:
                     published_sql = """
@@ -249,6 +282,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                         OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
                         OR history_visibility = 'world_readable'
                     )
+                    {room_type_clause}
                     AND joined_members > 0
             """
 
@@ -347,8 +381,12 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         if ignore_non_federatable:
             where_clauses.append("is_federatable")
 
-        if search_filter and search_filter.get("generic_search_term", None):
-            search_term = "%" + search_filter["generic_search_term"] + "%"
+        if search_filter and search_filter.get(
+            PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None
+        ):
+            search_term = (
+                "%" + search_filter[PublicRoomsFilterFields.GENERIC_SEARCH_TERM] + "%"
+            )
 
             where_clauses.append(
                 """
@@ -365,6 +403,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 search_term.lower(),
             ]
 
+        room_type_clause, args = self._construct_room_type_where_clause(
+            search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+            if search_filter
+            else None
+        )
+        if room_type_clause:
+            where_clauses.append(room_type_clause)
+        query_args += args
+
         where_clause = ""
         if where_clauses:
             where_clause = " AND " + " AND ".join(where_clauses)
@@ -373,7 +420,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         sql = f"""
             SELECT
                 room_id, name, topic, canonical_alias, joined_members,
-                avatar, history_visibility, guest_access, join_rules
+                avatar, history_visibility, guest_access, join_rules, room_type
             FROM (
                 {published_sql}
             ) published
@@ -1166,6 +1213,7 @@ class _BackgroundUpdates:
     POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
     REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
     POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column"
+    ADD_ROOM_TYPE_COLUMN = "add_room_type_column"
 
 
 _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
@@ -1200,6 +1248,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             self._background_add_rooms_room_version_column,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+            self._background_add_room_type_column,
+        )
+
         # BG updates to change the type of room_depth.min_depth
         self.db_pool.updates.register_background_update_handler(
             _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
@@ -1569,6 +1622,69 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         return batch_size
 
+    async def _background_add_room_type_column(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Background update to go and add room_type information to `room_stats_state`
+        table from `event_json` table.
+        """
+
+        last_room_id = progress.get("room_id", "")
+
+        def _background_add_room_type_column_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
+            sql = """
+                SELECT state.room_id, json FROM event_json
+                INNER JOIN current_state_events AS state USING (event_id)
+                WHERE state.room_id > ? AND type = 'm.room.create'
+                ORDER BY state.room_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_room_id, batch_size))
+            room_id_to_create_event_results = txn.fetchall()
+
+            new_last_room_id = None
+            for room_id, event_json in room_id_to_create_event_results:
+                event_dict = db_to_json(event_json)
+
+                room_type = event_dict.get("content", {}).get(
+                    EventContentFields.ROOM_TYPE, None
+                )
+                if isinstance(room_type, str):
+                    self.db_pool.simple_update_txn(
+                        txn,
+                        table="room_stats_state",
+                        keyvalues={"room_id": room_id},
+                        updatevalues={"room_type": room_type},
+                    )
+
+                new_last_room_id = room_id
+
+            if new_last_room_id is None:
+                return True
+
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+                {"room_id": new_last_room_id},
+            )
+
+            return False
+
+        end = await self.db_pool.runInteraction(
+            "_background_add_room_type_column",
+            _background_add_room_type_column_txn,
+        )
+
+        if end:
+            await self.db_pool.updates._end_background_update(
+                _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN
+            )
+
+        return batch_size
+
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
     def __init__(
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 82851ffa95..b4c652acf3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
 import logging
 from enum import Enum
 from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
 
 from typing_extensions import Counter
 
@@ -238,6 +238,7 @@ class StatsStore(StateDeltasStore):
         * avatar
         * canonical_alias
         * guest_access
+        * room_type
 
         A is_federatable key can also be included with a boolean value.
 
@@ -263,6 +264,7 @@ class StatsStore(StateDeltasStore):
             "avatar",
             "canonical_alias",
             "guest_access",
+            "room_type",
         ):
             field = fields.get(col, sentinel)
             if field is not sentinel and (not isinstance(field, str) or "\0" in field):
@@ -572,7 +574,7 @@ class StatsStore(StateDeltasStore):
 
         state_event_map = await self.get_events(event_ids, get_prev_content=False)  # type: ignore[attr-defined]
 
-        room_state = {
+        room_state: Dict[str, Union[None, bool, str]] = {
             "join_rules": None,
             "history_visibility": None,
             "encryption": None,
@@ -581,6 +583,7 @@ class StatsStore(StateDeltasStore):
             "avatar": None,
             "canonical_alias": None,
             "is_federatable": True,
+            "room_type": None,
         }
 
         for event in state_event_map.values():
@@ -604,6 +607,9 @@ class StatsStore(StateDeltasStore):
                 room_state["is_federatable"] = (
                     event.content.get(EventContentFields.FEDERATE, True) is True
                 )
+                room_type = event.content.get(EventContentFields.ROOM_TYPE)
+                if isinstance(room_type, str):
+                    room_state["room_type"] = room_type
 
         await self.update_room_state(room_id, room_state)
 
diff --git a/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql b/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql
new file mode 100644
index 0000000000..d5e0765471
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql
@@ -0,0 +1,19 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_stats_state ADD room_type TEXT;
+
+INSERT INTO background_updates (update_name, progress_json)
+    VALUES ('add_room_type_column', '{}');
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 35c59ee9e0..1ccd96a207 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -18,7 +18,7 @@
 """Tests REST events for /rooms paths."""
 
 import json
-from typing import Any, Dict, Iterable, List, Optional, Union
+from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
 from unittest.mock import Mock, call
 from urllib import parse as urlparse
 
@@ -33,7 +33,9 @@ from synapse.api.constants import (
     EventContentFields,
     EventTypes,
     Membership,
+    PublicRoomsFilterFields,
     RelationTypes,
+    RoomTypes,
 )
 from synapse.api.errors import Codes, HttpResponseException
 from synapse.handlers.pagination import PurgeStatus
@@ -1858,6 +1860,90 @@ class PublicRoomsRestrictedTestCase(unittest.HomeserverTestCase):
         self.assertEqual(channel.code, 200, channel.result)
 
 
+class PublicRoomsRoomTypeFilterTestCase(unittest.HomeserverTestCase):
+
+    servlets = [
+        synapse.rest.admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+
+        config = self.default_config()
+        config["allow_public_rooms_without_auth"] = True
+        config["experimental_features"] = {"msc3827_enabled": True}
+        self.hs = self.setup_test_homeserver(config=config)
+        self.url = b"/_matrix/client/r0/publicRooms"
+
+        return self.hs
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        user = self.register_user("alice", "pass")
+        self.token = self.login(user, "pass")
+
+        # Create a room
+        self.helper.create_room_as(
+            user,
+            is_public=True,
+            extra_content={"visibility": "public"},
+            tok=self.token,
+        )
+        # Create a space
+        self.helper.create_room_as(
+            user,
+            is_public=True,
+            extra_content={
+                "visibility": "public",
+                "creation_content": {EventContentFields.ROOM_TYPE: RoomTypes.SPACE},
+            },
+            tok=self.token,
+        )
+
+    def make_public_rooms_request(
+        self, room_types: Union[List[Union[str, None]], None]
+    ) -> Tuple[List[Dict[str, Any]], int]:
+        channel = self.make_request(
+            "POST",
+            self.url,
+            {"filter": {PublicRoomsFilterFields.ROOM_TYPES: room_types}},
+            self.token,
+        )
+        chunk = channel.json_body["chunk"]
+        count = channel.json_body["total_room_count_estimate"]
+
+        self.assertEqual(len(chunk), count)
+
+        return chunk, count
+
+    def test_returns_both_rooms_and_spaces_if_no_filter(self) -> None:
+        chunk, count = self.make_public_rooms_request(None)
+
+        self.assertEqual(count, 2)
+
+    def test_returns_only_rooms_based_on_filter(self) -> None:
+        chunk, count = self.make_public_rooms_request([None])
+
+        self.assertEqual(count, 1)
+        self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), None)
+
+    def test_returns_only_space_based_on_filter(self) -> None:
+        chunk, count = self.make_public_rooms_request(["m.space"])
+
+        self.assertEqual(count, 1)
+        self.assertEqual(chunk[0].get("org.matrix.msc3827.room_type", None), "m.space")
+
+    def test_returns_both_rooms_and_space_based_on_filter(self) -> None:
+        chunk, count = self.make_public_rooms_request(["m.space", None])
+
+        self.assertEqual(count, 2)
+
+    def test_returns_both_rooms_and_spaces_if_array_is_empty(self) -> None:
+        chunk, count = self.make_public_rooms_request([])
+
+        self.assertEqual(count, 2)
+
+
 class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
     """Test that we correctly fallback to local filtering if a remote server
     doesn't support search.
@@ -1882,7 +1968,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
         "Simple test for searching rooms over federation"
         self.federation_client.get_public_rooms.return_value = make_awaitable({})  # type: ignore[attr-defined]
 
-        search_filter = {"generic_search_term": "foobar"}
+        search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
 
         channel = self.make_request(
             "POST",
@@ -1911,7 +1997,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
             make_awaitable({}),
         )
 
-        search_filter = {"generic_search_term": "foobar"}
+        search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
 
         channel = self.make_request(
             "POST",
diff --git a/tests/storage/databases/main/test_room.py b/tests/storage/databases/main/test_room.py
index 9abd0cb446..1edb619630 100644
--- a/tests/storage/databases/main/test_room.py
+++ b/tests/storage/databases/main/test_room.py
@@ -12,6 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import json
+
+from synapse.api.constants import RoomTypes
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.storage.databases.main.room import _BackgroundUpdates
@@ -91,3 +94,69 @@ class RoomBackgroundUpdateStoreTestCase(HomeserverTestCase):
             )
         )
         self.assertEqual(room_creator_after, self.user_id)
+
+    def test_background_add_room_type_column(self):
+        """Test that the background update to populate the `room_type` column in
+        `room_stats_state` works properly.
+        """
+
+        # Create a room without a type
+        room_id = self._generate_room()
+
+        # Get event_id of the m.room.create event
+        event_id = self.get_success(
+            self.store.db_pool.simple_select_one_onecol(
+                table="current_state_events",
+                keyvalues={
+                    "room_id": room_id,
+                    "type": "m.room.create",
+                },
+                retcol="event_id",
+            )
+        )
+
+        # Fake a room creation event with a room type
+        event = {
+            "content": {
+                "creator": "@user:server.org",
+                "room_version": "9",
+                "type": RoomTypes.SPACE,
+            },
+            "type": "m.room.create",
+        }
+        self.get_success(
+            self.store.db_pool.simple_update(
+                table="event_json",
+                keyvalues={"event_id": event_id},
+                updatevalues={"json": json.dumps(event)},
+                desc="test",
+            )
+        )
+
+        # Insert and run the background update
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {
+                    "update_name": _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+                    "progress_json": "{}",
+                },
+            )
+        )
+
+        # ... and tell the DataStore that it hasn't finished all updates yet
+        self.store.db_pool.updates._all_done = False
+
+        # Now let's actually drive the updates to completion
+        self.wait_for_background_updates()
+
+        # Make sure the background update filled in the room type
+        room_type_after = self.get_success(
+            self.store.db_pool.simple_select_one_onecol(
+                table="room_stats_state",
+                keyvalues={"room_id": room_id},
+                retcol="room_type",
+                allow_none=True,
+            )
+        )
+        self.assertEqual(room_type_after, RoomTypes.SPACE)