diff options
-rw-r--r-- | changelog.d/15518.feature | 1 | ||||
-rw-r--r-- | synapse/federation/transport/server/__init__.py | 7 | ||||
-rw-r--r-- | synapse/handlers/room_list.py | 186 | ||||
-rw-r--r-- | synapse/module_api/__init__.py | 17 | ||||
-rw-r--r-- | synapse/module_api/callbacks/__init__.py | 3 | ||||
-rw-r--r-- | synapse/module_api/callbacks/public_rooms_callbacks.py | 45 | ||||
-rw-r--r-- | synapse/rest/client/room.py | 12 | ||||
-rw-r--r-- | synapse/storage/databases/main/room.py | 69 | ||||
-rw-r--r-- | synapse/types/__init__.py | 14 | ||||
-rw-r--r-- | synapse/util/__init__.py | 4 | ||||
-rw-r--r-- | tests/module_api/test_fetch_public_rooms.py | 261 | ||||
-rw-r--r-- | tests/rest/client/test_public_rooms.py | 148 |
12 files changed, 679 insertions, 88 deletions
diff --git a/changelog.d/15518.feature b/changelog.d/15518.feature new file mode 100644 index 0000000000..325724bad3 --- /dev/null +++ b/changelog.d/15518.feature @@ -0,0 +1 @@ +Allow modules to provide local /publicRooms results. diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py index 55d2cd0a9a..16663e97ce 100644 --- a/synapse/federation/transport/server/__init__.py +++ b/synapse/federation/transport/server/__init__.py @@ -149,7 +149,10 @@ class PublicRoomList(BaseFederationServlet): limit = None data = await self.handler.get_local_public_room_list( - limit, since_token, network_tuple=network_tuple, from_federation=True + limit, + since_token, + network_tuple=network_tuple, + from_remote_server_name=origin, ) return 200, data @@ -190,7 +193,7 @@ class PublicRoomList(BaseFederationServlet): since_token=since_token, search_filter=search_filter, network_tuple=network_tuple, - from_federation=True, + from_remote_server_name=origin, ) return 200, data diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py index 36e2db8975..6049708aba 100644 --- a/synapse/handlers/room_list.py +++ b/synapse/handlers/room_list.py @@ -13,7 +13,7 @@ # limitations under the License. import logging -from typing import TYPE_CHECKING, Any, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import attr import msgpack @@ -33,7 +33,8 @@ from synapse.api.errors import ( RequestSendFailed, SynapseError, ) -from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID +from synapse.types import JsonDict, JsonMapping, PublicRoom, ThirdPartyInstanceID +from synapse.util import filter_none from synapse.util.caches.descriptors import _CacheContext, cached from synapse.util.caches.response_cache import ResponseCache @@ -60,6 +61,7 @@ class RoomListHandler: self.remote_response_cache: ResponseCache[ Tuple[str, Optional[int], Optional[str], bool, Optional[str]] ] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000) + self._module_api_callbacks = hs.get_module_api_callbacks().public_rooms async def get_local_public_room_list( self, @@ -67,7 +69,8 @@ class RoomListHandler: since_token: Optional[str] = None, search_filter: Optional[dict] = None, network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID, - from_federation: bool = False, + from_client_mxid: Optional[str] = None, + from_remote_server_name: Optional[str] = None, ) -> JsonDict: """Generate a local public room list. @@ -75,14 +78,20 @@ class RoomListHandler: party network. A client can ask for a specific list or to return all. Args: - limit - since_token - search_filter + limit: The maximum number of rooms to return, or None to return all rooms. + since_token: A pagination token, or None to return the head of the public + rooms list. + search_filter: An optional dictionary with the following keys: + * generic_search_term: A string to search for in room ... + * room_types: A list to filter returned rooms by their type. If None or + an empty list is passed, rooms will not be filtered by type. network_tuple: Which public list to use. This can be (None, None) to indicate the main list, or a particular appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. - from_federation: true iff the request comes from the federation API + from_client_mxid: A user's MXID if this request came from a registered user. + from_remote_server_name: A remote homeserver's server name, if this + request came from the federation API. """ if not self.enable_room_list_search: return {"chunk": [], "total_room_count_estimate": 0} @@ -105,7 +114,8 @@ class RoomListHandler: since_token, search_filter, network_tuple=network_tuple, - from_federation=from_federation, + from_client_mxid=from_client_mxid, + from_remote_server_name=from_remote_server_name, ) key = (limit, since_token, network_tuple) @@ -115,7 +125,8 @@ class RoomListHandler: limit, since_token, network_tuple=network_tuple, - from_federation=from_federation, + from_client_mxid=from_client_mxid, + from_remote_server_name=from_remote_server_name, ) async def _get_public_room_list( @@ -124,7 +135,8 @@ class RoomListHandler: since_token: Optional[str] = None, search_filter: Optional[dict] = None, network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID, - from_federation: bool = False, + from_client_mxid: Optional[str] = None, + from_remote_server_name: Optional[str] = None, ) -> JsonDict: """Generate a public room list. Args: @@ -135,65 +147,106 @@ class RoomListHandler: This can be (None, None) to indicate the main list, or a particular appservice and network id to use an appservice specific one. Setting to None returns all public rooms across all lists. - from_federation: Whether this request originated from a - federating server or a client. Used for room filtering. + from_client_mxid: A user's MXID if this request came from a registered user. + from_remote_server_name: A remote homeserver's server name, if this + request came from the federation API. """ # Pagination tokens work by storing the room ID sent in the last batch, # plus the direction (forwards or backwards). Next batch tokens always # go forwards, prev batch tokens always go backwards. + forwards = True + last_joined_members = None + last_room_id = None + last_module_index = None if since_token: batch_token = RoomListNextBatch.from_token(since_token) - - bounds: Optional[Tuple[int, str]] = ( - batch_token.last_joined_members, - batch_token.last_room_id, - ) + print(batch_token) forwards = batch_token.direction_is_forward - has_batch_token = True - else: - bounds = None - - forwards = True - has_batch_token = False + last_joined_members = batch_token.last_joined_members + last_room_id = batch_token.last_room_id + last_module_index = batch_token.last_module_index - # we request one more than wanted to see if there are more pages to come + # We request one more than wanted to see if there are more pages to come probing_limit = limit + 1 if limit is not None else None - results = await self.store.get_largest_public_rooms( + # We bucket results per joined members number since we want to keep order + # per joined members number + num_joined_members_buckets: Dict[int, List[PublicRoom]] = {} + room_ids_to_module_index: Dict[str, int] = {} + + local_public_rooms = await self.store.get_largest_public_rooms( network_tuple, search_filter, probing_limit, - bounds=bounds, + bounds=( + last_joined_members, + last_room_id if last_module_index is None else None, + ), forwards=forwards, - ignore_non_federatable=from_federation, + ignore_non_federatable=bool(from_remote_server_name), ) - def build_room_entry(room: JsonDict) -> JsonDict: - entry = { - "room_id": room["room_id"], - "name": room["name"], - "topic": room["topic"], - "canonical_alias": room["canonical_alias"], - "num_joined_members": room["joined_members"], - "avatar_url": room["avatar"], - "world_readable": room["history_visibility"] - == HistoryVisibility.WORLD_READABLE, - "guest_can_join": room["guest_access"] == "can_join", - "join_rule": room["join_rules"], - "room_type": room["room_type"], - } + for room in local_public_rooms: + num_joined_members_buckets.setdefault(room.num_joined_members, []).append( + room + ) + + nb_modules = len(self._module_api_callbacks.fetch_public_rooms_callbacks) - # Filter out Nones – rather omit the field altogether - return {k: v for k, v in entry.items() if v is not None} + module_range = range(nb_modules) + # if not forwards: + # module_range = reversed(module_range) - results = [build_room_entry(r) for r in results] + for module_index in module_range: + fetch_public_rooms = ( + self._module_api_callbacks.fetch_public_rooms_callbacks[module_index] + ) + # Ask each module for a list of public rooms given the last_joined_members + # value from the since token and the probing limit + # last_joined_members needs to be reduce by one if this module has already + # given its result for last_joined_members + module_last_joined_members = last_joined_members + if module_last_joined_members is not None and last_module_index is not None: + if forwards and module_index < last_module_index: + module_last_joined_members = module_last_joined_members - 1 + # if not forwards and module_index > last_module_index: + # module_last_joined_members = module_last_joined_members - 1 + + module_public_rooms = await fetch_public_rooms( + network_tuple, + search_filter, + probing_limit, + ( + module_last_joined_members, + last_room_id if last_module_index == module_index else None, + ), + forwards, + ) + + for room in module_public_rooms: + num_joined_members_buckets.setdefault( + room.num_joined_members, [] + ).append(room) + room_ids_to_module_index[room.room_id] = module_index + + nums_joined_members = list(num_joined_members_buckets.keys()) + nums_joined_members.sort(reverse=forwards) + + results = [] + for num_joined_members in nums_joined_members: + rooms = num_joined_members_buckets[num_joined_members] + # if not forwards: + # rooms.reverse() + results += rooms + + print([(r.room_id, r.num_joined_members) for r in results]) response: JsonDict = {} num_results = len(results) - if limit is not None: - more_to_come = num_results == probing_limit + if limit is not None and probing_limit is not None: + more_to_come = num_results >= probing_limit # Depending on direction we trim either the front or back. if forwards: @@ -203,46 +256,60 @@ class RoomListHandler: else: more_to_come = False + print([(r.room_id, r.num_joined_members) for r in results]) + if num_results > 0: final_entry = results[-1] initial_entry = results[0] if forwards: - if has_batch_token: + if since_token is not None: # If there was a token given then we assume that there # must be previous results. response["prev_batch"] = RoomListNextBatch( - last_joined_members=initial_entry["num_joined_members"], - last_room_id=initial_entry["room_id"], + last_joined_members=initial_entry.num_joined_members, + last_room_id=initial_entry.room_id, direction_is_forward=False, + last_module_index=room_ids_to_module_index.get( + initial_entry.room_id + ), ).to_token() if more_to_come: response["next_batch"] = RoomListNextBatch( - last_joined_members=final_entry["num_joined_members"], - last_room_id=final_entry["room_id"], + last_joined_members=final_entry.num_joined_members, + last_room_id=final_entry.room_id, direction_is_forward=True, + last_module_index=room_ids_to_module_index.get( + final_entry.room_id + ), ).to_token() else: - if has_batch_token: + if since_token is not None: response["next_batch"] = RoomListNextBatch( - last_joined_members=final_entry["num_joined_members"], - last_room_id=final_entry["room_id"], + last_joined_members=final_entry.num_joined_members, + last_room_id=final_entry.room_id, direction_is_forward=True, + last_module_index=room_ids_to_module_index.get( + final_entry.room_id + ), ).to_token() if more_to_come: response["prev_batch"] = RoomListNextBatch( - last_joined_members=initial_entry["num_joined_members"], - last_room_id=initial_entry["room_id"], + last_joined_members=initial_entry.num_joined_members, + last_room_id=initial_entry.room_id, direction_is_forward=False, + last_module_index=room_ids_to_module_index.get( + initial_entry.room_id + ), ).to_token() - response["chunk"] = results + response["chunk"] = [attr.asdict(r, filter=filter_none) for r in results] response["total_room_count_estimate"] = await self.store.count_public_rooms( network_tuple, - ignore_non_federatable=from_federation, + ignore_non_federatable=bool(from_remote_server_name), search_filter=search_filter, ) @@ -484,11 +551,13 @@ class RoomListNextBatch: last_joined_members: int # The count to get rooms after/before last_room_id: str # The room_id to get rooms after/before direction_is_forward: bool # True if this is a next_batch, false if prev_batch + last_module_index: Optional[int] = None KEY_DICT = { "last_joined_members": "m", "last_room_id": "r", "direction_is_forward": "d", + "last_module_index": "i", } REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()} @@ -501,6 +570,7 @@ class RoomListNextBatch: ) def to_token(self) -> str: + # print(self) return encode_base64( msgpack.dumps( {self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()} diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py index 65e2aca456..c5dbe9c757 100644 --- a/synapse/module_api/__init__.py +++ b/synapse/module_api/__init__.py @@ -79,6 +79,9 @@ from synapse.module_api.callbacks.account_validity_callbacks import ( ON_LEGACY_SEND_MAIL_CALLBACK, ON_USER_REGISTRATION_CALLBACK, ) +from synapse.module_api.callbacks.public_rooms_callbacks import ( + FETCH_PUBLIC_ROOMS_CALLBACK, +) from synapse.module_api.callbacks.spamchecker_callbacks import ( CHECK_EVENT_FOR_SPAM_CALLBACK, CHECK_LOGIN_FOR_SPAM_CALLBACK, @@ -170,6 +173,7 @@ __all__ = [ "DirectServeJsonResource", "ModuleApi", "PRESENCE_ALL_USERS", + "PublicRoomChunk", "LoginResponse", "JsonDict", "JsonMapping", @@ -472,6 +476,19 @@ class ModuleApi: on_account_data_updated=on_account_data_updated, ) + def register_public_rooms_callbacks( + self, + *, + fetch_public_rooms: Optional[FETCH_PUBLIC_ROOMS_CALLBACK] = None, + ) -> None: + """Registers callback functions related to the public room directory. + + Added in Synapse v1.80.0 + """ + return self._callbacks.public_rooms.register_callbacks( + fetch_public_rooms=fetch_public_rooms, + ) + def register_web_resource(self, path: str, resource: Resource) -> None: """Registers a web resource to be served at the given path. diff --git a/synapse/module_api/callbacks/__init__.py b/synapse/module_api/callbacks/__init__.py index dcb036552b..5a0ff22b10 100644 --- a/synapse/module_api/callbacks/__init__.py +++ b/synapse/module_api/callbacks/__init__.py @@ -27,9 +27,12 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import ( ThirdPartyEventRulesModuleApiCallbacks, ) +from .public_rooms_callbacks import PublicRoomsModuleApiCallbacks + class ModuleApiCallbacks: def __init__(self, hs: "HomeServer") -> None: self.account_validity = AccountValidityModuleApiCallbacks() self.spam_checker = SpamCheckerModuleApiCallbacks(hs) self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks(hs) + self.public_rooms = PublicRoomsModuleApiCallbacks() diff --git a/synapse/module_api/callbacks/public_rooms_callbacks.py b/synapse/module_api/callbacks/public_rooms_callbacks.py new file mode 100644 index 0000000000..b3eeb84606 --- /dev/null +++ b/synapse/module_api/callbacks/public_rooms_callbacks.py @@ -0,0 +1,45 @@ +# Copyright 2023 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from typing import Awaitable, Callable, List, Optional, Tuple + +from synapse.types import PublicRoom, ThirdPartyInstanceID + +logger = logging.getLogger(__name__) + + +# Types for callbacks to be registered via the module api +FETCH_PUBLIC_ROOMS_CALLBACK = Callable[ + [ + Optional[ThirdPartyInstanceID], # network_tuple + Optional[dict], # search_filter + Optional[int], # limit + Tuple[Optional[int], Optional[str]], # bounds + bool, # forwards + ], + Awaitable[List[PublicRoom]], +] + + +class PublicRoomsModuleApiCallbacks: + def __init__(self) -> None: + self.fetch_public_rooms_callbacks: List[FETCH_PUBLIC_ROOMS_CALLBACK] = [] + + def register_callbacks( + self, + fetch_public_rooms: Optional[FETCH_PUBLIC_ROOMS_CALLBACK] = None, + ) -> None: + if fetch_public_rooms is not None: + self.fetch_public_rooms_callbacks.append(fetch_public_rooms) diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py index 553938ce9d..99ad387b50 100644 --- a/synapse/rest/client/room.py +++ b/synapse/rest/client/room.py @@ -476,8 +476,9 @@ class PublicRoomListRestServlet(RestServlet): async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]: server = parse_string(request, "server") + requester: Optional[Requester] = None try: - await self.auth.get_user_by_req(request, allow_guest=True) + requester = await self.auth.get_user_by_req(request, allow_guest=True) except InvalidClientCredentialsError as e: # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private @@ -516,8 +517,15 @@ class PublicRoomListRestServlet(RestServlet): server, limit=limit, since_token=since_token ) else: + # If a user we know made this request, pass that information to the + # public rooms list handler. + if requester is None: + from_client_mxid = None + else: + from_client_mxid = requester.user.to_string() + data = await handler.get_local_public_room_list( - limit=limit, since_token=since_token + limit=limit, since_token=since_token, from_client_mxid=from_client_mxid ) return 200, data diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py index 1d4d99932b..3972d70358 100644 --- a/synapse/storage/databases/main/room.py +++ b/synapse/storage/databases/main/room.py @@ -38,6 +38,7 @@ from synapse.api.constants import ( Direction, EventContentFields, EventTypes, + HistoryVisibility, JoinRules, PublicRoomsFilterFields, ) @@ -61,7 +62,13 @@ from synapse.storage.util.id_generators import ( MultiWriterIdGenerator, StreamIdGenerator, ) -from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID +from synapse.types import ( + JsonDict, + PublicRoom, + RetentionPolicy, + StrCollection, + ThirdPartyInstanceID, +) from synapse.util import json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.stringutils import MXC_REGEX @@ -365,21 +372,21 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): network_tuple: Optional[ThirdPartyInstanceID], search_filter: Optional[dict], limit: Optional[int], - bounds: Optional[Tuple[int, str]], + bounds: Tuple[Optional[int], Optional[str]], forwards: bool, ignore_non_federatable: bool = False, - ) -> List[Dict[str, Any]]: + ) -> List[PublicRoom]: """Gets the largest public rooms (where largest is in terms of joined members, as tracked in the statistics table). Args: network_tuple search_filter - limit: Maxmimum number of rows to return, unlimited otherwise. - bounds: An uppoer or lower bound to apply to result set if given, + limit: Maximum number of rows to return, unlimited otherwise. + bounds: An upper or lower bound to apply to result set if given, consists of a joined member count and room_id (these are excluded from result set). - forwards: true iff going forwards, going backwards otherwise + forwards: true if going forwards, going backwards otherwise ignore_non_federatable: If true filters out non-federatable rooms. Returns: @@ -413,26 +420,18 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): # Work out the bounds if we're given them, these bounds look slightly # odd, but are designed to help query planner use indices by pulling # out a common bound. - if bounds: - last_joined_members, last_room_id = bounds - if forwards: - where_clauses.append( - """ - joined_members <= ? AND ( - joined_members < ? OR room_id < ? - ) - """ - ) - else: - where_clauses.append( - """ - joined_members >= ? AND ( - joined_members > ? OR room_id > ? - ) - """ - ) + last_joined_members, last_room_id = bounds + if last_joined_members is not None: + comp = "<" if forwards else ">" + + clause = f"joined_members {comp} ?" + query_args += [last_joined_members] - query_args += [last_joined_members, last_joined_members, last_room_id] + if last_room_id is not None: + clause += f" OR (joined_members = ? AND room_id {comp} ?)" + query_args += [last_joined_members, last_room_id] + + where_clauses.append(clause) if ignore_non_federatable: where_clauses.append("is_federatable") @@ -518,7 +517,25 @@ class RoomWorkerStore(CacheInvalidationWorkerStore): ret_val = await self.db_pool.runInteraction( "get_largest_public_rooms", _get_largest_public_rooms_txn ) - return ret_val + + def build_room_entry(room: JsonDict) -> PublicRoom: + entry = PublicRoom( + room_id=room["room_id"], + name=room["name"], + topic=room["topic"], + canonical_alias=room["canonical_alias"], + num_joined_members=room["joined_members"], + avatar_url=room["avatar"], + world_readable=room["history_visibility"] + == HistoryVisibility.WORLD_READABLE, + guest_can_join=room["guest_access"] == "can_join", + join_rule=room["join_rules"], + room_type=room["room_type"], + ) + + return entry + + return [build_room_entry(r) for r in ret_val] @cached(max_entries=10000) async def is_room_blocked(self, room_id: str) -> Optional[bool]: diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py index 09a88c86a7..77d9679582 100644 --- a/synapse/types/__init__.py +++ b/synapse/types/__init__.py @@ -1045,6 +1045,20 @@ class UserInfo: locked: bool +@attr.s(auto_attribs=True, frozen=True, slots=True) +class PublicRoom: + room_id: str + num_joined_members: int + world_readable: bool + guest_can_join: bool + name: Optional[str] = None + topic: Optional[str] = None + canonical_alias: Optional[str] = None + avatar_url: Optional[str] = None + join_rule: Optional[str] = None + room_type: Optional[str] = None + + class UserProfile(TypedDict): user_id: str display_name: Optional[str] diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py index 9f3b8741c1..ba7d4f1246 100644 --- a/synapse/util/__init__.py +++ b/synapse/util/__init__.py @@ -206,3 +206,7 @@ class ExceptionBundle(Exception): parts.append(str(e)) super().__init__("\n - ".join(parts)) self.exceptions = exceptions + + +def filter_none(attr: attr.Attribute, value: Any) -> bool: + return value is not None diff --git a/tests/module_api/test_fetch_public_rooms.py b/tests/module_api/test_fetch_public_rooms.py new file mode 100644 index 0000000000..8daf8c5c40 --- /dev/null +++ b/tests/module_api/test_fetch_public_rooms.py @@ -0,0 +1,261 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin, login, room +from synapse.server import HomeServer +from synapse.types import PublicRoom, ThirdPartyInstanceID +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + + +class FetchPublicRoomsTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config["allow_public_rooms_without_auth"] = True + self.hs = self.setup_test_homeserver(config=config) + self.url = "/_matrix/client/r0/publicRooms" + + return self.hs + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self._store = homeserver.get_datastores().main + self._module_api = homeserver.get_module_api() + + async def module1_cb( + network_tuple: Optional[ThirdPartyInstanceID], + search_filter: Optional[dict], + limit: Optional[int], + bounds: Tuple[Optional[int], Optional[str]], + forwards: bool, + ) -> List[PublicRoom]: + room1 = PublicRoom( + room_id="!one_members:module1", + num_joined_members=1, + world_readable=True, + guest_can_join=False, + ) + room3 = PublicRoom( + room_id="!three_members:module1", + num_joined_members=3, + world_readable=True, + guest_can_join=False, + ) + room3_2 = PublicRoom( + room_id="!three_members_2:module1", + num_joined_members=3, + world_readable=True, + guest_can_join=False, + ) + + (last_joined_members, last_room_id) = bounds + + if forwards: + result = [room3_2, room3, room1] + else: + result = [room1, room3, room3_2] + + if last_joined_members is not None: + if last_joined_members == 1: + if forwards: + if last_room_id == room1.room_id: + result = [] + else: + result = [room1] + else: + if last_room_id == room1.room_id: + result = [room3, room3_2] + else: + result = [room1, room3, room3_2] + elif last_joined_members == 2: + if forwards: + result = [room1] + else: + result = [room3, room3_2] + elif last_joined_members == 3: + if forwards: + if last_room_id == room3.room_id: + result = [room1] + elif last_room_id == room3_2.room_id: + result = [room3, room1] + else: + if last_room_id == room3.room_id: + result = [room3_2] + elif last_room_id == room3_2.room_id: + result = [] + else: + result = [room3, room3_2] + + if limit is not None: + result = result[:limit] + + return result + + async def module2_cb( + network_tuple: Optional[ThirdPartyInstanceID], + search_filter: Optional[dict], + limit: Optional[int], + bounds: Tuple[Optional[int], Optional[str]], + forwards: bool, + ) -> List[PublicRoom]: + room3 = PublicRoom( + room_id="!three_members:module2", + num_joined_members=3, + world_readable=True, + guest_can_join=False, + ) + + (last_joined_members, last_room_id) = bounds + + result = [room3] + + if last_joined_members is not None: + if forwards: + if last_joined_members < 3: + result = [] + elif last_joined_members == 3 and last_room_id == room3.room_id: + result = [] + else: + if last_joined_members > 3: + result = [] + elif last_joined_members == 3 and last_room_id == room3.room_id: + result = [] + + return result + + self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module1_cb) + self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module2_cb) + + user = self.register_user("alice", "pass") + token = self.login(user, "pass") + user2 = self.register_user("alice2", "pass") + token2 = self.login(user2, "pass") + user3 = self.register_user("alice3", "pass") + token3 = self.login(user3, "pass") + + # Create a room with 2 people + room_id = self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=token, + ) + self.helper.join(room_id, user2, tok=token2) + + # Create a room with 3 people + room_id = self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=token, + ) + self.helper.join(room_id, user2, tok=token2) + self.helper.join(room_id, user3, tok=token3) + + def test_no_limit(self) -> None: + channel = self.make_request("GET", self.url) + chunk = channel.json_body["chunk"] + + self.assertEquals(len(chunk), 6) + for i in range(4): + self.assertEquals(chunk[i]["num_joined_members"], 3) + self.assertEquals(chunk[4]["num_joined_members"], 2) + self.assertEquals(chunk[5]["num_joined_members"], 1) + + def test_pagination_limit_1(self) -> None: + returned_three_members_rooms = set() + + next_batch = None + for _i in range(4): + since_query_str = f"&since={next_batch}" if next_batch else "" + channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 3) + self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[0]["room_id"]) + next_batch = channel.json_body["next_batch"] + + channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 2) + next_batch = channel.json_body["next_batch"] + + channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 1) + prev_batch = channel.json_body["prev_batch"] + + self.assertNotIn("next_batch", channel.json_body) + + channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 2) + + returned_three_members_rooms = set() + for _i in range(4): + prev_batch = channel.json_body["prev_batch"] + channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 3) + self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[0]["room_id"]) + + self.assertNotIn("prev_batch", channel.json_body) + + def test_pagination_limit_2(self) -> None: + returned_three_members_rooms = set() + + next_batch = None + for _i in range(2): + since_query_str = f"&since={next_batch}" if next_batch else "" + channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 3) + self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[0]["room_id"]) + self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[1]["room_id"]) + next_batch = channel.json_body["next_batch"] + + channel = self.make_request("GET", f"{self.url}?limit=2&since={next_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 2) + self.assertEquals(chunk[1]["num_joined_members"], 1) + + self.assertNotIn("next_batch", channel.json_body) + + returned_three_members_rooms = set() + + for _i in range(2): + prev_batch = channel.json_body["prev_batch"] + channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(chunk[0]["num_joined_members"], 3) + self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[0]["room_id"]) + self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms) + returned_three_members_rooms.add(chunk[1]["room_id"]) + + self.assertNotIn("prev_batch", channel.json_body) diff --git a/tests/rest/client/test_public_rooms.py b/tests/rest/client/test_public_rooms.py new file mode 100644 index 0000000000..3de43760db --- /dev/null +++ b/tests/rest/client/test_public_rooms.py @@ -0,0 +1,148 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from twisted.test.proto_helpers import MemoryReactor + +from synapse.rest import admin, login, room +from synapse.server import HomeServer +from synapse.util import Clock + +from tests.unittest import HomeserverTestCase + + +class PublicRoomsTestCase(HomeserverTestCase): + servlets = [ + admin.register_servlets, + login.register_servlets, + room.register_servlets, + ] + + def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer: + config = self.default_config() + config["allow_public_rooms_without_auth"] = True + self.hs = self.setup_test_homeserver(config=config) + self.url = "/_matrix/client/r0/publicRooms" + + return self.hs + + def prepare( + self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer + ) -> None: + self._store = homeserver.get_datastores().main + + user = self.register_user("alice", "pass") + token = self.login(user, "pass") + user2 = self.register_user("alice2", "pass") + token2 = self.login(user2, "pass") + user3 = self.register_user("alice3", "pass") + token3 = self.login(user3, "pass") + + # Create 10 rooms + for _ in range(3): + self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=token, + ) + + for _ in range(3): + room_id = self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=token, + ) + self.helper.join(room_id, user2, tok=token2) + + for _ in range(4): + room_id = self.helper.create_room_as( + user, + is_public=True, + extra_content={"visibility": "public"}, + tok=token, + ) + self.helper.join(room_id, user2, tok=token2) + self.helper.join(room_id, user3, tok=token3) + + def test_no_limit(self) -> None: + channel = self.make_request("GET", self.url) + chunk = channel.json_body["chunk"] + + self.assertEquals(len(chunk), 10) + + def test_pagination_limit_1(self) -> None: + returned_rooms = set() + + channel = None + for i in range(10): + next_batch = None if i == 0 else channel.json_body["next_batch"] + since_query_str = f"&since={next_batch}" if next_batch else "" + channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}") + chunk = channel.json_body["chunk"] + self.assertEquals(len(chunk), 1) + print(chunk[0]["room_id"]) + self.assertTrue(chunk[0]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[0]["room_id"]) + + self.assertNotIn("next_batch", channel.json_body) + + returned_rooms = set() + returned_rooms.add(chunk[0]["room_id"]) + + for i in range(9): + print(i) + prev_batch = channel.json_body["prev_batch"] + channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(len(chunk), 1) + print(chunk[0]["room_id"]) + self.assertTrue(chunk[0]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[0]["room_id"]) + + def test_pagination_limit_2(self) -> None: + returned_rooms = set() + + channel = None + for i in range(5): + next_batch = None if i == 0 else channel.json_body["next_batch"] + since_query_str = f"&since={next_batch}" if next_batch else "" + channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}") + chunk = channel.json_body["chunk"] + self.assertEquals(len(chunk), 2) + print(chunk[0]["room_id"]) + self.assertTrue(chunk[0]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[0]["room_id"]) + print(chunk[1]["room_id"]) + self.assertTrue(chunk[1]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[1]["room_id"]) + + self.assertNotIn("next_batch", channel.json_body) + + returned_rooms = set() + returned_rooms.add(chunk[0]["room_id"]) + returned_rooms.add(chunk[1]["room_id"]) + + for i in range(4): + print(i) + prev_batch = channel.json_body["prev_batch"] + channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}") + chunk = channel.json_body["chunk"] + self.assertEquals(len(chunk), 2) + print(chunk[0]["room_id"]) + self.assertTrue(chunk[0]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[0]["room_id"]) + print(chunk[1]["room_id"]) + self.assertTrue(chunk[1]["room_id"] not in returned_rooms) + returned_rooms.add(chunk[1]["room_id"]) |