summary refs log tree commit diff
path: root/synapse/storage/databases/main/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/relations.py')
-rw-r--r--synapse/storage/databases/main/relations.py61
1 files changed, 38 insertions, 23 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2cb5d06c13..a9a5dd5f03 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,17 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-    Union,
-    cast,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
 
 import attr
 from frozendict import frozendict
@@ -43,6 +33,7 @@ from synapse.storage.relations import (
     PaginationChunk,
     RelationPaginationToken,
 )
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -51,6 +42,30 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+    latest_event: EventBase
+    count: int
+    current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+    """
+    The bundled aggregations for an event.
+
+    Some values require additional processing during serialization.
+    """
+
+    annotations: Optional[JsonDict] = None
+    references: Optional[JsonDict] = None
+    replace: Optional[EventBase] = None
+    thread: Optional[_ThreadAggregation] = None
+
+    def __bool__(self) -> bool:
+        return bool(self.annotations or self.references or self.replace or self.thread)
+
+
 class RelationsWorkerStore(SQLBaseStore):
     def __init__(
         self,
@@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
     async def _get_bundled_aggregation_for_event(
         self, event: EventBase, user_id: str
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[BundledAggregations]:
         """Generate bundled aggregations for an event.
 
         Note that this does not use a cache, but depends on cached methods.
@@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore):
         # The bundled aggregations to include, a mapping of relation type to a
         # type-specific value. Some types include the direct return type here
         # while others need more processing during serialization.
-        aggregations: Dict[str, Any] = {}
+        aggregations = BundledAggregations()
 
         annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
         if annotations.chunk:
-            aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+            aggregations.annotations = annotations.to_dict()
 
         references = await self.get_relations_for_event(
             event_id, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
-            aggregations[RelationTypes.REFERENCE] = references.to_dict()
+            aggregations.references = references.to_dict()
 
         edit = None
         if event.type == EventTypes.Message:
             edit = await self.get_applicable_edit(event_id, room_id)
 
         if edit:
-            aggregations[RelationTypes.REPLACE] = edit
+            aggregations.replace = edit
 
         # If this event is the start of a thread, include a summary of the replies.
         if self._msc3440_enabled:
@@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore):
                 event_id, room_id, user_id
             )
             if latest_thread_event:
-                aggregations[RelationTypes.THREAD] = {
-                    "latest_event": latest_thread_event,
-                    "count": thread_count,
-                    "current_user_participated": participated,
-                }
+                aggregations.thread = _ThreadAggregation(
+                    latest_event=latest_thread_event,
+                    count=thread_count,
+                    current_user_participated=participated,
+                )
 
         # Store the bundled aggregations in the event metadata for later use.
         return aggregations
@@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore):
         self,
         events: Iterable[EventBase],
         user_id: str,
-    ) -> Dict[str, Dict[str, Any]]:
+    ) -> Dict[str, BundledAggregations]:
         """Generate bundled aggregations for events.
 
         Args:
@@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore):
         results = {}
         for event in events:
             event_result = await self._get_bundled_aggregation_for_event(event, user_id)
-            if event_result is not None:
+            if event_result:
                 results[event.event_id] = event_result
 
         return results