summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-05-03 14:41:37 +0100
committerGitHub <noreply@github.com>2023-05-03 13:41:37 +0000
commitfc3a878220f934a248b008277e89b85ad187d220 (patch)
tree19835f32ff9977a68272f0e4b6012392899ba2fb
parentSuppress the trusted key server warning for matrix.org in the demo scripts (#... (diff)
downloadsynapse-fc3a878220f934a248b008277e89b85ad187d220.tar.xz
Speed up rebuilding of the user directory for local users (#15529)
The idea here is to batch up the work.
Diffstat (limited to '')
-rw-r--r--changelog.d/15529.misc1
-rw-r--r--synapse/storage/database.py13
-rw-r--r--synapse/storage/databases/main/user_directory.py235
3 files changed, 172 insertions, 77 deletions
diff --git a/changelog.d/15529.misc b/changelog.d/15529.misc
new file mode 100644
index 0000000000..7ad424d8df
--- /dev/null
+++ b/changelog.d/15529.misc
@@ -0,0 +1 @@
+Speed up rebuilding of the user directory for local users.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 1f5f5eb6f8..313cf1a8d0 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -386,13 +386,20 @@ class LoggingTransaction:
             self.executemany(sql, args)
 
     def execute_values(
-        self, sql: str, values: Iterable[Iterable[Any]], fetch: bool = True
+        self,
+        sql: str,
+        values: Iterable[Iterable[Any]],
+        template: Optional[str] = None,
+        fetch: bool = True,
     ) -> List[Tuple]:
         """Corresponds to psycopg2.extras.execute_values. Only available when
         using postgres.
 
         The `fetch` parameter must be set to False if the query does not return
         rows (e.g. INSERTs).
+
+        The `template` is the snippet to merge to every item in argslist to
+        compose the query.
         """
         assert isinstance(self.database_engine, PostgresEngine)
         from psycopg2.extras import execute_values
@@ -400,7 +407,9 @@ class LoggingTransaction:
         return self._do_execute(
             # TODO: is it safe for values to be Iterable[Iterable[Any]] here?
             # https://www.psycopg.org/docs/extras.html?highlight=execute_batch#psycopg2.extras.execute_values says values should be Sequence[Sequence]
-            lambda the_sql: execute_values(self.txn, the_sql, values, fetch=fetch),
+            lambda the_sql: execute_values(
+                self.txn, the_sql, values, template=template, fetch=fetch
+            ),
             sql,
         )
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 5d65faed16..b7d58978de 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -27,6 +27,8 @@ from typing import (
     cast,
 )
 
+import attr
+
 try:
     # Figure out if ICU support is available for searching users.
     import icu
@@ -66,6 +68,19 @@ logger = logging.getLogger(__name__)
 TEMP_TABLE = "_temp_populate_user_directory"
 
 
+@attr.s(auto_attribs=True, frozen=True)
+class _UserDirProfile:
+    """Helper type for the user directory code for an entry to be inserted into
+    the directory.
+    """
+
+    user_id: str
+
+    # If the display name or avatar URL are unexpected types, replace with None
+    display_name: Optional[str] = attr.ib(default=None, converter=non_null_str_or_none)
+    avatar_url: Optional[str] = attr.ib(default=None, converter=non_null_str_or_none)
+
+
 class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
     # How many records do we calculate before sending it to
     # add_users_who_share_private_rooms?
@@ -381,25 +396,65 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
             % (len(users_to_work_on), progress["remaining"])
         )
 
-        for user_id in users_to_work_on:
-            if await self.should_include_local_user_in_dir(user_id):
-                profile = await self.get_profileinfo(get_localpart_from_id(user_id))  # type: ignore[attr-defined]
-                await self.update_profile_in_user_dir(
-                    user_id, profile.display_name, profile.avatar_url
-                )
-
-            # We've finished processing a user. Delete it from the table.
-            await self.db_pool.simple_delete_one(
-                TEMP_TABLE + "_users", {"user_id": user_id}
-            )
-            # Update the remaining counter.
-            progress["remaining"] -= 1
-            await self.db_pool.runInteraction(
-                "populate_user_directory",
-                self.db_pool.updates._background_update_progress_txn,
-                "populate_user_directory_process_users",
-                progress,
+        # First filter down to users we want to insert into the user directory.
+        users_to_insert = [
+            user_id
+            for user_id in users_to_work_on
+            if await self.should_include_local_user_in_dir(user_id)
+        ]
+
+        # Next fetch their profiles. Note that the `user_id` here is the
+        # *localpart*, and that not all users have profiles.
+        profile_rows = await self.db_pool.simple_select_many_batch(
+            table="profiles",
+            column="user_id",
+            iterable=[get_localpart_from_id(u) for u in users_to_insert],
+            retcols=(
+                "user_id",
+                "displayname",
+                "avatar_url",
+            ),
+            keyvalues={},
+            desc="populate_user_directory_process_users_get_profiles",
+        )
+        profiles = {
+            f"@{row['user_id']}:{self.server_name}": _UserDirProfile(
+                f"@{row['user_id']}:{self.server_name}",
+                row["displayname"],
+                row["avatar_url"],
             )
+            for row in profile_rows
+        }
+
+        profiles_to_insert = [
+            profiles.get(user_id) or _UserDirProfile(user_id)
+            for user_id in users_to_insert
+        ]
+
+        # Actually insert the users with their profiles into the directory.
+        await self.db_pool.runInteraction(
+            "populate_user_directory_process_users_insertion",
+            self._update_profiles_in_user_dir_txn,
+            profiles_to_insert,
+        )
+
+        # We've finished processing the users. Delete it from the table.
+        await self.db_pool.simple_delete_many(
+            table=TEMP_TABLE + "_users",
+            column="user_id",
+            iterable=users_to_work_on,
+            keyvalues={},
+            desc="populate_user_directory_process_users_delete",
+        )
+
+        # Update the remaining counter.
+        progress["remaining"] -= len(users_to_work_on)
+        await self.db_pool.runInteraction(
+            "populate_user_directory",
+            self.db_pool.updates._background_update_progress_txn,
+            "populate_user_directory_process_users",
+            progress,
+        )
 
         return len(users_to_work_on)
 
@@ -584,72 +639,102 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore):
         Update or add a user's profile in the user directory.
         If the user is remote, the profile will be marked as not stale.
         """
-        # If the display name or avatar URL are unexpected types, replace with None.
-        display_name = non_null_str_or_none(display_name)
-        avatar_url = non_null_str_or_none(avatar_url)
+        await self.db_pool.runInteraction(
+            "update_profiles_in_user_dir",
+            self._update_profiles_in_user_dir_txn,
+            [_UserDirProfile(user_id, display_name, avatar_url)],
+        )
+
+    def _update_profiles_in_user_dir_txn(
+        self,
+        txn: LoggingTransaction,
+        profiles: Sequence[_UserDirProfile],
+    ) -> None:
+        self.db_pool.simple_upsert_many_txn(
+            txn,
+            table="user_directory",
+            key_names=("user_id",),
+            key_values=[(p.user_id,) for p in profiles],
+            value_names=("display_name", "avatar_url"),
+            value_values=[
+                (
+                    p.display_name,
+                    p.avatar_url,
+                )
+                for p in profiles
+            ],
+        )
 
-        def _update_profile_in_user_dir_txn(txn: LoggingTransaction) -> None:
-            self.db_pool.simple_upsert_txn(
+        # Remote users: Make sure the profile is not marked as stale anymore.
+        remote_users = [
+            p.user_id for p in profiles if not self.hs.is_mine_id(p.user_id)
+        ]
+        if remote_users:
+            self.db_pool.simple_delete_many_txn(
                 txn,
-                table="user_directory",
-                keyvalues={"user_id": user_id},
-                values={"display_name": display_name, "avatar_url": avatar_url},
+                table="user_directory_stale_remote_users",
+                column="user_id",
+                values=remote_users,
+                keyvalues={},
             )
 
-            if not self.hs.is_mine_id(user_id):
-                # Remote users: Make sure the profile is not marked as stale anymore.
-                self.db_pool.simple_delete_txn(
-                    txn,
-                    table="user_directory_stale_remote_users",
-                    keyvalues={"user_id": user_id},
+        if isinstance(self.database_engine, PostgresEngine):
+            # We weight the localpart most highly, then display name and finally
+            # server name
+            template = """
+                (
+                    %s,
+                    setweight(to_tsvector('simple', %s), 'A')
+                    || setweight(to_tsvector('simple', %s), 'D')
+                    || setweight(to_tsvector('simple', COALESCE(%s, '')), 'B')
                 )
+            """
 
-            # The display name that goes into the database index.
-            index_display_name = display_name
-            if index_display_name is not None:
-                index_display_name = _filter_text_for_index(index_display_name)
-
-            if isinstance(self.database_engine, PostgresEngine):
-                # We weight the localpart most highly, then display name and finally
-                # server name
-                sql = """
-                        INSERT INTO user_directory_search(user_id, vector)
-                        VALUES (?,
-                            setweight(to_tsvector('simple', ?), 'A')
-                            || setweight(to_tsvector('simple', ?), 'D')
-                            || setweight(to_tsvector('simple', COALESCE(?, '')), 'B')
-                        ) ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
-                    """
-                txn.execute(
-                    sql,
+            sql = """
+                    INSERT INTO user_directory_search(user_id, vector)
+                    VALUES ? ON CONFLICT (user_id) DO UPDATE SET vector=EXCLUDED.vector
+                """
+            txn.execute_values(
+                sql,
+                [
                     (
-                        user_id,
-                        get_localpart_from_id(user_id),
-                        get_domain_from_id(user_id),
-                        index_display_name,
-                    ),
-                )
-            elif isinstance(self.database_engine, Sqlite3Engine):
-                value = (
-                    "%s %s" % (user_id, index_display_name)
-                    if index_display_name
-                    else user_id
-                )
-                self.db_pool.simple_upsert_txn(
-                    txn,
-                    table="user_directory_search",
-                    keyvalues={"user_id": user_id},
-                    values={"value": value},
-                )
-            else:
-                # This should be unreachable.
-                raise Exception("Unrecognized database engine")
+                        p.user_id,
+                        get_localpart_from_id(p.user_id),
+                        get_domain_from_id(p.user_id),
+                        _filter_text_for_index(p.display_name)
+                        if p.display_name
+                        else None,
+                    )
+                    for p in profiles
+                ],
+                template=template,
+                fetch=False,
+            )
+        elif isinstance(self.database_engine, Sqlite3Engine):
+            values = []
+            for p in profiles:
+                if p.display_name is not None:
+                    index_display_name = _filter_text_for_index(p.display_name)
+                    value = f"{p.user_id} {index_display_name}"
+                else:
+                    value = p.user_id
 
-            txn.call_after(self.get_user_in_directory.invalidate, (user_id,))
+                values.append((value,))
 
-        await self.db_pool.runInteraction(
-            "update_profile_in_user_dir", _update_profile_in_user_dir_txn
-        )
+            self.db_pool.simple_upsert_many_txn(
+                txn,
+                table="user_directory_search",
+                key_names=("user_id",),
+                key_values=[(p.user_id,) for p in profiles],
+                value_names=("value",),
+                value_values=values,
+            )
+        else:
+            # This should be unreachable.
+            raise Exception("Unrecognized database engine")
+
+        for p in profiles:
+            txn.call_after(self.get_user_in_directory.invalidate, (p.user_id,))
 
     async def add_users_who_share_private_room(
         self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]]