summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py178
1 files changed, 172 insertions, 6 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 4ff6aed253..2cb5d06c13 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,14 +13,30 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
 
 import attr
+from frozendict import frozendict
 
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.relations import (
     AggregationPaginationToken,
@@ -29,10 +45,24 @@ from synapse.storage.relations import (
 )
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class RelationsWorkerStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._msc1849_enabled = hs.config.experimental.msc1849_enabled
+        self._msc3440_enabled = hs.config.experimental.msc3440_enabled
+
     @cached(tree=True)
     async def get_relations_for_event(
         self,
@@ -354,8 +384,7 @@ class RelationsWorkerStore(SQLBaseStore):
     async def get_thread_summary(
         self, event_id: str, room_id: str
     ) -> Tuple[int, Optional[EventBase]]:
-        """Get the number of threaded replies, the senders of those replies, and
-        the latest reply (if any) for the given event.
+        """Get the number of threaded replies and the latest reply (if any) for the given event.
 
         Args:
             event_id: Summarize the thread related to this event ID.
@@ -368,7 +397,7 @@ class RelationsWorkerStore(SQLBaseStore):
         def _get_thread_summary_txn(
             txn: LoggingTransaction,
         ) -> Tuple[int, Optional[str]]:
-            # Fetch the count of threaded events and the latest event ID.
+            # Fetch the latest event ID in the thread.
             # TODO Should this only allow m.room.message events.
             sql = """
                 SELECT event_id
@@ -389,6 +418,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
             latest_event_id = row[0]
 
+            # Fetch the number of threaded replies.
             sql = """
                 SELECT COUNT(event_id)
                 FROM event_relations
@@ -413,6 +443,44 @@ class RelationsWorkerStore(SQLBaseStore):
 
         return count, latest_event
 
+    @cached()
+    async def get_thread_participated(
+        self, event_id: str, room_id: str, user_id: str
+    ) -> bool:
+        """Get whether the requesting user participated in a thread.
+
+        This is separate from get_thread_summary since that can be cached across
+        all users while this value is specific to the requeser.
+
+        Args:
+            event_id: The thread related to this event ID.
+            room_id: The room the event belongs to.
+            user_id: The user requesting the summary.
+
+        Returns:
+            True if the requesting user participated in the thread, otherwise false.
+        """
+
+        def _get_thread_summary_txn(txn: LoggingTransaction) -> bool:
+            # Fetch whether the requester has participated or not.
+            sql = """
+                SELECT 1
+                FROM event_relations
+                INNER JOIN events USING (event_id)
+                WHERE
+                    relates_to_id = ?
+                    AND room_id = ?
+                    AND relation_type = ?
+                    AND sender = ?
+            """
+
+            txn.execute(sql, (event_id, room_id, RelationTypes.THREAD, user_id))
+            return bool(txn.fetchone())
+
+        return await self.db_pool.runInteraction(
+            "get_thread_summary", _get_thread_summary_txn
+        )
+
     async def events_have_relations(
         self,
         parent_ids: List[str],
@@ -515,6 +583,104 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
+    async def _get_bundled_aggregation_for_event(
+        self, event: EventBase, user_id: str
+    ) -> Optional[Dict[str, Any]]:
+        """Generate bundled aggregations for an event.
+
+        Note that this does not use a cache, but depends on cached methods.
+
+        Args:
+            event: The event to calculate bundled aggregations for.
+            user_id: The user requesting the bundled aggregations.
+
+        Returns:
+            The bundled aggregations for an event, if bundled aggregations are
+            enabled and the event can have bundled aggregations.
+        """
+        # State events and redacted events do not get bundled aggregations.
+        if event.is_state() or event.internal_metadata.is_redacted():
+            return None
+
+        # Do not bundle aggregations for an event which represents an edit or an
+        # annotation. It does not make sense for them to have related events.
+        relates_to = event.content.get("m.relates_to")
+        if isinstance(relates_to, (dict, frozendict)):
+            relation_type = relates_to.get("rel_type")
+            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+                return None
+
+        event_id = event.event_id
+        room_id = event.room_id
+
+        # The bundled aggregations to include, a mapping of relation type to a
+        # type-specific value. Some types include the direct return type here
+        # while others need more processing during serialization.
+        aggregations: Dict[str, Any] = {}
+
+        annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
+        if annotations.chunk:
+            aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+        references = await self.get_relations_for_event(
+            event_id, room_id, RelationTypes.REFERENCE, direction="f"
+        )
+        if references.chunk:
+            aggregations[RelationTypes.REFERENCE] = references.to_dict()
+
+        edit = None
+        if event.type == EventTypes.Message:
+            edit = await self.get_applicable_edit(event_id, room_id)
+
+        if edit:
+            aggregations[RelationTypes.REPLACE] = edit
+
+        # If this event is the start of a thread, include a summary of the replies.
+        if self._msc3440_enabled:
+            thread_count, latest_thread_event = await self.get_thread_summary(
+                event_id, room_id
+            )
+            participated = await self.get_thread_participated(
+                event_id, room_id, user_id
+            )
+            if latest_thread_event:
+                aggregations[RelationTypes.THREAD] = {
+                    "latest_event": latest_thread_event,
+                    "count": thread_count,
+                    "current_user_participated": participated,
+                }
+
+        # Store the bundled aggregations in the event metadata for later use.
+        return aggregations
+
+    async def get_bundled_aggregations(
+        self,
+        events: Iterable[EventBase],
+        user_id: str,
+    ) -> Dict[str, Dict[str, Any]]:
+        """Generate bundled aggregations for events.
+
+        Args:
+            events: The iterable of events to calculate bundled aggregations for.
+            user_id: The user requesting the bundled aggregations.
+
+        Returns:
+            A map of event ID to the bundled aggregation for the event. Not all
+            events may have bundled aggregations in the results.
+        """
+        # If bundled aggregations are disabled, nothing to do.
+        if not self._msc1849_enabled:
+            return {}
+
+        # TODO Parallelize.
+        results = {}
+        for event in events:
+            event_result = await self._get_bundled_aggregation_for_event(event, user_id)
+            if event_result is not None:
+                results[event.event_id] = event_result
+
+        return results
+
 
 class RelationsStore(RelationsWorkerStore):
     pass