summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
authorDavid Robertson <davidr@element.io>2021-09-30 11:04:40 +0100
committerGitHub <noreply@github.com>2021-09-30 11:04:40 +0100
commit3aefc7b66d9c7fb98addc71eaf5ef501a4c6a583 (patch)
treeea5fba5e1ee0bafb997e08ad8a1dc56d841b44fb /tests
parentSplit `event_auth.check` into two parts (#10940) (diff)
downloadsynapse-3aefc7b66d9c7fb98addc71eaf5ef501a4c6a583.tar.xz
Refactor user directory tests (#10935)
* Pull out GetUserDirectoryTables helper
* Don't rebuild the dir in tests that don't need it

In #10796 I changed registering a user to add directory entries under.
This means we don't have to force a directory regbuild in to tests of
the user directory search.

* Move test_initial to tests/storage
* Add type hints to both test_user_directory files

Co-authored-by: Richard van der Hoff <1389908+richvdh@users.noreply.github.com>
Diffstat (limited to '')
-rw-r--r--tests/handlers/test_user_directory.py283
-rw-r--r--tests/storage/test_user_directory.py192
-rw-r--r--tests/unittest.py4
3 files changed, 280 insertions, 199 deletions
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 266333c553..2988befb21 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -11,26 +11,37 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import List, Tuple
-from unittest.mock import Mock
+from unittest.mock import Mock, patch
 from urllib.parse import quote
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 import synapse.rest.admin
 from synapse.api.constants import UserTypes
 from synapse.api.room_versions import RoomVersion, RoomVersions
 from synapse.rest.client import login, room, user_directory
+from synapse.server import HomeServer
 from synapse.storage.roommember import ProfileInfo
 from synapse.types import create_requester
+from synapse.util import Clock
 
 from tests import unittest
+from tests.storage.test_user_directory import GetUserDirectoryTables
 from tests.unittest import override_config
 
 
 class UserDirectoryTestCase(unittest.HomeserverTestCase):
-    """
-    Tests the UserDirectoryHandler.
+    """Tests the UserDirectoryHandler.
+
+    We're broadly testing two kinds of things here.
+
+    1. Check that we correctly update the user directory in response
+       to events (e.g. join a room, leave a room, change name, make public)
+    2. Check that the search logic behaves as expected.
+
+    The background process that rebuilds the user directory is tested in
+    tests/storage/test_user_directory.py.
     """
 
     servlets = [
@@ -39,19 +50,19 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         room.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
-
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["update_user_directory"] = True
         return self.setup_test_homeserver(config=config)
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastore()
         self.handler = hs.get_user_directory_handler()
         self.event_builder_factory = self.hs.get_event_builder_factory()
         self.event_creation_handler = self.hs.get_event_creation_handler()
+        self.user_dir_helper = GetUserDirectoryTables(self.store)
 
-    def test_handle_local_profile_change_with_support_user(self):
+    def test_handle_local_profile_change_with_support_user(self) -> None:
         support_user_id = "@support:test"
         self.get_success(
             self.store.register_user(
@@ -64,7 +75,9 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         )
 
         self.get_success(
-            self.handler.handle_local_profile_change(support_user_id, None)
+            self.handler.handle_local_profile_change(
+                support_user_id, ProfileInfo("I love support me", None)
+            )
         )
         profile = self.get_success(self.store.get_user_in_directory(support_user_id))
         self.assertTrue(profile is None)
@@ -77,7 +90,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
         self.assertTrue(profile["display_name"] == display_name)
 
-    def test_handle_local_profile_change_with_deactivated_user(self):
+    def test_handle_local_profile_change_with_deactivated_user(self) -> None:
         # create user
         r_user_id = "@regular:test"
         self.get_success(
@@ -112,7 +125,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         profile = self.get_success(self.store.get_user_in_directory(r_user_id))
         self.assertTrue(profile is None)
 
-    def test_handle_user_deactivated_support_user(self):
+    def test_handle_user_deactivated_support_user(self) -> None:
         s_user_id = "@support:test"
         self.get_success(
             self.store.register_user(
@@ -120,20 +133,29 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             )
         )
 
-        self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
-        self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
-        self.store.remove_from_user_dir.not_called()
+        mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
+        with patch.object(
+            self.store, "remove_from_user_dir", mock_remove_from_user_dir
+        ):
+            self.get_success(self.handler.handle_local_user_deactivated(s_user_id))
+        # BUG: the correct spelling is assert_not_called, but that makes the test fail
+        # and it's not clear that this is actually the behaviour we want.
+        mock_remove_from_user_dir.not_called()
 
-    def test_handle_user_deactivated_regular_user(self):
+    def test_handle_user_deactivated_regular_user(self) -> None:
         r_user_id = "@regular:test"
         self.get_success(
             self.store.register_user(user_id=r_user_id, password_hash=None)
         )
-        self.store.remove_from_user_dir = Mock(return_value=defer.succeed(None))
-        self.get_success(self.handler.handle_local_user_deactivated(r_user_id))
-        self.store.remove_from_user_dir.called_once_with(r_user_id)
 
-    def test_reactivation_makes_regular_user_searchable(self):
+        mock_remove_from_user_dir = Mock(return_value=defer.succeed(None))
+        with patch.object(
+            self.store, "remove_from_user_dir", mock_remove_from_user_dir
+        ):
+            self.get_success(self.handler.handle_local_user_deactivated(r_user_id))
+        mock_remove_from_user_dir.assert_called_once_with(r_user_id)
+
+    def test_reactivation_makes_regular_user_searchable(self) -> None:
         user = self.register_user("regular", "pass")
         user_token = self.login(user, "pass")
         admin_user = self.register_user("admin", "pass", admin=True)
@@ -171,7 +193,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(s["results"]), 1)
         self.assertEqual(s["results"][0]["user_id"], user)
 
-    def test_private_room(self):
+    def test_private_room(self) -> None:
         """
         A user can be searched for only by people that are either in a public
         room, or that share a private chat.
@@ -191,11 +213,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2, tok=u2_token)
 
         # Check we have populated the database correctly.
-        shares_private = self.get_users_who_share_private_rooms()
-        public_users = self.get_users_in_public_rooms()
+        shares_private = self.get_success(
+            self.user_dir_helper.get_users_who_share_private_rooms()
+        )
+        public_users = self.get_success(
+            self.user_dir_helper.get_users_in_public_rooms()
+        )
 
         self.assertEqual(
-            self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+            self.user_dir_helper._compress_shared(shares_private),
+            {(u1, u2, room), (u2, u1, room)},
         )
         self.assertEqual(public_users, [])
 
@@ -215,10 +242,14 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.helper.leave(room, user=u2, tok=u2_token)
 
         # Check we have removed the values.
-        shares_private = self.get_users_who_share_private_rooms()
-        public_users = self.get_users_in_public_rooms()
+        shares_private = self.get_success(
+            self.user_dir_helper.get_users_who_share_private_rooms()
+        )
+        public_users = self.get_success(
+            self.user_dir_helper.get_users_in_public_rooms()
+        )
 
-        self.assertEqual(self._compress_shared(shares_private), set())
+        self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
         self.assertEqual(public_users, [])
 
         # User1 now gets no search results for any of the other users.
@@ -228,7 +259,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user3", 10))
         self.assertEqual(len(s["results"]), 0)
 
-    def test_spam_checker(self):
+    def test_spam_checker(self) -> None:
         """
         A user which fails the spam checks will not appear in search results.
         """
@@ -246,11 +277,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2, tok=u2_token)
 
         # Check we have populated the database correctly.
-        shares_private = self.get_users_who_share_private_rooms()
-        public_users = self.get_users_in_public_rooms()
+        shares_private = self.get_success(
+            self.user_dir_helper.get_users_who_share_private_rooms()
+        )
+        public_users = self.get_success(
+            self.user_dir_helper.get_users_in_public_rooms()
+        )
 
         self.assertEqual(
-            self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+            self.user_dir_helper._compress_shared(shares_private),
+            {(u1, u2, room), (u2, u1, room)},
         )
         self.assertEqual(public_users, [])
 
@@ -258,7 +294,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
         self.assertEqual(len(s["results"]), 1)
 
-        async def allow_all(user_profile):
+        async def allow_all(user_profile: ProfileInfo) -> bool:
             # Allow all users.
             return False
 
@@ -272,7 +308,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.assertEqual(len(s["results"]), 1)
 
         # Configure a spam checker that filters all users.
-        async def block_all(user_profile):
+        async def block_all(user_profile: ProfileInfo) -> bool:
             # All users are spammy.
             return True
 
@@ -282,7 +318,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
         self.assertEqual(len(s["results"]), 0)
 
-    def test_legacy_spam_checker(self):
+    def test_legacy_spam_checker(self) -> None:
         """
         A spam checker without the expected method should be ignored.
         """
@@ -300,11 +336,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2, tok=u2_token)
 
         # Check we have populated the database correctly.
-        shares_private = self.get_users_who_share_private_rooms()
-        public_users = self.get_users_in_public_rooms()
+        shares_private = self.get_success(
+            self.user_dir_helper.get_users_who_share_private_rooms()
+        )
+        public_users = self.get_success(
+            self.user_dir_helper.get_users_in_public_rooms()
+        )
 
         self.assertEqual(
-            self._compress_shared(shares_private), {(u1, u2, room), (u2, u1, room)}
+            self.user_dir_helper._compress_shared(shares_private),
+            {(u1, u2, room), (u2, u1, room)},
         )
         self.assertEqual(public_users, [])
 
@@ -317,134 +358,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
         self.assertEqual(len(s["results"]), 1)
 
-    def _compress_shared(self, shared):
-        """
-        Compress a list of users who share rooms dicts to a list of tuples.
-        """
-        r = set()
-        for i in shared:
-            r.add((i["user_id"], i["other_user_id"], i["room_id"]))
-        return r
-
-    def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
-        r = self.get_success(
-            self.store.db_pool.simple_select_list(
-                "users_in_public_rooms", None, ("user_id", "room_id")
-            )
-        )
-        retval = []
-        for i in r:
-            retval.append((i["user_id"], i["room_id"]))
-        return retval
-
-    def get_users_who_share_private_rooms(self) -> List[Tuple[str, str, str]]:
-        return self.get_success(
-            self.store.db_pool.simple_select_list(
-                "users_who_share_private_rooms",
-                None,
-                ["user_id", "other_user_id", "room_id"],
-            )
-        )
-
-    def _add_background_updates(self):
-        """
-        Add the background updates we need to run.
-        """
-        # Ugh, have to reset this flag
-        self.store.db_pool.updates._all_done = False
-
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
-                {
-                    "update_name": "populate_user_directory_createtables",
-                    "progress_json": "{}",
-                },
-            )
-        )
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
-                {
-                    "update_name": "populate_user_directory_process_rooms",
-                    "progress_json": "{}",
-                    "depends_on": "populate_user_directory_createtables",
-                },
-            )
-        )
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
-                {
-                    "update_name": "populate_user_directory_process_users",
-                    "progress_json": "{}",
-                    "depends_on": "populate_user_directory_process_rooms",
-                },
-            )
-        )
-        self.get_success(
-            self.store.db_pool.simple_insert(
-                "background_updates",
-                {
-                    "update_name": "populate_user_directory_cleanup",
-                    "progress_json": "{}",
-                    "depends_on": "populate_user_directory_process_users",
-                },
-            )
-        )
-
-    def test_initial(self):
-        """
-        The user directory's initial handler correctly updates the search tables.
-        """
-        u1 = self.register_user("user1", "pass")
-        u1_token = self.login(u1, "pass")
-        u2 = self.register_user("user2", "pass")
-        u2_token = self.login(u2, "pass")
-        u3 = self.register_user("user3", "pass")
-        u3_token = self.login(u3, "pass")
-
-        room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
-        self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
-        self.helper.join(room, user=u2, tok=u2_token)
-
-        private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
-        self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token)
-        self.helper.join(private_room, user=u3, tok=u3_token)
-
-        self.get_success(self.store.update_user_directory_stream_pos(None))
-        self.get_success(self.store.delete_all_from_user_dir())
-
-        shares_private = self.get_users_who_share_private_rooms()
-        public_users = self.get_users_in_public_rooms()
-
-        # Nothing updated yet
-        self.assertEqual(shares_private, [])
-        self.assertEqual(public_users, [])
-
-        # Do the initial population of the user directory via the background update
-        self._add_background_updates()
-
-        while not self.get_success(
-            self.store.db_pool.updates.has_completed_background_updates()
-        ):
-            self.get_success(
-                self.store.db_pool.updates.do_next_background_update(100), by=0.1
-            )
-
-        shares_private = self.get_users_who_share_private_rooms()
-        public_users = self.get_users_in_public_rooms()
-
-        # User 1 and User 2 are in the same public room
-        self.assertEqual(set(public_users), {(u1, room), (u2, room)})
-
-        # User 1 and User 3 share private rooms
-        self.assertEqual(
-            self._compress_shared(shares_private),
-            {(u1, u3, private_room), (u3, u1, private_room)},
-        )
-
-    def test_initial_share_all_users(self):
+    def test_initial_share_all_users(self) -> None:
         """
         Search all users = True means that a user does not have to share a
         private room with the searching user or be in a public room to be search
@@ -457,26 +371,16 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.register_user("user2", "pass")
         u3 = self.register_user("user3", "pass")
 
-        # Wipe the user dir
-        self.get_success(self.store.update_user_directory_stream_pos(None))
-        self.get_success(self.store.delete_all_from_user_dir())
-
-        # Do the initial population of the user directory via the background update
-        self._add_background_updates()
-
-        while not self.get_success(
-            self.store.db_pool.updates.has_completed_background_updates()
-        ):
-            self.get_success(
-                self.store.db_pool.updates.do_next_background_update(100), by=0.1
-            )
-
-        shares_private = self.get_users_who_share_private_rooms()
-        public_users = self.get_users_in_public_rooms()
+        shares_private = self.get_success(
+            self.user_dir_helper.get_users_who_share_private_rooms()
+        )
+        public_users = self.get_success(
+            self.user_dir_helper.get_users_in_public_rooms()
+        )
 
         # No users share rooms
         self.assertEqual(public_users, [])
-        self.assertEqual(self._compress_shared(shares_private), set())
+        self.assertEqual(self.user_dir_helper._compress_shared(shares_private), set())
 
         # Despite not sharing a room, search_all_users means we get a search
         # result.
@@ -501,7 +405,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             }
         }
     )
-    def test_prefer_local_users(self):
+    def test_prefer_local_users(self) -> None:
         """Tests that local users are shown higher in search results when
         user_directory.prefer_local_users is True.
         """
@@ -535,15 +439,6 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         local_users = [local_user_1, local_user_2, local_user_3]
         remote_users = [remote_user_1, remote_user_2, remote_user_3]
 
-        # Populate the user directory via background update
-        self._add_background_updates()
-        while not self.get_success(
-            self.store.db_pool.updates.has_completed_background_updates()
-        ):
-            self.get_success(
-                self.store.db_pool.updates.do_next_background_update(100), by=0.1
-            )
-
         # The local searching user searches for the term "user", which other users have
         # in their user id
         results = self.get_success(
@@ -565,7 +460,7 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         room_id: str,
         room_version: RoomVersion,
         user_id: str,
-    ):
+    ) -> None:
         # Add a user to the room.
         builder = self.event_builder_factory.for_room_version(
             room_version,
@@ -597,7 +492,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
         synapse.rest.admin.register_servlets_for_client_rest_resource,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         config = self.default_config()
         config["update_user_directory"] = True
         hs = self.setup_test_homeserver(config=config)
@@ -606,7 +501,7 @@ class TestUserDirSearchDisabled(unittest.HomeserverTestCase):
 
         return hs
 
-    def test_disabling_room_list(self):
+    def test_disabling_room_list(self) -> None:
         self.config.userdirectory.user_directory_search_enabled = True
 
         # First we create a room with another user so that user dir is non-empty
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 222e5d129d..74c8a8599e 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -11,6 +11,15 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+from typing import Dict, List, Set, Tuple
+
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.rest import admin
+from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.storage import DataStore
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase, override_config
 
@@ -21,8 +30,183 @@ BOBBY = "@bobby:a"
 BELA = "@somenickname:a"
 
 
+class GetUserDirectoryTables:
+    """Helper functions that we want to reuse in tests/handlers/test_user_directory.py"""
+
+    def __init__(self, store: DataStore):
+        self.store = store
+
+    def _compress_shared(
+        self, shared: List[Dict[str, str]]
+    ) -> Set[Tuple[str, str, str]]:
+        """
+        Compress a list of users who share rooms dicts to a list of tuples.
+        """
+        r = set()
+        for i in shared:
+            r.add((i["user_id"], i["other_user_id"], i["room_id"]))
+        return r
+
+    async def get_users_in_public_rooms(self) -> List[Tuple[str, str]]:
+        r = await self.store.db_pool.simple_select_list(
+            "users_in_public_rooms", None, ("user_id", "room_id")
+        )
+
+        retval = []
+        for i in r:
+            retval.append((i["user_id"], i["room_id"]))
+        return retval
+
+    async def get_users_who_share_private_rooms(self) -> List[Dict[str, str]]:
+        return await self.store.db_pool.simple_select_list(
+            "users_who_share_private_rooms",
+            None,
+            ["user_id", "other_user_id", "room_id"],
+        )
+
+
+class UserDirectoryInitialPopulationTestcase(HomeserverTestCase):
+    """Ensure that rebuilding the directory writes the correct data to the DB.
+
+    See also tests/handlers/test_user_directory.py for similar checks. They
+    test the incremental updates, rather than the big rebuild.
+    """
+
+    servlets = [
+        login.register_servlets,
+        admin.register_servlets_for_client_rest_resource,
+        room.register_servlets,
+    ]
+
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
+        self.store = hs.get_datastore()
+        self.user_dir_helper = GetUserDirectoryTables(self.store)
+
+    def _purge_and_rebuild_user_dir(self) -> None:
+        """Nuke the user directory tables, start the background process to
+        repopulate them, and wait for the process to complete. This allows us
+        to inspect the outcome of the background process alone, without any of
+        the other incremental updates.
+        """
+        self.get_success(self.store.update_user_directory_stream_pos(None))
+        self.get_success(self.store.delete_all_from_user_dir())
+
+        shares_private = self.get_success(
+            self.user_dir_helper.get_users_who_share_private_rooms()
+        )
+        public_users = self.get_success(
+            self.user_dir_helper.get_users_in_public_rooms()
+        )
+
+        # Nothing updated yet
+        self.assertEqual(shares_private, [])
+        self.assertEqual(public_users, [])
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {
+                    "update_name": "populate_user_directory_createtables",
+                    "progress_json": "{}",
+                },
+            )
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {
+                    "update_name": "populate_user_directory_process_rooms",
+                    "progress_json": "{}",
+                    "depends_on": "populate_user_directory_createtables",
+                },
+            )
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {
+                    "update_name": "populate_user_directory_process_users",
+                    "progress_json": "{}",
+                    "depends_on": "populate_user_directory_process_rooms",
+                },
+            )
+        )
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {
+                    "update_name": "populate_user_directory_cleanup",
+                    "progress_json": "{}",
+                    "depends_on": "populate_user_directory_process_users",
+                },
+            )
+        )
+
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+    def test_initial(self) -> None:
+        """
+        The user directory's initial handler correctly updates the search tables.
+        """
+        u1 = self.register_user("user1", "pass")
+        u1_token = self.login(u1, "pass")
+        u2 = self.register_user("user2", "pass")
+        u2_token = self.login(u2, "pass")
+        u3 = self.register_user("user3", "pass")
+        u3_token = self.login(u3, "pass")
+
+        room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+        self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+        self.helper.join(room, user=u2, tok=u2_token)
+
+        private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+        self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token)
+        self.helper.join(private_room, user=u3, tok=u3_token)
+
+        self.get_success(self.store.update_user_directory_stream_pos(None))
+        self.get_success(self.store.delete_all_from_user_dir())
+
+        shares_private = self.get_success(
+            self.user_dir_helper.get_users_who_share_private_rooms()
+        )
+        public_users = self.get_success(
+            self.user_dir_helper.get_users_in_public_rooms()
+        )
+
+        # Nothing updated yet
+        self.assertEqual(shares_private, [])
+        self.assertEqual(public_users, [])
+
+        # Do the initial population of the user directory via the background update
+        self._purge_and_rebuild_user_dir()
+
+        shares_private = self.get_success(
+            self.user_dir_helper.get_users_who_share_private_rooms()
+        )
+        public_users = self.get_success(
+            self.user_dir_helper.get_users_in_public_rooms()
+        )
+
+        # User 1 and User 2 are in the same public room
+        self.assertEqual(set(public_users), {(u1, room), (u2, room)})
+
+        # User 1 and User 3 share private rooms
+        self.assertEqual(
+            self.user_dir_helper._compress_shared(shares_private),
+            {(u1, u3, private_room), (u3, u1, private_room)},
+        )
+
+
 class UserDirectoryStoreTestCase(HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.store = hs.get_datastore()
 
         # alice and bob are both in !room_id. bobby is not but shares
@@ -33,7 +217,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
         self.get_success(self.store.update_profile_in_user_dir(BELA, "Bela", None))
         self.get_success(self.store.add_users_in_public_rooms("!room:id", (ALICE, BOB)))
 
-    def test_search_user_dir(self):
+    def test_search_user_dir(self) -> None:
         # normally when alice searches the directory she should just find
         # bob because bobby doesn't share a room with her.
         r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
@@ -44,7 +228,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
         )
 
     @override_config({"user_directory": {"search_all_users": True}})
-    def test_search_user_dir_all_users(self):
+    def test_search_user_dir_all_users(self) -> None:
         r = self.get_success(self.store.search_user_dir(ALICE, "bob", 10))
         self.assertFalse(r["limited"])
         self.assertEqual(2, len(r["results"]))
@@ -58,7 +242,7 @@ class UserDirectoryStoreTestCase(HomeserverTestCase):
         )
 
     @override_config({"user_directory": {"search_all_users": True}})
-    def test_search_user_dir_stop_words(self):
+    def test_search_user_dir_stop_words(self) -> None:
         """Tests that a user can look up another user by searching for the start if its
         display name even if that name happens to be a common English word that would
         usually be ignored in full text searches.
diff --git a/tests/unittest.py b/tests/unittest.py
index 6d5d87cb78..5f93ebf147 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -28,6 +28,7 @@ from canonicaljson import json
 from twisted.internet.defer import Deferred, ensureDeferred, succeed
 from twisted.python.failure import Failure
 from twisted.python.threadpool import ThreadPool
+from twisted.test.proto_helpers import MemoryReactor
 from twisted.trial import unittest
 from twisted.web.resource import Resource
 
@@ -46,6 +47,7 @@ from synapse.logging.context import (
 )
 from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
+from synapse.util import Clock
 from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.ratelimitutils import FederationRateLimiter
 
@@ -371,7 +373,7 @@ class HomeserverTestCase(TestCase):
 
         return config
 
-    def prepare(self, reactor, clock, homeserver):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer):
         """
         Prepare for the test.  This involves things like mocking out parts of
         the homeserver, or building test data common across the whole test