summary refs log tree commit diff
diff options
context:
space:
mode:
authorAmber Brown <hawkowl@atleastfornow.net>2019-03-07 01:22:53 -0800
committerGitHub <noreply@github.com>2019-03-07 01:22:53 -0800
commitf6135d06cf94fdef9942051f43872c7518511e74 (patch)
treee33eda50d2942aee1be2374122465979eb100375
parentReword the sample config header to be less scary (#4801) (diff)
downloadsynapse-f6135d06cf94fdef9942051f43872c7518511e74.tar.xz
Rewrite userdir to be faster (#4537)
-rw-r--r--changelog.d/4537.feature1
-rw-r--r--synapse/handlers/user_directory.py222
-rw-r--r--synapse/storage/schema/delta/53/user_share.sql47
-rw-r--r--synapse/storage/user_directory.py271
-rw-r--r--tests/handlers/test_user_directory.py266
-rw-r--r--tests/storage/test_user_directory.py2
6 files changed, 400 insertions, 409 deletions
diff --git a/changelog.d/4537.feature b/changelog.d/4537.feature
new file mode 100644
index 0000000000..8f792b8890
--- /dev/null
+++ b/changelog.d/4537.feature
@@ -0,0 +1 @@
+The user directory has been rewritten to make it faster, with less chance of falling behind on a large server.
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index 283c6c1b81..c21da8343a 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -15,7 +15,7 @@
 
 import logging
 
-from six import iteritems
+from six import iteritems, iterkeys
 
 from twisted.internet import defer
 
@@ -63,10 +63,6 @@ class UserDirectoryHandler(object):
         # When start up for the first time we need to populate the user_directory.
         # This is a set of user_id's we've inserted already
         self.initially_handled_users = set()
-        self.initially_handled_users_in_public = set()
-
-        self.initially_handled_users_share = set()
-        self.initially_handled_users_share_private_room = set()
 
         # The current position in the current_state_delta stream
         self.pos = None
@@ -140,7 +136,6 @@ class UserDirectoryHandler(object):
         # FIXME(#3714): We should probably do this in the same worker as all
         # the other changes.
         yield self.store.remove_from_user_dir(user_id)
-        yield self.store.remove_from_user_in_public_room(user_id)
 
     @defer.inlineCallbacks
     def _unsafe_process(self):
@@ -215,15 +210,13 @@ class UserDirectoryHandler(object):
             logger.info("Processed all users")
 
         self.initially_handled_users = None
-        self.initially_handled_users_in_public = None
-        self.initially_handled_users_share = None
-        self.initially_handled_users_share_private_room = None
 
         yield self.store.update_user_directory_stream_pos(new_pos)
 
     @defer.inlineCallbacks
     def _handle_initial_room(self, room_id):
-        """Called when we initially fill out user_directory one room at a time
+        """
+        Called when we initially fill out user_directory one room at a time
         """
         is_in_room = yield self.store.is_host_joined(room_id, self.server_name)
         if not is_in_room:
@@ -238,23 +231,15 @@ class UserDirectoryHandler(object):
         unhandled_users = user_ids - self.initially_handled_users
 
         yield self.store.add_profiles_to_user_dir(
-            room_id,
             {user_id: users_with_profile[user_id] for user_id in unhandled_users},
         )
 
         self.initially_handled_users |= unhandled_users
 
-        if is_public:
-            yield self.store.add_users_to_public_room(
-                room_id, user_ids=user_ids - self.initially_handled_users_in_public
-            )
-            self.initially_handled_users_in_public |= user_ids
-
         # We now go and figure out the new users who share rooms with user entries
         # We sleep aggressively here as otherwise it can starve resources.
         # We also batch up inserts/updates, but try to avoid too many at once.
         to_insert = set()
-        to_update = set()
         count = 0
         for user_id in user_ids:
             if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
@@ -277,21 +262,7 @@ class UserDirectoryHandler(object):
                 count += 1
 
                 user_set = (user_id, other_user_id)
-
-                if user_set in self.initially_handled_users_share_private_room:
-                    continue
-
-                if user_set in self.initially_handled_users_share:
-                    if is_public:
-                        continue
-                    to_update.add(user_set)
-                else:
-                    to_insert.add(user_set)
-
-                if is_public:
-                    self.initially_handled_users_share.add(user_set)
-                else:
-                    self.initially_handled_users_share_private_room.add(user_set)
+                to_insert.add(user_set)
 
                 if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
                     yield self.store.add_users_who_share_room(
@@ -299,22 +270,10 @@ class UserDirectoryHandler(object):
                     )
                     to_insert.clear()
 
-                if len(to_update) > self.INITIAL_ROOM_BATCH_SIZE:
-                    yield self.store.update_users_who_share_room(
-                        room_id, not is_public, to_update
-                    )
-                    to_update.clear()
-
         if to_insert:
             yield self.store.add_users_who_share_room(room_id, not is_public, to_insert)
             to_insert.clear()
 
-        if to_update:
-            yield self.store.update_users_who_share_room(
-                room_id, not is_public, to_update
-            )
-            to_update.clear()
-
     @defer.inlineCallbacks
     def _handle_deltas(self, deltas):
         """Called with the state deltas to process
@@ -356,6 +315,7 @@ class UserDirectoryHandler(object):
                         user_ids = yield self.store.get_users_in_dir_due_to_room(
                             room_id
                         )
+
                         for user_id in user_ids:
                             yield self._handle_remove_user(room_id, user_id)
                         return
@@ -436,14 +396,20 @@ class UserDirectoryHandler(object):
             # ignore the change
             return
 
-        if change:
-            users_with_profile = yield self.state.get_current_user_in_room(room_id)
-            for user_id, profile in iteritems(users_with_profile):
-                yield self._handle_new_user(room_id, user_id, profile)
-        else:
-            users = yield self.store.get_users_in_public_due_to_room(room_id)
-            for user_id in users:
-                yield self._handle_remove_user(room_id, user_id)
+        users_with_profile = yield self.state.get_current_user_in_room(room_id)
+
+        # Remove every user from the sharing tables for that room.
+        for user_id in iterkeys(users_with_profile):
+            yield self.store.remove_user_who_share_room(user_id, room_id)
+
+        # Then, re-add them to the tables.
+        # NOTE: this is not the most efficient method, as handle_new_user sets
+        # up local_user -> other_user and other_user_whos_local -> local_user,
+        # which when ran over an entire room, will result in the same values
+        # being added multiple times. The batching upserts shouldn't make this
+        # too bad, though.
+        for user_id, profile in iteritems(users_with_profile):
+            yield self._handle_new_user(room_id, user_id, profile)
 
     @defer.inlineCallbacks
     def _handle_local_user(self, user_id):
@@ -457,7 +423,7 @@ class UserDirectoryHandler(object):
 
         row = yield self.store.get_user_in_directory(user_id)
         if not row:
-            yield self.store.add_profiles_to_user_dir(None, {user_id: profile})
+            yield self.store.add_profiles_to_user_dir({user_id: profile})
 
     @defer.inlineCallbacks
     def _handle_new_user(self, room_id, user_id, profile):
@@ -471,55 +437,27 @@ class UserDirectoryHandler(object):
 
         row = yield self.store.get_user_in_directory(user_id)
         if not row:
-            yield self.store.add_profiles_to_user_dir(room_id, {user_id: profile})
+            yield self.store.add_profiles_to_user_dir({user_id: profile})
 
         is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
             room_id
         )
-
-        if is_public:
-            row = yield self.store.get_user_in_public_room(user_id)
-            if not row:
-                yield self.store.add_users_to_public_room(room_id, [user_id])
-        else:
-            logger.debug("Not adding new user to public dir, %r", user_id)
-
-        # Now we update users who share rooms with users. We do this by getting
-        # all the current users in the room and seeing which aren't already
-        # marked in the database as sharing with `user_id`
-
+        # Now we update users who share rooms with users.
         users_with_profile = yield self.state.get_current_user_in_room(room_id)
 
         to_insert = set()
-        to_update = set()
-
-        is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
 
         # First, if they're our user then we need to update for every user
-        if self.is_mine_id(user_id) and not is_appservice:
-            # Returns a map of other_user_id -> shared_private. We only need
-            # to update mappings if for users that either don't share a room
-            # already (aren't in the map) or, if the room is private, those that
-            # only share a public room.
-            user_ids_shared = yield self.store.get_users_who_share_room_from_dir(
-                user_id
-            )
+        if self.is_mine_id(user_id):
 
-            for other_user_id in users_with_profile:
-                if user_id == other_user_id:
-                    continue
+            is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
+
+            # We don't care about appservice users.
+            if not is_appservice:
+                for other_user_id in users_with_profile:
+                    if user_id == other_user_id:
+                        continue
 
-                shared_is_private = user_ids_shared.get(other_user_id)
-                if shared_is_private is True:
-                    # We've already marked in the database they share a private room
-                    continue
-                elif shared_is_private is False:
-                    # They already share a public room, so only update if this is
-                    # a private room
-                    if not is_public:
-                        to_update.add((user_id, other_user_id))
-                elif shared_is_private is None:
-                    # This is the first time they both share a room
                     to_insert.add((user_id, other_user_id))
 
         # Next we need to update for every local user in the room
@@ -531,29 +469,11 @@ class UserDirectoryHandler(object):
                 other_user_id
             )
             if self.is_mine_id(other_user_id) and not is_appservice:
-                shared_is_private = yield self.store.get_if_users_share_a_room(
-                    other_user_id, user_id
-                )
-                if shared_is_private is True:
-                    # We've already marked in the database they share a private room
-                    continue
-                elif shared_is_private is False:
-                    # They already share a public room, so only update if this is
-                    # a private room
-                    if not is_public:
-                        to_update.add((other_user_id, user_id))
-                elif shared_is_private is None:
-                    # This is the first time they both share a room
-                    to_insert.add((other_user_id, user_id))
+                to_insert.add((other_user_id, user_id))
 
         if to_insert:
             yield self.store.add_users_who_share_room(room_id, not is_public, to_insert)
 
-        if to_update:
-            yield self.store.update_users_who_share_room(
-                room_id, not is_public, to_update
-            )
-
     @defer.inlineCallbacks
     def _handle_remove_user(self, room_id, user_id):
         """Called when we might need to remove user to directory
@@ -562,84 +482,16 @@ class UserDirectoryHandler(object):
             room_id (str): room_id that user left or stopped being public that
             user_id (str)
         """
-        logger.debug("Maybe removing user %r", user_id)
-
-        row = yield self.store.get_user_in_directory(user_id)
-        update_user_dir = row and row["room_id"] == room_id
-
-        row = yield self.store.get_user_in_public_room(user_id)
-        update_user_in_public = row and row["room_id"] == room_id
-
-        if update_user_in_public or update_user_dir:
-            # XXX: Make this faster?
-            rooms = yield self.store.get_rooms_for_user(user_id)
-            for j_room_id in rooms:
-                if not update_user_in_public and not update_user_dir:
-                    break
-
-                is_in_room = yield self.store.is_host_joined(
-                    j_room_id, self.server_name
-                )
-
-                if not is_in_room:
-                    continue
-
-                if update_user_dir:
-                    update_user_dir = False
-                    yield self.store.update_user_in_user_dir(user_id, j_room_id)
+        logger.debug("Removing user %r", user_id)
 
-                is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
-                    j_room_id
-                )
+        # Remove user from sharing tables
+        yield self.store.remove_user_who_share_room(user_id, room_id)
 
-                if update_user_in_public and is_public:
-                    yield self.store.update_user_in_public_user_list(user_id, j_room_id)
-                    update_user_in_public = False
+        # Are they still in a room with members? If not, remove them entirely.
+        users_in_room_with = yield self.store.get_users_who_share_room_from_dir(user_id)
 
-        if update_user_dir:
+        if len(users_in_room_with) == 0:
             yield self.store.remove_from_user_dir(user_id)
-        elif update_user_in_public:
-            yield self.store.remove_from_user_in_public_room(user_id)
-
-        # Now handle users_who_share_rooms.
-
-        # Get a list of user tuples that were in the DB due to this room and
-        # users (this includes tuples where the other user matches `user_id`)
-        user_tuples = yield self.store.get_users_in_share_dir_with_room_id(
-            user_id, room_id
-        )
-
-        for user_id, other_user_id in user_tuples:
-            # For each user tuple get a list of rooms that they still share,
-            # trying to find a private room, and update the entry in the DB
-            rooms = yield self.store.get_rooms_in_common_for_users(
-                user_id, other_user_id
-            )
-
-            # If they dont share a room anymore, remove the mapping
-            if not rooms:
-                yield self.store.remove_user_who_share_room(user_id, other_user_id)
-                continue
-
-            found_public_share = None
-            for j_room_id in rooms:
-                is_public = yield self.store.is_room_world_readable_or_publicly_joinable(
-                    j_room_id
-                )
-
-                if is_public:
-                    found_public_share = j_room_id
-                else:
-                    found_public_share = None
-                    yield self.store.update_users_who_share_room(
-                        room_id, not is_public, [(user_id, other_user_id)]
-                    )
-                    break
-
-            if found_public_share:
-                yield self.store.update_users_who_share_room(
-                    room_id, not is_public, [(user_id, other_user_id)]
-                )
 
     @defer.inlineCallbacks
     def _handle_profile_change(self, user_id, room_id, prev_event_id, event_id):
diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/schema/delta/53/user_share.sql
new file mode 100644
index 0000000000..14424ded0c
--- /dev/null
+++ b/synapse/storage/schema/delta/53/user_share.sql
@@ -0,0 +1,47 @@
+/* Copyright 2017 Vector Creations Ltd, 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- Old disused version of the tables below.
+DROP TABLE IF EXISTS users_who_share_rooms;
+
+-- This is no longer used because it's duplicated by the users_who_share_public_rooms
+DROP TABLE IF EXISTS users_in_public_rooms;
+
+-- Tables keeping track of what users share rooms. This is a map of local users
+-- to local or remote users, per room. Remote users cannot be in the user_id
+-- column, only the other_user_id column. There are two tables, one for public
+-- rooms and those for private rooms.
+CREATE TABLE IF NOT EXISTS users_who_share_public_rooms (
+    user_id TEXT NOT NULL,
+    other_user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL
+);
+
+CREATE TABLE IF NOT EXISTS users_who_share_private_rooms (
+    user_id TEXT NOT NULL,
+    other_user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX users_who_share_public_rooms_u_idx ON users_who_share_public_rooms(user_id, other_user_id, room_id);
+CREATE INDEX users_who_share_public_rooms_r_idx ON users_who_share_public_rooms(room_id);
+CREATE INDEX users_who_share_public_rooms_o_idx ON users_who_share_public_rooms(other_user_id);
+
+CREATE UNIQUE INDEX users_who_share_private_rooms_u_idx ON users_who_share_private_rooms(user_id, other_user_id, room_id);
+CREATE INDEX users_who_share_private_rooms_r_idx ON users_who_share_private_rooms(room_id);
+CREATE INDEX users_who_share_private_rooms_o_idx ON users_who_share_private_rooms(other_user_id);
+
+-- Make sure that we populate the tables initially by resetting the stream ID
+UPDATE user_directory_stream_pos SET stream_id = NULL;
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index fea866c043..2317d22ed6 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -63,31 +63,14 @@ class UserDirectoryStore(SQLBaseStore):
 
         defer.returnValue(False)
 
-    @defer.inlineCallbacks
-    def add_users_to_public_room(self, room_id, user_ids):
-        """Add user to the list of users in public rooms
-
-        Args:
-            room_id (str): A room_id that all users are in that is world_readable
-                or publically joinable
-            user_ids (list(str)): Users to add
-        """
-        yield self._simple_insert_many(
-            table="users_in_public_rooms",
-            values=[{"user_id": user_id, "room_id": room_id} for user_id in user_ids],
-            desc="add_users_to_public_room",
-        )
-        for user_id in user_ids:
-            self.get_user_in_public_room.invalidate((user_id,))
-
-    def add_profiles_to_user_dir(self, room_id, users_with_profile):
+    def add_profiles_to_user_dir(self, users_with_profile):
         """Add profiles to the user directory
 
         Args:
-            room_id (str): A room_id that all users are joined to
             users_with_profile (dict): Users to add to directory in the form of
                 mapping of user_id -> ProfileInfo
         """
+
         if isinstance(self.database_engine, PostgresEngine):
             # We weight the loclpart most highly, then display name and finally
             # server name
@@ -113,7 +96,7 @@ class UserDirectoryStore(SQLBaseStore):
                 INSERT INTO user_directory_search(user_id, value)
                 VALUES (?,?)
             """
-            args = (
+            args = tuple(
                 (
                     user_id,
                     "%s %s" % (user_id, p.display_name) if p.display_name else user_id,
@@ -132,7 +115,7 @@ class UserDirectoryStore(SQLBaseStore):
                 values=[
                     {
                         "user_id": user_id,
-                        "room_id": room_id,
+                        "room_id": None,
                         "display_name": profile.display_name,
                         "avatar_url": profile.avatar_url,
                     }
@@ -250,16 +233,6 @@ class UserDirectoryStore(SQLBaseStore):
             "update_profile_in_user_dir", _update_profile_in_user_dir_txn
         )
 
-    @defer.inlineCallbacks
-    def update_user_in_public_user_list(self, user_id, room_id):
-        yield self._simple_update_one(
-            table="users_in_public_rooms",
-            keyvalues={"user_id": user_id},
-            updatevalues={"room_id": room_id},
-            desc="update_user_in_public_user_list",
-        )
-        self.get_user_in_public_room.invalidate((user_id,))
-
     def remove_from_user_dir(self, user_id):
         def _remove_from_user_dir_txn(txn):
             self._simple_delete_txn(
@@ -269,62 +242,50 @@ class UserDirectoryStore(SQLBaseStore):
                 txn, table="user_directory_search", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
-                txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
+                txn,
+                table="users_who_share_public_rooms",
+                keyvalues={"user_id": user_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="users_who_share_public_rooms",
+                keyvalues={"other_user_id": user_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="users_who_share_private_rooms",
+                keyvalues={"user_id": user_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="users_who_share_private_rooms",
+                keyvalues={"other_user_id": user_id},
             )
             txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
-            txn.call_after(self.get_user_in_public_room.invalidate, (user_id,))
 
         return self.runInteraction("remove_from_user_dir", _remove_from_user_dir_txn)
 
     @defer.inlineCallbacks
-    def remove_from_user_in_public_room(self, user_id):
-        yield self._simple_delete(
-            table="users_in_public_rooms",
-            keyvalues={"user_id": user_id},
-            desc="remove_from_user_in_public_room",
-        )
-        self.get_user_in_public_room.invalidate((user_id,))
-
-    def get_users_in_public_due_to_room(self, room_id):
-        """Get all user_ids that are in the room directory because they're
-        in the given room_id
-        """
-        return self._simple_select_onecol(
-            table="users_in_public_rooms",
-            keyvalues={"room_id": room_id},
-            retcol="user_id",
-            desc="get_users_in_public_due_to_room",
-        )
-
-    @defer.inlineCallbacks
     def get_users_in_dir_due_to_room(self, room_id):
         """Get all user_ids that are in the room directory because they're
         in the given room_id
         """
-        user_ids_dir = yield self._simple_select_onecol(
-            table="user_directory",
-            keyvalues={"room_id": room_id},
-            retcol="user_id",
-            desc="get_users_in_dir_due_to_room",
-        )
-
-        user_ids_pub = yield self._simple_select_onecol(
-            table="users_in_public_rooms",
+        user_ids_share_pub = yield self._simple_select_onecol(
+            table="users_who_share_public_rooms",
             keyvalues={"room_id": room_id},
-            retcol="user_id",
+            retcol="other_user_id",
             desc="get_users_in_dir_due_to_room",
         )
 
-        user_ids_share = yield self._simple_select_onecol(
-            table="users_who_share_rooms",
+        user_ids_share_priv = yield self._simple_select_onecol(
+            table="users_who_share_private_rooms",
             keyvalues={"room_id": room_id},
-            retcol="user_id",
+            retcol="other_user_id",
             desc="get_users_in_dir_due_to_room",
         )
 
-        user_ids = set(user_ids_dir)
-        user_ids.update(user_ids_pub)
-        user_ids.update(user_ids_share)
+        user_ids = set(user_ids_share_pub)
+        user_ids.update(user_ids_share_priv)
 
         defer.returnValue(user_ids)
 
@@ -351,7 +312,7 @@ class UserDirectoryStore(SQLBaseStore):
         defer.returnValue([name for name, in rows])
 
     def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
-        """Insert entries into the users_who_share_rooms table. The first
+        """Insert entries into the users_who_share_*_rooms table. The first
         user should be a local user.
 
         Args:
@@ -361,109 +322,71 @@ class UserDirectoryStore(SQLBaseStore):
         """
 
         def _add_users_who_share_room_txn(txn):
-            self._simple_insert_many_txn(
+
+            if share_private:
+                tbl = "users_who_share_private_rooms"
+            else:
+                tbl = "users_who_share_public_rooms"
+
+            self._simple_upsert_many_txn(
                 txn,
-                table="users_who_share_rooms",
-                values=[
-                    {
-                        "user_id": user_id,
-                        "other_user_id": other_user_id,
-                        "room_id": room_id,
-                        "share_private": share_private,
-                    }
+                table=tbl,
+                key_names=["user_id", "other_user_id", "room_id"],
+                key_values=[
+                    (user_id, other_user_id, room_id)
                     for user_id, other_user_id in user_id_tuples
                 ],
+                value_names=(),
+                value_values=None,
             )
             for user_id, other_user_id in user_id_tuples:
                 txn.call_after(
                     self.get_users_who_share_room_from_dir.invalidate, (user_id,)
                 )
-                txn.call_after(
-                    self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
-                )
 
         return self.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
 
-    def update_users_who_share_room(self, room_id, share_private, user_id_sets):
-        """Updates entries in the users_who_share_rooms table. The first
-        user should be a local user.
-
-        Args:
-            room_id (str)
-            share_private (bool): Is the room private
-            user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
+    def remove_user_who_share_room(self, user_id, room_id):
         """
-
-        def _update_users_who_share_room_txn(txn):
-            sql = """
-                UPDATE users_who_share_rooms
-                SET room_id = ?, share_private = ?
-                WHERE user_id = ? AND other_user_id = ?
-            """
-            txn.executemany(
-                sql, ((room_id, share_private, uid, oid) for uid, oid in user_id_sets)
-            )
-            for user_id, other_user_id in user_id_sets:
-                txn.call_after(
-                    self.get_users_who_share_room_from_dir.invalidate, (user_id,)
-                )
-                txn.call_after(
-                    self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
-                )
-
-        return self.runInteraction(
-            "update_users_who_share_room", _update_users_who_share_room_txn
-        )
-
-    def remove_user_who_share_room(self, user_id, other_user_id):
-        """Deletes entries in the users_who_share_rooms table. The first
+        Deletes entries in the users_who_share_*_rooms table. The first
         user should be a local user.
 
         Args:
+            user_id (str)
             room_id (str)
-            share_private (bool): Is the room private
-            user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
 
         def _remove_user_who_share_room_txn(txn):
             self._simple_delete_txn(
                 txn,
-                table="users_who_share_rooms",
-                keyvalues={"user_id": user_id, "other_user_id": other_user_id},
+                table="users_who_share_private_rooms",
+                keyvalues={"user_id": user_id, "room_id": room_id},
             )
-            txn.call_after(
-                self.get_users_who_share_room_from_dir.invalidate, (user_id,)
+            self._simple_delete_txn(
+                txn,
+                table="users_who_share_private_rooms",
+                keyvalues={"other_user_id": user_id, "room_id": room_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="users_who_share_public_rooms",
+                keyvalues={"user_id": user_id, "room_id": room_id},
+            )
+            self._simple_delete_txn(
+                txn,
+                table="users_who_share_public_rooms",
+                keyvalues={"other_user_id": user_id, "room_id": room_id},
             )
             txn.call_after(
-                self.get_if_users_share_a_room.invalidate, (user_id, other_user_id)
+                self.get_users_who_share_room_from_dir.invalidate, (user_id,)
             )
 
         return self.runInteraction(
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )
 
-    @cached(max_entries=500000)
-    def get_if_users_share_a_room(self, user_id, other_user_id):
-        """Gets if users share a room.
-
-        Args:
-            user_id (str): Must be a local user_id
-            other_user_id (str)
-
-        Returns:
-            bool|None: None if they don't share a room, otherwise whether they
-            share a private room or not.
-        """
-        return self._simple_select_one_onecol(
-            table="users_who_share_rooms",
-            keyvalues={"user_id": user_id, "other_user_id": other_user_id},
-            retcol="share_private",
-            allow_none=True,
-            desc="get_if_users_share_a_room",
-        )
-
     @cachedInlineCallbacks(max_entries=500000, iterable=True)
     def get_users_who_share_room_from_dir(self, user_id):
         """Returns the set of users who share a room with `user_id`
@@ -472,32 +395,29 @@ class UserDirectoryStore(SQLBaseStore):
             user_id(str): Must be a local user
 
         Returns:
-            dict: user_id -> share_private mapping
+            list: user_id
         """
-        rows = yield self._simple_select_list(
-            table="users_who_share_rooms",
+        rows = yield self._simple_select_onecol(
+            table="users_who_share_private_rooms",
+            keyvalues={"user_id": user_id},
+            retcol="other_user_id",
+            desc="get_users_who_share_room_with_user",
+        )
+
+        pub_rows = yield self._simple_select_onecol(
+            table="users_who_share_public_rooms",
             keyvalues={"user_id": user_id},
-            retcols=("other_user_id", "share_private"),
+            retcol="other_user_id",
             desc="get_users_who_share_room_with_user",
         )
 
-        defer.returnValue({row["other_user_id"]: row["share_private"] for row in rows})
+        users = set(pub_rows)
+        users.update(rows)
 
-    def get_users_in_share_dir_with_room_id(self, user_id, room_id):
-        """Get all user tuples that are in the users_who_share_rooms due to the
-        given room_id.
+        # Remove the user themselves from this list.
+        users.discard(user_id)
 
-        Returns:
-            [(user_id, other_user_id)]: where one of the two will match the given
-            user_id.
-        """
-        sql = """
-            SELECT user_id, other_user_id FROM users_who_share_rooms
-            WHERE room_id = ? AND (user_id = ? OR other_user_id = ?)
-        """
-        return self._execute(
-            "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id
-        )
+        defer.returnValue(list(users))
 
     @defer.inlineCallbacks
     def get_rooms_in_common_for_users(self, user_id, other_user_id):
@@ -532,12 +452,10 @@ class UserDirectoryStore(SQLBaseStore):
         def _delete_all_from_user_dir_txn(txn):
             txn.execute("DELETE FROM user_directory")
             txn.execute("DELETE FROM user_directory_search")
-            txn.execute("DELETE FROM users_in_public_rooms")
-            txn.execute("DELETE FROM users_who_share_rooms")
+            txn.execute("DELETE FROM users_who_share_public_rooms")
+            txn.execute("DELETE FROM users_who_share_private_rooms")
             txn.call_after(self.get_user_in_directory.invalidate_all)
-            txn.call_after(self.get_user_in_public_room.invalidate_all)
             txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
-            txn.call_after(self.get_if_users_share_a_room.invalidate_all)
 
         return self.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
@@ -548,21 +466,11 @@ class UserDirectoryStore(SQLBaseStore):
         return self._simple_select_one(
             table="user_directory",
             keyvalues={"user_id": user_id},
-            retcols=("room_id", "display_name", "avatar_url"),
+            retcols=("display_name", "avatar_url"),
             allow_none=True,
             desc="get_user_in_directory",
         )
 
-    @cached()
-    def get_user_in_public_room(self, user_id):
-        return self._simple_select_one(
-            table="users_in_public_rooms",
-            keyvalues={"user_id": user_id},
-            retcols=("room_id",),
-            allow_none=True,
-            desc="get_user_in_public_room",
-        )
-
     def get_user_directory_stream_pos(self):
         return self._simple_select_one_onecol(
             table="user_directory_stream_pos",
@@ -660,14 +568,15 @@ class UserDirectoryStore(SQLBaseStore):
             where_clause = "1=1"
         else:
             join_clause = """
-                LEFT JOIN users_in_public_rooms AS p USING (user_id)
                 LEFT JOIN (
-                    SELECT other_user_id AS user_id FROM users_who_share_rooms
-                    WHERE user_id = ? AND share_private
-                ) AS s USING (user_id)
+                    SELECT other_user_id AS user_id FROM users_who_share_public_rooms
+                    UNION
+                    SELECT other_user_id AS user_id FROM users_who_share_private_rooms
+                    WHERE user_id = ?
+                ) AS p USING (user_id)
             """
             join_args = (user_id,)
-            where_clause = "(s.user_id IS NOT NULL OR p.user_id IS NOT NULL)"
+            where_clause = "p.user_id IS NOT NULL"
 
         if isinstance(self.database_engine, PostgresEngine):
             full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -686,7 +595,7 @@ class UserDirectoryStore(SQLBaseStore):
                     %s
                     AND vector @@ to_tsquery('english', ?)
                 ORDER BY
-                    (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
+                    (CASE WHEN d.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END)
                     * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END)
                     * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END)
                     * (
diff --git a/tests/handlers/test_user_directory.py b/tests/handlers/test_user_directory.py
index 11f2bae698..a16a2dc67b 100644
--- a/tests/handlers/test_user_directory.py
+++ b/tests/handlers/test_user_directory.py
@@ -14,78 +14,262 @@
 # limitations under the License.
 from mock import Mock
 
-from twisted.internet import defer
-
 from synapse.api.constants import UserTypes
-from synapse.handlers.user_directory import UserDirectoryHandler
+from synapse.rest.client.v1 import admin, login, room
 from synapse.storage.roommember import ProfileInfo
 
 from tests import unittest
-from tests.utils import setup_test_homeserver
 
 
-class UserDirectoryHandlers(object):
-    def __init__(self, hs):
-        self.user_directory_handler = UserDirectoryHandler(hs)
+class UserDirectoryTestCase(unittest.HomeserverTestCase):
+    """
+    Tests the UserDirectoryHandler.
+    """
 
+    servlets = [
+        login.register_servlets,
+        admin.register_servlets,
+        room.register_servlets,
+    ]
 
-class UserDirectoryTestCase(unittest.TestCase):
-    """ Tests the UserDirectoryHandler. """
+    def make_homeserver(self, reactor, clock):
 
-    @defer.inlineCallbacks
-    def setUp(self):
-        hs = yield setup_test_homeserver(self.addCleanup)
-        self.store = hs.get_datastore()
-        hs.handlers = UserDirectoryHandlers(hs)
+        config = self.default_config()
+        config.update_user_directory = True
+        return self.setup_test_homeserver(config=config)
 
-        self.handler = hs.get_handlers().user_directory_handler
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.handler = hs.get_user_directory_handler()
 
-    @defer.inlineCallbacks
     def test_handle_local_profile_change_with_support_user(self):
         support_user_id = "@support:test"
-        yield self.store.register(
-            user_id=support_user_id,
-            token="123",
-            password_hash=None,
-            user_type=UserTypes.SUPPORT
+        self.get_success(
+            self.store.register(
+                user_id=support_user_id,
+                token="123",
+                password_hash=None,
+                user_type=UserTypes.SUPPORT,
+            )
         )
 
-        yield self.handler.handle_local_profile_change(support_user_id, None)
-        profile = yield self.store.get_user_in_directory(support_user_id)
+        self.get_success(
+            self.handler.handle_local_profile_change(support_user_id, None)
+        )
+        profile = self.get_success(self.store.get_user_in_directory(support_user_id))
         self.assertTrue(profile is None)
         display_name = 'display_name'
 
-        profile_info = ProfileInfo(
-            avatar_url='avatar_url',
-            display_name=display_name,
-        )
+        profile_info = ProfileInfo(avatar_url='avatar_url', display_name=display_name)
         regular_user_id = '@regular:test'
-        yield self.handler.handle_local_profile_change(regular_user_id, profile_info)
-        profile = yield self.store.get_user_in_directory(regular_user_id)
+        self.get_success(
+            self.handler.handle_local_profile_change(regular_user_id, profile_info)
+        )
+        profile = self.get_success(self.store.get_user_in_directory(regular_user_id))
         self.assertTrue(profile['display_name'] == display_name)
 
-    @defer.inlineCallbacks
     def test_handle_user_deactivated_support_user(self):
         s_user_id = "@support:test"
-        self.store.register(
-            user_id=s_user_id,
-            token="123",
-            password_hash=None,
-            user_type=UserTypes.SUPPORT
+        self.get_success(
+            self.store.register(
+                user_id=s_user_id,
+                token="123",
+                password_hash=None,
+                user_type=UserTypes.SUPPORT,
+            )
         )
 
         self.store.remove_from_user_dir = Mock()
         self.store.remove_from_user_in_public_room = Mock()
-        yield self.handler.handle_user_deactivated(s_user_id)
+        self.get_success(self.handler.handle_user_deactivated(s_user_id))
         self.store.remove_from_user_dir.not_called()
         self.store.remove_from_user_in_public_room.not_called()
 
-    @defer.inlineCallbacks
     def test_handle_user_deactivated_regular_user(self):
         r_user_id = "@regular:test"
-        self.store.register(user_id=r_user_id, token="123", password_hash=None)
+        self.get_success(
+            self.store.register(user_id=r_user_id, token="123", password_hash=None)
+        )
         self.store.remove_from_user_dir = Mock()
-        self.store.remove_from_user_in_public_room = Mock()
-        yield self.handler.handle_user_deactivated(r_user_id)
+        self.get_success(self.handler.handle_user_deactivated(r_user_id))
         self.store.remove_from_user_dir.called_once_with(r_user_id)
-        self.store.remove_from_user_in_public_room.assert_called_once_with(r_user_id)
+
+    def test_private_room(self):
+        """
+        A user can be searched for only by people that are either in a public
+        room, or that share a private chat.
+        """
+        u1 = self.register_user("user1", "pass")
+        u1_token = self.login(u1, "pass")
+        u2 = self.register_user("user2", "pass")
+        u2_token = self.login(u2, "pass")
+        u3 = self.register_user("user3", "pass")
+
+        # We do not add users to the directory until they join a room.
+        s = self.get_success(self.handler.search_users(u1, "user2", 10))
+        self.assertEqual(len(s["results"]), 0)
+
+        room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+        self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+        self.helper.join(room, user=u2, tok=u2_token)
+
+        # Check we have populated the database correctly.
+        shares_public = self.get_users_who_share_public_rooms()
+        shares_private = self.get_users_who_share_private_rooms()
+
+        self.assertEqual(shares_public, [])
+        self.assertEqual(
+            self._compress_shared(shares_private), set([(u1, u2, room), (u2, u1, room)])
+        )
+
+        # We get one search result when searching for user2 by user1.
+        s = self.get_success(self.handler.search_users(u1, "user2", 10))
+        self.assertEqual(len(s["results"]), 1)
+
+        # We get NO search results when searching for user2 by user3.
+        s = self.get_success(self.handler.search_users(u3, "user2", 10))
+        self.assertEqual(len(s["results"]), 0)
+
+        # We get NO search results when searching for user3 by user1.
+        s = self.get_success(self.handler.search_users(u1, "user3", 10))
+        self.assertEqual(len(s["results"]), 0)
+
+        # User 2 then leaves.
+        self.helper.leave(room, user=u2, tok=u2_token)
+
+        # Check we have removed the values.
+        shares_public = self.get_users_who_share_public_rooms()
+        shares_private = self.get_users_who_share_private_rooms()
+
+        self.assertEqual(shares_public, [])
+        self.assertEqual(self._compress_shared(shares_private), set())
+
+        # User1 now gets no search results for any of the other users.
+        s = self.get_success(self.handler.search_users(u1, "user2", 10))
+        self.assertEqual(len(s["results"]), 0)
+
+        s = self.get_success(self.handler.search_users(u1, "user3", 10))
+        self.assertEqual(len(s["results"]), 0)
+
+    def _compress_shared(self, shared):
+        """
+        Compress a list of users who share rooms dicts to a list of tuples.
+        """
+        r = set()
+        for i in shared:
+            r.add((i["user_id"], i["other_user_id"], i["room_id"]))
+        return r
+
+    def get_users_who_share_public_rooms(self):
+        return self.get_success(
+            self.store._simple_select_list(
+                "users_who_share_public_rooms",
+                None,
+                ["user_id", "other_user_id", "room_id"],
+            )
+        )
+
+    def get_users_who_share_private_rooms(self):
+        return self.get_success(
+            self.store._simple_select_list(
+                "users_who_share_private_rooms",
+                None,
+                ["user_id", "other_user_id", "room_id"],
+            )
+        )
+
+    def test_initial(self):
+        """
+        The user directory's initial handler correctly updates the search tables.
+        """
+        u1 = self.register_user("user1", "pass")
+        u1_token = self.login(u1, "pass")
+        u2 = self.register_user("user2", "pass")
+        u2_token = self.login(u2, "pass")
+        u3 = self.register_user("user3", "pass")
+        u3_token = self.login(u3, "pass")
+
+        room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+        self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+        self.helper.join(room, user=u2, tok=u2_token)
+
+        private_room = self.helper.create_room_as(u1, is_public=False, tok=u1_token)
+        self.helper.invite(private_room, src=u1, targ=u3, tok=u1_token)
+        self.helper.join(private_room, user=u3, tok=u3_token)
+
+        self.get_success(self.store.update_user_directory_stream_pos(None))
+        self.get_success(self.store.delete_all_from_user_dir())
+
+        shares_public = self.get_users_who_share_public_rooms()
+        shares_private = self.get_users_who_share_private_rooms()
+
+        self.assertEqual(shares_private, [])
+        self.assertEqual(shares_public, [])
+
+        # Reset the handled users caches
+        self.handler.initially_handled_users = set()
+
+        # Do the initial population
+        d = self.handler._do_initial_spam()
+
+        # This takes a while, so pump it a bunch of times to get through the
+        # sleep delays
+        for i in range(10):
+            self.pump(1)
+
+        self.get_success(d)
+
+        shares_public = self.get_users_who_share_public_rooms()
+        shares_private = self.get_users_who_share_private_rooms()
+
+        # User 1 and User 2 share public rooms
+        self.assertEqual(
+            self._compress_shared(shares_public), set([(u1, u2, room), (u2, u1, room)])
+        )
+
+        # User 1 and User 3 share private rooms
+        self.assertEqual(
+            self._compress_shared(shares_private),
+            set([(u1, u3, private_room), (u3, u1, private_room)]),
+        )
+
+    def test_search_all_users(self):
+        """
+        Search all users = True means that a user does not have to share a
+        private room with the searching user or be in a public room to be search
+        visible.
+        """
+        self.handler.search_all_users = True
+        self.hs.config.user_directory_search_all_users = True
+
+        u1 = self.register_user("user1", "pass")
+        u1_token = self.login(u1, "pass")
+        u2 = self.register_user("user2", "pass")
+        u2_token = self.login(u2, "pass")
+        u3 = self.register_user("user3", "pass")
+
+        # User 1 and User 2 join a room. User 3 never does.
+        room = self.helper.create_room_as(u1, is_public=True, tok=u1_token)
+        self.helper.invite(room, src=u1, targ=u2, tok=u1_token)
+        self.helper.join(room, user=u2, tok=u2_token)
+
+        self.get_success(self.store.update_user_directory_stream_pos(None))
+        self.get_success(self.store.delete_all_from_user_dir())
+
+        # Reset the handled users caches
+        self.handler.initially_handled_users = set()
+
+        # Do the initial population
+        d = self.handler._do_initial_spam()
+
+        # This takes a while, so pump it a bunch of times to get through the
+        # sleep delays
+        for i in range(10):
+            self.pump(1)
+
+        self.get_success(d)
+
+        # Despite not sharing a room, search_all_users means we get a search
+        # result.
+        s = self.get_success(self.handler.search_users(u1, u3, 10))
+        self.assertEqual(len(s["results"]), 1)
diff --git a/tests/storage/test_user_directory.py b/tests/storage/test_user_directory.py
index 0dde1ab2fe..a2a652a235 100644
--- a/tests/storage/test_user_directory.py
+++ b/tests/storage/test_user_directory.py
@@ -35,14 +35,12 @@ class UserDirectoryStoreTestCase(unittest.TestCase):
         # alice and bob are both in !room_id. bobby is not but shares
         # a homeserver with alice.
         yield self.store.add_profiles_to_user_dir(
-            "!room:id",
             {
                 ALICE: ProfileInfo(None, "alice"),
                 BOB: ProfileInfo(None, "bob"),
                 BOBBY: ProfileInfo(None, "bobby"),
             },
         )
-        yield self.store.add_users_to_public_room("!room:id", [ALICE, BOB])
         yield self.store.add_users_who_share_room(
             "!room:id", False, ((ALICE, BOB), (BOB, ALICE))
         )