summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py85
1 files changed, 83 insertions, 2 deletions
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index fc49112063..f92d824876 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -17,11 +17,15 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
 import attr
 
-from synapse.api.constants import EventContentFields
+from synapse.api.constants import EventContentFields, RelationTypes
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingTransaction,
+    make_tuple_comparison_clause,
+)
 from synapse.storage.databases.main.events import PersistEventsStore
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
@@ -167,6 +171,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             self._purged_chain_cover_index,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "event_thread_relation", self._event_thread_relation
+        )
+
         ################################################################################
 
         # bg updates for replacing stream_ordering with a BIGINT
@@ -1091,6 +1099,79 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         return result
 
+    async def _event_thread_relation(self, progress: JsonDict, batch_size: int) -> int:
+        """Background update handler which will store thread relations for existing events."""
+        last_event_id = progress.get("last_event_id", "")
+
+        def _event_thread_relation_txn(txn: LoggingTransaction) -> int:
+            txn.execute(
+                """
+                SELECT event_id, json FROM event_json
+                LEFT JOIN event_relations USING (event_id)
+                WHERE event_id > ? AND relates_to_id IS NULL
+                ORDER BY event_id LIMIT ?
+                """,
+                (last_event_id, batch_size),
+            )
+
+            results = list(txn)
+            missing_thread_relations = []
+            for (event_id, event_json_raw) in results:
+                try:
+                    event_json = db_to_json(event_json_raw)
+                except Exception as e:
+                    logger.warning(
+                        "Unable to load event %s (no relations will be updated): %s",
+                        event_id,
+                        e,
+                    )
+                    continue
+
+                # If there's no relation (or it is not a thread), skip!
+                relates_to = event_json["content"].get("m.relates_to")
+                if not relates_to or not isinstance(relates_to, dict):
+                    continue
+                if relates_to.get("rel_type") != RelationTypes.THREAD:
+                    continue
+
+                # Get the parent ID.
+                parent_id = relates_to.get("event_id")
+                if not isinstance(parent_id, str):
+                    continue
+
+                missing_thread_relations.append((event_id, parent_id))
+
+            # Insert the missing data.
+            self.db_pool.simple_insert_many_txn(
+                txn=txn,
+                table="event_relations",
+                values=[
+                    {
+                        "event_id": event_id,
+                        "relates_to_Id": parent_id,
+                        "relation_type": RelationTypes.THREAD,
+                    }
+                    for event_id, parent_id in missing_thread_relations
+                ],
+            )
+
+            if results:
+                latest_event_id = results[-1][0]
+                self.db_pool.updates._background_update_progress_txn(
+                    txn, "event_thread_relation", {"last_event_id": latest_event_id}
+                )
+
+            return len(results)
+
+        num_rows = await self.db_pool.runInteraction(
+            desc="event_thread_relation", func=_event_thread_relation_txn
+        )
+
+        if not num_rows:
+            await self.db_pool.updates._end_background_update("event_thread_relation")
+
+        return num_rows
+
     async def _background_populate_stream_ordering2(
         self, progress: JsonDict, batch_size: int
     ) -> int: