summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_user_directory.py142
-rw-r--r--tests/storage/test_user_directory.py77
2 files changed, 140 insertions, 79 deletions
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 0120b4688b..e0635c8898 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -109,18 +109,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             tok=alice_token,
         )
 
-        users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
-        in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
-        in_private = self.get_success(
-            self.user_dir_helper.get_users_who_share_private_rooms()
+        # The user directory should reflect the room memberships above.
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
         )
-
         self.assertEqual(users, {alice, bob})
+        self.assertEqual(in_public, {(alice, public), (bob, public), (alice, public2)})
         self.assertEqual(
-            set(in_public), {(alice, public), (bob, public), (alice, public2)}
-        )
-        self.assertEqual(
-            self.user_dir_helper._compress_shared(in_private),
+            in_private,
             {(alice, bob, private), (bob, alice, private)},
         )
 
@@ -209,6 +205,88 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
         self.assertEqual(set(in_public), {(user1, room), (user2, room)})
 
+    def test_excludes_users_when_making_room_public(self) -> None:
+        # Create a regular user and a support user.
+        alice = self.register_user("alice", "pass")
+        alice_token = self.login(alice, "pass")
+        support = "@support1:test"
+        self.get_success(
+            self.store.register_user(
+                user_id=support, password_hash=None, user_type=UserTypes.SUPPORT
+            )
+        )
+
+        # Make a public and private room containing Alice and the support user
+        public, initially_private = self._create_rooms_and_inject_memberships(
+            alice, alice_token, support
+        )
+        self._check_only_one_user_in_directory(alice, public)
+
+        # Alice makes the private room public.
+        self.helper.send_state(
+            initially_private,
+            "m.room.join_rules",
+            {"join_rule": "public"},
+            tok=alice_token,
+        )
+
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
+        )
+        self.assertEqual(users, {alice})
+        self.assertEqual(in_public, {(alice, public), (alice, initially_private)})
+        self.assertEqual(in_private, set())
+
+    def test_switching_from_private_to_public_to_private(self) -> None:
+        """Check we update the room sharing tables when switching a room
+        from private to public, then back again to private."""
+        # Alice and Bob share a private room.
+        alice = self.register_user("alice", "pass")
+        alice_token = self.login(alice, "pass")
+        bob = self.register_user("bob", "pass")
+        bob_token = self.login(bob, "pass")
+        room = self.helper.create_room_as(alice, is_public=False, tok=alice_token)
+        self.helper.invite(room, alice, bob, tok=alice_token)
+        self.helper.join(room, bob, tok=bob_token)
+
+        # The user directory should reflect this.
+        def check_user_dir_for_private_room() -> None:
+            users, in_public, in_private = self.get_success(
+                self.user_dir_helper.get_tables()
+            )
+            self.assertEqual(users, {alice, bob})
+            self.assertEqual(in_public, set())
+            self.assertEqual(in_private, {(alice, bob, room), (bob, alice, room)})
+
+        check_user_dir_for_private_room()
+
+        # Alice makes the room public.
+        self.helper.send_state(
+            room,
+            "m.room.join_rules",
+            {"join_rule": "public"},
+            tok=alice_token,
+        )
+
+        # The user directory should be updated accordingly
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
+        )
+        self.assertEqual(users, {alice, bob})
+        self.assertEqual(in_public, {(alice, room), (bob, room)})
+        self.assertEqual(in_private, set())
+
+        # Alice makes the room private.
+        self.helper.send_state(
+            room,
+            "m.room.join_rules",
+            {"join_rule": "invite"},
+            tok=alice_token,
+        )
+
+        # The user directory should be updated accordingly
+        check_user_dir_for_private_room()
+
     def _create_rooms_and_inject_memberships(
         self, creator: str, token: str, joiner: str
     ) -> Tuple[str, str]:
@@ -232,15 +310,18 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         return public_room, private_room
 
     def _check_only_one_user_in_directory(self, user: str, public: str) -> None:
-        users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
-        in_public = self.get_success(self.user_dir_helper.get_users_in_public_rooms())
-        in_private = self.get_success(
-            self.user_dir_helper.get_users_who_share_private_rooms()
-        )
+        """Check that the user directory DB tables show that:
 
+        - only one user is in the user directory
+        - they belong to exactly one public room
+        - they don't share a private room with anyone.
+        """
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
+        )
         self.assertEqual(users, {user})
-        self.assertEqual(set(in_public), {(user, public)})
-        self.assertEqual(in_private, [])
+        self.assertEqual(in_public, {(user, public)})
+        self.assertEqual(in_private, set())
 
     def test_handle_local_profile_change_with_support_user(self) -> None:
         support_user_id = "@support:test"
@@ -581,11 +662,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.user_dir_helper.get_users_in_public_rooms()
         )
 
-        self.assertEqual(
-            self.user_dir_helper._compress_shared(shares_private),
-            {(u1, u2, room), (u2, u1, room)},
-        )
-        self.assertEqual(public_users, [])
+        self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+        self.assertEqual(public_users, set())
 
         # We get one search result when searching for user2 by user1.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -610,8 +688,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.user_dir_helper.get_users_in_public_rooms()
         )
 
-        self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
-        self.assertEqual(public_users, [])
+        self.assertEqual(shares_private, set())
+        self.assertEqual(public_users, set())
 
         # User1 now gets no search results for any of the other users.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -645,11 +723,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.user_dir_helper.get_users_in_public_rooms()
         )
 
-        self.assertEqual(
-            self.user_dir_helper._compress_shared(shares_private),
-            {(u1, u2, room), (u2, u1, room)},
-        )
-        self.assertEqual(public_users, [])
+        self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+        self.assertEqual(public_users, set())
 
         # We get one search result when searching for user2 by user1.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -704,11 +779,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             self.user_dir_helper.get_users_in_public_rooms()
         )
 
-        self.assertEqual(
-            self.user_dir_helper._compress_shared(shares_private),
-            {(u1, u2, room), (u2, u1, room)},
-        )
-        self.assertEqual(public_users, [])
+        self.assertEqual(shares_private, {(u1, u2, room), (u2, u1, room)})
+        self.assertEqual(public_users, set())
 
         # Configure a spam checker.
         spam_checker = self.hs.get_spam_checker()
@@ -740,8 +812,8 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
 
         # No users share rooms
-        self.assertEqual(public_users, [])
-        self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
+        self.assertEqual(public_users, set())
+        self.assertEqual(shares_private, set())
 
         # Despite not sharing a room, search_all_users means we get a search
         # result.
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index be3ed64f5e..37cf7bb232 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -11,7 +11,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Any, Dict, List, Set, Tuple
+from typing import Any, Dict, Set, Tuple
 from unittest import mock
 from unittest.mock import Mock, patch
 
@@ -42,18 +42,7 @@ class GetUserDirectoryTables:
     def __init__(self, store: DataStore):
         self.store = store
 
-    def _compress_shared(
-        self, shared: List[Dict[str, str]]
-    ) -> Set[Tuple[str, str, str]]:
-        """
-        Compress a list of users who share rooms dicts to a list of tuples.
-        """
-        r = set()
-        for i in shared:
-            r.add((i["user_id"], i["other_user_id"], i["room_id"]))
-        return r
-
-    async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
+    async def get_users_in_public_rooms(self) -> Set[Tuple[str, str]]:
         """Fetch the entire `users_in_public_rooms` table.
 
         Returns a list of tuples (user_id, room_id) where room_id is public and
@@ -63,24 +52,27 @@ class GetUserDirectoryTables:
             "users_in_public_rooms", None, ("user_id", "room_id")
         )
 
-        retval = []
+        retval = set()
         for i in r:
-            retval.append((i["user_id"], i["room_id"]))
+            retval.add((i["user_id"], i["room_id"]))
         return retval
 
-    async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]:
+    async def get_users_who_share_private_rooms(self) -> Set[Tuple[str, str, str]]:
         """Fetch the entire `users_who_share_private_rooms` table.
 
-        Returns a dict containing "user_id", "other_user_id" and "room_id" keys.
-        The dicts can be flattened to Tuples with the `_compress_shared` method.
-        (This seems a little awkward---maybe we could clean this up.)
+        Returns a set of tuples (user_id, other_user_id, room_id) corresponding
+        to the rows of `users_who_share_private_rooms`.
         """
 
-        return await self.store.db_pool.simple_select_list(
+        rows = await self.store.db_pool.simple_select_list(
             "users_who_share_private_rooms",
             None,
             ["user_id", "other_user_id", "room_id"],
         )
+        rv = set()
+        for row in rows:
+            rv.add((row["user_id"], row["other_user_id"], row["room_id"]))
+        return rv
 
     async def get_users_in_user_directory(self) -> Set[str]:
         """Fetch the set of users in the `user_directory` table.
@@ -113,6 +105,16 @@ class GetUserDirectoryTables:
             for row in rows
         }
 
+    async def get_tables(
+        self,
+    ) -> Tuple[Set[str], Set[Tuple[str, str]], Set[Tuple[str, str, str]]]:
+        """Multiple tests want to inspect these tables, so expose them together."""
+        return (
+            await self.get_users_in_user_directory(),
+            await self.get_users_in_public_rooms(),
+            await self.get_users_who_share_private_rooms(),
+        )
+
 
 class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
     """Ensure that rebuilding the directory writes the correct data to the DB.
@@ -166,8 +168,8 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
         )
 
         # Nothing updated yet
-        self.assertEqual(shares_private, [])
-        self.assertEqual(public_users, [])
+        self.assertEqual(shares_private, set())
+        self.assertEqual(public_users, set())
 
         # Ugh, have to reset this flag
         self.store.db_pool.updates._all_done = False
@@ -236,24 +238,15 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
         # Do the initial population of the user directory via the background update
         self._purge_and_rebuild_user_dir()
 
-        shares_private = self.get_success(
-            self.user_dir_helper.get_users_who_share_private_rooms()
-        )
-        public_users = self.get_success(
-            self.user_dir_helper.get_users_in_public_rooms()
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
         )
 
         # User 1 and User 2 are in the same public room
-        self.assertEqual(set(public_users), {(u1, room), (u2, room)})
-
+        self.assertEqual(in_public, {(u1, room), (u2, room)})
         # User 1 and User 3 share private rooms
-        self.assertEqual(
-            self.user_dir_helper._compress_shared(shares_private),
-            {(u1, u3, private_room), (u3, u1, private_room)},
-        )
-
+        self.assertEqual(in_private, {(u1, u3, private_room), (u3, u1, private_room)})
         # All three should have entries in the directory
-        users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
         self.assertEqual(users, {u1, u2, u3})
 
     # The next four tests (test_population_excludes_*) all set up
@@ -289,16 +282,12 @@ class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
         self, normal_user: str, public_room: str, private_room: str
     ) -> None:
         # After rebuilding the directory, we should only see the normal user.
-        users = self.get_success(self.user_dir_helper.get_users_in_user_directory())
-        self.assertEqual(users, {normal_user})
-        in_public_rooms = self.get_success(
-            self.user_dir_helper.get_users_in_public_rooms()
+        users, in_public, in_private = self.get_success(
+            self.user_dir_helper.get_tables()
         )
-        self.assertEqual(set(in_public_rooms), {(normal_user, public_room)})
-        in_private_rooms = self.get_success(
-            self.user_dir_helper.get_users_who_share_private_rooms()
-        )
-        self.assertEqual(in_private_rooms, [])
+        self.assertEqual(users, {normal_user})
+        self.assertEqual(in_public, {(normal_user, public_room)})
+        self.assertEqual(in_private, set())
 
     def test_population_excludes_support_user(self) -> None:
         # Create a normal and support user.