summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py128
1 files changed, 125 insertions, 3 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 4ff6aed253..c6c4bd18da 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,14 +13,30 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
 
 import attr
+from frozendict import frozendict
 
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.relations import (
     AggregationPaginationToken,
@@ -29,10 +45,24 @@ from synapse.storage.relations import (
 )
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class RelationsWorkerStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._msc1849_enabled = hs.config.experimental.msc1849_enabled
+        self._msc3440_enabled = hs.config.experimental.msc3440_enabled
+
     @cached(tree=True)
     async def get_relations_for_event(
         self,
@@ -515,6 +545,98 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
+    async def _get_bundled_aggregation_for_event(
+        self, event: EventBase
+    ) -> Optional[Dict[str, Any]]:
+        """Generate bundled aggregations for an event.
+
+        Note that this does not use a cache, but depends on cached methods.
+
+        Args:
+            event: The event to calculate bundled aggregations for.
+
+        Returns:
+            The bundled aggregations for an event, if bundled aggregations are
+            enabled and the event can have bundled aggregations.
+        """
+        # State events and redacted events do not get bundled aggregations.
+        if event.is_state() or event.internal_metadata.is_redacted():
+            return None
+
+        # Do not bundle aggregations for an event which represents an edit or an
+        # annotation. It does not make sense for them to have related events.
+        relates_to = event.content.get("m.relates_to")
+        if isinstance(relates_to, (dict, frozendict)):
+            relation_type = relates_to.get("rel_type")
+            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+                return None
+
+        event_id = event.event_id
+        room_id = event.room_id
+
+        # The bundled aggregations to include, a mapping of relation type to a
+        # type-specific value. Some types include the direct return type here
+        # while others need more processing during serialization.
+        aggregations: Dict[str, Any] = {}
+
+        annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
+        if annotations.chunk:
+            aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+        references = await self.get_relations_for_event(
+            event_id, room_id, RelationTypes.REFERENCE, direction="f"
+        )
+        if references.chunk:
+            aggregations[RelationTypes.REFERENCE] = references.to_dict()
+
+        edit = None
+        if event.type == EventTypes.Message:
+            edit = await self.get_applicable_edit(event_id, room_id)
+
+        if edit:
+            aggregations[RelationTypes.REPLACE] = edit
+
+        # If this event is the start of a thread, include a summary of the replies.
+        if self._msc3440_enabled:
+            (
+                thread_count,
+                latest_thread_event,
+            ) = await self.get_thread_summary(event_id, room_id)
+            if latest_thread_event:
+                aggregations[RelationTypes.THREAD] = {
+                    # Don't bundle aggregations as this could recurse forever.
+                    "latest_event": latest_thread_event,
+                    "count": thread_count,
+                }
+
+        # Store the bundled aggregations in the event metadata for later use.
+        return aggregations
+
+    async def get_bundled_aggregations(
+        self, events: Iterable[EventBase]
+    ) -> Dict[str, Dict[str, Any]]:
+        """Generate bundled aggregations for events.
+
+        Args:
+            events: The iterable of events to calculate bundled aggregations for.
+
+        Returns:
+            A map of event ID to the bundled aggregation for the event. Not all
+            events may have bundled aggregations in the results.
+        """
+        # If bundled aggregations are disabled, nothing to do.
+        if not self._msc1849_enabled:
+            return {}
+
+        # TODO Parallelize.
+        results = {}
+        for event in events:
+            event_result = await self._get_bundled_aggregation_for_event(event)
+            if event_result is not None:
+                results[event.event_id] = event_result
+
+        return results
+
 
 class RelationsStore(RelationsWorkerStore):
     pass