summary refs log tree commit diff
diff options
context:
space:
mode:
authorAndrew Morgan <andrew@amorgan.xyz>2023-03-15 17:19:05 +0000
committerAndrew Morgan <andrew@amorgan.xyz>2023-05-02 15:23:32 +0100
commit5c1e9f24da77bebdde557a312a532a4f5c857c69 (patch)
tree351fc73d2ac6f0fd072862394d28d964a30d9631
parentAdd a new public rooms callback class, a new fetch_public_rooms callback (diff)
downloadsynapse-5c1e9f24da77bebdde557a312a532a4f5c857c69.tar.xz
wip: call the public room callback
-rw-r--r--synapse/handlers/room_list.py87
-rw-r--r--synapse/module_api/callbacks/public_rooms_callbacks.py6
-rw-r--r--synapse/rest/client/room.py12
-rw-r--r--synapse/storage/databases/main/room.py22
-rw-r--r--synapse/types/__init__.py13
5 files changed, 91 insertions, 49 deletions
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index bb0bdb8e6f..9e2b46699d 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -27,6 +27,7 @@ from synapse.api.constants import (
     JoinRules,
     PublicRoomsFilterFields,
 )
+from synapse.types import Requester
 from synapse.api.errors import (
     Codes,
     HttpResponseException,
@@ -60,6 +61,7 @@ class RoomListHandler:
         self.remote_response_cache: ResponseCache[
             Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
         ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
+        self._module_api_callbacks = hs.get_module_api_callbacks().public_rooms
 
     async def get_local_public_room_list(
         self,
@@ -67,7 +69,8 @@ class RoomListHandler:
         since_token: Optional[str] = None,
         search_filter: Optional[dict] = None,
         network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
-        from_federation: bool = False,
+        from_client_mxid: Optional[str] = None,
+        from_remote_server_name: Optional[str] = None,
     ) -> JsonDict:
         """Generate a local public room list.
 
@@ -75,14 +78,20 @@ class RoomListHandler:
         party network. A client can ask for a specific list or to return all.
 
         Args:
-            limit
-            since_token
-            search_filter
+            limit: The maximum number of rooms to return, or None to return all rooms.
+            since_token: A pagination token, or None to return the head of the public
+                rooms list.
+            search_filter: An optional dictionary with the following keys:
+                * generic_search_term: A string to search for in room ...
+                * room_types: A list to filter returned rooms by their type. If None or
+                    an empty list is passed, rooms will not be filtered by type.
             network_tuple: Which public list to use.
                 This can be (None, None) to indicate the main list, or a particular
                 appservice and network id to use an appservice specific one.
                 Setting to None returns all public rooms across all lists.
-            from_federation: true iff the request comes from the federation API
+            from_client_mxid: A user's MXID if this request came from a registered user.
+            from_remote_server_name: A remote homeserver's server name, if this
+                request came from the federation API.
         """
         if not self.enable_room_list_search:
             return {"chunk": [], "total_room_count_estimate": 0}
@@ -105,7 +114,8 @@ class RoomListHandler:
                 since_token,
                 search_filter,
                 network_tuple=network_tuple,
-                from_federation=from_federation,
+                from_client_mxid=from_client_mxid,
+                from_remote_server_name=from_remote_server_name,
             )
 
         key = (limit, since_token, network_tuple)
@@ -115,7 +125,8 @@ class RoomListHandler:
             limit,
             since_token,
             network_tuple=network_tuple,
-            from_federation=from_federation,
+            from_client_mxid=from_client_mxid,
+            from_remote_server_name=from_remote_server_name,
         )
 
     async def _get_public_room_list(
@@ -124,7 +135,8 @@ class RoomListHandler:
         since_token: Optional[str] = None,
         search_filter: Optional[dict] = None,
         network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
-        from_federation: bool = False,
+        from_client_mxid: Optional[str] = None,
+        from_remote_server_name: Optional[str] = None,
     ) -> JsonDict:
         """Generate a public room list.
         Args:
@@ -135,8 +147,9 @@ class RoomListHandler:
                 This can be (None, None) to indicate the main list, or a particular
                 appservice and network id to use an appservice specific one.
                 Setting to None returns all public rooms across all lists.
-            from_federation: Whether this request originated from a
-                federating server or a client. Used for room filtering.
+            from_client_mxid: A user's MXID if this request came from a registered user.
+            from_remote_server_name: A remote homeserver's server name, if this
+                request came from the federation API.
         """
 
         # Pagination tokens work by storing the room ID sent in the last batch,
@@ -145,50 +158,38 @@ class RoomListHandler:
 
         if since_token:
             batch_token = RoomListNextBatch.from_token(since_token)
-
-            bounds: Optional[Tuple[int, str]] = (
-                batch_token.last_joined_members,
-                batch_token.last_room_id,
-            )
             forwards = batch_token.direction_is_forward
-            has_batch_token = True
         else:
-            bounds = None
-
+            batch_token = None
             forwards = True
-            has_batch_token = False
 
         # we request one more than wanted to see if there are more pages to come
         probing_limit = limit + 1 if limit is not None else None
 
-        results = await self.store.get_largest_public_rooms(
+        public_rooms = await self.store.get_largest_public_rooms(
             network_tuple,
             search_filter,
             probing_limit,
-            bounds=bounds,
+            bounds=(
+                [batch_token.last_joined_members, batch_token.last_room_id]
+                if batch_token else None
+            ),
             forwards=forwards,
-            ignore_non_federatable=from_federation,
+            ignore_non_federatable=bool(from_remote_server_name),
         )
 
-        def build_room_entry(room: JsonDict) -> JsonDict:
-            entry = {
-                "room_id": room["room_id"],
-                "name": room["name"],
-                "topic": room["topic"],
-                "canonical_alias": room["canonical_alias"],
-                "num_joined_members": room["joined_members"],
-                "avatar_url": room["avatar"],
-                "world_readable": room["history_visibility"]
-                == HistoryVisibility.WORLD_READABLE,
-                "guest_can_join": room["guest_access"] == "can_join",
-                "join_rule": room["join_rules"],
-                "room_type": room["room_type"],
-            }
-
-            # Filter out Nones – rather omit the field altogether
-            return {k: v for k, v in entry.items() if v is not None}
+        for fetch_public_rooms in self._module_api_callbacks.fetch_public_rooms_callbacks:
+            # Ask each module for a list of public rooms given the last_joined_members
+            # value from the since token and the probing limit.
+            module_public_rooms = await fetch_public_rooms(
+                limit=probing_limit,
+                max_member_count=(
+                    batch_token.last_joined_members
+                    if batch_token else None
+                ),
+            )
 
-        results = [build_room_entry(r) for r in results]
+            # Insert the module's reported public rooms into the list
 
         response: JsonDict = {}
         num_results = len(results)
@@ -208,7 +209,7 @@ class RoomListHandler:
             initial_entry = results[0]
 
             if forwards:
-                if has_batch_token:
+                if batch_token is not None:
                     # If there was a token given then we assume that there
                     # must be previous results.
                     response["prev_batch"] = RoomListNextBatch(
@@ -224,7 +225,7 @@ class RoomListHandler:
                         direction_is_forward=True,
                     ).to_token()
             else:
-                if has_batch_token:
+                if batch_token is not None:
                     response["next_batch"] = RoomListNextBatch(
                         last_joined_members=final_entry["num_joined_members"],
                         last_room_id=final_entry["room_id"],
@@ -242,7 +243,7 @@ class RoomListHandler:
 
         response["total_room_count_estimate"] = await self.store.count_public_rooms(
             network_tuple,
-            ignore_non_federatable=from_federation,
+            ignore_non_federatable=bool(from_remote_server_name),
             search_filter=search_filter,
         )
 
diff --git a/synapse/module_api/callbacks/public_rooms_callbacks.py b/synapse/module_api/callbacks/public_rooms_callbacks.py
index 017ba53b50..b2a02e3ab8 100644
--- a/synapse/module_api/callbacks/public_rooms_callbacks.py
+++ b/synapse/module_api/callbacks/public_rooms_callbacks.py
@@ -21,7 +21,7 @@ logger = logging.getLogger(__name__)
 
 
 @attr.s(auto_attribs=True)
-class PublicRoomChunk:
+class PublicRoom:
     room_id: str
     name: str
     topic: str
@@ -36,8 +36,8 @@ class PublicRoomChunk:
 
 # Types for callbacks to be registered via the module api
 FETCH_PUBLIC_ROOMS_CALLBACK = Callable[
-    [int, Optional[int], Optional[dict], Optional[str], Optional[str]],
-    Awaitable[Tuple[Iterable[PublicRoomChunk], bool]],
+    [int, Optional[Tuple[int, bool]], Optional[dict], Optional[str], Optional[str]],
+    Awaitable[Tuple[Iterable[PublicRoom], bool]],
 ]
 
 
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 7699cc8d1b..ab992c5159 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -476,8 +476,9 @@ class PublicRoomListRestServlet(RestServlet):
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         server = parse_string(request, "server")
 
+        requester: Optional[Requester] = None
         try:
-            await self.auth.get_user_by_req(request, allow_guest=True)
+            requester = await self.auth.get_user_by_req(request, allow_guest=True)
         except InvalidClientCredentialsError as e:
             # Option to allow servers to require auth when accessing
             # /publicRooms via CS API. This is especially helpful in private
@@ -516,8 +517,15 @@ class PublicRoomListRestServlet(RestServlet):
                 server, limit=limit, since_token=since_token
             )
         else:
+            # If a user we know made this request, pass that information to the
+            # public rooms list handler.
+            if requester is None:
+                from_client_mxid = None
+            else:
+                from_client_mxid = requester.user.to_string()
+
             data = await handler.get_local_public_room_list(
-                limit=limit, since_token=since_token
+                limit=limit, since_token=since_token, from_client_mxid=from_client_mxid
             )
 
         return 200, data
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index dd7dbb6901..d6e8f62f5b 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -16,6 +16,7 @@
 import logging
 from abc import abstractmethod
 from enum import Enum
+from synapse.api.constants import HistoryVisibility
 from typing import (
     TYPE_CHECKING,
     AbstractSet,
@@ -518,7 +519,26 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         ret_val = await self.db_pool.runInteraction(
             "get_largest_public_rooms", _get_largest_public_rooms_txn
         )
-        return ret_val
+
+        def build_room_entry(room: JsonDict) -> JsonDict:
+            entry = {
+                "room_id": room["room_id"],
+                "name": room["name"],
+                "topic": room["topic"],
+                "canonical_alias": room["canonical_alias"],
+                "num_joined_members": room["joined_members"],
+                "avatar_url": room["avatar"],
+                "world_readable": room["history_visibility"]
+                                  == HistoryVisibility.WORLD_READABLE,
+                "guest_can_join": room["guest_access"] == "can_join",
+                "join_rule": room["join_rules"],
+                "room_type": room["room_type"],
+            }
+
+            # Filter out Nones – rather omit the field altogether
+            return {k: v for k, v in entry.items() if v is not None}
+
+        return [build_room_entry(r) for r in ret_val]
 
     @cached(max_entries=10000)
     async def is_room_blocked(self, room_id: str) -> Optional[bool]:
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 5cee9c3194..ba2effd413 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -936,6 +936,19 @@ class UserInfo:
     is_shadow_banned: bool
 
 
+class PublicRoomsChunk:
+    room_id: str
+    name: str
+    topic: str
+    num_joined_members: int
+    canonical_alias: str
+    avatar_url: str
+    world_readable: bool
+    guest_can_join: bool
+    join_rule: str
+    room_type: str
+
+
 class UserProfile(TypedDict):
     user_id: str
     display_name: Optional[str]