summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/4846.feature1
-rwxr-xr-xsynapse/app/homeserver.py1
-rw-r--r--synapse/handlers/user_directory.py125
-rw-r--r--synapse/server.py13
-rw-r--r--synapse/storage/_base.py13
-rw-r--r--synapse/storage/schema/delta/53/user_share.sql3
-rw-r--r--synapse/storage/schema/delta/53/users_in_public_rooms.sql28
-rw-r--r--synapse/storage/user_directory.py123
-rw-r--r--tests/handlers/test_user_directory.py31
-rw-r--r--tests/storage/test_user_directory.py4
-rw-r--r--tests/utils.py2
11 files changed, 206 insertions, 138 deletions
diff --git a/changelog.d/4846.feature b/changelog.d/4846.feature
new file mode 100644
index 0000000000..8f792b8890
--- /dev/null
+++ b/changelog.d/4846.feature
@@ -0,0 +1 @@
+The user directory has been rewritten to make it faster, with less chance of falling behind on a large server.
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e8b6cc3114..e0431608e8 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -376,6 +376,7 @@ def setup(config_options):
     logger.info("Database prepared in %s.", config.database_config['name'])
 
     hs.setup()
+    hs.setup_master()
 
     @defer.inlineCallbacks
     def do_acme():
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index c21da8343a..d92f8c529c 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -60,6 +60,12 @@ class UserDirectoryHandler(object):
         self.update_user_directory = hs.config.update_user_directory
         self.search_all_users = hs.config.user_directory_search_all_users
 
+        # If we're a worker, don't sleep when doing the initial room work, as it
+        # won't monopolise the master's CPU.
+        if hs.config.worker_app:
+            self.INITIAL_ROOM_SLEEP_MS = 0
+            self.INITIAL_USER_SLEEP_MS = 0
+
         # When start up for the first time we need to populate the user_directory.
         # This is a set of user_id's we've inserted already
         self.initially_handled_users = set()
@@ -231,7 +237,7 @@ class UserDirectoryHandler(object):
         unhandled_users = user_ids - self.initially_handled_users
 
         yield self.store.add_profiles_to_user_dir(
-            {user_id: users_with_profile[user_id] for user_id in unhandled_users},
+            {user_id: users_with_profile[user_id] for user_id in unhandled_users}
         )
 
         self.initially_handled_users |= unhandled_users
@@ -241,38 +247,58 @@ class UserDirectoryHandler(object):
         # We also batch up inserts/updates, but try to avoid too many at once.
         to_insert = set()
         count = 0
-        for user_id in user_ids:
-            if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
-                yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
-
-            if not self.is_mine_id(user_id):
-                count += 1
-                continue
 
-            if self.store.get_if_app_services_interested_in_user(user_id):
-                count += 1
-                continue
+        if is_public:
+            for user_id in user_ids:
+                if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+                    yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
 
-            for other_user_id in user_ids:
-                if user_id == other_user_id:
+                if self.store.get_if_app_services_interested_in_user(user_id):
+                    count += 1
                     continue
 
+                to_insert.add(user_id)
+                if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
+                    yield self.store.add_users_in_public_rooms(room_id, to_insert)
+                    to_insert.clear()
+
+            if to_insert:
+                yield self.store.add_users_in_public_rooms(room_id, to_insert)
+                to_insert.clear()
+        else:
+
+            for user_id in user_ids:
                 if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
                     yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
-                count += 1
 
-                user_set = (user_id, other_user_id)
-                to_insert.add(user_set)
+                if not self.is_mine_id(user_id):
+                    count += 1
+                    continue
 
-                if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
-                    yield self.store.add_users_who_share_room(
-                        room_id, not is_public, to_insert
-                    )
-                    to_insert.clear()
+                if self.store.get_if_app_services_interested_in_user(user_id):
+                    count += 1
+                    continue
+
+                for other_user_id in user_ids:
+                    if user_id == other_user_id:
+                        continue
+
+                    if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+                        yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
+                    count += 1
+
+                    user_set = (user_id, other_user_id)
+                    to_insert.add(user_set)
+
+                    if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
+                        yield self.store.add_users_who_share_private_room(
+                            room_id, not is_public, to_insert
+                        )
+                        to_insert.clear()
 
-        if to_insert:
-            yield self.store.add_users_who_share_room(room_id, not is_public, to_insert)
-            to_insert.clear()
+            if to_insert:
+                yield self.store.add_users_who_share_private_room(room_id, to_insert)
+                to_insert.clear()
 
     @defer.inlineCallbacks
     def _handle_deltas(self, deltas):
@@ -445,34 +471,37 @@ class UserDirectoryHandler(object):
         # Now we update users who share rooms with users.
         users_with_profile = yield self.state.get_current_user_in_room(room_id)
 
-        to_insert = set()
+        if is_public:
+            yield self.store.add_users_in_public_rooms(room_id, (user_id,))
+        else:
+            to_insert = set()
 
-        # First, if they're our user then we need to update for every user
-        if self.is_mine_id(user_id):
+            # First, if they're our user then we need to update for every user
+            if self.is_mine_id(user_id):
 
-            is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
+                is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
 
-            # We don't care about appservice users.
-            if not is_appservice:
-                for other_user_id in users_with_profile:
-                    if user_id == other_user_id:
-                        continue
+                # We don't care about appservice users.
+                if not is_appservice:
+                    for other_user_id in users_with_profile:
+                        if user_id == other_user_id:
+                            continue
 
-                    to_insert.add((user_id, other_user_id))
+                        to_insert.add((user_id, other_user_id))
 
-        # Next we need to update for every local user in the room
-        for other_user_id in users_with_profile:
-            if user_id == other_user_id:
-                continue
+            # Next we need to update for every local user in the room
+            for other_user_id in users_with_profile:
+                if user_id == other_user_id:
+                    continue
 
-            is_appservice = self.store.get_if_app_services_interested_in_user(
-                other_user_id
-            )
-            if self.is_mine_id(other_user_id) and not is_appservice:
-                to_insert.add((other_user_id, user_id))
+                is_appservice = self.store.get_if_app_services_interested_in_user(
+                    other_user_id
+                )
+                if self.is_mine_id(other_user_id) and not is_appservice:
+                    to_insert.add((other_user_id, user_id))
 
-        if to_insert:
-            yield self.store.add_users_who_share_room(room_id, not is_public, to_insert)
+            if to_insert:
+                yield self.store.add_users_who_share_private_room(room_id, to_insert)
 
     @defer.inlineCallbacks
     def _handle_remove_user(self, room_id, user_id):
@@ -487,10 +516,10 @@ class UserDirectoryHandler(object):
         # Remove user from sharing tables
         yield self.store.remove_user_who_share_room(user_id, room_id)
 
-        # Are they still in a room with members? If not, remove them entirely.
-        users_in_room_with = yield self.store.get_users_who_share_room_from_dir(user_id)
+        # Are they still in any rooms? If not, remove them entirely.
+        rooms_user_is_in = yield self.store.get_user_dir_rooms_user_is_in(user_id)
 
-        if len(users_in_room_with) == 0:
+        if len(rooms_user_is_in) == 0:
             yield self.store.remove_from_user_dir(user_id)
 
     @defer.inlineCallbacks
diff --git a/synapse/server.py b/synapse/server.py
index 72835e8c86..b9549dd042 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -185,6 +185,10 @@ class HomeServer(object):
         'registration_handler',
     ]
 
+    REQUIRED_ON_MASTER_STARTUP = [
+        "user_directory_handler",
+    ]
+
     # This is overridden in derived application classes
     # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
     # instantiated during setup() for future return by get_datastore()
@@ -221,6 +225,15 @@ class HomeServer(object):
             conn.commit()
         logger.info("Finished setting up.")
 
+    def setup_master(self):
+        """
+        Some handlers have side effects on instantiation (like registering
+        background updates). This function causes them to be fetched, and
+        therefore instantiated, to run those side effects.
+        """
+        for i in self.REQUIRED_ON_MASTER_STARTUP:
+            getattr(self, "get_" + i)()
+
     def get_reactor(self):
         """
         Fetch the Twisted reactor in use by this HomeServer.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a0333d5309..7e3903859b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -767,18 +767,25 @@ class SQLBaseStore(object):
         """
         allvalues = {}
         allvalues.update(keyvalues)
-        allvalues.update(values)
         allvalues.update(insertion_values)
 
+        if not values:
+            latter = "NOTHING"
+        else:
+            allvalues.update(values)
+            latter = (
+                "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
+            )
+
         sql = (
             "INSERT INTO %s (%s) VALUES (%s) "
-            "ON CONFLICT (%s) DO UPDATE SET %s"
+            "ON CONFLICT (%s) DO %s"
         ) % (
             table,
             ", ".join(k for k in allvalues),
             ", ".join("?" for _ in allvalues),
             ", ".join(k for k in keyvalues),
-            ", ".join(k + "=EXCLUDED." + k for k in values),
+            latter
         )
         txn.execute(sql, list(allvalues.values()))
 
diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/schema/delta/53/user_share.sql
index 14424ded0c..5831b1a6f8 100644
--- a/synapse/storage/schema/delta/53/user_share.sql
+++ b/synapse/storage/schema/delta/53/user_share.sql
@@ -16,9 +16,6 @@
 -- Old disused version of the tables below.
 DROP TABLE IF EXISTS users_who_share_rooms;
 
--- This is no longer used because it's duplicated by the users_who_share_public_rooms
-DROP TABLE IF EXISTS users_in_public_rooms;
-
 -- Tables keeping track of what users share rooms. This is a map of local users
 -- to local or remote users, per room. Remote users cannot be in the user_id
 -- column, only the other_user_id column. There are two tables, one for public
diff --git a/synapse/storage/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/schema/delta/53/users_in_public_rooms.sql
new file mode 100644
index 0000000000..f7827ca6d2
--- /dev/null
+++ b/synapse/storage/schema/delta/53/users_in_public_rooms.sql
@@ -0,0 +1,28 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We don't need the old version of this table.
+DROP TABLE IF EXISTS users_in_public_rooms;
+
+-- Old version of users_in_public_rooms
+DROP TABLE IF EXISTS users_who_share_public_rooms;
+
+-- Track what users are in public rooms.
+CREATE TABLE IF NOT EXISTS users_in_public_rooms (
+    user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id);
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 2317d22ed6..1c00b956e5 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -21,12 +21,11 @@ from six import iteritems
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, JoinRules
+from synapse.storage._base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.state import StateFilter
 from synapse.types import get_domain_from_id, get_localpart_from_id
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-
-from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
 
@@ -242,14 +241,7 @@ class UserDirectoryStore(SQLBaseStore):
                 txn, table="user_directory_search", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
-                txn,
-                table="users_who_share_public_rooms",
-                keyvalues={"user_id": user_id},
-            )
-            self._simple_delete_txn(
-                txn,
-                table="users_who_share_public_rooms",
-                keyvalues={"other_user_id": user_id},
+                txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
                 txn,
@@ -271,9 +263,9 @@ class UserDirectoryStore(SQLBaseStore):
         in the given room_id
         """
         user_ids_share_pub = yield self._simple_select_onecol(
-            table="users_who_share_public_rooms",
+            table="users_in_public_rooms",
             keyvalues={"room_id": room_id},
-            retcol="other_user_id",
+            retcol="user_id",
             desc="get_users_in_dir_due_to_room",
         )
 
@@ -311,26 +303,19 @@ class UserDirectoryStore(SQLBaseStore):
         rows = yield self._execute("get_all_local_users", None, sql)
         defer.returnValue([name for name, in rows])
 
-    def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
-        """Insert entries into the users_who_share_*_rooms table. The first
+    def add_users_who_share_private_room(self, room_id, user_id_tuples):
+        """Insert entries into the users_who_share_private_rooms table. The first
         user should be a local user.
 
         Args:
             room_id (str)
-            share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
 
         def _add_users_who_share_room_txn(txn):
-
-            if share_private:
-                tbl = "users_who_share_private_rooms"
-            else:
-                tbl = "users_who_share_public_rooms"
-
             self._simple_upsert_many_txn(
                 txn,
-                table=tbl,
+                table="users_who_share_private_rooms",
                 key_names=["user_id", "other_user_id", "room_id"],
                 key_values=[
                     (user_id, other_user_id, room_id)
@@ -339,15 +324,35 @@ class UserDirectoryStore(SQLBaseStore):
                 value_names=(),
                 value_values=None,
             )
-            for user_id, other_user_id in user_id_tuples:
-                txn.call_after(
-                    self.get_users_who_share_room_from_dir.invalidate, (user_id,)
-                )
 
         return self.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
 
+    def add_users_in_public_rooms(self, room_id, user_ids):
+        """Insert entries into the users_who_share_private_rooms table. The first
+        user should be a local user.
+
+        Args:
+            room_id (str)
+            user_ids (list[str])
+        """
+
+        def _add_users_in_public_rooms_txn(txn):
+
+            self._simple_upsert_many_txn(
+                txn,
+                table="users_in_public_rooms",
+                key_names=["user_id", "room_id"],
+                key_values=[(user_id, room_id) for user_id in user_ids],
+                value_names=(),
+                value_values=None,
+            )
+
+        return self.runInteraction(
+            "add_users_in_public_rooms", _add_users_in_public_rooms_txn
+        )
+
     def remove_user_who_share_room(self, user_id, room_id):
         """
         Deletes entries in the users_who_share_*_rooms table. The first
@@ -371,25 +376,18 @@ class UserDirectoryStore(SQLBaseStore):
             )
             self._simple_delete_txn(
                 txn,
-                table="users_who_share_public_rooms",
+                table="users_in_public_rooms",
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
-            self._simple_delete_txn(
-                txn,
-                table="users_who_share_public_rooms",
-                keyvalues={"other_user_id": user_id, "room_id": room_id},
-            )
-            txn.call_after(
-                self.get_users_who_share_room_from_dir.invalidate, (user_id,)
-            )
 
         return self.runInteraction(
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )
 
-    @cachedInlineCallbacks(max_entries=500000, iterable=True)
-    def get_users_who_share_room_from_dir(self, user_id):
-        """Returns the set of users who share a room with `user_id`
+    @defer.inlineCallbacks
+    def get_user_dir_rooms_user_is_in(self, user_id):
+        """
+        Returns the rooms that a user is in.
 
         Args:
             user_id(str): Must be a local user
@@ -400,23 +398,19 @@ class UserDirectoryStore(SQLBaseStore):
         rows = yield self._simple_select_onecol(
             table="users_who_share_private_rooms",
             keyvalues={"user_id": user_id},
-            retcol="other_user_id",
-            desc="get_users_who_share_room_with_user",
+            retcol="room_id",
+            desc="get_rooms_user_is_in",
         )
 
         pub_rows = yield self._simple_select_onecol(
-            table="users_who_share_public_rooms",
+            table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
-            retcol="other_user_id",
-            desc="get_users_who_share_room_with_user",
+            retcol="room_id",
+            desc="get_rooms_user_is_in",
         )
 
         users = set(pub_rows)
         users.update(rows)
-
-        # Remove the user themselves from this list.
-        users.discard(user_id)
-
         defer.returnValue(list(users))
 
     @defer.inlineCallbacks
@@ -452,10 +446,9 @@ class UserDirectoryStore(SQLBaseStore):
         def _delete_all_from_user_dir_txn(txn):
             txn.execute("DELETE FROM user_directory")
             txn.execute("DELETE FROM user_directory_search")
-            txn.execute("DELETE FROM users_who_share_public_rooms")
+            txn.execute("DELETE FROM users_in_public_rooms")
             txn.execute("DELETE FROM users_who_share_private_rooms")
             txn.call_after(self.get_user_in_directory.invalidate_all)
-            txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
 
         return self.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
@@ -560,23 +553,19 @@ class UserDirectoryStore(SQLBaseStore):
         """
 
         if self.hs.config.user_directory_search_all_users:
-            # make s.user_id null to keep the ordering algorithm happy
-            join_clause = """
-                CROSS JOIN (SELECT NULL as user_id) AS s
-            """
             join_args = ()
             where_clause = "1=1"
         else:
-            join_clause = """
-                LEFT JOIN (
-                    SELECT other_user_id AS user_id FROM users_who_share_public_rooms
-                    UNION
-                    SELECT other_user_id AS user_id FROM users_who_share_private_rooms
-                    WHERE user_id = ?
-                ) AS p USING (user_id)
-            """
             join_args = (user_id,)
-            where_clause = "p.user_id IS NOT NULL"
+            where_clause = """
+                (
+                    EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id)
+                    OR EXISTS (
+                        SELECT 1 FROM users_who_share_private_rooms
+                        WHERE user_id = ? AND other_user_id = t.user_id
+                    )
+                )
+            """
 
         if isinstance(self.database_engine, PostgresEngine):
             full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -588,9 +577,8 @@ class UserDirectoryStore(SQLBaseStore):
             # search: (domain, _, display name, localpart)
             sql = """
                 SELECT d.user_id AS user_id, display_name, avatar_url
-                FROM user_directory_search
+                FROM user_directory_search as t
                 INNER JOIN user_directory AS d USING (user_id)
-                %s
                 WHERE
                     %s
                     AND vector @@ to_tsquery('english', ?)
@@ -617,7 +605,6 @@ class UserDirectoryStore(SQLBaseStore):
                     avatar_url IS NULL
                 LIMIT ?
             """ % (
-                join_clause,
                 where_clause,
             )
             args = join_args + (full_query, exact_query, prefix_query, limit + 1)
@@ -626,9 +613,8 @@ class UserDirectoryStore(SQLBaseStore):
 
             sql = """
                 SELECT d.user_id AS user_id, display_name, avatar_url
-                FROM user_directory_search
+                FROM user_directory_search as t
                 INNER JOIN user_directory AS d USING (user_id)
-                %s
                 WHERE
                     %s
                     AND value MATCH ?
@@ -638,7 +624,6 @@ class UserDirectoryStore(SQLBaseStore):
                     avatar_url IS NULL
                 LIMIT ?
             """ % (
-                join_clause,
                 where_clause,
             )
             args = join_args + (search_query, limit + 1)
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index a16a2dc67b..114807efc1 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -114,13 +114,13 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.helper.join(room, user=u2, tok=u2_token)
 
         # Check we have populated the database correctly.
-        shares_public = self.get_users_who_share_public_rooms()
         shares_private = self.get_users_who_share_private_rooms()
+        public_users = self.get_users_in_public_rooms()
 
-        self.assertEqual(shares_public, [])
         self.assertEqual(
             self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)])
         )
+        self.assertEqual(public_users, [])
 
         # We get one search result when searching for user2 by user1.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -138,11 +138,11 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.helper.leave(room, user=u2, tok=u2_token)
 
         # Check we have removed the values.
-        shares_public = self.get_users_who_share_public_rooms()
         shares_private = self.get_users_who_share_private_rooms()
+        public_users = self.get_users_in_public_rooms()
 
-        self.assertEqual(shares_public, [])
         self.assertEqual(self._compress_shared(shares_private), set())
+        self.assertEqual(public_users, [])
 
         # User1 now gets no search results for any of the other users.
         s = self.get_success(self.handler.search_users(u1, "user2", 10))
@@ -160,14 +160,18 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
             r.add((i["user_id"], i["other_user_id"], i["room_id"]))
         return r
 
-    def get_users_who_share_public_rooms(self):
-        return self.get_success(
+    def get_users_in_public_rooms(self):
+        r = self.get_success(
             self.store._simple_select_list(
-                "users_who_share_public_rooms",
+                "users_in_public_rooms",
                 None,
-                ["user_id", "other_user_id", "room_id"],
+                ("user_id", "room_id"),
             )
         )
+        retval = []
+        for i in r:
+            retval.append((i["user_id"], i["room_id"]))
+        return retval
 
     def get_users_who_share_private_rooms(self):
         return self.get_success(
@@ -200,11 +204,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
         self.get_success(self.store.update_user_directory_stream_pos(None))
         self.get_success(self.store.delete_all_from_user_dir())
 
-        shares_public = self.get_users_who_share_public_rooms()
         shares_private = self.get_users_who_share_private_rooms()
+        public_users = self.get_users_in_public_rooms()
 
+        # Nothing updated yet
         self.assertEqual(shares_private, [])
-        self.assertEqual(shares_public, [])
+        self.assertEqual(public_users, [])
 
         # Reset the handled users caches
         self.handler.initially_handled_users = set()
@@ -219,12 +224,12 @@ class UserDirectoryTestCase(unittest.HomeserverTestCase):
 
         self.get_success(d)
 
-        shares_public = self.get_users_who_share_public_rooms()
         shares_private = self.get_users_who_share_private_rooms()
+        public_users = self.get_users_in_public_rooms()
 
-        # User 1 and User 2 share public rooms
+        # User 1 and User 2 are in the same public room
         self.assertEqual(
-            self._compress_shared(shares_public), set([(u1, u2, room), (u2, u1, room)])
+            set(public_users), set([(u1, room), (u2, room)])
         )
 
         # User 1 and User 3 share private rooms
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index a2a652a235..512d76e7a3 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -41,8 +41,8 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
                 BOBBY: ProfileInfo(None, "bobby"),
             },
         )
-        yield self.store.add_users_who_share_room(
-            "!room:id", False, ((ALICE, BOB), (BOB, ALICE))
+        yield self.store.add_users_in_public_rooms(
+            "!room:id", (ALICE, BOB)
         )
 
     @defer.inlineCallbacks
diff --git a/tests/utils.py b/tests/utils.py
index 9c8dc9dbce..03b5a05b22 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -331,6 +331,8 @@ def setup_test_homeserver(
                 cleanup_func(cleanup)
 
         hs.setup()
+        if homeserverToUse.__name__ == "TestHomeServer":
+            hs.setup_master()
     else:
         hs = homeserverToUse(
             name,