summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/relations.py139
1 files changed, 85 insertions, 54 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index ca431002c8..f96a16956a 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -20,6 +20,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -394,106 +395,136 @@ class RelationsWorkerStore(SQLBaseStore):
         )
         return result is not None
 
-    @cached(tree=True)
-    async def get_aggregation_groups_for_event(
-        self, event_id: str, room_id: str, limit: int = 5
-    ) -> List[JsonDict]:
-        """Get a list of annotations on the event, grouped by event type and
+    @cached()
+    async def get_aggregation_groups_for_event(self, event_id: str) -> List[JsonDict]:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="get_aggregation_groups_for_event", list_name="event_ids"
+    )
+    async def get_aggregation_groups_for_events(
+        self, event_ids: Collection[str]
+    ) -> Mapping[str, Optional[List[JsonDict]]]:
+        """Get a list of annotations on the given events, grouped by event type and
         aggregation key, sorted by count.
 
         This is used e.g. to get the what and how many reactions have happend
         on an event.
 
         Args:
-            event_id: Fetch events that relate to this event ID.
-            room_id: The room the event belongs to.
-            limit: Only fetch the `limit` groups.
+            event_ids: Fetch events that relate to these event IDs.
 
         Returns:
-            List of groups of annotations that match. Each row is a dict with
-            `type`, `key` and `count` fields.
+            A map of event IDs to a list of groups of annotations that match.
+            Each entry is a dict with `type`, `key` and `count` fields.
         """
+        # The number of entries to return per event ID.
+        limit = 5
 
-        args = [
-            event_id,
-            room_id,
-            RelationTypes.ANNOTATION,
-            limit,
-        ]
+        clause, args = make_in_list_sql_clause(
+            self.database_engine, "relates_to_id", event_ids
+        )
+        args.append(RelationTypes.ANNOTATION)
 
-        sql = """
-            SELECT type, aggregation_key, COUNT(DISTINCT sender)
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE relates_to_id = ? AND room_id = ? AND relation_type = ?
-            GROUP BY relation_type, type, aggregation_key
-            ORDER BY COUNT(*) DESC
-            LIMIT ?
+        sql = f"""
+            SELECT
+                relates_to_id,
+                annotation.type,
+                aggregation_key,
+                COUNT(DISTINCT annotation.sender)
+            FROM events AS annotation
+            INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS parent ON
+                parent.event_id = relates_to_id
+                AND parent.room_id = annotation.room_id
+            WHERE
+                {clause}
+                AND relation_type = ?
+            GROUP BY relates_to_id, annotation.type, aggregation_key
+            ORDER BY relates_to_id, COUNT(*) DESC
         """
 
-        def _get_aggregation_groups_for_event_txn(
+        def _get_aggregation_groups_for_events_txn(
             txn: LoggingTransaction,
-        ) -> List[JsonDict]:
+        ) -> Mapping[str, List[JsonDict]]:
             txn.execute(sql, args)
 
-            return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
+            result: Dict[str, List[JsonDict]] = {}
+            for event_id, type, key, count in cast(
+                List[Tuple[str, str, str, int]], txn
+            ):
+                event_results = result.setdefault(event_id, [])
+
+                # Limit the number of results per event ID.
+                if len(event_results) == limit:
+                    continue
+
+                event_results.append({"type": type, "key": key, "count": count})
+
+            return result
 
         return await self.db_pool.runInteraction(
-            "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
+            "get_aggregation_groups_for_events", _get_aggregation_groups_for_events_txn
         )
 
     async def get_aggregation_groups_for_users(
-        self,
-        event_id: str,
-        room_id: str,
-        limit: int,
-        users: FrozenSet[str] = frozenset(),
-    ) -> Dict[Tuple[str, str], int]:
+        self, event_ids: Collection[str], users: FrozenSet[str]
+    ) -> Dict[str, Dict[Tuple[str, str], int]]:
         """Fetch the partial aggregations for an event for specific users.
 
         This is used, in conjunction with get_aggregation_groups_for_event, to
         remove information from the results for ignored users.
 
         Args:
-            event_id: Fetch events that relate to this event ID.
-            room_id: The room the event belongs to.
-            limit: Only fetch the `limit` groups.
+            event_ids: Fetch events that relate to these event IDs.
             users: The users to fetch information for.
 
         Returns:
-            A map of (event type, aggregation key) to a count of users.
+            A map of event ID to a map of (event type, aggregation key) to a
+            count of users.
         """
 
         if not users:
             return {}
 
-        args: List[Union[str, int]] = [
-            event_id,
-            room_id,
-            RelationTypes.ANNOTATION,
-        ]
+        events_sql, args = make_in_list_sql_clause(
+            self.database_engine, "relates_to_id", event_ids
+        )
 
         users_sql, users_args = make_in_list_sql_clause(
-            self.database_engine, "sender", users
+            self.database_engine, "annotation.sender", users
         )
         args.extend(users_args)
+        args.append(RelationTypes.ANNOTATION)
 
         sql = f"""
-            SELECT type, aggregation_key, COUNT(DISTINCT sender)
-            FROM event_relations
-            INNER JOIN events USING (event_id)
-            WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
-            GROUP BY relation_type, type, aggregation_key
-            ORDER BY COUNT(*) DESC
-            LIMIT ?
+            SELECT
+                relates_to_id,
+                annotation.type,
+                aggregation_key,
+                COUNT(DISTINCT annotation.sender)
+            FROM events AS annotation
+            INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS parent ON
+                parent.event_id = relates_to_id
+                AND parent.room_id = annotation.room_id
+            WHERE {events_sql} AND {users_sql} AND relation_type = ?
+            GROUP BY relates_to_id, annotation.type, aggregation_key
+            ORDER BY relates_to_id, COUNT(*) DESC
         """
 
         def _get_aggregation_groups_for_users_txn(
             txn: LoggingTransaction,
-        ) -> Dict[Tuple[str, str], int]:
-            txn.execute(sql, args + [limit])
+        ) -> Dict[str, Dict[Tuple[str, str], int]]:
+            txn.execute(sql, args)
 
-            return {(row[0], row[1]): row[2] for row in txn}
+            result: Dict[str, Dict[Tuple[str, str], int]] = {}
+            for event_id, type, key, count in cast(
+                List[Tuple[str, str, str, int]], txn
+            ):
+                result.setdefault(event_id, {})[(type, key)] = count
+
+            return result
 
         return await self.db_pool.runInteraction(
             "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn