diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 878d9683b6..b12ffc3665 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -796,6 +796,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
+ # Kept old spam checker without `requester_id` tests for backwards compatibility.
async def allow_all(user_profile: UserProfile) -> bool:
# Allow all users.
return False
@@ -809,6 +810,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 1)
+ # Kept old spam checker without `requester_id` tests for backwards compatibility.
# Configure a spam checker that filters all users.
async def block_all(user_profile: UserProfile) -> bool:
# All users are spammy.
@@ -820,6 +822,40 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, "user2", 10))
self.assertEqual(len(s["results"]), 0)
+ async def allow_all_expects_requester_id(
+ user_profile: UserProfile, requester_id: str
+ ) -> bool:
+ self.assertEqual(requester_id, u1)
+ # Allow all users.
+ return False
+
+ # Configure a spam checker that does not filter any users.
+ spam_checker = self.hs.get_module_api_callbacks().spam_checker
+ spam_checker._check_username_for_spam_callbacks = [
+ allow_all_expects_requester_id
+ ]
+
+ # The results do not change:
+ # We get one search result when searching for user2 by user1.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 1)
+
+ # Configure a spam checker that filters all users.
+ async def block_all_expects_requester_id(
+ user_profile: UserProfile, requester_id: str
+ ) -> bool:
+ self.assertEqual(requester_id, u1)
+ # All users are spammy.
+ return True
+
+ spam_checker._check_username_for_spam_callbacks = [
+ block_all_expects_requester_id
+ ]
+
+ # User1 now gets no search results for any of the other users.
+ s = self.get_success(self.handler.search_users(u1, "user2", 10))
+ self.assertEqual(len(s["results"]), 0)
+
@override_config(
{
"spam_checker": {
@@ -956,6 +992,67 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
[self.assertIn(user, local_users) for user in received_user_id_ordering[:3]]
[self.assertIn(user, remote_users) for user in received_user_id_ordering[3:]]
+ @override_config(
+ {
+ "user_directory": {
+ "enabled": True,
+ "search_all_users": True,
+ "exclude_remote_users": True,
+ }
+ }
+ )
+ def test_exclude_remote_users(self) -> None:
+ """Tests that only local users are returned when
+ user_directory.exclude_remote_users is True.
+ """
+
+ # Create a room and few users to test the directory with
+ searching_user = self.register_user("searcher", "password")
+ searching_user_tok = self.login("searcher", "password")
+
+ room_id = self.helper.create_room_as(
+ searching_user,
+ room_version=RoomVersions.V1.identifier,
+ tok=searching_user_tok,
+ )
+
+ # Create a few local users and join them to the room
+ local_user_1 = self.register_user("user_xxxxx", "password")
+ local_user_2 = self.register_user("user_bbbbb", "password")
+ local_user_3 = self.register_user("user_zzzzz", "password")
+
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_3)
+
+ # Create a few "remote" users and join them to the room
+ remote_user_1 = "@user_aaaaa:remote_server"
+ remote_user_2 = "@user_yyyyy:remote_server"
+ remote_user_3 = "@user_ccccc:remote_server"
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_3)
+
+ local_users = [local_user_1, local_user_2, local_user_3]
+ remote_users = [remote_user_1, remote_user_2, remote_user_3]
+
+ # The local searching user searches for the term "user", which other users have
+ # in their user id
+ results = self.get_success(
+ self.handler.search_users(searching_user, "user", 20)
+ )["results"]
+ received_user_ids = [result["user_id"] for result in results]
+
+ for user in local_users:
+ self.assertIn(
+ user, received_user_ids, f"Local user {user} not found in results"
+ )
+
+ for user in remote_users:
+ self.assertNotIn(
+ user, received_user_ids, f"Remote user {user} should not be in results"
+ )
+
def _add_user_to_room(
self,
room_id: str,
@@ -1081,10 +1178,10 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
for use_numeric in [False, True]:
if use_numeric:
prefix1 = f"{i}"
- prefix2 = f"{i+1}"
+ prefix2 = f"{i + 1}"
else:
prefix1 = f"a{i}"
- prefix2 = f"a{i+1}"
+ prefix2 = f"a{i + 1}"
local_user_1 = self.register_user(f"user{char}{prefix1}", "password")
local_user_2 = self.register_user(f"user{char}{prefix2}", "password")
|