diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 5760d3428e..d8026e3fac 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -32,12 +32,17 @@ from typing import (
import attr
-from synapse.api.constants import EventContentFields, EventTypes, JoinRules
+from synapse.api.constants import (
+ EventContentFields,
+ EventTypes,
+ JoinRules,
+ PublicRoomsFilterFields,
+)
from synapse.api.errors import StoreError
from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.config.homeserver import HomeServerConfig
from synapse.events import EventBase
-from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
from synapse.storage.database import (
DatabasePool,
LoggingDatabaseConnection,
@@ -199,10 +204,29 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
desc="get_public_room_ids",
)
+ def _construct_room_type_where_clause(
+ self, room_types: Union[List[Union[str, None]], None]
+ ) -> Tuple[Union[str, None], List[str]]:
+ if not room_types or not self.config.experimental.msc3827_enabled:
+ return None, []
+ else:
+ # We use None when we want get rooms without a type
+ is_null_clause = ""
+ if None in room_types:
+ is_null_clause = "OR room_type IS NULL"
+ room_types = [value for value in room_types if value is not None]
+
+ list_clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_type", room_types
+ )
+
+ return f"({list_clause} {is_null_clause})", args
+
async def count_public_rooms(
self,
network_tuple: Optional[ThirdPartyInstanceID],
ignore_non_federatable: bool,
+ search_filter: Optional[dict],
) -> int:
"""Counts the number of public rooms as tracked in the room_stats_current
and room_stats_state table.
@@ -210,11 +234,20 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
Args:
network_tuple
ignore_non_federatable: If true filters out non-federatable rooms
+ search_filter
"""
def _count_public_rooms_txn(txn: LoggingTransaction) -> int:
query_args = []
+ room_type_clause, args = self._construct_room_type_where_clause(
+ search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+ if search_filter
+ else None
+ )
+ room_type_clause = f" AND {room_type_clause}" if room_type_clause else ""
+ query_args += args
+
if network_tuple:
if network_tuple.appservice_id:
published_sql = """
@@ -249,6 +282,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
OR join_rules = '{JoinRules.KNOCK_RESTRICTED}'
OR history_visibility = 'world_readable'
)
+ {room_type_clause}
AND joined_members > 0
"""
@@ -347,8 +381,12 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
if ignore_non_federatable:
where_clauses.append("is_federatable")
- if search_filter and search_filter.get("generic_search_term", None):
- search_term = "%" + search_filter["generic_search_term"] + "%"
+ if search_filter and search_filter.get(
+ PublicRoomsFilterFields.GENERIC_SEARCH_TERM, None
+ ):
+ search_term = (
+ "%" + search_filter[PublicRoomsFilterFields.GENERIC_SEARCH_TERM] + "%"
+ )
where_clauses.append(
"""
@@ -365,6 +403,15 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
search_term.lower(),
]
+ room_type_clause, args = self._construct_room_type_where_clause(
+ search_filter.get(PublicRoomsFilterFields.ROOM_TYPES, None)
+ if search_filter
+ else None
+ )
+ if room_type_clause:
+ where_clauses.append(room_type_clause)
+ query_args += args
+
where_clause = ""
if where_clauses:
where_clause = " AND " + " AND ".join(where_clauses)
@@ -373,7 +420,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
sql = f"""
SELECT
room_id, name, topic, canonical_alias, joined_members,
- avatar, history_visibility, guest_access, join_rules
+ avatar, history_visibility, guest_access, join_rules, room_type
FROM (
{published_sql}
) published
@@ -1166,6 +1213,7 @@ class _BackgroundUpdates:
POPULATE_ROOM_DEPTH_MIN_DEPTH2 = "populate_room_depth_min_depth2"
REPLACE_ROOM_DEPTH_MIN_DEPTH = "replace_room_depth_min_depth"
POPULATE_ROOMS_CREATOR_COLUMN = "populate_rooms_creator_column"
+ ADD_ROOM_TYPE_COLUMN = "add_room_type_column"
_REPLACE_ROOM_DEPTH_SQL_COMMANDS = (
@@ -1200,6 +1248,11 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
self._background_add_rooms_room_version_column,
)
+ self.db_pool.updates.register_background_update_handler(
+ _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+ self._background_add_room_type_column,
+ )
+
# BG updates to change the type of room_depth.min_depth
self.db_pool.updates.register_background_update_handler(
_BackgroundUpdates.POPULATE_ROOM_DEPTH_MIN_DEPTH2,
@@ -1569,6 +1622,69 @@ class RoomBackgroundUpdateStore(SQLBaseStore):
return batch_size
+ async def _background_add_room_type_column(
+ self, progress: JsonDict, batch_size: int
+ ) -> int:
+ """Background update to go and add room_type information to `room_stats_state`
+ table from `event_json` table.
+ """
+
+ last_room_id = progress.get("room_id", "")
+
+ def _background_add_room_type_column_txn(
+ txn: LoggingTransaction,
+ ) -> bool:
+ sql = """
+ SELECT state.room_id, json FROM event_json
+ INNER JOIN current_state_events AS state USING (event_id)
+ WHERE state.room_id > ? AND type = 'm.room.create'
+ ORDER BY state.room_id
+ LIMIT ?
+ """
+
+ txn.execute(sql, (last_room_id, batch_size))
+ room_id_to_create_event_results = txn.fetchall()
+
+ new_last_room_id = None
+ for room_id, event_json in room_id_to_create_event_results:
+ event_dict = db_to_json(event_json)
+
+ room_type = event_dict.get("content", {}).get(
+ EventContentFields.ROOM_TYPE, None
+ )
+ if isinstance(room_type, str):
+ self.db_pool.simple_update_txn(
+ txn,
+ table="room_stats_state",
+ keyvalues={"room_id": room_id},
+ updatevalues={"room_type": room_type},
+ )
+
+ new_last_room_id = room_id
+
+ if new_last_room_id is None:
+ return True
+
+ self.db_pool.updates._background_update_progress_txn(
+ txn,
+ _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN,
+ {"room_id": new_last_room_id},
+ )
+
+ return False
+
+ end = await self.db_pool.runInteraction(
+ "_background_add_room_type_column",
+ _background_add_room_type_column_txn,
+ )
+
+ if end:
+ await self.db_pool.updates._end_background_update(
+ _BackgroundUpdates.ADD_ROOM_TYPE_COLUMN
+ )
+
+ return batch_size
+
class RoomStore(RoomBackgroundUpdateStore, RoomWorkerStore):
def __init__(
|