summary refs log tree commit diff
path: root/synapse/storage/databases/main/profile.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/profile.py')
-rw-r--r--synapse/storage/databases/main/profile.py102
1 files changed, 100 insertions, 2 deletions
diff --git a/synapse/storage/databases/main/profile.py b/synapse/storage/databases/main/profile.py
index c4022d2427..65c92bef51 100644
--- a/synapse/storage/databases/main/profile.py
+++ b/synapse/storage/databases/main/profile.py
@@ -15,9 +15,14 @@ from typing import TYPE_CHECKING, Optional
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.roommember import ProfileInfo
-from synapse.types import UserID
+from synapse.storage.engines import PostgresEngine
+from synapse.types import JsonDict, UserID
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -31,6 +36,8 @@ class ProfileWorkerStore(SQLBaseStore):
         hs: "HomeServer",
     ):
         super().__init__(database, db_conn, hs)
+        self.server_name: str = hs.hostname
+        self.database_engine = database.engine
         self.db_pool.updates.register_background_index_update(
             "profiles_full_user_id_key_idx",
             index_name="profiles_full_user_id_key",
@@ -39,6 +46,97 @@ class ProfileWorkerStore(SQLBaseStore):
             unique=True,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "populate_full_user_id_profiles", self.populate_full_user_id_profiles
+        )
+
+    async def populate_full_user_id_profiles(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        """
+        Background update to populate the column `full_user_id` of the table
+        profiles from entries in the column `user_local_part` of the same table
+        """
+
+        lower_bound_id = progress.get("lower_bound_id", "")
+
+        def _get_last_id(txn: LoggingTransaction) -> Optional[str]:
+            sql = """
+                    SELECT user_id FROM profiles
+                    WHERE user_id > ?
+                    ORDER BY user_id
+                    LIMIT 1 OFFSET 50
+                  """
+            txn.execute(sql, (lower_bound_id,))
+            res = txn.fetchone()
+            if res:
+                upper_bound_id = res[0]
+                return upper_bound_id
+            else:
+                return None
+
+        def _process_batch(
+            txn: LoggingTransaction, lower_bound_id: str, upper_bound_id: str
+        ) -> None:
+            sql = """
+                    UPDATE profiles
+                    SET full_user_id = '@' || user_id || ?
+                    WHERE ? < user_id AND user_id <= ? AND full_user_id IS NULL
+                   """
+            txn.execute(sql, (f":{self.server_name}", lower_bound_id, upper_bound_id))
+
+        def _final_batch(txn: LoggingTransaction, lower_bound_id: str) -> None:
+            sql = """
+                    UPDATE profiles
+                    SET full_user_id = '@' || user_id || ?
+                    WHERE ? < user_id AND full_user_id IS NULL
+                   """
+            txn.execute(
+                sql,
+                (
+                    f":{self.server_name}",
+                    lower_bound_id,
+                ),
+            )
+
+            if isinstance(self.database_engine, PostgresEngine):
+                sql = """
+                        ALTER TABLE profiles VALIDATE CONSTRAINT full_user_id_not_null
+                      """
+                txn.execute(sql)
+
+        upper_bound_id = await self.db_pool.runInteraction(
+            "populate_full_user_id_profiles", _get_last_id
+        )
+
+        if upper_bound_id is None:
+            await self.db_pool.runInteraction(
+                "populate_full_user_id_profiles", _final_batch, lower_bound_id
+            )
+
+            await self.db_pool.updates._end_background_update(
+                "populate_full_user_id_profiles"
+            )
+            return 1
+
+        await self.db_pool.runInteraction(
+            "populate_full_user_id_profiles",
+            _process_batch,
+            lower_bound_id,
+            upper_bound_id,
+        )
+
+        progress["lower_bound_id"] = upper_bound_id
+
+        await self.db_pool.runInteraction(
+            "populate_full_user_id_profiles",
+            self.db_pool.updates._background_update_progress_txn,
+            "populate_full_user_id_profiles",
+            progress,
+        )
+
+        return 50
+
     async def get_profileinfo(self, user_localpart: str) -> ProfileInfo:
         try:
             profile = await self.db_pool.simple_select_one(