summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14491.feature1
-rw-r--r--synapse/handlers/relations.py197
-rw-r--r--synapse/storage/databases/main/relations.py139
-rw-r--r--synapse/util/caches/descriptors.py2
-rw-r--r--tests/rest/client/test_relations.py4
5 files changed, 202 insertions, 141 deletions
diff --git a/changelog.d/14491.feature b/changelog.d/14491.feature
new file mode 100644
index 0000000000..4fca7282f7
--- /dev/null
+++ b/changelog.d/14491.feature
@@ -0,0 +1 @@
+Reduce database load of [Client-Server endpoints](https://spec.matrix.org/v1.4/client-server-api/#aggregations) which return bundled aggregations.
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 8e71dda970..ca94239f61 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -13,7 +13,16 @@
 # limitations under the License.
 import enum
 import logging
-from typing import TYPE_CHECKING, Dict, FrozenSet, Iterable, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    FrozenSet,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+)
 
 import attr
 
@@ -259,48 +268,64 @@ class RelationsHandler:
                     e.msg,
                 )
 
-    async def get_annotations_for_event(
-        self,
-        event_id: str,
-        room_id: str,
-        limit: int = 5,
-        ignored_users: FrozenSet[str] = frozenset(),
-    ) -> List[JsonDict]:
-        """Get a list of annotations on the event, grouped by event type and
+    async def get_annotations_for_events(
+        self, event_ids: Collection[str], ignored_users: FrozenSet[str] = frozenset()
+    ) -> Dict[str, List[JsonDict]]:
+        """Get a list of annotations to the given events, grouped by event type and
         aggregation key, sorted by count.
 
-        This is used e.g. to get the what and how many reactions have happend
+        This is used e.g. to get the what and how many reactions have happened
         on an event.
 
         Args:
-            event_id: Fetch events that relate to this event ID.
-            room_id: The room the event belongs to.
-            limit: Only fetch the `limit` groups.
+            event_ids: Fetch events that relate to these event IDs.
             ignored_users: The users ignored by the requesting user.
 
         Returns:
-            List of groups of annotations that match. Each row is a dict with
-            `type`, `key` and `count` fields.
+            A map of event IDs to a list of groups of annotations that match.
+            Each entry is a dict with `type`, `key` and `count` fields.
         """
         # Get the base results for all users.
-        full_results = await self._main_store.get_aggregation_groups_for_event(
-            event_id, room_id, limit
+        full_results = await self._main_store.get_aggregation_groups_for_events(
+            event_ids
         )
 
+        # Avoid additional logic if there are no ignored users.
+        if not ignored_users:
+            return {
+                event_id: results
+                for event_id, results in full_results.items()
+                if results
+            }
+
         # Then subtract off the results for any ignored users.
         ignored_results = await self._main_store.get_aggregation_groups_for_users(
-            event_id, room_id, limit, ignored_users
+            [event_id for event_id, results in full_results.items() if results],
+            ignored_users,
         )
 
-        filtered_results = []
-        for result in full_results:
-            key = (result["type"], result["key"])
-            if key in ignored_results:
-                result = result.copy()
-                result["count"] -= ignored_results[key]
-                if result["count"] <= 0:
-                    continue
-            filtered_results.append(result)
+        filtered_results = {}
+        for event_id, results in full_results.items():
+            # If no annotations, skip.
+            if not results:
+                continue
+
+            # If there are not ignored results for this event, copy verbatim.
+            if event_id not in ignored_results:
+                filtered_results[event_id] = results
+                continue
+
+            # Otherwise, subtract out the ignored results.
+            event_ignored_results = ignored_results[event_id]
+            for result in results:
+                key = (result["type"], result["key"])
+                if key in event_ignored_results:
+                    # Ensure to not modify the cache.
+                    result = result.copy()
+                    result["count"] -= event_ignored_results[key]
+                    if result["count"] <= 0:
+                        continue
+                filtered_results.setdefault(event_id, []).append(result)
 
         return filtered_results
 
@@ -366,59 +391,62 @@ class RelationsHandler:
         results = {}
 
         for event_id, summary in summaries.items():
-            if summary:
-                thread_count, latest_thread_event = summary
-
-                # Subtract off the count of any ignored users.
-                for ignored_user in ignored_users:
-                    thread_count -= ignored_results.get((event_id, ignored_user), 0)
-
-                # This is gnarly, but if the latest event is from an ignored user,
-                # attempt to find one that isn't from an ignored user.
-                if latest_thread_event.sender in ignored_users:
-                    room_id = latest_thread_event.room_id
-
-                    # If the root event is not found, something went wrong, do
-                    # not include a summary of the thread.
-                    event = await self._event_handler.get_event(user, room_id, event_id)
-                    if event is None:
-                        continue
+            # If no thread, skip.
+            if not summary:
+                continue
 
-                    potential_events, _ = await self.get_relations_for_event(
-                        event_id,
-                        event,
-                        room_id,
-                        RelationTypes.THREAD,
-                        ignored_users,
-                    )
+            thread_count, latest_thread_event = summary
 
-                    # If all found events are from ignored users, do not include
-                    # a summary of the thread.
-                    if not potential_events:
-                        continue
+            # Subtract off the count of any ignored users.
+            for ignored_user in ignored_users:
+                thread_count -= ignored_results.get((event_id, ignored_user), 0)
 
-                    # The *last* event returned is the one that is cared about.
-                    event = await self._event_handler.get_event(
-                        user, room_id, potential_events[-1].event_id
-                    )
-                    # It is unexpected that the event will not exist.
-                    if event is None:
-                        logger.warning(
-                            "Unable to fetch latest event in a thread with event ID: %s",
-                            potential_events[-1].event_id,
-                        )
-                        continue
-                    latest_thread_event = event
-
-                results[event_id] = _ThreadAggregation(
-                    latest_event=latest_thread_event,
-                    count=thread_count,
-                    # If there's a thread summary it must also exist in the
-                    # participated dictionary.
-                    current_user_participated=events_by_id[event_id].sender == user_id
-                    or participated[event_id],
+            # This is gnarly, but if the latest event is from an ignored user,
+            # attempt to find one that isn't from an ignored user.
+            if latest_thread_event.sender in ignored_users:
+                room_id = latest_thread_event.room_id
+
+                # If the root event is not found, something went wrong, do
+                # not include a summary of the thread.
+                event = await self._event_handler.get_event(user, room_id, event_id)
+                if event is None:
+                    continue
+
+                potential_events, _ = await self.get_relations_for_event(
+                    event_id,
+                    event,
+                    room_id,
+                    RelationTypes.THREAD,
+                    ignored_users,
                 )
 
+                # If all found events are from ignored users, do not include
+                # a summary of the thread.
+                if not potential_events:
+                    continue
+
+                # The *last* event returned is the one that is cared about.
+                event = await self._event_handler.get_event(
+                    user, room_id, potential_events[-1].event_id
+                )
+                # It is unexpected that the event will not exist.
+                if event is None:
+                    logger.warning(
+                        "Unable to fetch latest event in a thread with event ID: %s",
+                        potential_events[-1].event_id,
+                    )
+                    continue
+                latest_thread_event = event
+
+            results[event_id] = _ThreadAggregation(
+                latest_event=latest_thread_event,
+                count=thread_count,
+                # If there's a thread summary it must also exist in the
+                # participated dictionary.
+                current_user_participated=events_by_id[event_id].sender == user_id
+                or participated[event_id],
+            )
+
         return results
 
     @trace
@@ -496,17 +524,18 @@ class RelationsHandler:
                 # (as that is what makes it part of the thread).
                 relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
 
-        # Fetch other relations per event.
-        for event in events_by_id.values():
-            # Fetch any annotations (ie, reactions) to bundle with this event.
-            annotations = await self.get_annotations_for_event(
-                event.event_id, event.room_id, ignored_users=ignored_users
-            )
+        # Fetch any annotations (ie, reactions) to bundle with this event.
+        annotations_by_event_id = await self.get_annotations_for_events(
+            events_by_id.keys(), ignored_users=ignored_users
+        )
+        for event_id, annotations in annotations_by_event_id.items():
             if annotations:
-                results.setdefault(
-                    event.event_id, BundledAggregations()
-                ).annotations = {"chunk": annotations}
+                results.setdefault(event_id, BundledAggregations()).annotations = {
+                    "chunk": annotations
+                }
 
+        # Fetch other relations per event.
+        for event in events_by_id.values():
             # Fetch any references to bundle with this event.
             references, next_token = await self.get_relations_for_event(
                 event.event_id,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index ca431002c8..f96a16956a 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,6 +20,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -394,106 +395,136 @@ class RelationsWorkerStore(SQLBaseStore):
         )
         return result is not None
 
-    @cached(tree=True)
-    async def get_aggregation_groups_for_event(
-        self, event_id: str, room_id: str, limit: int = 5
-    ) -> List[JsonDict]:
-        """Get a list of annotations on the event, grouped by event type and
+    @cached()
+    async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
+    )
+    async def get_aggregation_groups_for_events(
+        self, event_ids: Collection[str]
+    ) -> Mapping[str, Optional[List[JsonDict]]]:
+        """Get a list of annotations on the given events, grouped by event type and
         aggregation key, sorted by count.
 
         This is used e.g. to get the what and how many reactions have happend
         on an event.
 
         Args:
-            event_id: Fetch events that relate to this event ID.
-            room_id: The room the event belongs to.
-            limit: Only fetch the `limit` groups.
+            event_ids: Fetch events that relate to these event IDs.
 
         Returns:
-            List of groups of annotations that match. Each row is a dict with
-            `type`, `key` and `count` fields.
+            A map of event IDs to a list of groups of annotations that match.
+            Each entry is a dict with `type`, `key` and `count` fields.
         """
+        # The number of entries to return per event ID.
+        limit = 5
 
-        args = [
-            event_id,
-            room_id,
-            RelationTypes.ANNOTATION,
-            limit,
-        ]
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "relates_to_id", event_ids
+        )
+        args.append(RelationTypes.ANNOTATION)
 
-        sql = """
-            SELECT type, aggregation_key, COUNT(DISTINCT sender)
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
-            GROUP BY relation_type, type, aggregation_key
-            ORDER BY COUNT(*) DESC
-            LIMIT ?
+        sql = f"""
+            SELECT
+                relates_to_id,
+                annotation.type,
+                aggregation_key,
+                COUNT(DISTINCT annotation.sender)
+            FROM events AS annotation
+            INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS parent ON
+                parent.event_id = relates_to_id
+                AND parent.room_id = annotation.room_id
+            WHERE
+                {clause}
+                AND relation_type = ?
+            GROUP BY relates_to_id, annotation.type, aggregation_key
+            ORDER BY relates_to_id, COUNT(*) DESC
         """
 
-        def _get_aggregation_groups_for_event_txn(
+        def _get_aggregation_groups_for_events_txn(
             txn: LoggingTransaction,
-        ) -> List[JsonDict]:
+        ) -> Mapping[str, List[JsonDict]]:
             txn.execute(sql, args)
 
-            return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
+            result: Dict[str, List[JsonDict]] = {}
+            for event_id, type, key, count in cast(
+                List[Tuple[str, str, str, int]], txn
+            ):
+                event_results = result.setdefault(event_id, [])
+
+                # Limit the number of results per event ID.
+                if len(event_results) == limit:
+                    continue
+
+                event_results.append({"type": type, "key": key, "count": count})
+
+            return result
 
         return await self.db_pool.runInteraction(
-            "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+            "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn
         )
 
     async def get_aggregation_groups_for_users(
-        self,
-        event_id: str,
-        room_id: str,
-        limit: int,
-        users: FrozenSet[str] = frozenset(),
-    ) -> Dict[Tuple[str, str], int]:
+        self, event_ids: Collection[str], users: FrozenSet[str]
+    ) -> Dict[str, Dict[Tuple[str, str], int]]:
         """Fetch the partial aggregations for an event for specific users.
 
         This is used, in conjunction with get_aggregation_groups_for_event, to
         remove information from the results for ignored users.
 
         Args:
-            event_id: Fetch events that relate to this event ID.
-            room_id: The room the event belongs to.
-            limit: Only fetch the `limit` groups.
+            event_ids: Fetch events that relate to these event IDs.
             users: The users to fetch information for.
 
         Returns:
-            A map of (event type, aggregation key) to a count of users.
+            A map of event ID to a map of (event type, aggregation key) to a
+            count of users.
         """
 
         if not users:
             return {}
 
-        args: List[Union[str, int]] = [
-            event_id,
-            room_id,
-            RelationTypes.ANNOTATION,
-        ]
+        events_sql, args = make_in_list_sql_clause(
+            self.database_engine, "relates_to_id", event_ids
+        )
 
         users_sql, users_args = make_in_list_sql_clause(
-            self.database_engine, "sender", users
+            self.database_engine, "annotation.sender", users
         )
         args.extend(users_args)
+        args.append(RelationTypes.ANNOTATION)
 
         sql = f"""
-            SELECT type, aggregation_key, COUNT(DISTINCT sender)
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
-            GROUP BY relation_type, type, aggregation_key
-            ORDER BY COUNT(*) DESC
-            LIMIT ?
+            SELECT
+                relates_to_id,
+                annotation.type,
+                aggregation_key,
+                COUNT(DISTINCT annotation.sender)
+            FROM events AS annotation
+            INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS parent ON
+                parent.event_id = relates_to_id
+                AND parent.room_id = annotation.room_id
+            WHERE {events_sql} AND {users_sql} AND relation_type = ?
+            GROUP BY relates_to_id, annotation.type, aggregation_key
+            ORDER BY relates_to_id, COUNT(*) DESC
         """
 
         def _get_aggregation_groups_for_users_txn(
             txn: LoggingTransaction,
-        ) -> Dict[Tuple[str, str], int]:
-            txn.execute(sql, args + [limit])
+        ) -> Dict[str, Dict[Tuple[str, str], int]]:
+            txn.execute(sql, args)
 
-            return {(row[0], row[1]): row[2] for row in txn}
+            result: Dict[str, Dict[Tuple[str, str], int]] = {}
+            for event_id, type, key, count in cast(
+                List[Tuple[str, str, str, int]], txn
+            ):
+                result.setdefault(event_id, {})[(type, key)] = count
+
+            return result
 
         return await self.db_pool.runInteraction(
             "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 75428d19ba..72227359b9 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -503,7 +503,7 @@ def cachedList(
     is specified as a list that is iterated through to lookup keys in the
     original cache. A new tuple consisting of the (deduplicated) keys that weren't in
     the cache gets passed to the original function, which is expected to results
-    in a map of key to value for each passed value. THe new results are stored in the
+    in a map of key to value for each passed value. The new results are stored in the
     original cache. Note that any missing values are cached as None.
 
     Args:
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index e3d801f7a8..2d2b683548 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1108,7 +1108,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
 
         # The "user" sent the root event and is making queries for the bundled
         # aggregations: they have participated.
-        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 9)
+        self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
         # The "user2" sent replies in the thread and is making queries for the
         # bundled aggregations: they have participated.
         #
@@ -1170,7 +1170,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations["latest_event"].get("unsigned"),
             )
 
-        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
 
     def test_nested_thread(self) -> None:
         """