summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-02-11 09:50:14 -0500
committerGitHub <noreply@github.com>2022-02-11 09:50:14 -0500
commitb65acead428653b988351ae8d7b22127a22039cd (patch)
tree9cda595b79c838b8769570fd60cb4f8e1f262841 /synapse
parentPrepare for rename of default complement branch (#11971) (diff)
downloadsynapse-b65acead428653b988351ae8d7b22127a22039cd.tar.xz
Fetch thread summaries for multiple events in a single query (#11752)
This should reduce database usage when fetching bundled aggregations
as the number of individual queries (and round trips to the database) are
reduced.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/events.py2
-rw-r--r--synapse/storage/databases/main/relations.py222
2 files changed, 150 insertions, 74 deletions
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 2e44c77715..5246fccad5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1812,7 +1812,7 @@ class PersistEventsStore:
             # potentially error-prone) so it is always invalidated.
             txn.call_after(
                 self.store.get_thread_participated.invalidate,
-                (parent_id, event.room_id, event.sender),
+                (parent_id, event.sender),
             )
 
     def _handle_insertion_event(self, txn: LoggingTransaction, event: EventBase):
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index ad79cc5610..e2c27e594b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,6 +20,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Set,
     Tuple,
     Union,
     cast,
@@ -454,106 +455,175 @@ class RelationsWorkerStore(SQLBaseStore):
         }
 
     @cached()
-    async def get_thread_summary(
-        self, event_id: str, room_id: str
-    ) -> Tuple[int, Optional[EventBase]]:
+    def get_thread_summary(self, event_id: str) -> Optional[Tuple[int, EventBase]]:
+        raise NotImplementedError()
+
+    @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
+    async def _get_thread_summaries(
+        self, event_ids: Collection[str]
+    ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
         """Get the number of threaded replies and the latest reply (if any) for the given event.
 
         Args:
-            event_id: Summarize the thread related to this event ID.
-            room_id: The room the event belongs to.
+            event_ids: Summarize the thread related to this event ID.
 
         Returns:
-            The number of items in the thread and the most recent response, if any.
+            A map of the thread summary each event. A missing event implies there
+            are no threaded replies.
+
+            Each summary includes the number of items in the thread and the most
+            recent response.
         """
 
-        def _get_thread_summary_txn(
+        def _get_thread_summaries_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[int, Optional[str]]:
-            # Fetch the latest event ID in the thread.
+        ) -> Tuple[Dict[str, int], Dict[str, str]]:
+            # Fetch the count of threaded events and the latest event ID.
             # TODO Should this only allow m.room.message events.
-            sql = """
-                SELECT event_id
-                FROM event_relations
-                INNER JOIN events USING (event_id)
-                WHERE
-                    relates_to_id = ?
-                    AND room_id = ?
-                    AND relation_type = ?
-                ORDER BY topological_ordering DESC, stream_ordering DESC
-                LIMIT 1
-            """
+            if isinstance(self.database_engine, PostgresEngine):
+                # The `DISTINCT ON` clause will pick the *first* row it encounters,
+                # so ordering by topologica ordering + stream ordering desc will
+                # ensure we get the latest event in the thread.
+                sql = """
+                    SELECT DISTINCT ON (parent.event_id) parent.event_id, child.event_id FROM events AS child
+                    INNER JOIN event_relations USING (event_id)
+                    INNER JOIN events AS parent ON
+                        parent.event_id = relates_to_id
+                        AND parent.room_id = child.room_id
+                    WHERE
+                        %s
+                        AND relation_type = ?
+                    ORDER BY parent.event_id, child.topological_ordering DESC, child.stream_ordering DESC
+                """
+            else:
+                # SQLite uses a simplified query which returns all entries for a
+                # thread. The first result for each thread is chosen to and subsequent
+                # results for a thread are ignored.
+                sql = """
+                    SELECT parent.event_id, child.event_id FROM events AS child
+                    INNER JOIN event_relations USING (event_id)
+                    INNER JOIN events AS parent ON
+                        parent.event_id = relates_to_id
+                        AND parent.room_id = child.room_id
+                    WHERE
+                        %s
+                        AND relation_type = ?
+                    ORDER BY child.topological_ordering DESC, child.stream_ordering DESC
+                """
+
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", event_ids
+            )
+            args.append(RelationTypes.THREAD)
 
-            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
-            row = txn.fetchone()
-            if row is None:
-                return 0, None
+            txn.execute(sql % (clause,), args)
+            latest_event_ids = {}
+            for parent_event_id, child_event_id in txn:
+                # Only consider the latest threaded reply (by topological ordering).
+                if parent_event_id not in latest_event_ids:
+                    latest_event_ids[parent_event_id] = child_event_id
 
-            latest_event_id = row[0]
+            # If no threads were found, bail.
+            if not latest_event_ids:
+                return {}, latest_event_ids
 
             # Fetch the number of threaded replies.
             sql = """
-                SELECT COUNT(event_id)
-                FROM event_relations
-                INNER JOIN events USING (event_id)
+                SELECT parent.event_id, COUNT(child.event_id) FROM events AS child
+                INNER JOIN event_relations USING (event_id)
+                INNER JOIN events AS parent ON
+                    parent.event_id = relates_to_id
+                    AND parent.room_id = child.room_id
                 WHERE
-                    relates_to_id = ?
-                    AND room_id = ?
+                    %s
                     AND relation_type = ?
+                GROUP BY parent.event_id
             """
-            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD))
-            count = cast(Tuple[int], txn.fetchone())[0]
 
-            return count, latest_event_id
+            # Regenerate the arguments since only threads found above could
+            # possibly have any replies.
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", latest_event_ids.keys()
+            )
+            args.append(RelationTypes.THREAD)
+
+            txn.execute(sql % (clause,), args)
+            counts = dict(cast(List[Tuple[str, int]], txn.fetchall()))
 
-        count, latest_event_id = await self.db_pool.runInteraction(
-            "get_thread_summary", _get_thread_summary_txn
+            return counts, latest_event_ids
+
+        counts, latest_event_ids = await self.db_pool.runInteraction(
+            "get_thread_summaries", _get_thread_summaries_txn
         )
 
-        latest_event = None
-        if latest_event_id:
-            latest_event = await self.get_event(latest_event_id, allow_none=True)  # type: ignore[attr-defined]
+        latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
+
+        # Map to the event IDs to the thread summary.
+        #
+        # There might not be a summary due to there not being a thread or
+        # due to the latest event not being known, either case is treated the same.
+        summaries = {}
+        for parent_event_id, latest_event_id in latest_event_ids.items():
+            latest_event = latest_events.get(latest_event_id)
+
+            summary = None
+            if latest_event:
+                summary = (counts[parent_event_id], latest_event)
+            summaries[parent_event_id] = summary
 
-        return count, latest_event
+        return summaries
 
     @cached()
-    async def get_thread_participated(
-        self, event_id: str, room_id: str, user_id: str
-    ) -> bool:
-        """Get whether the requesting user participated in a thread.
+    def get_thread_participated(self, event_id: str, user_id: str) -> bool:
+        raise NotImplementedError()
 
-        This is separate from get_thread_summary since that can be cached across
-        all users while this value is specific to the requeser.
+    @cachedList(cached_method_name="get_thread_participated", list_name="event_ids")
+    async def _get_threads_participated(
+        self, event_ids: Collection[str], user_id: str
+    ) -> Dict[str, bool]:
+        """Get whether the requesting user participated in the given threads.
+
+        This is separate from get_thread_summaries since that can be cached across
+        all users while this value is specific to the requester.
 
         Args:
-            event_id: The thread related to this event ID.
-            room_id: The room the event belongs to.
+            event_ids: The thread related to these event IDs.
             user_id: The user requesting the summary.
 
         Returns:
-            True if the requesting user participated in the thread, otherwise false.
+            A map of event ID to a boolean which represents if the requesting
+            user participated in that event's thread, otherwise false.
         """
 
-        def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+        def _get_thread_summary_txn(txn: LoggingTransaction) -> Set[str]:
             # Fetch whether the requester has participated or not.
             sql = """
-                SELECT 1
-                FROM event_relations
-                INNER JOIN events USING (event_id)
+                SELECT DISTINCT relates_to_id
+                FROM events AS child
+                INNER JOIN event_relations USING (event_id)
+                INNER JOIN events AS parent ON
+                    parent.event_id = relates_to_id
+                    AND parent.room_id = child.room_id
                 WHERE
-                    relates_to_id = ?
-                    AND room_id = ?
+                    %s
                     AND relation_type = ?
-                    AND sender = ?
+                    AND child.sender = ?
             """
 
-            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
-            return bool(txn.fetchone())
+            clause, args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", event_ids
+            )
+            args.extend((RelationTypes.THREAD, user_id))
 
-        return await self.db_pool.runInteraction(
+            txn.execute(sql % (clause,), args)
+            return {row[0] for row in txn.fetchall()}
+
+        participated_threads = await self.db_pool.runInteraction(
             "get_thread_summary", _get_thread_summary_txn
         )
 
+        return {event_id: event_id in participated_threads for event_id in event_ids}
+
     async def events_have_relations(
         self,
         parent_ids: List[str],
@@ -700,21 +770,6 @@ class RelationsWorkerStore(SQLBaseStore):
         if references.chunk:
             aggregations.references = await references.to_dict(cast("DataStore", self))
 
-        # If this event is the start of a thread, include a summary of the replies.
-        if self._msc3440_enabled:
-            thread_count, latest_thread_event = await self.get_thread_summary(
-                event_id, room_id
-            )
-            participated = await self.get_thread_participated(
-                event_id, room_id, user_id
-            )
-            if latest_thread_event:
-                aggregations.thread = _ThreadAggregation(
-                    latest_event=latest_thread_event,
-                    count=thread_count,
-                    current_user_participated=participated,
-                )
-
         # Store the bundled aggregations in the event metadata for later use.
         return aggregations
 
@@ -763,6 +818,27 @@ class RelationsWorkerStore(SQLBaseStore):
         for event_id, edit in edits.items():
             results.setdefault(event_id, BundledAggregations()).replace = edit
 
+        # Fetch thread summaries.
+        if self._msc3440_enabled:
+            summaries = await self._get_thread_summaries(seen_event_ids)
+            # Only fetch participated for a limited selection based on what had
+            # summaries.
+            participated = await self._get_threads_participated(
+                summaries.keys(), user_id
+            )
+            for event_id, summary in summaries.items():
+                if summary:
+                    thread_count, latest_thread_event = summary
+                    results.setdefault(
+                        event_id, BundledAggregations()
+                    ).thread = _ThreadAggregation(
+                        latest_event=latest_thread_event,
+                        count=thread_count,
+                        # If there's a thread summary it must also exist in the
+                        # participated dictionary.
+                        current_user_participated=participated[event_id],
+                    )
+
         return results