summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py154
1 files changed, 144 insertions, 10 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 64a7808140..407158ceee 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -17,6 +17,7 @@ from typing import (
     TYPE_CHECKING,
     Collection,
     Dict,
+    FrozenSet,
     Iterable,
     List,
     Optional,
@@ -26,6 +27,8 @@ from typing import (
     cast,
 )
 
+import attr
+
 from synapse.api.constants import RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
@@ -46,6 +49,19 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _RelatedEvent:
+    """
+    Contains enough information about a related event in order to properly filter
+    events from ignored users.
+    """
+
+    # The event ID of the related event.
+    event_id: str
+    # The sender of the related event.
+    sender: str
+
+
 class RelationsWorkerStore(SQLBaseStore):
     def __init__(
         self,
@@ -70,7 +86,7 @@ class RelationsWorkerStore(SQLBaseStore):
         direction: str = "b",
         from_token: Optional[StreamToken] = None,
         to_token: Optional[StreamToken] = None,
-    ) -> Tuple[List[str], Optional[StreamToken]]:
+    ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
         """Get a list of relations for an event, ordered by topological ordering.
 
         Args:
@@ -88,7 +104,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Returns:
             A tuple of:
-                A list of related event IDs
+                A list of related event IDs & their senders.
 
                 The next stream token, if one exists.
         """
@@ -131,7 +147,7 @@ class RelationsWorkerStore(SQLBaseStore):
             order = "ASC"
 
         sql = """
-            SELECT event_id, relation_type, topological_ordering, stream_ordering
+            SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
             FROM event_relations
             INNER JOIN events USING (event_id)
             WHERE %s
@@ -145,7 +161,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         def _get_recent_references_for_event_txn(
             txn: LoggingTransaction,
-        ) -> Tuple[List[str], Optional[StreamToken]]:
+        ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
             txn.execute(sql, where_args + [limit + 1])
 
             last_topo_id = None
@@ -155,9 +171,9 @@ class RelationsWorkerStore(SQLBaseStore):
                 # Do not include edits for redacted events as they leak event
                 # content.
                 if not is_redacted or row[1] != RelationTypes.REPLACE:
-                    events.append(row[0])
-                last_topo_id = row[2]
-                last_stream_id = row[3]
+                    events.append(_RelatedEvent(row[0], row[2]))
+                last_topo_id = row[3]
+                last_stream_id = row[4]
 
             # If there are more events, generate the next pagination key.
             next_token = None
@@ -267,7 +283,7 @@ class RelationsWorkerStore(SQLBaseStore):
             `type`, `key` and `count` fields.
         """
 
-        where_args = [
+        args = [
             event_id,
             room_id,
             RelationTypes.ANNOTATION,
@@ -287,7 +303,7 @@ class RelationsWorkerStore(SQLBaseStore):
         def _get_aggregation_groups_for_event_txn(
             txn: LoggingTransaction,
         ) -> List[JsonDict]:
-            txn.execute(sql, where_args)
+            txn.execute(sql, args)
 
             return [{"type": row[0], "key": row[1], "count": row[2]} for row in txn]
 
@@ -295,6 +311,63 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_aggregation_groups_for_event", _get_aggregation_groups_for_event_txn
         )
 
+    async def get_aggregation_groups_for_users(
+        self,
+        event_id: str,
+        room_id: str,
+        limit: int,
+        users: FrozenSet[str] = frozenset(),
+    ) -> Dict[Tuple[str, str], int]:
+        """Fetch the partial aggregations for an event for specific users.
+
+        This is used, in conjunction with get_aggregation_groups_for_event, to
+        remove information from the results for ignored users.
+
+        Args:
+            event_id: Fetch events that relate to this event ID.
+            room_id: The room the event belongs to.
+            limit: Only fetch the `limit` groups.
+            users: The users to fetch information for.
+
+        Returns:
+            A map of (event type, aggregation key) to a count of users.
+        """
+
+        if not users:
+            return {}
+
+        args: List[Union[str, int]] = [
+            event_id,
+            room_id,
+            RelationTypes.ANNOTATION,
+        ]
+
+        users_sql, users_args = make_in_list_sql_clause(
+            self.database_engine, "sender", users
+        )
+        args.extend(users_args)
+
+        sql = f"""
+            SELECT type, aggregation_key, COUNT(DISTINCT sender)
+            FROM event_relations
+            INNER JOIN events USING (event_id)
+            WHERE relates_to_id = ? AND room_id = ? AND relation_type = ? AND {users_sql}
+            GROUP BY relation_type, type, aggregation_key
+            ORDER BY COUNT(*) DESC
+            LIMIT ?
+        """
+
+        def _get_aggregation_groups_for_users_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[Tuple[str, str], int]:
+            txn.execute(sql, args + [limit])
+
+            return {(row[0], row[1]): row[2] for row in txn}
+
+        return await self.db_pool.runInteraction(
+            "get_aggregation_groups_for_users", _get_aggregation_groups_for_users_txn
+        )
+
     @cached()
     def get_applicable_edit(self, event_id: str) -> Optional[EventBase]:
         raise NotImplementedError()
@@ -521,6 +594,67 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return summaries
 
+    async def get_threaded_messages_per_user(
+        self,
+        event_ids: Collection[str],
+        users: FrozenSet[str] = frozenset(),
+    ) -> Dict[Tuple[str, str], int]:
+        """Get the number of threaded replies for a set of users.
+
+        This is used, in conjunction with get_thread_summaries, to calculate an
+        accurate count of the replies to a thread by subtracting ignored users.
+
+        Args:
+            event_ids: The events to check for threaded replies.
+            users: The user to calculate the count of their replies.
+
+        Returns:
+            A map of the (event_id, sender) to the count of their replies.
+        """
+        if not users:
+            return {}
+
+        # Fetch the number of threaded replies.
+        sql = """
+            SELECT parent.event_id, child.sender, COUNT(child.event_id) FROM events AS child
+            INNER JOIN event_relations USING (event_id)
+            INNER JOIN events AS parent ON
+                parent.event_id = relates_to_id
+                AND parent.room_id = child.room_id
+            WHERE
+                %s
+                AND %s
+                AND %s
+            GROUP BY parent.event_id, child.sender
+        """
+
+        def _get_threaded_messages_per_user_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[Tuple[str, str], int]:
+            users_sql, users_args = make_in_list_sql_clause(
+                self.database_engine, "child.sender", users
+            )
+            events_clause, events_args = make_in_list_sql_clause(
+                txn.database_engine, "relates_to_id", event_ids
+            )
+
+            if self._msc3440_enabled:
+                relations_clause = "(relation_type = ? OR relation_type = ?)"
+                relations_args = [RelationTypes.THREAD, RelationTypes.UNSTABLE_THREAD]
+            else:
+                relations_clause = "relation_type = ?"
+                relations_args = [RelationTypes.THREAD]
+
+            txn.execute(
+                sql % (users_sql, events_clause, relations_clause),
+                users_args + events_args + relations_args,
+            )
+            return {(row[0], row[1]): row[2] for row in txn}
+
+        return await self.db_pool.runInteraction(
+            "get_threaded_messages_per_user", _get_threaded_messages_per_user_txn
+        )
+
     @cached()
     def get_thread_participated(self, event_id: str, user_id: str) -> bool:
         raise NotImplementedError()
@@ -608,7 +742,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 %s;
         """
 
-        def _get_if_events_have_relations(txn) -> List[str]:
+        def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]:
             clauses: List[str] = []
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", parent_ids