summary refs log tree commit diff
path: root/synapse/events/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events/utils.py')
-rw-r--r--synapse/events/utils.py61
1 files changed, 42 insertions, 19 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index de0e0c1731..243696b357 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -14,7 +14,17 @@
 # limitations under the License.
 import collections.abc
 import re
-from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Union,
+)
 
 from frozendict import frozendict
 
@@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze
 
 from . import EventBase
 
+if TYPE_CHECKING:
+    from synapse.storage.databases.main.relations import BundledAggregations
+
+
 # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
 # (?<!stuff) matches if the current position in the string is not preceded
 # by a match for 'stuff'.
@@ -376,7 +390,7 @@ class EventClientSerializer:
         event: Union[JsonDict, EventBase],
         time_now: int,
         *,
-        bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
+        bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
         **kwargs: Any,
     ) -> JsonDict:
         """Serializes a single event.
@@ -402,7 +416,7 @@ class EventClientSerializer:
         if bundle_aggregations:
             event_aggregations = bundle_aggregations.get(event.event_id)
             if event_aggregations:
-                self._injected_bundled_aggregations(
+                self._inject_bundled_aggregations(
                     event,
                     time_now,
                     bundle_aggregations[event.event_id],
@@ -411,11 +425,11 @@ class EventClientSerializer:
 
         return serialized_event
 
-    def _injected_bundled_aggregations(
+    def _inject_bundled_aggregations(
         self,
         event: EventBase,
         time_now: int,
-        aggregations: JsonDict,
+        aggregations: "BundledAggregations",
         serialized_event: JsonDict,
     ) -> None:
         """Potentially injects bundled aggregations into the unsigned portion of the serialized event.
@@ -427,13 +441,18 @@ class EventClientSerializer:
             serialized_event: The serialized event which may be modified.
 
         """
-        # Make a copy in-case the object is cached.
-        aggregations = aggregations.copy()
+        serialized_aggregations = {}
+
+        if aggregations.annotations:
+            serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations
+
+        if aggregations.references:
+            serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
 
-        if RelationTypes.REPLACE in aggregations:
+        if aggregations.replace:
             # If there is an edit replace the content, preserving existing
             # relations.
-            edit = aggregations[RelationTypes.REPLACE]
+            edit = aggregations.replace
 
             # Ensure we take copies of the edit content, otherwise we risk modifying
             # the original event.
@@ -451,24 +470,28 @@ class EventClientSerializer:
             else:
                 serialized_event["content"].pop("m.relates_to", None)
 
-            aggregations[RelationTypes.REPLACE] = {
+            serialized_aggregations[RelationTypes.REPLACE] = {
                 "event_id": edit.event_id,
                 "origin_server_ts": edit.origin_server_ts,
                 "sender": edit.sender,
             }
 
         # If this event is the start of a thread, include a summary of the replies.
-        if RelationTypes.THREAD in aggregations:
-            # Serialize the latest thread event.
-            latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]
-
-            # Don't bundle aggregations as this could recurse forever.
-            aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
-                latest_thread_event, time_now, bundle_aggregations=None
-            )
+        if aggregations.thread:
+            serialized_aggregations[RelationTypes.THREAD] = {
+                # Don't bundle aggregations as this could recurse forever.
+                "latest_event": self.serialize_event(
+                    aggregations.thread.latest_event, time_now, bundle_aggregations=None
+                ),
+                "count": aggregations.thread.count,
+                "current_user_participated": aggregations.thread.current_user_participated,
+            }
 
         # Include the bundled aggregations in the event.
-        serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
+        if serialized_aggregations:
+            serialized_event["unsigned"].setdefault("m.relations", {}).update(
+                serialized_aggregations
+            )
 
     def serialize_events(
         self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any