summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-01-07 09:10:46 -0500
committerGitHub <noreply@github.com>2022-01-07 09:10:46 -0500
commit6bf81a7a61d8d5248be5def955104c44fcb78dae (patch)
tree2e1222879c207d00a2545f0a3962c31a0b801a9a /synapse/storage/databases/main/relations.py
parentRemove the /send_relation endpoint. (#11682) (diff)
downloadsynapse-6bf81a7a61d8d5248be5def955104c44fcb78dae.tar.xz
Bundle aggregations outside of the serialization method. (#11612)
This makes the serialization of events synchronous (and it no
longer access the database), but we must manually calculate and
provide the bundled aggregations.

Overall this should cause no change in behavior, but is prep work
for other improvements.
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py128
1 files changed, 125 insertions, 3 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 4ff6aed253..c6c4bd18da 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,14 +13,30 @@
 # limitations under the License.
 
 import logging
-from typing import List, Optional, Tuple, Union, cast
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Tuple,
+    Union,
+    cast,
+)
 
 import attr
+from frozendict import frozendict
 
-from synapse.api.constants import RelationTypes
+from synapse.api.constants import EventTypes, RelationTypes
 from synapse.events import EventBase
 from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction, make_in_list_sql_clause
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+    make_in_list_sql_clause,
+)
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.relations import (
     AggregationPaginationToken,
@@ -29,10 +45,24 @@ from synapse.storage.relations import (
 )
 from synapse.util.caches.descriptors import cached
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
 class RelationsWorkerStore(SQLBaseStore):
+    def __init__(
+        self,
+        database: DatabasePool,
+        db_conn: LoggingDatabaseConnection,
+        hs: "HomeServer",
+    ):
+        super().__init__(database, db_conn, hs)
+
+        self._msc1849_enabled = hs.config.experimental.msc1849_enabled
+        self._msc3440_enabled = hs.config.experimental.msc3440_enabled
+
     @cached(tree=True)
     async def get_relations_for_event(
         self,
@@ -515,6 +545,98 @@ class RelationsWorkerStore(SQLBaseStore):
             "get_if_user_has_annotated_event", _get_if_user_has_annotated_event
         )
 
+    async def _get_bundled_aggregation_for_event(
+        self, event: EventBase
+    ) -> Optional[Dict[str, Any]]:
+        """Generate bundled aggregations for an event.
+
+        Note that this does not use a cache, but depends on cached methods.
+
+        Args:
+            event: The event to calculate bundled aggregations for.
+
+        Returns:
+            The bundled aggregations for an event, if bundled aggregations are
+            enabled and the event can have bundled aggregations.
+        """
+        # State events and redacted events do not get bundled aggregations.
+        if event.is_state() or event.internal_metadata.is_redacted():
+            return None
+
+        # Do not bundle aggregations for an event which represents an edit or an
+        # annotation. It does not make sense for them to have related events.
+        relates_to = event.content.get("m.relates_to")
+        if isinstance(relates_to, (dict, frozendict)):
+            relation_type = relates_to.get("rel_type")
+            if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+                return None
+
+        event_id = event.event_id
+        room_id = event.room_id
+
+        # The bundled aggregations to include, a mapping of relation type to a
+        # type-specific value. Some types include the direct return type here
+        # while others need more processing during serialization.
+        aggregations: Dict[str, Any] = {}
+
+        annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
+        if annotations.chunk:
+            aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+
+        references = await self.get_relations_for_event(
+            event_id, room_id, RelationTypes.REFERENCE, direction="f"
+        )
+        if references.chunk:
+            aggregations[RelationTypes.REFERENCE] = references.to_dict()
+
+        edit = None
+        if event.type == EventTypes.Message:
+            edit = await self.get_applicable_edit(event_id, room_id)
+
+        if edit:
+            aggregations[RelationTypes.REPLACE] = edit
+
+        # If this event is the start of a thread, include a summary of the replies.
+        if self._msc3440_enabled:
+            (
+                thread_count,
+                latest_thread_event,
+            ) = await self.get_thread_summary(event_id, room_id)
+            if latest_thread_event:
+                aggregations[RelationTypes.THREAD] = {
+                    # Don't bundle aggregations as this could recurse forever.
+                    "latest_event": latest_thread_event,
+                    "count": thread_count,
+                }
+
+        # Store the bundled aggregations in the event metadata for later use.
+        return aggregations
+
+    async def get_bundled_aggregations(
+        self, events: Iterable[EventBase]
+    ) -> Dict[str, Dict[str, Any]]:
+        """Generate bundled aggregations for events.
+
+        Args:
+            events: The iterable of events to calculate bundled aggregations for.
+
+        Returns:
+            A map of event ID to the bundled aggregation for the event. Not all
+            events may have bundled aggregations in the results.
+        """
+        # If bundled aggregations are disabled, nothing to do.
+        if not self._msc1849_enabled:
+            return {}
+
+        # TODO Parallelize.
+        results = {}
+        for event in events:
+            event_result = await self._get_bundled_aggregation_for_event(event)
+            if event_result is not None:
+                results[event.event_id] = event_result
+
+        return results
+
 
 class RelationsStore(RelationsWorkerStore):
     pass