summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/app/_base.py18
-rwxr-xr-xsynapse/app/homeserver.py17
-rw-r--r--synapse/config/_base.py8
-rw-r--r--synapse/config/tls.py5
-rw-r--r--synapse/config/workers.py28
-rw-r--r--synapse/federation/transport/client.py2
-rw-r--r--synapse/federation/transport/server.py14
-rw-r--r--synapse/handlers/user_directory.py125
-rw-r--r--synapse/server.py13
-rw-r--r--synapse/storage/_base.py13
-rw-r--r--synapse/storage/schema/delta/53/user_share.sql3
-rw-r--r--synapse/storage/schema/delta/53/users_in_public_rooms.sql28
-rw-r--r--synapse/storage/user_directory.py123
13 files changed, 249 insertions, 148 deletions
diff --git a/synapse/app/_base.py b/synapse/app/_base.py
index 32e8b8a3f5..d4c6c4c8e2 100644
--- a/synapse/app/_base.py
+++ b/synapse/app/_base.py
@@ -63,12 +63,13 @@ def start_worker_reactor(appname, config):
 
     start_reactor(
         appname,
-        config.soft_file_limit,
-        config.gc_thresholds,
-        config.worker_pid_file,
-        config.worker_daemonize,
-        config.worker_cpu_affinity,
-        logger,
+        soft_file_limit=config.soft_file_limit,
+        gc_thresholds=config.gc_thresholds,
+        pid_file=config.worker_pid_file,
+        daemonize=config.worker_daemonize,
+        cpu_affinity=config.worker_cpu_affinity,
+        print_pidfile=config.print_pidfile,
+        logger=logger,
     )
 
 
@@ -79,6 +80,7 @@ def start_reactor(
         pid_file,
         daemonize,
         cpu_affinity,
+        print_pidfile,
         logger,
 ):
     """ Run the reactor in the main process
@@ -93,6 +95,7 @@ def start_reactor(
         pid_file (str): name of pid file to write to if daemonize is True
         daemonize (bool): true to run the reactor in a background process
         cpu_affinity (int|None): cpu affinity mask
+        print_pidfile (bool): whether to print the pid file, if daemonize is True
         logger (logging.Logger): logger instance to pass to Daemonize
     """
 
@@ -124,6 +127,9 @@ def start_reactor(
             reactor.run()
 
     if daemonize:
+        if print_pidfile:
+            print(pid_file)
+
         daemon = Daemonize(
             app=appname,
             pid=pid_file,
diff --git a/synapse/app/homeserver.py b/synapse/app/homeserver.py
index e8b6cc3114..869c028d1f 100755
--- a/synapse/app/homeserver.py
+++ b/synapse/app/homeserver.py
@@ -376,6 +376,7 @@ def setup(config_options):
     logger.info("Database prepared in %s.", config.database_config['name'])
 
     hs.setup()
+    hs.setup_master()
 
     @defer.inlineCallbacks
     def do_acme():
@@ -636,17 +637,15 @@ def run(hs):
         # be quite busy the first few minutes
         clock.call_later(5 * 60, start_phone_stats_home)
 
-    if hs.config.daemonize and hs.config.print_pidfile:
-        print(hs.config.pid_file)
-
     _base.start_reactor(
         "synapse-homeserver",
-        hs.config.soft_file_limit,
-        hs.config.gc_thresholds,
-        hs.config.pid_file,
-        hs.config.daemonize,
-        hs.config.cpu_affinity,
-        logger,
+        soft_file_limit=hs.config.soft_file_limit,
+        gc_thresholds=hs.config.gc_thresholds,
+        pid_file=hs.config.pid_file,
+        daemonize=hs.config.daemonize,
+        cpu_affinity=hs.config.cpu_affinity,
+        print_pidfile=hs.config.print_pidfile,
+        logger=logger,
     )
 
 
diff --git a/synapse/config/_base.py b/synapse/config/_base.py
index c4d3087fa4..5613f38e4d 100644
--- a/synapse/config/_base.py
+++ b/synapse/config/_base.py
@@ -214,14 +214,20 @@ class Config(object):
             " Defaults to the directory containing the last config file",
         )
 
+        obj = cls()
+
+        obj.invoke_all("add_arguments", config_parser)
+
         config_args = config_parser.parse_args(argv)
 
         config_files = find_config_files(search_paths=config_args.config_path)
 
-        obj = cls()
         obj.read_config_files(
             config_files, keys_directory=config_args.keys_directory, generate_keys=False
         )
+
+        obj.invoke_all("read_arguments", config_args)
+
         return obj
 
     @classmethod
diff --git a/synapse/config/tls.py b/synapse/config/tls.py
index 40045de7ac..f0014902da 100644
--- a/synapse/config/tls.py
+++ b/synapse/config/tls.py
@@ -181,6 +181,11 @@ class TlsConfig(Config):
         # See 'ACME support' below to enable auto-provisioning this certificate via
         # Let's Encrypt.
         #
+        # If supplying your own, be sure to use a `.pem` file that includes the
+        # full certificate chain including any intermediate certificates (for
+        # instance, if using certbot, use `fullchain.pem` as your certificate,
+        # not `cert.pem`).
+        #
         #tls_certificate_path: "%(tls_certificate_path)s"
 
         # PEM-encoded private key for TLS
diff --git a/synapse/config/workers.py b/synapse/config/workers.py
index 80baf0ce0e..bfbd8b6c91 100644
--- a/synapse/config/workers.py
+++ b/synapse/config/workers.py
@@ -28,7 +28,7 @@ class WorkerConfig(Config):
         if self.worker_app == "synapse.app.homeserver":
             self.worker_app = None
 
-        self.worker_listeners = config.get("worker_listeners")
+        self.worker_listeners = config.get("worker_listeners", [])
         self.worker_daemonize = config.get("worker_daemonize")
         self.worker_pid_file = config.get("worker_pid_file")
         self.worker_log_file = config.get("worker_log_file")
@@ -48,6 +48,17 @@ class WorkerConfig(Config):
         self.worker_main_http_uri = config.get("worker_main_http_uri", None)
         self.worker_cpu_affinity = config.get("worker_cpu_affinity")
 
+        # This option is really only here to support `--manhole` command line
+        # argument.
+        manhole = config.get("worker_manhole")
+        if manhole:
+            self.worker_listeners.append({
+                "port": manhole,
+                "bind_addresses": ["127.0.0.1"],
+                "type": "manhole",
+                "tls": False,
+            })
+
         if self.worker_listeners:
             for listener in self.worker_listeners:
                 bind_address = listener.pop("bind_address", None)
@@ -57,3 +68,18 @@ class WorkerConfig(Config):
                     bind_addresses.append(bind_address)
                 elif not bind_addresses:
                     bind_addresses.append('')
+
+    def read_arguments(self, args):
+        # We support a bunch of command line arguments that override options in
+        # the config. A lot of these options have a worker_* prefix when running
+        # on workers so we also have to override them when command line options
+        # are specified.
+
+        if args.daemonize is not None:
+            self.worker_daemonize = args.daemonize
+        if args.log_config is not None:
+            self.worker_log_config = args.log_config
+        if args.log_file is not None:
+            self.worker_log_file = args.log_file
+        if args.manhole is not None:
+            self.worker_manhole = args.worker_manhole
diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py
index 4e8919d657..8e2be218e2 100644
--- a/synapse/federation/transport/client.py
+++ b/synapse/federation/transport/client.py
@@ -167,7 +167,7 @@ class TransportLayerClient(object):
         # generated by the json_data_callback.
         json_data = transaction.get_dict()
 
-        path = _create_v1_path("/send/%s", transaction.transaction_id)
+        path = _create_v1_path("/send/%s/", transaction.transaction_id)
 
         response = yield self.client.put_json(
             transaction.destination,
diff --git a/synapse/federation/transport/server.py b/synapse/federation/transport/server.py
index efb6bdca48..96d680a5ad 100644
--- a/synapse/federation/transport/server.py
+++ b/synapse/federation/transport/server.py
@@ -312,7 +312,7 @@ class BaseFederationServlet(object):
 
 
 class FederationSendServlet(BaseFederationServlet):
-    PATH = "/send/(?P<transaction_id>[^/]*)/?"
+    PATH = "/send/(?P<transaction_id>[^/]*)/"
 
     def __init__(self, handler, server_name, **kwargs):
         super(FederationSendServlet, self).__init__(
@@ -378,7 +378,7 @@ class FederationSendServlet(BaseFederationServlet):
 
 
 class FederationEventServlet(BaseFederationServlet):
-    PATH = "/event/(?P<event_id>[^/]*)/?"
+    PATH = "/event/(?P<event_id>[^/]*)/"
 
     # This is when someone asks for a data item for a given server data_id pair.
     def on_GET(self, origin, content, query, event_id):
@@ -386,7 +386,7 @@ class FederationEventServlet(BaseFederationServlet):
 
 
 class FederationStateServlet(BaseFederationServlet):
-    PATH = "/state/(?P<context>[^/]*)/?"
+    PATH = "/state/(?P<context>[^/]*)/"
 
     # This is when someone asks for all data for a given context.
     def on_GET(self, origin, content, query, context):
@@ -398,7 +398,7 @@ class FederationStateServlet(BaseFederationServlet):
 
 
 class FederationStateIdsServlet(BaseFederationServlet):
-    PATH = "/state_ids/(?P<room_id>[^/]*)/?"
+    PATH = "/state_ids/(?P<room_id>[^/]*)/"
 
     def on_GET(self, origin, content, query, room_id):
         return self.handler.on_state_ids_request(
@@ -409,7 +409,7 @@ class FederationStateIdsServlet(BaseFederationServlet):
 
 
 class FederationBackfillServlet(BaseFederationServlet):
-    PATH = "/backfill/(?P<context>[^/]*)/?"
+    PATH = "/backfill/(?P<context>[^/]*)/"
 
     def on_GET(self, origin, content, query, context):
         versions = [x.decode('ascii') for x in query[b"v"]]
@@ -1080,7 +1080,7 @@ class FederationGroupsCategoriesServlet(BaseFederationServlet):
     """Get all categories for a group
     """
     PATH = (
-        "/groups/(?P<group_id>[^/]*)/categories/?"
+        "/groups/(?P<group_id>[^/]*)/categories/"
     )
 
     @defer.inlineCallbacks
@@ -1150,7 +1150,7 @@ class FederationGroupsRolesServlet(BaseFederationServlet):
     """Get roles in a group
     """
     PATH = (
-        "/groups/(?P<group_id>[^/]*)/roles/?"
+        "/groups/(?P<group_id>[^/]*)/roles/"
     )
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/user_directory.py b/synapse/handlers/user_directory.py
index c21da8343a..d92f8c529c 100644
--- a/synapse/handlers/user_directory.py
+++ b/synapse/handlers/user_directory.py
@@ -60,6 +60,12 @@ class UserDirectoryHandler(object):
         self.update_user_directory = hs.config.update_user_directory
         self.search_all_users = hs.config.user_directory_search_all_users
 
+        # If we're a worker, don't sleep when doing the initial room work, as it
+        # won't monopolise the master's CPU.
+        if hs.config.worker_app:
+            self.INITIAL_ROOM_SLEEP_MS = 0
+            self.INITIAL_USER_SLEEP_MS = 0
+
         # When start up for the first time we need to populate the user_directory.
         # This is a set of user_id's we've inserted already
         self.initially_handled_users = set()
@@ -231,7 +237,7 @@ class UserDirectoryHandler(object):
         unhandled_users = user_ids - self.initially_handled_users
 
         yield self.store.add_profiles_to_user_dir(
-            {user_id: users_with_profile[user_id] for user_id in unhandled_users},
+            {user_id: users_with_profile[user_id] for user_id in unhandled_users}
         )
 
         self.initially_handled_users |= unhandled_users
@@ -241,38 +247,58 @@ class UserDirectoryHandler(object):
         # We also batch up inserts/updates, but try to avoid too many at once.
         to_insert = set()
         count = 0
-        for user_id in user_ids:
-            if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
-                yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
-
-            if not self.is_mine_id(user_id):
-                count += 1
-                continue
 
-            if self.store.get_if_app_services_interested_in_user(user_id):
-                count += 1
-                continue
+        if is_public:
+            for user_id in user_ids:
+                if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+                    yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
 
-            for other_user_id in user_ids:
-                if user_id == other_user_id:
+                if self.store.get_if_app_services_interested_in_user(user_id):
+                    count += 1
                     continue
 
+                to_insert.add(user_id)
+                if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
+                    yield self.store.add_users_in_public_rooms(room_id, to_insert)
+                    to_insert.clear()
+
+            if to_insert:
+                yield self.store.add_users_in_public_rooms(room_id, to_insert)
+                to_insert.clear()
+        else:
+
+            for user_id in user_ids:
                 if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
                     yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
-                count += 1
 
-                user_set = (user_id, other_user_id)
-                to_insert.add(user_set)
+                if not self.is_mine_id(user_id):
+                    count += 1
+                    continue
 
-                if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
-                    yield self.store.add_users_who_share_room(
-                        room_id, not is_public, to_insert
-                    )
-                    to_insert.clear()
+                if self.store.get_if_app_services_interested_in_user(user_id):
+                    count += 1
+                    continue
+
+                for other_user_id in user_ids:
+                    if user_id == other_user_id:
+                        continue
+
+                    if count % self.INITIAL_ROOM_SLEEP_COUNT == 0:
+                        yield self.clock.sleep(self.INITIAL_ROOM_SLEEP_MS / 1000.0)
+                    count += 1
+
+                    user_set = (user_id, other_user_id)
+                    to_insert.add(user_set)
+
+                    if len(to_insert) > self.INITIAL_ROOM_BATCH_SIZE:
+                        yield self.store.add_users_who_share_private_room(
+                            room_id, not is_public, to_insert
+                        )
+                        to_insert.clear()
 
-        if to_insert:
-            yield self.store.add_users_who_share_room(room_id, not is_public, to_insert)
-            to_insert.clear()
+            if to_insert:
+                yield self.store.add_users_who_share_private_room(room_id, to_insert)
+                to_insert.clear()
 
     @defer.inlineCallbacks
     def _handle_deltas(self, deltas):
@@ -445,34 +471,37 @@ class UserDirectoryHandler(object):
         # Now we update users who share rooms with users.
         users_with_profile = yield self.state.get_current_user_in_room(room_id)
 
-        to_insert = set()
+        if is_public:
+            yield self.store.add_users_in_public_rooms(room_id, (user_id,))
+        else:
+            to_insert = set()
 
-        # First, if they're our user then we need to update for every user
-        if self.is_mine_id(user_id):
+            # First, if they're our user then we need to update for every user
+            if self.is_mine_id(user_id):
 
-            is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
+                is_appservice = self.store.get_if_app_services_interested_in_user(user_id)
 
-            # We don't care about appservice users.
-            if not is_appservice:
-                for other_user_id in users_with_profile:
-                    if user_id == other_user_id:
-                        continue
+                # We don't care about appservice users.
+                if not is_appservice:
+                    for other_user_id in users_with_profile:
+                        if user_id == other_user_id:
+                            continue
 
-                    to_insert.add((user_id, other_user_id))
+                        to_insert.add((user_id, other_user_id))
 
-        # Next we need to update for every local user in the room
-        for other_user_id in users_with_profile:
-            if user_id == other_user_id:
-                continue
+            # Next we need to update for every local user in the room
+            for other_user_id in users_with_profile:
+                if user_id == other_user_id:
+                    continue
 
-            is_appservice = self.store.get_if_app_services_interested_in_user(
-                other_user_id
-            )
-            if self.is_mine_id(other_user_id) and not is_appservice:
-                to_insert.add((other_user_id, user_id))
+                is_appservice = self.store.get_if_app_services_interested_in_user(
+                    other_user_id
+                )
+                if self.is_mine_id(other_user_id) and not is_appservice:
+                    to_insert.add((other_user_id, user_id))
 
-        if to_insert:
-            yield self.store.add_users_who_share_room(room_id, not is_public, to_insert)
+            if to_insert:
+                yield self.store.add_users_who_share_private_room(room_id, to_insert)
 
     @defer.inlineCallbacks
     def _handle_remove_user(self, room_id, user_id):
@@ -487,10 +516,10 @@ class UserDirectoryHandler(object):
         # Remove user from sharing tables
         yield self.store.remove_user_who_share_room(user_id, room_id)
 
-        # Are they still in a room with members? If not, remove them entirely.
-        users_in_room_with = yield self.store.get_users_who_share_room_from_dir(user_id)
+        # Are they still in any rooms? If not, remove them entirely.
+        rooms_user_is_in = yield self.store.get_user_dir_rooms_user_is_in(user_id)
 
-        if len(users_in_room_with) == 0:
+        if len(rooms_user_is_in) == 0:
             yield self.store.remove_from_user_dir(user_id)
 
     @defer.inlineCallbacks
diff --git a/synapse/server.py b/synapse/server.py
index 72835e8c86..b9549dd042 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -185,6 +185,10 @@ class HomeServer(object):
         'registration_handler',
     ]
 
+    REQUIRED_ON_MASTER_STARTUP = [
+        "user_directory_handler",
+    ]
+
     # This is overridden in derived application classes
     # (such as synapse.app.homeserver.SynapseHomeServer) and gives the class to be
     # instantiated during setup() for future return by get_datastore()
@@ -221,6 +225,15 @@ class HomeServer(object):
             conn.commit()
         logger.info("Finished setting up.")
 
+    def setup_master(self):
+        """
+        Some handlers have side effects on instantiation (like registering
+        background updates). This function causes them to be fetched, and
+        therefore instantiated, to run those side effects.
+        """
+        for i in self.REQUIRED_ON_MASTER_STARTUP:
+            getattr(self, "get_" + i)()
+
     def get_reactor(self):
         """
         Fetch the Twisted reactor in use by this HomeServer.
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a0333d5309..7e3903859b 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -767,18 +767,25 @@ class SQLBaseStore(object):
         """
         allvalues = {}
         allvalues.update(keyvalues)
-        allvalues.update(values)
         allvalues.update(insertion_values)
 
+        if not values:
+            latter = "NOTHING"
+        else:
+            allvalues.update(values)
+            latter = (
+                "UPDATE SET " + ", ".join(k + "=EXCLUDED." + k for k in values)
+            )
+
         sql = (
             "INSERT INTO %s (%s) VALUES (%s) "
-            "ON CONFLICT (%s) DO UPDATE SET %s"
+            "ON CONFLICT (%s) DO %s"
         ) % (
             table,
             ", ".join(k for k in allvalues),
             ", ".join("?" for _ in allvalues),
             ", ".join(k for k in keyvalues),
-            ", ".join(k + "=EXCLUDED." + k for k in values),
+            latter
         )
         txn.execute(sql, list(allvalues.values()))
 
diff --git a/synapse/storage/schema/delta/53/user_share.sql b/synapse/storage/schema/delta/53/user_share.sql
index 14424ded0c..5831b1a6f8 100644
--- a/synapse/storage/schema/delta/53/user_share.sql
+++ b/synapse/storage/schema/delta/53/user_share.sql
@@ -16,9 +16,6 @@
 -- Old disused version of the tables below.
 DROP TABLE IF EXISTS users_who_share_rooms;
 
--- This is no longer used because it's duplicated by the users_who_share_public_rooms
-DROP TABLE IF EXISTS users_in_public_rooms;
-
 -- Tables keeping track of what users share rooms. This is a map of local users
 -- to local or remote users, per room. Remote users cannot be in the user_id
 -- column, only the other_user_id column. There are two tables, one for public
diff --git a/synapse/storage/schema/delta/53/users_in_public_rooms.sql b/synapse/storage/schema/delta/53/users_in_public_rooms.sql
new file mode 100644
index 0000000000..f7827ca6d2
--- /dev/null
+++ b/synapse/storage/schema/delta/53/users_in_public_rooms.sql
@@ -0,0 +1,28 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+-- We don't need the old version of this table.
+DROP TABLE IF EXISTS users_in_public_rooms;
+
+-- Old version of users_in_public_rooms
+DROP TABLE IF EXISTS users_who_share_public_rooms;
+
+-- Track what users are in public rooms.
+CREATE TABLE IF NOT EXISTS users_in_public_rooms (
+    user_id TEXT NOT NULL,
+    room_id TEXT NOT NULL
+);
+
+CREATE UNIQUE INDEX users_in_public_rooms_u_idx ON users_in_public_rooms(user_id, room_id);
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 2317d22ed6..1c00b956e5 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -21,12 +21,11 @@ from six import iteritems
 from twisted.internet import defer
 
 from synapse.api.constants import EventTypes, JoinRules
+from synapse.storage._base import SQLBaseStore
 from synapse.storage.engines import PostgresEngine, Sqlite3Engine
 from synapse.storage.state import StateFilter
 from synapse.types import get_domain_from_id, get_localpart_from_id
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
-
-from ._base import SQLBaseStore
+from synapse.util.caches.descriptors import cached
 
 logger = logging.getLogger(__name__)
 
@@ -242,14 +241,7 @@ class UserDirectoryStore(SQLBaseStore):
                 txn, table="user_directory_search", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
-                txn,
-                table="users_who_share_public_rooms",
-                keyvalues={"user_id": user_id},
-            )
-            self._simple_delete_txn(
-                txn,
-                table="users_who_share_public_rooms",
-                keyvalues={"other_user_id": user_id},
+                txn, table="users_in_public_rooms", keyvalues={"user_id": user_id}
             )
             self._simple_delete_txn(
                 txn,
@@ -271,9 +263,9 @@ class UserDirectoryStore(SQLBaseStore):
         in the given room_id
         """
         user_ids_share_pub = yield self._simple_select_onecol(
-            table="users_who_share_public_rooms",
+            table="users_in_public_rooms",
             keyvalues={"room_id": room_id},
-            retcol="other_user_id",
+            retcol="user_id",
             desc="get_users_in_dir_due_to_room",
         )
 
@@ -311,26 +303,19 @@ class UserDirectoryStore(SQLBaseStore):
         rows = yield self._execute("get_all_local_users", None, sql)
         defer.returnValue([name for name, in rows])
 
-    def add_users_who_share_room(self, room_id, share_private, user_id_tuples):
-        """Insert entries into the users_who_share_*_rooms table. The first
+    def add_users_who_share_private_room(self, room_id, user_id_tuples):
+        """Insert entries into the users_who_share_private_rooms table. The first
         user should be a local user.
 
         Args:
             room_id (str)
-            share_private (bool): Is the room private
             user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs.
         """
 
         def _add_users_who_share_room_txn(txn):
-
-            if share_private:
-                tbl = "users_who_share_private_rooms"
-            else:
-                tbl = "users_who_share_public_rooms"
-
             self._simple_upsert_many_txn(
                 txn,
-                table=tbl,
+                table="users_who_share_private_rooms",
                 key_names=["user_id", "other_user_id", "room_id"],
                 key_values=[
                     (user_id, other_user_id, room_id)
@@ -339,15 +324,35 @@ class UserDirectoryStore(SQLBaseStore):
                 value_names=(),
                 value_values=None,
             )
-            for user_id, other_user_id in user_id_tuples:
-                txn.call_after(
-                    self.get_users_who_share_room_from_dir.invalidate, (user_id,)
-                )
 
         return self.runInteraction(
             "add_users_who_share_room", _add_users_who_share_room_txn
         )
 
+    def add_users_in_public_rooms(self, room_id, user_ids):
+        """Insert entries into the users_who_share_private_rooms table. The first
+        user should be a local user.
+
+        Args:
+            room_id (str)
+            user_ids (list[str])
+        """
+
+        def _add_users_in_public_rooms_txn(txn):
+
+            self._simple_upsert_many_txn(
+                txn,
+                table="users_in_public_rooms",
+                key_names=["user_id", "room_id"],
+                key_values=[(user_id, room_id) for user_id in user_ids],
+                value_names=(),
+                value_values=None,
+            )
+
+        return self.runInteraction(
+            "add_users_in_public_rooms", _add_users_in_public_rooms_txn
+        )
+
     def remove_user_who_share_room(self, user_id, room_id):
         """
         Deletes entries in the users_who_share_*_rooms table. The first
@@ -371,25 +376,18 @@ class UserDirectoryStore(SQLBaseStore):
             )
             self._simple_delete_txn(
                 txn,
-                table="users_who_share_public_rooms",
+                table="users_in_public_rooms",
                 keyvalues={"user_id": user_id, "room_id": room_id},
             )
-            self._simple_delete_txn(
-                txn,
-                table="users_who_share_public_rooms",
-                keyvalues={"other_user_id": user_id, "room_id": room_id},
-            )
-            txn.call_after(
-                self.get_users_who_share_room_from_dir.invalidate, (user_id,)
-            )
 
         return self.runInteraction(
             "remove_user_who_share_room", _remove_user_who_share_room_txn
         )
 
-    @cachedInlineCallbacks(max_entries=500000, iterable=True)
-    def get_users_who_share_room_from_dir(self, user_id):
-        """Returns the set of users who share a room with `user_id`
+    @defer.inlineCallbacks
+    def get_user_dir_rooms_user_is_in(self, user_id):
+        """
+        Returns the rooms that a user is in.
 
         Args:
             user_id(str): Must be a local user
@@ -400,23 +398,19 @@ class UserDirectoryStore(SQLBaseStore):
         rows = yield self._simple_select_onecol(
             table="users_who_share_private_rooms",
             keyvalues={"user_id": user_id},
-            retcol="other_user_id",
-            desc="get_users_who_share_room_with_user",
+            retcol="room_id",
+            desc="get_rooms_user_is_in",
         )
 
         pub_rows = yield self._simple_select_onecol(
-            table="users_who_share_public_rooms",
+            table="users_in_public_rooms",
             keyvalues={"user_id": user_id},
-            retcol="other_user_id",
-            desc="get_users_who_share_room_with_user",
+            retcol="room_id",
+            desc="get_rooms_user_is_in",
         )
 
         users = set(pub_rows)
         users.update(rows)
-
-        # Remove the user themselves from this list.
-        users.discard(user_id)
-
         defer.returnValue(list(users))
 
     @defer.inlineCallbacks
@@ -452,10 +446,9 @@ class UserDirectoryStore(SQLBaseStore):
         def _delete_all_from_user_dir_txn(txn):
             txn.execute("DELETE FROM user_directory")
             txn.execute("DELETE FROM user_directory_search")
-            txn.execute("DELETE FROM users_who_share_public_rooms")
+            txn.execute("DELETE FROM users_in_public_rooms")
             txn.execute("DELETE FROM users_who_share_private_rooms")
             txn.call_after(self.get_user_in_directory.invalidate_all)
-            txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all)
 
         return self.runInteraction(
             "delete_all_from_user_dir", _delete_all_from_user_dir_txn
@@ -560,23 +553,19 @@ class UserDirectoryStore(SQLBaseStore):
         """
 
         if self.hs.config.user_directory_search_all_users:
-            # make s.user_id null to keep the ordering algorithm happy
-            join_clause = """
-                CROSS JOIN (SELECT NULL as user_id) AS s
-            """
             join_args = ()
             where_clause = "1=1"
         else:
-            join_clause = """
-                LEFT JOIN (
-                    SELECT other_user_id AS user_id FROM users_who_share_public_rooms
-                    UNION
-                    SELECT other_user_id AS user_id FROM users_who_share_private_rooms
-                    WHERE user_id = ?
-                ) AS p USING (user_id)
-            """
             join_args = (user_id,)
-            where_clause = "p.user_id IS NOT NULL"
+            where_clause = """
+                (
+                    EXISTS (select 1 from users_in_public_rooms WHERE user_id = t.user_id)
+                    OR EXISTS (
+                        SELECT 1 FROM users_who_share_private_rooms
+                        WHERE user_id = ? AND other_user_id = t.user_id
+                    )
+                )
+            """
 
         if isinstance(self.database_engine, PostgresEngine):
             full_query, exact_query, prefix_query = _parse_query_postgres(search_term)
@@ -588,9 +577,8 @@ class UserDirectoryStore(SQLBaseStore):
             # search: (domain, _, display name, localpart)
             sql = """
                 SELECT d.user_id AS user_id, display_name, avatar_url
-                FROM user_directory_search
+                FROM user_directory_search as t
                 INNER JOIN user_directory AS d USING (user_id)
-                %s
                 WHERE
                     %s
                     AND vector @@ to_tsquery('english', ?)
@@ -617,7 +605,6 @@ class UserDirectoryStore(SQLBaseStore):
                     avatar_url IS NULL
                 LIMIT ?
             """ % (
-                join_clause,
                 where_clause,
             )
             args = join_args + (full_query, exact_query, prefix_query, limit + 1)
@@ -626,9 +613,8 @@ class UserDirectoryStore(SQLBaseStore):
 
             sql = """
                 SELECT d.user_id AS user_id, display_name, avatar_url
-                FROM user_directory_search
+                FROM user_directory_search as t
                 INNER JOIN user_directory AS d USING (user_id)
-                %s
                 WHERE
                     %s
                     AND value MATCH ?
@@ -638,7 +624,6 @@ class UserDirectoryStore(SQLBaseStore):
                     avatar_url IS NULL
                 LIMIT ?
             """ % (
-                join_clause,
                 where_clause,
             )
             args = join_args + (search_query, limit + 1)