summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/room.py126
-rw-r--r--synapse/storage/databases/main/stats.py10
-rw-r--r--synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql19
3 files changed, 148 insertions, 7 deletions
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 5760d3428e..d8026e3fac 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -32,12 +32,17 @@ from typing import (
 
 import attr
 
-from synapse.api.constants import EventContentFields, EventTypes, JoinRules
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    JoinRules,
+    PublicRoomsFilterFields,
+)
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.config.homeserver import HomeServerConfig
 from synapse.events import EventBase
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -199,10 +204,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             desc="get_public_room_ids",
         )
 
+    def _construct_room_type_where_clause(
+        self, room_types: Union[List[Union[str, None]], None]
+    ) -> Tuple[Union[str, None], List[str]]:
+        if not room_types or not self.config.experimental.msc3827_enabled:
+            return None, []
+        else:
+            # We use None when we want get rooms without a type
+            is_null_clause = ""
+            if None in room_types:
+                is_null_clause = "OR room_type IS NULL"
+                room_types = [value for value in room_types if value is not None]
+
+            list_clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_type", room_types
+            )
+
+            return f"({list_clause} {is_null_clause})", args
+
     async def count_public_rooms(
         self,
         network_tuple: Optional[ThirdPartyInstanceID],
         ignore_non_federatable: bool,
+        search_filter: Optional[dict],
     ) -> int:
         """Counts the number of public rooms as tracked in the room_stats_current
         and room_stats_state table.
@@ -210,11 +234,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         Args:
             network_tuple
             ignore_non_federatable: If true filters out non-federatable rooms
+            search_filter
         """
 
         def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
             query_args = []
 
+            room_type_clause, args = self._construct_room_type_where_clause(
+                search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+                if search_filter
+                else None
+            )
+            room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+            query_args += args
+
             if network_tuple:
                 if network_tuple.appservice_id:
                     published_sql = """
@@ -249,6 +282,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                         OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
                         OR history_visibility = 'world_readable'
                     )
+                    {room_type_clause}
                     AND joined_members > 0
             """
 
@@ -347,8 +381,12 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         if ignore_non_federatable:
             where_clauses.append("is_federatable")
 
-        if search_filter and search_filter.get("generic_search_term", None):
-            search_term = "%" + search_filter["generic_search_term"] + "%"
+        if search_filter and search_filter.get(
+            PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None
+        ):
+            search_term = (
+                "%" + search_filter[PublicRoomsFilterFields.GENERIC_SEARCH_TERM] + "%"
+            )
 
             where_clauses.append(
                 """
@@ -365,6 +403,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 search_term.lower(),
             ]
 
+        room_type_clause, args = self._construct_room_type_where_clause(
+            search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+            if search_filter
+            else None
+        )
+        if room_type_clause:
+            where_clauses.append(room_type_clause)
+        query_args += args
+
         where_clause = ""
         if where_clauses:
             where_clause = " AND " + " AND ".join(where_clauses)
@@ -373,7 +420,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         sql = f"""
             SELECT
                 room_id, name, topic, canonical_alias, joined_members,
-                avatar, history_visibility, guest_access, join_rules
+                avatar, history_visibility, guest_access, join_rules, room_type
             FROM (
                 {published_sql}
             ) published
@@ -1166,6 +1213,7 @@ class _BackgroundUpdates:
     POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
     REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
     POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column"
+    ADD_ROOM_TYPE_COLUMN = "add_room_type_column"
 
 
 _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
@@ -1200,6 +1248,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             self._background_add_rooms_room_version_column,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+            self._background_add_room_type_column,
+        )
+
         # BG updates to change the type of room_depth.min_depth
         self.db_pool.updates.register_background_update_handler(
             _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
@@ -1569,6 +1622,69 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         return batch_size
 
+    async def _background_add_room_type_column(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Background update to go and add room_type information to `room_stats_state`
+        table from `event_json` table.
+        """
+
+        last_room_id = progress.get("room_id", "")
+
+        def _background_add_room_type_column_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
+            sql = """
+                SELECT state.room_id, json FROM event_json
+                INNER JOIN current_state_events AS state USING (event_id)
+                WHERE state.room_id > ? AND type = 'm.room.create'
+                ORDER BY state.room_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_room_id, batch_size))
+            room_id_to_create_event_results = txn.fetchall()
+
+            new_last_room_id = None
+            for room_id, event_json in room_id_to_create_event_results:
+                event_dict = db_to_json(event_json)
+
+                room_type = event_dict.get("content", {}).get(
+                    EventContentFields.ROOM_TYPE, None
+                )
+                if isinstance(room_type, str):
+                    self.db_pool.simple_update_txn(
+                        txn,
+                        table="room_stats_state",
+                        keyvalues={"room_id": room_id},
+                        updatevalues={"room_type": room_type},
+                    )
+
+                new_last_room_id = room_id
+
+            if new_last_room_id is None:
+                return True
+
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+                {"room_id": new_last_room_id},
+            )
+
+            return False
+
+        end = await self.db_pool.runInteraction(
+            "_background_add_room_type_column",
+            _background_add_room_type_column_txn,
+        )
+
+        if end:
+            await self.db_pool.updates._end_background_update(
+                _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN
+            )
+
+        return batch_size
+
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
     def __init__(
diff --git a/synapse/storage/databases/main/stats.py b/synapse/storage/databases/main/stats.py
index 82851ffa95..b4c652acf3 100644
--- a/synapse/storage/databases/main/stats.py
+++ b/synapse/storage/databases/main/stats.py
@@ -16,7 +16,7 @@
 import logging
 from enum import Enum
 from itertools import chain
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
 
 from typing_extensions import Counter
 
@@ -238,6 +238,7 @@ class StatsStore(StateDeltasStore):
         * avatar
         * canonical_alias
         * guest_access
+        * room_type
 
         A is_federatable key can also be included with a boolean value.
 
@@ -263,6 +264,7 @@ class StatsStore(StateDeltasStore):
             "avatar",
             "canonical_alias",
             "guest_access",
+            "room_type",
         ):
             field = fields.get(col, sentinel)
             if field is not sentinel and (not isinstance(field, str) or "\0" in field):
@@ -572,7 +574,7 @@ class StatsStore(StateDeltasStore):
 
         state_event_map = await self.get_events(event_ids, get_prev_content=False)  # type: ignore[attr-defined]
 
-        room_state = {
+        room_state: Dict[str, Union[None, bool, str]] = {
             "join_rules": None,
             "history_visibility": None,
             "encryption": None,
@@ -581,6 +583,7 @@ class StatsStore(StateDeltasStore):
             "avatar": None,
             "canonical_alias": None,
             "is_federatable": True,
+            "room_type": None,
         }
 
         for event in state_event_map.values():
@@ -604,6 +607,9 @@ class StatsStore(StateDeltasStore):
                 room_state["is_federatable"] = (
                     event.content.get(EventContentFields.FEDERATE, True) is True
                 )
+                room_type = event.content.get(EventContentFields.ROOM_TYPE)
+                if isinstance(room_type, str):
+                    room_state["room_type"] = room_type
 
         await self.update_room_state(room_id, room_state)
 
diff --git a/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql b/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql
new file mode 100644
index 0000000000..d5e0765471
--- /dev/null
+++ b/synapse/storage/schema/main/delta/72/01add_room_type_to_state_stats.sql
@@ -0,0 +1,19 @@
+/* Copyright 2022 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+ALTER TABLE room_stats_state ADD room_type TEXT;
+
+INSERT INTO background_updates (update_name, progress_json)
+    VALUES ('add_room_type_column', '{}');