summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15518.feature1
-rw-r--r--synapse/federation/transport/server/__init__.py7
-rw-r--r--synapse/handlers/room_list.py186
-rw-r--r--synapse/module_api/__init__.py17
-rw-r--r--synapse/module_api/callbacks/__init__.py3
-rw-r--r--synapse/module_api/callbacks/public_rooms_callbacks.py45
-rw-r--r--synapse/rest/client/room.py12
-rw-r--r--synapse/storage/databases/main/room.py69
-rw-r--r--synapse/types/__init__.py14
-rw-r--r--synapse/util/__init__.py4
-rw-r--r--tests/module_api/test_fetch_public_rooms.py261
-rw-r--r--tests/rest/client/test_public_rooms.py148
12 files changed, 679 insertions, 88 deletions
diff --git a/changelog.d/15518.feature b/changelog.d/15518.feature
new file mode 100644
index 0000000000..325724bad3
--- /dev/null
+++ b/changelog.d/15518.feature
@@ -0,0 +1 @@
+Allow modules to provide local /publicRooms results.
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index 55d2cd0a9a..16663e97ce 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -149,7 +149,10 @@ class PublicRoomList(BaseFederationServlet):
             limit = None
 
         data = await self.handler.get_local_public_room_list(
-            limit, since_token, network_tuple=network_tuple, from_federation=True
+            limit,
+            since_token,
+            network_tuple=network_tuple,
+            from_remote_server_name=origin,
         )
         return 200, data
 
@@ -190,7 +193,7 @@ class PublicRoomList(BaseFederationServlet):
             since_token=since_token,
             search_filter=search_filter,
             network_tuple=network_tuple,
-            from_federation=True,
+            from_remote_server_name=origin,
         )
 
         return 200, data
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 36e2db8975..6049708aba 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 import attr
 import msgpack
@@ -33,7 +33,8 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
-from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID
+from synapse.types import JsonDict, JsonMapping, PublicRoom, ThirdPartyInstanceID
+from synapse.util import filter_none
 from synapse.util.caches.descriptors import _CacheContext, cached
 from synapse.util.caches.response_cache import ResponseCache
 
@@ -60,6 +61,7 @@ class RoomListHandler:
         self.remote_response_cache: ResponseCache[
             Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
         ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
+        self._module_api_callbacks = hs.get_module_api_callbacks().public_rooms
 
     async def get_local_public_room_list(
         self,
@@ -67,7 +69,8 @@ class RoomListHandler:
         since_token: Optional[str] = None,
         search_filter: Optional[dict] = None,
         network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
-        from_federation: bool = False,
+        from_client_mxid: Optional[str] = None,
+        from_remote_server_name: Optional[str] = None,
     ) -> JsonDict:
         """Generate a local public room list.
 
@@ -75,14 +78,20 @@ class RoomListHandler:
         party network. A client can ask for a specific list or to return all.
 
         Args:
-            limit
-            since_token
-            search_filter
+            limit: The maximum number of rooms to return, or None to return all rooms.
+            since_token: A pagination token, or None to return the head of the public
+                rooms list.
+            search_filter: An optional dictionary with the following keys:
+                * generic_search_term: A string to search for in room ...
+                * room_types: A list to filter returned rooms by their type. If None or
+                    an empty list is passed, rooms will not be filtered by type.
             network_tuple: Which public list to use.
                 This can be (None, None) to indicate the main list, or a particular
                 appservice and network id to use an appservice specific one.
                 Setting to None returns all public rooms across all lists.
-            from_federation: true iff the request comes from the federation API
+            from_client_mxid: A user's MXID if this request came from a registered user.
+            from_remote_server_name: A remote homeserver's server name, if this
+                request came from the federation API.
         """
         if not self.enable_room_list_search:
             return {"chunk": [], "total_room_count_estimate": 0}
@@ -105,7 +114,8 @@ class RoomListHandler:
                 since_token,
                 search_filter,
                 network_tuple=network_tuple,
-                from_federation=from_federation,
+                from_client_mxid=from_client_mxid,
+                from_remote_server_name=from_remote_server_name,
             )
 
         key = (limit, since_token, network_tuple)
@@ -115,7 +125,8 @@ class RoomListHandler:
             limit,
             since_token,
             network_tuple=network_tuple,
-            from_federation=from_federation,
+            from_client_mxid=from_client_mxid,
+            from_remote_server_name=from_remote_server_name,
         )
 
     async def _get_public_room_list(
@@ -124,7 +135,8 @@ class RoomListHandler:
         since_token: Optional[str] = None,
         search_filter: Optional[dict] = None,
         network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
-        from_federation: bool = False,
+        from_client_mxid: Optional[str] = None,
+        from_remote_server_name: Optional[str] = None,
     ) -> JsonDict:
         """Generate a public room list.
         Args:
@@ -135,65 +147,106 @@ class RoomListHandler:
                 This can be (None, None) to indicate the main list, or a particular
                 appservice and network id to use an appservice specific one.
                 Setting to None returns all public rooms across all lists.
-            from_federation: Whether this request originated from a
-                federating server or a client. Used for room filtering.
+            from_client_mxid: A user's MXID if this request came from a registered user.
+            from_remote_server_name: A remote homeserver's server name, if this
+                request came from the federation API.
         """
 
         # Pagination tokens work by storing the room ID sent in the last batch,
         # plus the direction (forwards or backwards). Next batch tokens always
         # go forwards, prev batch tokens always go backwards.
 
+        forwards = True
+        last_joined_members = None
+        last_room_id = None
+        last_module_index = None
         if since_token:
             batch_token = RoomListNextBatch.from_token(since_token)
-
-            bounds: Optional[Tuple[int, str]] = (
-                batch_token.last_joined_members,
-                batch_token.last_room_id,
-            )
+            print(batch_token)
             forwards = batch_token.direction_is_forward
-            has_batch_token = True
-        else:
-            bounds = None
-
-            forwards = True
-            has_batch_token = False
+            last_joined_members = batch_token.last_joined_members
+            last_room_id = batch_token.last_room_id
+            last_module_index = batch_token.last_module_index
 
-        # we request one more than wanted to see if there are more pages to come
+        # We request one more than wanted to see if there are more pages to come
         probing_limit = limit + 1 if limit is not None else None
 
-        results = await self.store.get_largest_public_rooms(
+        # We bucket results per joined members number since we want to keep order
+        # per joined members number
+        num_joined_members_buckets: Dict[int, List[PublicRoom]] = {}
+        room_ids_to_module_index: Dict[str, int] = {}
+
+        local_public_rooms = await self.store.get_largest_public_rooms(
             network_tuple,
             search_filter,
             probing_limit,
-            bounds=bounds,
+            bounds=(
+                last_joined_members,
+                last_room_id if last_module_index is None else None,
+            ),
             forwards=forwards,
-            ignore_non_federatable=from_federation,
+            ignore_non_federatable=bool(from_remote_server_name),
         )
 
-        def build_room_entry(room: JsonDict) -> JsonDict:
-            entry = {
-                "room_id": room["room_id"],
-                "name": room["name"],
-                "topic": room["topic"],
-                "canonical_alias": room["canonical_alias"],
-                "num_joined_members": room["joined_members"],
-                "avatar_url": room["avatar"],
-                "world_readable": room["history_visibility"]
-                == HistoryVisibility.WORLD_READABLE,
-                "guest_can_join": room["guest_access"] == "can_join",
-                "join_rule": room["join_rules"],
-                "room_type": room["room_type"],
-            }
+        for room in local_public_rooms:
+            num_joined_members_buckets.setdefault(room.num_joined_members, []).append(
+                room
+            )
+
+        nb_modules = len(self._module_api_callbacks.fetch_public_rooms_callbacks)
 
-            # Filter out Nones – rather omit the field altogether
-            return {k: v for k, v in entry.items() if v is not None}
+        module_range = range(nb_modules)
+        # if not forwards:
+        #     module_range = reversed(module_range)
 
-        results = [build_room_entry(r) for r in results]
+        for module_index in module_range:
+            fetch_public_rooms = (
+                self._module_api_callbacks.fetch_public_rooms_callbacks[module_index]
+            )
+            # Ask each module for a list of public rooms given the last_joined_members
+            # value from the since token and the probing limit
+            # last_joined_members needs to be reduce by one if this module has already
+            # given its result for last_joined_members
+            module_last_joined_members = last_joined_members
+            if module_last_joined_members is not None and last_module_index is not None:
+                if forwards and module_index < last_module_index:
+                    module_last_joined_members = module_last_joined_members - 1
+                # if not forwards and module_index > last_module_index:
+                #     module_last_joined_members = module_last_joined_members - 1
+
+            module_public_rooms = await fetch_public_rooms(
+                network_tuple,
+                search_filter,
+                probing_limit,
+                (
+                    module_last_joined_members,
+                    last_room_id if last_module_index == module_index else None,
+                ),
+                forwards,
+            )
+
+            for room in module_public_rooms:
+                num_joined_members_buckets.setdefault(
+                    room.num_joined_members, []
+                ).append(room)
+                room_ids_to_module_index[room.room_id] = module_index
+
+        nums_joined_members = list(num_joined_members_buckets.keys())
+        nums_joined_members.sort(reverse=forwards)
+
+        results = []
+        for num_joined_members in nums_joined_members:
+            rooms = num_joined_members_buckets[num_joined_members]
+            # if not forwards:
+            #     rooms.reverse()
+            results += rooms
+
+        print([(r.room_id, r.num_joined_members) for r in results])
 
         response: JsonDict = {}
         num_results = len(results)
-        if limit is not None:
-            more_to_come = num_results == probing_limit
+        if limit is not None and probing_limit is not None:
+            more_to_come = num_results >= probing_limit
 
             # Depending on direction we trim either the front or back.
             if forwards:
@@ -203,46 +256,60 @@ class RoomListHandler:
         else:
             more_to_come = False
 
+        print([(r.room_id, r.num_joined_members) for r in results])
+
         if num_results > 0:
             final_entry = results[-1]
             initial_entry = results[0]
 
             if forwards:
-                if has_batch_token:
+                if since_token is not None:
                     # If there was a token given then we assume that there
                     # must be previous results.
                     response["prev_batch"] = RoomListNextBatch(
-                        last_joined_members=initial_entry["num_joined_members"],
-                        last_room_id=initial_entry["room_id"],
+                        last_joined_members=initial_entry.num_joined_members,
+                        last_room_id=initial_entry.room_id,
                         direction_is_forward=False,
+                        last_module_index=room_ids_to_module_index.get(
+                            initial_entry.room_id
+                        ),
                     ).to_token()
 
                 if more_to_come:
                     response["next_batch"] = RoomListNextBatch(
-                        last_joined_members=final_entry["num_joined_members"],
-                        last_room_id=final_entry["room_id"],
+                        last_joined_members=final_entry.num_joined_members,
+                        last_room_id=final_entry.room_id,
                         direction_is_forward=True,
+                        last_module_index=room_ids_to_module_index.get(
+                            final_entry.room_id
+                        ),
                     ).to_token()
             else:
-                if has_batch_token:
+                if since_token is not None:
                     response["next_batch"] = RoomListNextBatch(
-                        last_joined_members=final_entry["num_joined_members"],
-                        last_room_id=final_entry["room_id"],
+                        last_joined_members=final_entry.num_joined_members,
+                        last_room_id=final_entry.room_id,
                         direction_is_forward=True,
+                        last_module_index=room_ids_to_module_index.get(
+                            final_entry.room_id
+                        ),
                     ).to_token()
 
                 if more_to_come:
                     response["prev_batch"] = RoomListNextBatch(
-                        last_joined_members=initial_entry["num_joined_members"],
-                        last_room_id=initial_entry["room_id"],
+                        last_joined_members=initial_entry.num_joined_members,
+                        last_room_id=initial_entry.room_id,
                         direction_is_forward=False,
+                        last_module_index=room_ids_to_module_index.get(
+                            initial_entry.room_id
+                        ),
                     ).to_token()
 
-        response["chunk"] = results
+        response["chunk"] = [attr.asdict(r, filter=filter_none) for r in results]
 
         response["total_room_count_estimate"] = await self.store.count_public_rooms(
             network_tuple,
-            ignore_non_federatable=from_federation,
+            ignore_non_federatable=bool(from_remote_server_name),
             search_filter=search_filter,
         )
 
@@ -484,11 +551,13 @@ class RoomListNextBatch:
     last_joined_members: int  # The count to get rooms after/before
     last_room_id: str  # The room_id to get rooms after/before
     direction_is_forward: bool  # True if this is a next_batch, false if prev_batch
+    last_module_index: Optional[int] = None
 
     KEY_DICT = {
         "last_joined_members": "m",
         "last_room_id": "r",
         "direction_is_forward": "d",
+        "last_module_index": "i",
     }
 
     REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@@ -501,6 +570,7 @@ class RoomListNextBatch:
         )
 
     def to_token(self) -> str:
+        # print(self)
         return encode_base64(
             msgpack.dumps(
                 {self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()}
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 65e2aca456..c5dbe9c757 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -79,6 +79,9 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
     ON_LEGACY_SEND_MAIL_CALLBACK,
     ON_USER_REGISTRATION_CALLBACK,
 )
+from synapse.module_api.callbacks.public_rooms_callbacks import (
+    FETCH_PUBLIC_ROOMS_CALLBACK,
+)
 from synapse.module_api.callbacks.spamchecker_callbacks import (
     CHECK_EVENT_FOR_SPAM_CALLBACK,
     CHECK_LOGIN_FOR_SPAM_CALLBACK,
@@ -170,6 +173,7 @@ __all__ = [
     "DirectServeJsonResource",
     "ModuleApi",
     "PRESENCE_ALL_USERS",
+    "PublicRoomChunk",
     "LoginResponse",
     "JsonDict",
     "JsonMapping",
@@ -472,6 +476,19 @@ class ModuleApi:
             on_account_data_updated=on_account_data_updated,
         )
 
+    def register_public_rooms_callbacks(
+        self,
+        *,
+        fetch_public_rooms: Optional[FETCH_PUBLIC_ROOMS_CALLBACK] = None,
+    ) -> None:
+        """Registers callback functions related to the public room directory.
+
+        Added in Synapse v1.80.0
+        """
+        return self._callbacks.public_rooms.register_callbacks(
+            fetch_public_rooms=fetch_public_rooms,
+        )
+
     def register_web_resource(self, path: str, resource: Resource) -> None:
         """Registers a web resource to be served at the given path.
 
diff --git a/synapse/module_api/callbacks/__init__.py b/synapse/module_api/callbacks/__init__.py
index dcb036552b..5a0ff22b10 100644
--- a/synapse/module_api/callbacks/__init__.py
+++ b/synapse/module_api/callbacks/__init__.py
@@ -27,9 +27,12 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
     ThirdPartyEventRulesModuleApiCallbacks,
 )
 
+from .public_rooms_callbacks import PublicRoomsModuleApiCallbacks
+
 
 class ModuleApiCallbacks:
     def __init__(self, hs: "HomeServer") -> None:
         self.account_validity = AccountValidityModuleApiCallbacks()
         self.spam_checker = SpamCheckerModuleApiCallbacks(hs)
         self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks(hs)
+        self.public_rooms = PublicRoomsModuleApiCallbacks()
diff --git a/synapse/module_api/callbacks/public_rooms_callbacks.py b/synapse/module_api/callbacks/public_rooms_callbacks.py
new file mode 100644
index 0000000000..b3eeb84606
--- /dev/null
+++ b/synapse/module_api/callbacks/public_rooms_callbacks.py
@@ -0,0 +1,45 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Awaitable, Callable, List, Optional, Tuple
+
+from synapse.types import PublicRoom, ThirdPartyInstanceID
+
+logger = logging.getLogger(__name__)
+
+
+# Types for callbacks to be registered via the module api
+FETCH_PUBLIC_ROOMS_CALLBACK = Callable[
+    [
+        Optional[ThirdPartyInstanceID],  # network_tuple
+        Optional[dict],  # search_filter
+        Optional[int],  # limit
+        Tuple[Optional[int], Optional[str]],  # bounds
+        bool,  # forwards
+    ],
+    Awaitable[List[PublicRoom]],
+]
+
+
+class PublicRoomsModuleApiCallbacks:
+    def __init__(self) -> None:
+        self.fetch_public_rooms_callbacks: List[FETCH_PUBLIC_ROOMS_CALLBACK] = []
+
+    def register_callbacks(
+        self,
+        fetch_public_rooms: Optional[FETCH_PUBLIC_ROOMS_CALLBACK] = None,
+    ) -> None:
+        if fetch_public_rooms is not None:
+            self.fetch_public_rooms_callbacks.append(fetch_public_rooms)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 553938ce9d..99ad387b50 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -476,8 +476,9 @@ class PublicRoomListRestServlet(RestServlet):
     async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
         server = parse_string(request, "server")
 
+        requester: Optional[Requester] = None
         try:
-            await self.auth.get_user_by_req(request, allow_guest=True)
+            requester = await self.auth.get_user_by_req(request, allow_guest=True)
         except InvalidClientCredentialsError as e:
             # Option to allow servers to require auth when accessing
             # /publicRooms via CS API. This is especially helpful in private
@@ -516,8 +517,15 @@ class PublicRoomListRestServlet(RestServlet):
                 server, limit=limit, since_token=since_token
             )
         else:
+            # If a user we know made this request, pass that information to the
+            # public rooms list handler.
+            if requester is None:
+                from_client_mxid = None
+            else:
+                from_client_mxid = requester.user.to_string()
+
             data = await handler.get_local_public_room_list(
-                limit=limit, since_token=since_token
+                limit=limit, since_token=since_token, from_client_mxid=from_client_mxid
             )
 
         return 200, data
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 1d4d99932b..3972d70358 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -38,6 +38,7 @@ from synapse.api.constants import (
     Direction,
     EventContentFields,
     EventTypes,
+    HistoryVisibility,
     JoinRules,
     PublicRoomsFilterFields,
 )
@@ -61,7 +62,13 @@ from synapse.storage.util.id_generators import (
     MultiWriterIdGenerator,
     StreamIdGenerator,
 )
-from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
+from synapse.types import (
+    JsonDict,
+    PublicRoom,
+    RetentionPolicy,
+    StrCollection,
+    ThirdPartyInstanceID,
+)
 from synapse.util import json_encoder
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.stringutils import MXC_REGEX
@@ -365,21 +372,21 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         network_tuple: Optional[ThirdPartyInstanceID],
         search_filter: Optional[dict],
         limit: Optional[int],
-        bounds: Optional[Tuple[int, str]],
+        bounds: Tuple[Optional[int], Optional[str]],
         forwards: bool,
         ignore_non_federatable: bool = False,
-    ) -> List[Dict[str, Any]]:
+    ) -> List[PublicRoom]:
         """Gets the largest public rooms (where largest is in terms of joined
         members, as tracked in the statistics table).
 
         Args:
             network_tuple
             search_filter
-            limit: Maxmimum number of rows to return, unlimited otherwise.
-            bounds: An uppoer or lower bound to apply to result set if given,
+            limit: Maximum number of rows to return, unlimited otherwise.
+            bounds: An upper or lower bound to apply to result set if given,
                 consists of a joined member count and room_id (these are
                 excluded from result set).
-            forwards: true iff going forwards, going backwards otherwise
+            forwards: true if going forwards, going backwards otherwise
             ignore_non_federatable: If true filters out non-federatable rooms.
 
         Returns:
@@ -413,26 +420,18 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         # Work out the bounds if we're given them, these bounds look slightly
         # odd, but are designed to help query planner use indices by pulling
         # out a common bound.
-        if bounds:
-            last_joined_members, last_room_id = bounds
-            if forwards:
-                where_clauses.append(
-                    """
-                        joined_members <= ? AND (
-                            joined_members < ? OR room_id < ?
-                        )
-                    """
-                )
-            else:
-                where_clauses.append(
-                    """
-                        joined_members >= ? AND (
-                            joined_members > ? OR room_id > ?
-                        )
-                    """
-                )
+        last_joined_members, last_room_id = bounds
+        if last_joined_members is not None:
+            comp = "<" if forwards else ">"
+
+            clause = f"joined_members {comp} ?"
+            query_args += [last_joined_members]
 
-            query_args += [last_joined_members, last_joined_members, last_room_id]
+            if last_room_id is not None:
+                clause += f" OR (joined_members = ? AND room_id {comp} ?)"
+                query_args += [last_joined_members, last_room_id]
+
+            where_clauses.append(clause)
 
         if ignore_non_federatable:
             where_clauses.append("is_federatable")
@@ -518,7 +517,25 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
         ret_val = await self.db_pool.runInteraction(
             "get_largest_public_rooms", _get_largest_public_rooms_txn
         )
-        return ret_val
+
+        def build_room_entry(room: JsonDict) -> PublicRoom:
+            entry = PublicRoom(
+                room_id=room["room_id"],
+                name=room["name"],
+                topic=room["topic"],
+                canonical_alias=room["canonical_alias"],
+                num_joined_members=room["joined_members"],
+                avatar_url=room["avatar"],
+                world_readable=room["history_visibility"]
+                == HistoryVisibility.WORLD_READABLE,
+                guest_can_join=room["guest_access"] == "can_join",
+                join_rule=room["join_rules"],
+                room_type=room["room_type"],
+            )
+
+            return entry
+
+        return [build_room_entry(r) for r in ret_val]
 
     @cached(max_entries=10000)
     async def is_room_blocked(self, room_id: str) -> Optional[bool]:
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 09a88c86a7..77d9679582 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -1045,6 +1045,20 @@ class UserInfo:
     locked: bool
 
 
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class PublicRoom:
+    room_id: str
+    num_joined_members: int
+    world_readable: bool
+    guest_can_join: bool
+    name: Optional[str] = None
+    topic: Optional[str] = None
+    canonical_alias: Optional[str] = None
+    avatar_url: Optional[str] = None
+    join_rule: Optional[str] = None
+    room_type: Optional[str] = None
+
+
 class UserProfile(TypedDict):
     user_id: str
     display_name: Optional[str]
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 9f3b8741c1..ba7d4f1246 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -206,3 +206,7 @@ class ExceptionBundle(Exception):
             parts.append(str(e))
         super().__init__("\n  - ".join(parts))
         self.exceptions = exceptions
+
+
+def filter_none(attr: attr.Attribute, value: Any) -> bool:
+    return value is not None
diff --git a/tests/module_api/test_fetch_public_rooms.py b/tests/module_api/test_fetch_public_rooms.py
new file mode 100644
index 0000000000..8daf8c5c40
--- /dev/null
+++ b/tests/module_api/test_fetch_public_rooms.py
@@ -0,0 +1,261 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List, Optional, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin, login, room
+from synapse.server import HomeServer
+from synapse.types import PublicRoom, ThirdPartyInstanceID
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class FetchPublicRoomsTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+        config["allow_public_rooms_without_auth"] = True
+        self.hs = self.setup_test_homeserver(config=config)
+        self.url = "/_matrix/client/r0/publicRooms"
+
+        return self.hs
+
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        self._store = homeserver.get_datastores().main
+        self._module_api = homeserver.get_module_api()
+
+        async def module1_cb(
+            network_tuple: Optional[ThirdPartyInstanceID],
+            search_filter: Optional[dict],
+            limit: Optional[int],
+            bounds: Tuple[Optional[int], Optional[str]],
+            forwards: bool,
+        ) -> List[PublicRoom]:
+            room1 = PublicRoom(
+                room_id="!one_members:module1",
+                num_joined_members=1,
+                world_readable=True,
+                guest_can_join=False,
+            )
+            room3 = PublicRoom(
+                room_id="!three_members:module1",
+                num_joined_members=3,
+                world_readable=True,
+                guest_can_join=False,
+            )
+            room3_2 = PublicRoom(
+                room_id="!three_members_2:module1",
+                num_joined_members=3,
+                world_readable=True,
+                guest_can_join=False,
+            )
+
+            (last_joined_members, last_room_id) = bounds
+
+            if forwards:
+                result = [room3_2, room3, room1]
+            else:
+                result = [room1, room3, room3_2]
+
+            if last_joined_members is not None:
+                if last_joined_members == 1:
+                    if forwards:
+                        if last_room_id == room1.room_id:
+                            result = []
+                        else:
+                            result = [room1]
+                    else:
+                        if last_room_id == room1.room_id:
+                            result = [room3, room3_2]
+                        else:
+                            result = [room1, room3, room3_2]
+                elif last_joined_members == 2:
+                    if forwards:
+                        result = [room1]
+                    else:
+                        result = [room3, room3_2]
+                elif last_joined_members == 3:
+                    if forwards:
+                        if last_room_id == room3.room_id:
+                            result = [room1]
+                        elif last_room_id == room3_2.room_id:
+                            result = [room3, room1]
+                    else:
+                        if last_room_id == room3.room_id:
+                            result = [room3_2]
+                        elif last_room_id == room3_2.room_id:
+                            result = []
+                        else:
+                            result = [room3, room3_2]
+
+            if limit is not None:
+                result = result[:limit]
+
+            return result
+
+        async def module2_cb(
+            network_tuple: Optional[ThirdPartyInstanceID],
+            search_filter: Optional[dict],
+            limit: Optional[int],
+            bounds: Tuple[Optional[int], Optional[str]],
+            forwards: bool,
+        ) -> List[PublicRoom]:
+            room3 = PublicRoom(
+                room_id="!three_members:module2",
+                num_joined_members=3,
+                world_readable=True,
+                guest_can_join=False,
+            )
+
+            (last_joined_members, last_room_id) = bounds
+
+            result = [room3]
+
+            if last_joined_members is not None:
+                if forwards:
+                    if last_joined_members < 3:
+                        result = []
+                    elif last_joined_members == 3 and last_room_id == room3.room_id:
+                        result = []
+                else:
+                    if last_joined_members > 3:
+                        result = []
+                    elif last_joined_members == 3 and last_room_id == room3.room_id:
+                        result = []
+
+            return result
+
+        self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module1_cb)
+        self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module2_cb)
+
+        user = self.register_user("alice", "pass")
+        token = self.login(user, "pass")
+        user2 = self.register_user("alice2", "pass")
+        token2 = self.login(user2, "pass")
+        user3 = self.register_user("alice3", "pass")
+        token3 = self.login(user3, "pass")
+
+        # Create a room with 2 people
+        room_id = self.helper.create_room_as(
+            user,
+            is_public=True,
+            extra_content={"visibility": "public"},
+            tok=token,
+        )
+        self.helper.join(room_id, user2, tok=token2)
+
+        # Create a room with 3 people
+        room_id = self.helper.create_room_as(
+            user,
+            is_public=True,
+            extra_content={"visibility": "public"},
+            tok=token,
+        )
+        self.helper.join(room_id, user2, tok=token2)
+        self.helper.join(room_id, user3, tok=token3)
+
+    def test_no_limit(self) -> None:
+        channel = self.make_request("GET", self.url)
+        chunk = channel.json_body["chunk"]
+
+        self.assertEquals(len(chunk), 6)
+        for i in range(4):
+            self.assertEquals(chunk[i]["num_joined_members"], 3)
+        self.assertEquals(chunk[4]["num_joined_members"], 2)
+        self.assertEquals(chunk[5]["num_joined_members"], 1)
+
+    def test_pagination_limit_1(self) -> None:
+        returned_three_members_rooms = set()
+
+        next_batch = None
+        for _i in range(4):
+            since_query_str = f"&since={next_batch}" if next_batch else ""
+            channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}")
+            chunk = channel.json_body["chunk"]
+            self.assertEquals(chunk[0]["num_joined_members"], 3)
+            self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
+            returned_three_members_rooms.add(chunk[0]["room_id"])
+            next_batch = channel.json_body["next_batch"]
+
+        channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
+        chunk = channel.json_body["chunk"]
+        self.assertEquals(chunk[0]["num_joined_members"], 2)
+        next_batch = channel.json_body["next_batch"]
+
+        channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
+        chunk = channel.json_body["chunk"]
+        self.assertEquals(chunk[0]["num_joined_members"], 1)
+        prev_batch = channel.json_body["prev_batch"]
+
+        self.assertNotIn("next_batch", channel.json_body)
+
+        channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
+        chunk = channel.json_body["chunk"]
+        self.assertEquals(chunk[0]["num_joined_members"], 2)
+
+        returned_three_members_rooms = set()
+        for _i in range(4):
+            prev_batch = channel.json_body["prev_batch"]
+            channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
+            chunk = channel.json_body["chunk"]
+            self.assertEquals(chunk[0]["num_joined_members"], 3)
+            self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
+            returned_three_members_rooms.add(chunk[0]["room_id"])
+
+        self.assertNotIn("prev_batch", channel.json_body)
+
+    def test_pagination_limit_2(self) -> None:
+        returned_three_members_rooms = set()
+
+        next_batch = None
+        for _i in range(2):
+            since_query_str = f"&since={next_batch}" if next_batch else ""
+            channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}")
+            chunk = channel.json_body["chunk"]
+            self.assertEquals(chunk[0]["num_joined_members"], 3)
+            self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
+            returned_three_members_rooms.add(chunk[0]["room_id"])
+            self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms)
+            returned_three_members_rooms.add(chunk[1]["room_id"])
+            next_batch = channel.json_body["next_batch"]
+
+        channel = self.make_request("GET", f"{self.url}?limit=2&since={next_batch}")
+        chunk = channel.json_body["chunk"]
+        self.assertEquals(chunk[0]["num_joined_members"], 2)
+        self.assertEquals(chunk[1]["num_joined_members"], 1)
+
+        self.assertNotIn("next_batch", channel.json_body)
+
+        returned_three_members_rooms = set()
+
+        for _i in range(2):
+            prev_batch = channel.json_body["prev_batch"]
+            channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}")
+            chunk = channel.json_body["chunk"]
+            self.assertEquals(chunk[0]["num_joined_members"], 3)
+            self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
+            returned_three_members_rooms.add(chunk[0]["room_id"])
+            self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms)
+            returned_three_members_rooms.add(chunk[1]["room_id"])
+
+        self.assertNotIn("prev_batch", channel.json_body)
diff --git a/tests/rest/client/test_public_rooms.py b/tests/rest/client/test_public_rooms.py
new file mode 100644
index 0000000000..3de43760db
--- /dev/null
+++ b/tests/rest/client/test_public_rooms.py
@@ -0,0 +1,148 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin, login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class PublicRoomsTestCase(HomeserverTestCase):
+    servlets = [
+        admin.register_servlets,
+        login.register_servlets,
+        room.register_servlets,
+    ]
+
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+        config = self.default_config()
+        config["allow_public_rooms_without_auth"] = True
+        self.hs = self.setup_test_homeserver(config=config)
+        self.url = "/_matrix/client/r0/publicRooms"
+
+        return self.hs
+
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        self._store = homeserver.get_datastores().main
+
+        user = self.register_user("alice", "pass")
+        token = self.login(user, "pass")
+        user2 = self.register_user("alice2", "pass")
+        token2 = self.login(user2, "pass")
+        user3 = self.register_user("alice3", "pass")
+        token3 = self.login(user3, "pass")
+
+        # Create 10 rooms
+        for _ in range(3):
+            self.helper.create_room_as(
+                user,
+                is_public=True,
+                extra_content={"visibility": "public"},
+                tok=token,
+            )
+
+        for _ in range(3):
+            room_id = self.helper.create_room_as(
+                user,
+                is_public=True,
+                extra_content={"visibility": "public"},
+                tok=token,
+            )
+            self.helper.join(room_id, user2, tok=token2)
+
+        for _ in range(4):
+            room_id = self.helper.create_room_as(
+                user,
+                is_public=True,
+                extra_content={"visibility": "public"},
+                tok=token,
+            )
+            self.helper.join(room_id, user2, tok=token2)
+            self.helper.join(room_id, user3, tok=token3)
+
+    def test_no_limit(self) -> None:
+        channel = self.make_request("GET", self.url)
+        chunk = channel.json_body["chunk"]
+
+        self.assertEquals(len(chunk), 10)
+
+    def test_pagination_limit_1(self) -> None:
+        returned_rooms = set()
+
+        channel = None
+        for i in range(10):
+            next_batch = None if i == 0 else channel.json_body["next_batch"]
+            since_query_str = f"&since={next_batch}" if next_batch else ""
+            channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}")
+            chunk = channel.json_body["chunk"]
+            self.assertEquals(len(chunk), 1)
+            print(chunk[0]["room_id"])
+            self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
+            returned_rooms.add(chunk[0]["room_id"])
+
+        self.assertNotIn("next_batch", channel.json_body)
+
+        returned_rooms = set()
+        returned_rooms.add(chunk[0]["room_id"])
+
+        for i in range(9):
+            print(i)
+            prev_batch = channel.json_body["prev_batch"]
+            channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
+            chunk = channel.json_body["chunk"]
+            self.assertEquals(len(chunk), 1)
+            print(chunk[0]["room_id"])
+            self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
+            returned_rooms.add(chunk[0]["room_id"])
+
+    def test_pagination_limit_2(self) -> None:
+        returned_rooms = set()
+
+        channel = None
+        for i in range(5):
+            next_batch = None if i == 0 else channel.json_body["next_batch"]
+            since_query_str = f"&since={next_batch}" if next_batch else ""
+            channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}")
+            chunk = channel.json_body["chunk"]
+            self.assertEquals(len(chunk), 2)
+            print(chunk[0]["room_id"])
+            self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
+            returned_rooms.add(chunk[0]["room_id"])
+            print(chunk[1]["room_id"])
+            self.assertTrue(chunk[1]["room_id"] not in returned_rooms)
+            returned_rooms.add(chunk[1]["room_id"])
+
+        self.assertNotIn("next_batch", channel.json_body)
+
+        returned_rooms = set()
+        returned_rooms.add(chunk[0]["room_id"])
+        returned_rooms.add(chunk[1]["room_id"])
+
+        for i in range(4):
+            print(i)
+            prev_batch = channel.json_body["prev_batch"]
+            channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}")
+            chunk = channel.json_body["chunk"]
+            self.assertEquals(len(chunk), 2)
+            print(chunk[0]["room_id"])
+            self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
+            returned_rooms.add(chunk[0]["room_id"])
+            print(chunk[1]["room_id"])
+            self.assertTrue(chunk[1]["room_id"] not in returned_rooms)
+            returned_rooms.add(chunk[1]["room_id"])