diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 1260721dbf..dc433e1487 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -18,8 +18,9 @@ from twisted.internet import defer
import synapse.rest.admin
from synapse.api.constants import EventTypes, RoomEncryptionAlgorithms, UserTypes
+from synapse.api.room_versions import RoomVersion, RoomVersions
from synapse.rest.client.v1 import login, room
-from synapse.rest.client.v2_alpha import user_directory
+from synapse.rest.client.v2_alpha import account, account_validity, user_directory
from synapse.storage.roommember import ProfileInfo
from tests import unittest
@@ -46,6 +47,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor, clock, hs):
self.store = hs.get_datastore()
self.handler = hs.get_user_directory_handler()
+ self.event_builder_factory = self.hs.get_event_builder_factory()
+ self.event_creation_handler = self.hs.get_event_creation_handler()
def test_handle_local_profile_change_with_support_user(self):
support_user_id = "@support:test"
@@ -503,6 +506,97 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
s = self.get_success(self.handler.search_users(u1, u4, 10))
self.assertEqual(len(s["results"]), 1)
+ @override_config(
+ {
+ "user_directory": {
+ "enabled": True,
+ "search_all_users": True,
+ "prefer_local_users": True,
+ }
+ }
+ )
+ def test_prefer_local_users(self):
+ """Tests that local users are shown higher in search results when
+ user_directory.prefer_local_users is True.
+ """
+ # Create a room and few users to test the directory with
+ searching_user = self.register_user("searcher", "password")
+ searching_user_tok = self.login("searcher", "password")
+
+ room_id = self.helper.create_room_as(
+ searching_user,
+ room_version=RoomVersions.V1.identifier,
+ tok=searching_user_tok,
+ )
+
+ # Create a few local users and join them to the room
+ local_user_1 = self.register_user("user_xxxxx", "password")
+ local_user_2 = self.register_user("user_bbbbb", "password")
+ local_user_3 = self.register_user("user_zzzzz", "password")
+
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, local_user_3)
+
+ # Create a few "remote" users and join them to the room
+ remote_user_1 = "@user_aaaaa:remote_server"
+ remote_user_2 = "@user_yyyyy:remote_server"
+ remote_user_3 = "@user_ccccc:remote_server"
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_1)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_2)
+ self._add_user_to_room(room_id, RoomVersions.V1, remote_user_3)
+
+ local_users = [local_user_1, local_user_2, local_user_3]
+ remote_users = [remote_user_1, remote_user_2, remote_user_3]
+
+ # Populate the user directory via background update
+ self._add_background_updates()
+ while not self.get_success(
+ self.store.db_pool.updates.has_completed_background_updates()
+ ):
+ self.get_success(
+ self.store.db_pool.updates.do_next_background_update(100), by=0.1
+ )
+
+ # The local searching user searches for the term "user", which other users have
+ # in their user id
+ results = self.get_success(
+ self.handler.search_users(searching_user, "user", 20)
+ )["results"]
+ received_user_id_ordering = [result["user_id"] for result in results]
+
+ # Typically we'd expect Synapse to return users in lexicographical order,
+ # assuming they have similar User IDs/display names, and profile information.
+
+ # Check that the order of returned results using our module is as we expect,
+ # i.e our local users show up first, despite all users having lexographically mixed
+ # user IDs.
+ [self.assertIn(user, local_users) for user in received_user_id_ordering[:3]]
+ [self.assertIn(user, remote_users) for user in received_user_id_ordering[3:]]
+
+ def _add_user_to_room(
+ self, room_id: str, room_version: RoomVersion, user_id: str,
+ ):
+ # Add a user to the room.
+ builder = self.event_builder_factory.for_room_version(
+ room_version,
+ {
+ "type": "m.room.member",
+ "sender": user_id,
+ "state_key": user_id,
+ "room_id": room_id,
+ "content": {"membership": "join"},
+ },
+ )
+
+ event, context = self.get_success(
+ self.event_creation_handler.create_new_client_event(builder)
+ )
+
+ self.get_success(
+ self.hs.get_storage().persistence.persist_event(event, context)
+ )
+
class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
user_id = "@test:test"
@@ -547,3 +641,132 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
)
self.assertEquals(200, channel.code, channel.result)
self.assertTrue(len(channel.json_body["results"]) == 0)
+
+
+class UserInfoTestCase(unittest.FederatingHomeserverTestCase):
+ servlets = [
+ login.register_servlets,
+ synapse.rest.admin.register_servlets_for_client_rest_resource,
+ account_validity.register_servlets,
+ synapse.rest.client.v2_alpha.user_directory.register_servlets,
+ account.register_servlets,
+ ]
+
+ def default_config(self):
+ config = super().default_config()
+
+ # Set accounts to expire after a week
+ config["account_validity"] = {
+ "enabled": True,
+ "period": 604800000, # Time in ms for 1 week
+ }
+ return config
+
+ def prepare(self, reactor, clock, hs):
+ super(UserInfoTestCase, self).prepare(reactor, clock, hs)
+ self.store = hs.get_datastore()
+ self.handler = hs.get_user_directory_handler()
+
+ def test_user_info(self):
+ """Test /users/info for local users from the Client-Server API"""
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request info about each user from user_three
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/client/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ access_token=user_three_token,
+ shorthand=False,
+ )
+ self.assertEquals(200, channel.code, channel.result)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def test_user_info_federation(self):
+ """Test that /users/info can be called from the Federation API, and
+ and that we can query remote users from the Client-Server API
+ """
+ user_one, user_two, user_three, user_three_token = self.setup_test_users()
+
+ # Request information about our local users from the perspective of a remote server
+ request, channel = self.make_request(
+ "POST",
+ path="/_matrix/federation/unstable/users/info",
+ content={"user_ids": [user_one, user_two, user_three]},
+ )
+ self.assertEquals(200, channel.code)
+
+ # Check the state of user_one matches
+ user_one_info = channel.json_body[user_one]
+ self.assertTrue(user_one_info["deactivated"])
+ self.assertFalse(user_one_info["expired"])
+
+ # Check the state of user_two matches
+ user_two_info = channel.json_body[user_two]
+ self.assertFalse(user_two_info["deactivated"])
+ self.assertTrue(user_two_info["expired"])
+
+ # Check the state of user_three matches
+ user_three_info = channel.json_body[user_three]
+ self.assertFalse(user_three_info["deactivated"])
+ self.assertFalse(user_three_info["expired"])
+
+ def setup_test_users(self):
+ """Create an admin user and three test users, each with a different state"""
+
+ # Create an admin user to expire other users with
+ self.register_user("admin", "adminpassword", admin=True)
+ admin_token = self.login("admin", "adminpassword")
+
+ # Create three users
+ user_one = self.register_user("alice", "pass")
+ user_one_token = self.login("alice", "pass")
+ user_two = self.register_user("bob", "pass")
+ user_three = self.register_user("carl", "pass")
+ user_three_token = self.login("carl", "pass")
+
+ # Deactivate user_one
+ self.deactivate(user_one, user_one_token)
+
+ # Expire user_two
+ self.expire(user_two, admin_token)
+
+ # Do nothing to user_three
+
+ return user_one, user_two, user_three, user_three_token
+
+ def expire(self, user_id_to_expire, admin_tok):
+ url = "/_synapse/admin/v1/account_validity/validity"
+ request_data = {
+ "user_id": user_id_to_expire,
+ "expiration_ts": 0,
+ "enable_renewal_emails": False,
+ }
+ request, channel = self.make_request(
+ "POST", url, request_data, access_token=admin_tok
+ )
+ self.assertEquals(channel.result["code"], b"200", channel.result)
+
+ def deactivate(self, user_id, tok):
+ request_data = {
+ "auth": {"type": "m.login.password", "user": user_id, "password": "pass"},
+ "erase": False,
+ }
+ request, channel = self.make_request(
+ "POST", "account/deactivate", request_data, access_token=tok
+ )
+ self.assertEqual(request.code, 200)
|