diff --git a/changelog.d/15518.feature b/changelog.d/15518.feature
new file mode 100644
index 0000000000..325724bad3
--- /dev/null
+++ b/changelog.d/15518.feature
@@ -0,0 +1 @@
+Allow modules to provide local /publicRooms results.
diff --git a/synapse/federation/transport/server/__init__.py b/synapse/federation/transport/server/__init__.py
index 55d2cd0a9a..16663e97ce 100644
--- a/synapse/federation/transport/server/__init__.py
+++ b/synapse/federation/transport/server/__init__.py
@@ -149,7 +149,10 @@ class PublicRoomList(BaseFederationServlet):
limit = None
data = await self.handler.get_local_public_room_list(
- limit, since_token, network_tuple=network_tuple, from_federation=True
+ limit,
+ since_token,
+ network_tuple=network_tuple,
+ from_remote_server_name=origin,
)
return 200, data
@@ -190,7 +193,7 @@ class PublicRoomList(BaseFederationServlet):
since_token=since_token,
search_filter=search_filter,
network_tuple=network_tuple,
- from_federation=True,
+ from_remote_server_name=origin,
)
return 200, data
diff --git a/synapse/handlers/room_list.py b/synapse/handlers/room_list.py
index 36e2db8975..6049708aba 100644
--- a/synapse/handlers/room_list.py
+++ b/synapse/handlers/room_list.py
@@ -13,7 +13,7 @@
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Any, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
import attr
import msgpack
@@ -33,7 +33,8 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
-from synapse.types import JsonDict, JsonMapping, ThirdPartyInstanceID
+from synapse.types import JsonDict, JsonMapping, PublicRoom, ThirdPartyInstanceID
+from synapse.util import filter_none
from synapse.util.caches.descriptors import _CacheContext, cached
from synapse.util.caches.response_cache import ResponseCache
@@ -60,6 +61,7 @@ class RoomListHandler:
self.remote_response_cache: ResponseCache[
Tuple[str, Optional[int], Optional[str], bool, Optional[str]]
] = ResponseCache(hs.get_clock(), "remote_room_list", timeout_ms=30 * 1000)
+ self._module_api_callbacks = hs.get_module_api_callbacks().public_rooms
async def get_local_public_room_list(
self,
@@ -67,7 +69,8 @@ class RoomListHandler:
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
- from_federation: bool = False,
+ from_client_mxid: Optional[str] = None,
+ from_remote_server_name: Optional[str] = None,
) -> JsonDict:
"""Generate a local public room list.
@@ -75,14 +78,20 @@ class RoomListHandler:
party network. A client can ask for a specific list or to return all.
Args:
- limit
- since_token
- search_filter
+ limit: The maximum number of rooms to return, or None to return all rooms.
+ since_token: A pagination token, or None to return the head of the public
+ rooms list.
+ search_filter: An optional dictionary with the following keys:
+ * generic_search_term: A string to search for in room ...
+ * room_types: A list to filter returned rooms by their type. If None or
+ an empty list is passed, rooms will not be filtered by type.
network_tuple: Which public list to use.
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
- from_federation: true iff the request comes from the federation API
+ from_client_mxid: A user's MXID if this request came from a registered user.
+ from_remote_server_name: A remote homeserver's server name, if this
+ request came from the federation API.
"""
if not self.enable_room_list_search:
return {"chunk": [], "total_room_count_estimate": 0}
@@ -105,7 +114,8 @@ class RoomListHandler:
since_token,
search_filter,
network_tuple=network_tuple,
- from_federation=from_federation,
+ from_client_mxid=from_client_mxid,
+ from_remote_server_name=from_remote_server_name,
)
key = (limit, since_token, network_tuple)
@@ -115,7 +125,8 @@ class RoomListHandler:
limit,
since_token,
network_tuple=network_tuple,
- from_federation=from_federation,
+ from_client_mxid=from_client_mxid,
+ from_remote_server_name=from_remote_server_name,
)
async def _get_public_room_list(
@@ -124,7 +135,8 @@ class RoomListHandler:
since_token: Optional[str] = None,
search_filter: Optional[dict] = None,
network_tuple: Optional[ThirdPartyInstanceID] = EMPTY_THIRD_PARTY_ID,
- from_federation: bool = False,
+ from_client_mxid: Optional[str] = None,
+ from_remote_server_name: Optional[str] = None,
) -> JsonDict:
"""Generate a public room list.
Args:
@@ -135,65 +147,106 @@ class RoomListHandler:
This can be (None, None) to indicate the main list, or a particular
appservice and network id to use an appservice specific one.
Setting to None returns all public rooms across all lists.
- from_federation: Whether this request originated from a
- federating server or a client. Used for room filtering.
+ from_client_mxid: A user's MXID if this request came from a registered user.
+ from_remote_server_name: A remote homeserver's server name, if this
+ request came from the federation API.
"""
# Pagination tokens work by storing the room ID sent in the last batch,
# plus the direction (forwards or backwards). Next batch tokens always
# go forwards, prev batch tokens always go backwards.
+ forwards = True
+ last_joined_members = None
+ last_room_id = None
+ last_module_index = None
if since_token:
batch_token = RoomListNextBatch.from_token(since_token)
-
- bounds: Optional[Tuple[int, str]] = (
- batch_token.last_joined_members,
- batch_token.last_room_id,
- )
+ print(batch_token)
forwards = batch_token.direction_is_forward
- has_batch_token = True
- else:
- bounds = None
-
- forwards = True
- has_batch_token = False
+ last_joined_members = batch_token.last_joined_members
+ last_room_id = batch_token.last_room_id
+ last_module_index = batch_token.last_module_index
- # we request one more than wanted to see if there are more pages to come
+ # We request one more than wanted to see if there are more pages to come
probing_limit = limit + 1 if limit is not None else None
- results = await self.store.get_largest_public_rooms(
+ # We bucket results per joined members number since we want to keep order
+ # per joined members number
+ num_joined_members_buckets: Dict[int, List[PublicRoom]] = {}
+ room_ids_to_module_index: Dict[str, int] = {}
+
+ local_public_rooms = await self.store.get_largest_public_rooms(
network_tuple,
search_filter,
probing_limit,
- bounds=bounds,
+ bounds=(
+ last_joined_members,
+ last_room_id if last_module_index is None else None,
+ ),
forwards=forwards,
- ignore_non_federatable=from_federation,
+ ignore_non_federatable=bool(from_remote_server_name),
)
- def build_room_entry(room: JsonDict) -> JsonDict:
- entry = {
- "room_id": room["room_id"],
- "name": room["name"],
- "topic": room["topic"],
- "canonical_alias": room["canonical_alias"],
- "num_joined_members": room["joined_members"],
- "avatar_url": room["avatar"],
- "world_readable": room["history_visibility"]
- == HistoryVisibility.WORLD_READABLE,
- "guest_can_join": room["guest_access"] == "can_join",
- "join_rule": room["join_rules"],
- "room_type": room["room_type"],
- }
+ for room in local_public_rooms:
+ num_joined_members_buckets.setdefault(room.num_joined_members, []).append(
+ room
+ )
+
+ nb_modules = len(self._module_api_callbacks.fetch_public_rooms_callbacks)
- # Filter out Nones – rather omit the field altogether
- return {k: v for k, v in entry.items() if v is not None}
+ module_range = range(nb_modules)
+ # if not forwards:
+ # module_range = reversed(module_range)
- results = [build_room_entry(r) for r in results]
+ for module_index in module_range:
+ fetch_public_rooms = (
+ self._module_api_callbacks.fetch_public_rooms_callbacks[module_index]
+ )
+ # Ask each module for a list of public rooms given the last_joined_members
+ # value from the since token and the probing limit
+ # last_joined_members needs to be reduce by one if this module has already
+ # given its result for last_joined_members
+ module_last_joined_members = last_joined_members
+ if module_last_joined_members is not None and last_module_index is not None:
+ if forwards and module_index < last_module_index:
+ module_last_joined_members = module_last_joined_members - 1
+ # if not forwards and module_index > last_module_index:
+ # module_last_joined_members = module_last_joined_members - 1
+
+ module_public_rooms = await fetch_public_rooms(
+ network_tuple,
+ search_filter,
+ probing_limit,
+ (
+ module_last_joined_members,
+ last_room_id if last_module_index == module_index else None,
+ ),
+ forwards,
+ )
+
+ for room in module_public_rooms:
+ num_joined_members_buckets.setdefault(
+ room.num_joined_members, []
+ ).append(room)
+ room_ids_to_module_index[room.room_id] = module_index
+
+ nums_joined_members = list(num_joined_members_buckets.keys())
+ nums_joined_members.sort(reverse=forwards)
+
+ results = []
+ for num_joined_members in nums_joined_members:
+ rooms = num_joined_members_buckets[num_joined_members]
+ # if not forwards:
+ # rooms.reverse()
+ results += rooms
+
+ print([(r.room_id, r.num_joined_members) for r in results])
response: JsonDict = {}
num_results = len(results)
- if limit is not None:
- more_to_come = num_results == probing_limit
+ if limit is not None and probing_limit is not None:
+ more_to_come = num_results >= probing_limit
# Depending on direction we trim either the front or back.
if forwards:
@@ -203,46 +256,60 @@ class RoomListHandler:
else:
more_to_come = False
+ print([(r.room_id, r.num_joined_members) for r in results])
+
if num_results > 0:
final_entry = results[-1]
initial_entry = results[0]
if forwards:
- if has_batch_token:
+ if since_token is not None:
# If there was a token given then we assume that there
# must be previous results.
response["prev_batch"] = RoomListNextBatch(
- last_joined_members=initial_entry["num_joined_members"],
- last_room_id=initial_entry["room_id"],
+ last_joined_members=initial_entry.num_joined_members,
+ last_room_id=initial_entry.room_id,
direction_is_forward=False,
+ last_module_index=room_ids_to_module_index.get(
+ initial_entry.room_id
+ ),
).to_token()
if more_to_come:
response["next_batch"] = RoomListNextBatch(
- last_joined_members=final_entry["num_joined_members"],
- last_room_id=final_entry["room_id"],
+ last_joined_members=final_entry.num_joined_members,
+ last_room_id=final_entry.room_id,
direction_is_forward=True,
+ last_module_index=room_ids_to_module_index.get(
+ final_entry.room_id
+ ),
).to_token()
else:
- if has_batch_token:
+ if since_token is not None:
response["next_batch"] = RoomListNextBatch(
- last_joined_members=final_entry["num_joined_members"],
- last_room_id=final_entry["room_id"],
+ last_joined_members=final_entry.num_joined_members,
+ last_room_id=final_entry.room_id,
direction_is_forward=True,
+ last_module_index=room_ids_to_module_index.get(
+ final_entry.room_id
+ ),
).to_token()
if more_to_come:
response["prev_batch"] = RoomListNextBatch(
- last_joined_members=initial_entry["num_joined_members"],
- last_room_id=initial_entry["room_id"],
+ last_joined_members=initial_entry.num_joined_members,
+ last_room_id=initial_entry.room_id,
direction_is_forward=False,
+ last_module_index=room_ids_to_module_index.get(
+ initial_entry.room_id
+ ),
).to_token()
- response["chunk"] = results
+ response["chunk"] = [attr.asdict(r, filter=filter_none) for r in results]
response["total_room_count_estimate"] = await self.store.count_public_rooms(
network_tuple,
- ignore_non_federatable=from_federation,
+ ignore_non_federatable=bool(from_remote_server_name),
search_filter=search_filter,
)
@@ -484,11 +551,13 @@ class RoomListNextBatch:
last_joined_members: int # The count to get rooms after/before
last_room_id: str # The room_id to get rooms after/before
direction_is_forward: bool # True if this is a next_batch, false if prev_batch
+ last_module_index: Optional[int] = None
KEY_DICT = {
"last_joined_members": "m",
"last_room_id": "r",
"direction_is_forward": "d",
+ "last_module_index": "i",
}
REVERSE_KEY_DICT = {v: k for k, v in KEY_DICT.items()}
@@ -501,6 +570,7 @@ class RoomListNextBatch:
)
def to_token(self) -> str:
+ # print(self)
return encode_base64(
msgpack.dumps(
{self.KEY_DICT[key]: val for key, val in attr.asdict(self).items()}
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 65e2aca456..c5dbe9c757 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -79,6 +79,9 @@ from synapse.module_api.callbacks.account_validity_callbacks import (
ON_LEGACY_SEND_MAIL_CALLBACK,
ON_USER_REGISTRATION_CALLBACK,
)
+from synapse.module_api.callbacks.public_rooms_callbacks import (
+ FETCH_PUBLIC_ROOMS_CALLBACK,
+)
from synapse.module_api.callbacks.spamchecker_callbacks import (
CHECK_EVENT_FOR_SPAM_CALLBACK,
CHECK_LOGIN_FOR_SPAM_CALLBACK,
@@ -170,6 +173,7 @@ __all__ = [
"DirectServeJsonResource",
"ModuleApi",
"PRESENCE_ALL_USERS",
+ "PublicRoomChunk",
"LoginResponse",
"JsonDict",
"JsonMapping",
@@ -472,6 +476,19 @@ class ModuleApi:
on_account_data_updated=on_account_data_updated,
)
+ def register_public_rooms_callbacks(
+ self,
+ *,
+ fetch_public_rooms: Optional[FETCH_PUBLIC_ROOMS_CALLBACK] = None,
+ ) -> None:
+ """Registers callback functions related to the public room directory.
+
+ Added in Synapse v1.80.0
+ """
+ return self._callbacks.public_rooms.register_callbacks(
+ fetch_public_rooms=fetch_public_rooms,
+ )
+
def register_web_resource(self, path: str, resource: Resource) -> None:
"""Registers a web resource to be served at the given path.
diff --git a/synapse/module_api/callbacks/__init__.py b/synapse/module_api/callbacks/__init__.py
index dcb036552b..5a0ff22b10 100644
--- a/synapse/module_api/callbacks/__init__.py
+++ b/synapse/module_api/callbacks/__init__.py
@@ -27,9 +27,12 @@ from synapse.module_api.callbacks.third_party_event_rules_callbacks import (
ThirdPartyEventRulesModuleApiCallbacks,
)
+from .public_rooms_callbacks import PublicRoomsModuleApiCallbacks
+
class ModuleApiCallbacks:
def __init__(self, hs: "HomeServer") -> None:
self.account_validity = AccountValidityModuleApiCallbacks()
self.spam_checker = SpamCheckerModuleApiCallbacks(hs)
self.third_party_event_rules = ThirdPartyEventRulesModuleApiCallbacks(hs)
+ self.public_rooms = PublicRoomsModuleApiCallbacks()
diff --git a/synapse/module_api/callbacks/public_rooms_callbacks.py b/synapse/module_api/callbacks/public_rooms_callbacks.py
new file mode 100644
index 0000000000..b3eeb84606
--- /dev/null
+++ b/synapse/module_api/callbacks/public_rooms_callbacks.py
@@ -0,0 +1,45 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+from typing import Awaitable, Callable, List, Optional, Tuple
+
+from synapse.types import PublicRoom, ThirdPartyInstanceID
+
+logger = logging.getLogger(__name__)
+
+
+# Types for callbacks to be registered via the module api
+FETCH_PUBLIC_ROOMS_CALLBACK = Callable[
+ [
+ Optional[ThirdPartyInstanceID], # network_tuple
+ Optional[dict], # search_filter
+ Optional[int], # limit
+ Tuple[Optional[int], Optional[str]], # bounds
+ bool, # forwards
+ ],
+ Awaitable[List[PublicRoom]],
+]
+
+
+class PublicRoomsModuleApiCallbacks:
+ def __init__(self) -> None:
+ self.fetch_public_rooms_callbacks: List[FETCH_PUBLIC_ROOMS_CALLBACK] = []
+
+ def register_callbacks(
+ self,
+ fetch_public_rooms: Optional[FETCH_PUBLIC_ROOMS_CALLBACK] = None,
+ ) -> None:
+ if fetch_public_rooms is not None:
+ self.fetch_public_rooms_callbacks.append(fetch_public_rooms)
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 553938ce9d..99ad387b50 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -476,8 +476,9 @@ class PublicRoomListRestServlet(RestServlet):
async def on_GET(self, request: SynapseRequest) -> Tuple[int, JsonDict]:
server = parse_string(request, "server")
+ requester: Optional[Requester] = None
try:
- await self.auth.get_user_by_req(request, allow_guest=True)
+ requester = await self.auth.get_user_by_req(request, allow_guest=True)
except InvalidClientCredentialsError as e:
# Option to allow servers to require auth when accessing
# /publicRooms via CS API. This is especially helpful in private
@@ -516,8 +517,15 @@ class PublicRoomListRestServlet(RestServlet):
server, limit=limit, since_token=since_token
)
else:
+ # If a user we know made this request, pass that information to the
+ # public rooms list handler.
+ if requester is None:
+ from_client_mxid = None
+ else:
+ from_client_mxid = requester.user.to_string()
+
data = await handler.get_local_public_room_list(
- limit=limit, since_token=since_token
+ limit=limit, since_token=since_token, from_client_mxid=from_client_mxid
)
return 200, data
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 1d4d99932b..3972d70358 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -38,6 +38,7 @@ from synapse.api.constants import (
Direction,
EventContentFields,
EventTypes,
+ HistoryVisibility,
JoinRules,
PublicRoomsFilterFields,
)
@@ -61,7 +62,13 @@ from synapse.storage.util.id_generators import (
MultiWriterIdGenerator,
StreamIdGenerator,
)
-from synapse.types import JsonDict, RetentionPolicy, StrCollection, ThirdPartyInstanceID
+from synapse.types import (
+ JsonDict,
+ PublicRoom,
+ RetentionPolicy,
+ StrCollection,
+ ThirdPartyInstanceID,
+)
from synapse.util import json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.stringutils import MXC_REGEX
@@ -365,21 +372,21 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
network_tuple: Optional[ThirdPartyInstanceID],
search_filter: Optional[dict],
limit: Optional[int],
- bounds: Optional[Tuple[int, str]],
+ bounds: Tuple[Optional[int], Optional[str]],
forwards: bool,
ignore_non_federatable: bool = False,
- ) -> List[Dict[str, Any]]:
+ ) -> List[PublicRoom]:
"""Gets the largest public rooms (where largest is in terms of joined
members, as tracked in the statistics table).
Args:
network_tuple
search_filter
- limit: Maxmimum number of rows to return, unlimited otherwise.
- bounds: An uppoer or lower bound to apply to result set if given,
+ limit: Maximum number of rows to return, unlimited otherwise.
+ bounds: An upper or lower bound to apply to result set if given,
consists of a joined member count and room_id (these are
excluded from result set).
- forwards: true iff going forwards, going backwards otherwise
+ forwards: true if going forwards, going backwards otherwise
ignore_non_federatable: If true filters out non-federatable rooms.
Returns:
@@ -413,26 +420,18 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
# Work out the bounds if we're given them, these bounds look slightly
# odd, but are designed to help query planner use indices by pulling
# out a common bound.
- if bounds:
- last_joined_members, last_room_id = bounds
- if forwards:
- where_clauses.append(
- """
- joined_members <= ? AND (
- joined_members < ? OR room_id < ?
- )
- """
- )
- else:
- where_clauses.append(
- """
- joined_members >= ? AND (
- joined_members > ? OR room_id > ?
- )
- """
- )
+ last_joined_members, last_room_id = bounds
+ if last_joined_members is not None:
+ comp = "<" if forwards else ">"
+
+ clause = f"joined_members {comp} ?"
+ query_args += [last_joined_members]
- query_args += [last_joined_members, last_joined_members, last_room_id]
+ if last_room_id is not None:
+ clause += f" OR (joined_members = ? AND room_id {comp} ?)"
+ query_args += [last_joined_members, last_room_id]
+
+ where_clauses.append(clause)
if ignore_non_federatable:
where_clauses.append("is_federatable")
@@ -518,7 +517,25 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
ret_val = await self.db_pool.runInteraction(
"get_largest_public_rooms", _get_largest_public_rooms_txn
)
- return ret_val
+
+ def build_room_entry(room: JsonDict) -> PublicRoom:
+ entry = PublicRoom(
+ room_id=room["room_id"],
+ name=room["name"],
+ topic=room["topic"],
+ canonical_alias=room["canonical_alias"],
+ num_joined_members=room["joined_members"],
+ avatar_url=room["avatar"],
+ world_readable=room["history_visibility"]
+ == HistoryVisibility.WORLD_READABLE,
+ guest_can_join=room["guest_access"] == "can_join",
+ join_rule=room["join_rules"],
+ room_type=room["room_type"],
+ )
+
+ return entry
+
+ return [build_room_entry(r) for r in ret_val]
@cached(max_entries=10000)
async def is_room_blocked(self, room_id: str) -> Optional[bool]:
diff --git a/synapse/types/__init__.py b/synapse/types/__init__.py
index 09a88c86a7..77d9679582 100644
--- a/synapse/types/__init__.py
+++ b/synapse/types/__init__.py
@@ -1045,6 +1045,20 @@ class UserInfo:
locked: bool
+@attr.s(auto_attribs=True, frozen=True, slots=True)
+class PublicRoom:
+ room_id: str
+ num_joined_members: int
+ world_readable: bool
+ guest_can_join: bool
+ name: Optional[str] = None
+ topic: Optional[str] = None
+ canonical_alias: Optional[str] = None
+ avatar_url: Optional[str] = None
+ join_rule: Optional[str] = None
+ room_type: Optional[str] = None
+
+
class UserProfile(TypedDict):
user_id: str
display_name: Optional[str]
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 9f3b8741c1..ba7d4f1246 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -206,3 +206,7 @@ class ExceptionBundle(Exception):
parts.append(str(e))
super().__init__("\n - ".join(parts))
self.exceptions = exceptions
+
+
+def filter_none(attr: attr.Attribute, value: Any) -> bool:
+ return value is not None
diff --git a/tests/module_api/test_fetch_public_rooms.py b/tests/module_api/test_fetch_public_rooms.py
new file mode 100644
index 0000000000..8daf8c5c40
--- /dev/null
+++ b/tests/module_api/test_fetch_public_rooms.py
@@ -0,0 +1,261 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List, Optional, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin, login, room
+from synapse.server import HomeServer
+from synapse.types import PublicRoom, ThirdPartyInstanceID
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class FetchPublicRoomsTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+ config["allow_public_rooms_without_auth"] = True
+ self.hs = self.setup_test_homeserver(config=config)
+ self.url = "/_matrix/client/r0/publicRooms"
+
+ return self.hs
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self._store = homeserver.get_datastores().main
+ self._module_api = homeserver.get_module_api()
+
+ async def module1_cb(
+ network_tuple: Optional[ThirdPartyInstanceID],
+ search_filter: Optional[dict],
+ limit: Optional[int],
+ bounds: Tuple[Optional[int], Optional[str]],
+ forwards: bool,
+ ) -> List[PublicRoom]:
+ room1 = PublicRoom(
+ room_id="!one_members:module1",
+ num_joined_members=1,
+ world_readable=True,
+ guest_can_join=False,
+ )
+ room3 = PublicRoom(
+ room_id="!three_members:module1",
+ num_joined_members=3,
+ world_readable=True,
+ guest_can_join=False,
+ )
+ room3_2 = PublicRoom(
+ room_id="!three_members_2:module1",
+ num_joined_members=3,
+ world_readable=True,
+ guest_can_join=False,
+ )
+
+ (last_joined_members, last_room_id) = bounds
+
+ if forwards:
+ result = [room3_2, room3, room1]
+ else:
+ result = [room1, room3, room3_2]
+
+ if last_joined_members is not None:
+ if last_joined_members == 1:
+ if forwards:
+ if last_room_id == room1.room_id:
+ result = []
+ else:
+ result = [room1]
+ else:
+ if last_room_id == room1.room_id:
+ result = [room3, room3_2]
+ else:
+ result = [room1, room3, room3_2]
+ elif last_joined_members == 2:
+ if forwards:
+ result = [room1]
+ else:
+ result = [room3, room3_2]
+ elif last_joined_members == 3:
+ if forwards:
+ if last_room_id == room3.room_id:
+ result = [room1]
+ elif last_room_id == room3_2.room_id:
+ result = [room3, room1]
+ else:
+ if last_room_id == room3.room_id:
+ result = [room3_2]
+ elif last_room_id == room3_2.room_id:
+ result = []
+ else:
+ result = [room3, room3_2]
+
+ if limit is not None:
+ result = result[:limit]
+
+ return result
+
+ async def module2_cb(
+ network_tuple: Optional[ThirdPartyInstanceID],
+ search_filter: Optional[dict],
+ limit: Optional[int],
+ bounds: Tuple[Optional[int], Optional[str]],
+ forwards: bool,
+ ) -> List[PublicRoom]:
+ room3 = PublicRoom(
+ room_id="!three_members:module2",
+ num_joined_members=3,
+ world_readable=True,
+ guest_can_join=False,
+ )
+
+ (last_joined_members, last_room_id) = bounds
+
+ result = [room3]
+
+ if last_joined_members is not None:
+ if forwards:
+ if last_joined_members < 3:
+ result = []
+ elif last_joined_members == 3 and last_room_id == room3.room_id:
+ result = []
+ else:
+ if last_joined_members > 3:
+ result = []
+ elif last_joined_members == 3 and last_room_id == room3.room_id:
+ result = []
+
+ return result
+
+ self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module1_cb)
+ self._module_api.register_public_rooms_callbacks(fetch_public_rooms=module2_cb)
+
+ user = self.register_user("alice", "pass")
+ token = self.login(user, "pass")
+ user2 = self.register_user("alice2", "pass")
+ token2 = self.login(user2, "pass")
+ user3 = self.register_user("alice3", "pass")
+ token3 = self.login(user3, "pass")
+
+ # Create a room with 2 people
+ room_id = self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=token,
+ )
+ self.helper.join(room_id, user2, tok=token2)
+
+ # Create a room with 3 people
+ room_id = self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=token,
+ )
+ self.helper.join(room_id, user2, tok=token2)
+ self.helper.join(room_id, user3, tok=token3)
+
+ def test_no_limit(self) -> None:
+ channel = self.make_request("GET", self.url)
+ chunk = channel.json_body["chunk"]
+
+ self.assertEquals(len(chunk), 6)
+ for i in range(4):
+ self.assertEquals(chunk[i]["num_joined_members"], 3)
+ self.assertEquals(chunk[4]["num_joined_members"], 2)
+ self.assertEquals(chunk[5]["num_joined_members"], 1)
+
+ def test_pagination_limit_1(self) -> None:
+ returned_three_members_rooms = set()
+
+ next_batch = None
+ for _i in range(4):
+ since_query_str = f"&since={next_batch}" if next_batch else ""
+ channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(chunk[0]["num_joined_members"], 3)
+ self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
+ returned_three_members_rooms.add(chunk[0]["room_id"])
+ next_batch = channel.json_body["next_batch"]
+
+ channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(chunk[0]["num_joined_members"], 2)
+ next_batch = channel.json_body["next_batch"]
+
+ channel = self.make_request("GET", f"{self.url}?limit=1&since={next_batch}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(chunk[0]["num_joined_members"], 1)
+ prev_batch = channel.json_body["prev_batch"]
+
+ self.assertNotIn("next_batch", channel.json_body)
+
+ channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(chunk[0]["num_joined_members"], 2)
+
+ returned_three_members_rooms = set()
+ for _i in range(4):
+ prev_batch = channel.json_body["prev_batch"]
+ channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(chunk[0]["num_joined_members"], 3)
+ self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
+ returned_three_members_rooms.add(chunk[0]["room_id"])
+
+ self.assertNotIn("prev_batch", channel.json_body)
+
+ def test_pagination_limit_2(self) -> None:
+ returned_three_members_rooms = set()
+
+ next_batch = None
+ for _i in range(2):
+ since_query_str = f"&since={next_batch}" if next_batch else ""
+ channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(chunk[0]["num_joined_members"], 3)
+ self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
+ returned_three_members_rooms.add(chunk[0]["room_id"])
+ self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms)
+ returned_three_members_rooms.add(chunk[1]["room_id"])
+ next_batch = channel.json_body["next_batch"]
+
+ channel = self.make_request("GET", f"{self.url}?limit=2&since={next_batch}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(chunk[0]["num_joined_members"], 2)
+ self.assertEquals(chunk[1]["num_joined_members"], 1)
+
+ self.assertNotIn("next_batch", channel.json_body)
+
+ returned_three_members_rooms = set()
+
+ for _i in range(2):
+ prev_batch = channel.json_body["prev_batch"]
+ channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(chunk[0]["num_joined_members"], 3)
+ self.assertTrue(chunk[0]["room_id"] not in returned_three_members_rooms)
+ returned_three_members_rooms.add(chunk[0]["room_id"])
+ self.assertTrue(chunk[1]["room_id"] not in returned_three_members_rooms)
+ returned_three_members_rooms.add(chunk[1]["room_id"])
+
+ self.assertNotIn("prev_batch", channel.json_body)
diff --git a/tests/rest/client/test_public_rooms.py b/tests/rest/client/test_public_rooms.py
new file mode 100644
index 0000000000..3de43760db
--- /dev/null
+++ b/tests/rest/client/test_public_rooms.py
@@ -0,0 +1,148 @@
+# Copyright 2022 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin, login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class PublicRoomsTestCase(HomeserverTestCase):
+ servlets = [
+ admin.register_servlets,
+ login.register_servlets,
+ room.register_servlets,
+ ]
+
+ def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
+ config = self.default_config()
+ config["allow_public_rooms_without_auth"] = True
+ self.hs = self.setup_test_homeserver(config=config)
+ self.url = "/_matrix/client/r0/publicRooms"
+
+ return self.hs
+
+ def prepare(
+ self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+ ) -> None:
+ self._store = homeserver.get_datastores().main
+
+ user = self.register_user("alice", "pass")
+ token = self.login(user, "pass")
+ user2 = self.register_user("alice2", "pass")
+ token2 = self.login(user2, "pass")
+ user3 = self.register_user("alice3", "pass")
+ token3 = self.login(user3, "pass")
+
+ # Create 10 rooms
+ for _ in range(3):
+ self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=token,
+ )
+
+ for _ in range(3):
+ room_id = self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=token,
+ )
+ self.helper.join(room_id, user2, tok=token2)
+
+ for _ in range(4):
+ room_id = self.helper.create_room_as(
+ user,
+ is_public=True,
+ extra_content={"visibility": "public"},
+ tok=token,
+ )
+ self.helper.join(room_id, user2, tok=token2)
+ self.helper.join(room_id, user3, tok=token3)
+
+ def test_no_limit(self) -> None:
+ channel = self.make_request("GET", self.url)
+ chunk = channel.json_body["chunk"]
+
+ self.assertEquals(len(chunk), 10)
+
+ def test_pagination_limit_1(self) -> None:
+ returned_rooms = set()
+
+ channel = None
+ for i in range(10):
+ next_batch = None if i == 0 else channel.json_body["next_batch"]
+ since_query_str = f"&since={next_batch}" if next_batch else ""
+ channel = self.make_request("GET", f"{self.url}?limit=1{since_query_str}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(len(chunk), 1)
+ print(chunk[0]["room_id"])
+ self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
+ returned_rooms.add(chunk[0]["room_id"])
+
+ self.assertNotIn("next_batch", channel.json_body)
+
+ returned_rooms = set()
+ returned_rooms.add(chunk[0]["room_id"])
+
+ for i in range(9):
+ print(i)
+ prev_batch = channel.json_body["prev_batch"]
+ channel = self.make_request("GET", f"{self.url}?limit=1&since={prev_batch}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(len(chunk), 1)
+ print(chunk[0]["room_id"])
+ self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
+ returned_rooms.add(chunk[0]["room_id"])
+
+ def test_pagination_limit_2(self) -> None:
+ returned_rooms = set()
+
+ channel = None
+ for i in range(5):
+ next_batch = None if i == 0 else channel.json_body["next_batch"]
+ since_query_str = f"&since={next_batch}" if next_batch else ""
+ channel = self.make_request("GET", f"{self.url}?limit=2{since_query_str}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(len(chunk), 2)
+ print(chunk[0]["room_id"])
+ self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
+ returned_rooms.add(chunk[0]["room_id"])
+ print(chunk[1]["room_id"])
+ self.assertTrue(chunk[1]["room_id"] not in returned_rooms)
+ returned_rooms.add(chunk[1]["room_id"])
+
+ self.assertNotIn("next_batch", channel.json_body)
+
+ returned_rooms = set()
+ returned_rooms.add(chunk[0]["room_id"])
+ returned_rooms.add(chunk[1]["room_id"])
+
+ for i in range(4):
+ print(i)
+ prev_batch = channel.json_body["prev_batch"]
+ channel = self.make_request("GET", f"{self.url}?limit=2&since={prev_batch}")
+ chunk = channel.json_body["chunk"]
+ self.assertEquals(len(chunk), 2)
+ print(chunk[0]["room_id"])
+ self.assertTrue(chunk[0]["room_id"] not in returned_rooms)
+ returned_rooms.add(chunk[0]["room_id"])
+ print(chunk[1]["room_id"])
+ self.assertTrue(chunk[1]["room_id"] not in returned_rooms)
+ returned_rooms.add(chunk[1]["room_id"])
|