summary refs log tree commit diff
path: root/synapse/storage/databases/main/room.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/room.py')
-rw-r--r--synapse/storage/databases/main/room.py126
1 files changed, 121 insertions, 5 deletions
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 5760d3428e..d8026e3fac 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -32,12 +32,17 @@ from typing import (
 
 import attr
 
-from synapse.api.constants import EventContentFields, EventTypes, JoinRules
+from synapse.api.constants import (
+    EventContentFields,
+    EventTypes,
+    JoinRules,
+    PublicRoomsFilterFields,
+)
 from synapse.api.errors import StoreError
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.config.homeserver import HomeServerConfig
 from synapse.events import EventBase
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
 from synapse.storage.database import (
     DatabasePool,
     LoggingDatabaseConnection,
@@ -199,10 +204,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             desc="get_public_room_ids",
         )
 
+    def _construct_room_type_where_clause(
+        self, room_types: Union[List[Union[str, None]], None]
+    ) -> Tuple[Union[str, None], List[str]]:
+        if not room_types or not self.config.experimental.msc3827_enabled:
+            return None, []
+        else:
+            # We use None when we want get rooms without a type
+            is_null_clause = ""
+            if None in room_types:
+                is_null_clause = "OR room_type IS NULL"
+                room_types = [value for value in room_types if value is not None]
+
+            list_clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_type", room_types
+            )
+
+            return f"({list_clause} {is_null_clause})", args
+
     async def count_public_rooms(
         self,
         network_tuple: Optional[ThirdPartyInstanceID],
         ignore_non_federatable: bool,
+        search_filter: Optional[dict],
     ) -> int:
         """Counts the number of public rooms as tracked in the room_stats_current
         and room_stats_state table.
@@ -210,11 +234,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         Args:
             network_tuple
             ignore_non_federatable: If true filters out non-federatable rooms
+            search_filter
         """
 
         def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
             query_args = []
 
+            room_type_clause, args = self._construct_room_type_where_clause(
+                search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+                if search_filter
+                else None
+            )
+            room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+            query_args += args
+
             if network_tuple:
                 if network_tuple.appservice_id:
                     published_sql = """
@@ -249,6 +282,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                         OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
                         OR history_visibility = 'world_readable'
                     )
+                    {room_type_clause}
                     AND joined_members > 0
             """
 
@@ -347,8 +381,12 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         if ignore_non_federatable:
             where_clauses.append("is_federatable")
 
-        if search_filter and search_filter.get("generic_search_term", None):
-            search_term = "%" + search_filter["generic_search_term"] + "%"
+        if search_filter and search_filter.get(
+            PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None
+        ):
+            search_term = (
+                "%" + search_filter[PublicRoomsFilterFields.GENERIC_SEARCH_TERM] + "%"
+            )
 
             where_clauses.append(
                 """
@@ -365,6 +403,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
                 search_term.lower(),
             ]
 
+        room_type_clause, args = self._construct_room_type_where_clause(
+            search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+            if search_filter
+            else None
+        )
+        if room_type_clause:
+            where_clauses.append(room_type_clause)
+        query_args += args
+
         where_clause = ""
         if where_clauses:
             where_clause = " AND " + " AND ".join(where_clauses)
@@ -373,7 +420,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         sql = f"""
             SELECT
                 room_id, name, topic, canonical_alias, joined_members,
-                avatar, history_visibility, guest_access, join_rules
+                avatar, history_visibility, guest_access, join_rules, room_type
             FROM (
                 {published_sql}
             ) published
@@ -1166,6 +1213,7 @@ class _BackgroundUpdates:
     POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
     REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
     POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column"
+    ADD_ROOM_TYPE_COLUMN = "add_room_type_column"
 
 
 _REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
@@ -1200,6 +1248,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
             self._background_add_rooms_room_version_column,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+            self._background_add_room_type_column,
+        )
+
         # BG updates to change the type of room_depth.min_depth
         self.db_pool.updates.register_background_update_handler(
             _BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
@@ -1569,6 +1622,69 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
 
         return batch_size
 
+    async def _background_add_room_type_column(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """Background update to go and add room_type information to `room_stats_state`
+        table from `event_json` table.
+        """
+
+        last_room_id = progress.get("room_id", "")
+
+        def _background_add_room_type_column_txn(
+            txn: LoggingTransaction,
+        ) -> bool:
+            sql = """
+                SELECT state.room_id, json FROM event_json
+                INNER JOIN current_state_events AS state USING (event_id)
+                WHERE state.room_id > ? AND type = 'm.room.create'
+                ORDER BY state.room_id
+                LIMIT ?
+            """
+
+            txn.execute(sql, (last_room_id, batch_size))
+            room_id_to_create_event_results = txn.fetchall()
+
+            new_last_room_id = None
+            for room_id, event_json in room_id_to_create_event_results:
+                event_dict = db_to_json(event_json)
+
+                room_type = event_dict.get("content", {}).get(
+                    EventContentFields.ROOM_TYPE, None
+                )
+                if isinstance(room_type, str):
+                    self.db_pool.simple_update_txn(
+                        txn,
+                        table="room_stats_state",
+                        keyvalues={"room_id": room_id},
+                        updatevalues={"room_type": room_type},
+                    )
+
+                new_last_room_id = room_id
+
+            if new_last_room_id is None:
+                return True
+
+            self.db_pool.updates._background_update_progress_txn(
+                txn,
+                _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+                {"room_id": new_last_room_id},
+            )
+
+            return False
+
+        end = await self.db_pool.runInteraction(
+            "_background_add_room_type_column",
+            _background_add_room_type_column_txn,
+        )
+
+        if end:
+            await self.db_pool.updates._end_background_update(
+                _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN
+            )
+
+        return batch_size
+
 
 class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
     def __init__(