summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/background_updates.py55
-rw-r--r--synapse/storage/databases/main/receipts.py51
2 files changed, 58 insertions, 48 deletions
diff --git a/synapse/storage/background_updates.py b/synapse/storage/background_updates.py
index 2056ecb2c3..a99aea8926 100644
--- a/synapse/storage/background_updates.py
+++ b/synapse/storage/background_updates.py
@@ -544,6 +544,48 @@ class BackgroundUpdater:
                 The named index will be dropped upon completion of the new index.
         """
 
+        async def updater(progress: JsonDict, batch_size: int) -> int:
+            await self.create_index_in_background(
+                index_name=index_name,
+                table=table,
+                columns=columns,
+                where_clause=where_clause,
+                unique=unique,
+                psql_only=psql_only,
+                replaces_index=replaces_index,
+            )
+            await self._end_background_update(update_name)
+            return 1
+
+        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
+            updater, oneshot=True
+        )
+
+    async def create_index_in_background(
+        self,
+        index_name: str,
+        table: str,
+        columns: Iterable[str],
+        where_clause: Optional[str] = None,
+        unique: bool = False,
+        psql_only: bool = False,
+        replaces_index: Optional[str] = None,
+    ) -> None:
+        """Add an index in the background.
+
+        Args:
+            update_name: update_name to register for
+            index_name: name of index to add
+            table: table to add index to
+            columns: columns/expressions to include in index
+            where_clause: A WHERE clause to specify a partial unique index.
+            unique: true to make a UNIQUE index
+            psql_only: true to only create this index on psql databases (useful
+                for virtual sqlite tables)
+            replaces_index: The name of an index that this index replaces.
+                The named index will be dropped upon completion of the new index.
+        """
+
         def create_index_psql(conn: Connection) -> None:
             conn.rollback()
             # postgres insists on autocommit for the index
@@ -618,16 +660,11 @@ class BackgroundUpdater:
         else:
             runner = create_index_sqlite
 
-        async def updater(progress: JsonDict, batch_size: int) -> int:
-            if runner is not None:
-                logger.info("Adding index %s to %s", index_name, table)
-                await self.db_pool.runWithConnection(runner)
-            await self._end_background_update(update_name)
-            return 1
+        if runner is None:
+            return
 
-        self._background_update_handlers[update_name] = _BackgroundUpdateHandler(
-            updater, oneshot=True
-        )
+        logger.info("Adding index %s to %s", index_name, table)
+        await self.db_pool.runWithConnection(runner)
 
     async def _end_background_update(self, update_name: str) -> None:
         """Removes a completed background update task from the queue.
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index a580e4bdda..e06725f69c 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -924,39 +924,6 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
 
         return batch_size
 
-    async def _create_receipts_index(self, index_name: str, table: str) -> None:
-        """Adds a unique index on `(room_id, receipt_type, user_id)` to the given
-        receipts table, for non-thread receipts."""
-
-        def _create_index(conn: LoggingDatabaseConnection) -> None:
-            conn.rollback()
-
-            # we have to set autocommit, because postgres refuses to
-            # CREATE INDEX CONCURRENTLY without it.
-            if isinstance(self.database_engine, PostgresEngine):
-                conn.set_session(autocommit=True)
-
-            try:
-                c = conn.cursor()
-
-                # Now that the duplicates are gone, we can create the index.
-                concurrently = (
-                    "CONCURRENTLY"
-                    if isinstance(self.database_engine, PostgresEngine)
-                    else ""
-                )
-                sql = f"""
-                    CREATE UNIQUE INDEX {concurrently} {index_name}
-                    ON {table}(room_id, receipt_type, user_id)
-                    WHERE thread_id IS NULL
-                """
-                c.execute(sql)
-            finally:
-                if isinstance(self.database_engine, PostgresEngine):
-                    conn.set_session(autocommit=False)
-
-        await self.db_pool.runWithConnection(_create_index)
-
     async def _background_receipts_linearized_unique_index(
         self, progress: dict, batch_size: int
     ) -> int:
@@ -999,9 +966,12 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
             _remote_duplicate_receipts_txn,
         )
 
-        await self._create_receipts_index(
-            "receipts_linearized_unique_index",
-            "receipts_linearized",
+        await self.db_pool.updates.create_index_in_background(
+            index_name="receipts_linearized_unique_index",
+            table="receipts_linearized",
+            columns=["room_id", "receipt_type", "user_id"],
+            where_clause="thread_id IS NULL",
+            unique=True,
         )
 
         await self.db_pool.updates._end_background_update(
@@ -1050,9 +1020,12 @@ class ReceiptsBackgroundUpdateStore(SQLBaseStore):
             _remote_duplicate_receipts_txn,
         )
 
-        await self._create_receipts_index(
-            "receipts_graph_unique_index",
-            "receipts_graph",
+        await self.db_pool.updates.create_index_in_background(
+            index_name="receipts_graph_unique_index",
+            table="receipts_graph",
+            columns=["room_id", "receipt_type", "user_id"],
+            where_clause="thread_id IS NULL",
+            unique=True,
         )
 
         await self.db_pool.updates._end_background_update(