summary refs log tree commit diff
path: root/synapse/events/utils.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-03-03 10:43:06 -0500
committerGitHub <noreply@github.com>2022-03-03 10:43:06 -0500
commit1d11b452b70c768e4919bd9cf6bcaeda2050a3d4 (patch)
tree772b95de8e7ec9714b9e6334088174abad0a9222 /synapse/events/utils.py
parentEnable MSC2716 Complement tests in Synapse (#12145) (diff)
downloadsynapse-1d11b452b70c768e4919bd9cf6bcaeda2050a3d4.tar.xz
Use the proper serialization format when bundling aggregations. (#12090)
This ensures that the `latest_event` field of the bundled aggregation
for threads uses the same format as the other events in the response.
Diffstat (limited to 'synapse/events/utils.py')
-rw-r--r--synapse/events/utils.py81
1 files changed, 54 insertions, 27 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 9386fa29dd..ee34cb46e4 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -26,6 +26,7 @@ from typing import (
     Union,
 )
 
+import attr
 from frozendict import frozendict
 
 from synapse.api.constants import EventContentFields, EventTypes, RelationTypes
@@ -303,29 +304,37 @@ def format_event_for_client_v2_without_room_id(d: JsonDict) -> JsonDict:
     return d
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class SerializeEventConfig:
+    as_client_event: bool = True
+    # Function to convert from federation format to client format
+    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1
+    # ID of the user's auth token - used for namespacing of transaction IDs
+    token_id: Optional[int] = None
+    # List of event fields to include. If empty, all fields will be returned.
+    only_event_fields: Optional[List[str]] = None
+    # Some events can have stripped room state stored in the `unsigned` field.
+    # This is required for invite and knock functionality. If this option is
+    # False, that state will be removed from the event before it is returned.
+    # Otherwise, it will be kept.
+    include_stripped_room_state: bool = False
+
+
+_DEFAULT_SERIALIZE_EVENT_CONFIG = SerializeEventConfig()
+
+
 def serialize_event(
     e: Union[JsonDict, EventBase],
     time_now_ms: int,
     *,
-    as_client_event: bool = True,
-    event_format: Callable[[JsonDict], JsonDict] = format_event_for_client_v1,
-    token_id: Optional[str] = None,
-    only_event_fields: Optional[List[str]] = None,
-    include_stripped_room_state: bool = False,
+    config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
 ) -> JsonDict:
     """Serialize event for clients
 
     Args:
         e
         time_now_ms
-        as_client_event
-        event_format
-        token_id
-        only_event_fields
-        include_stripped_room_state: Some events can have stripped room state
-            stored in the `unsigned` field. This is required for invite and knock
-            functionality. If this option is False, that state will be removed from the
-            event before it is returned. Otherwise, it will be kept.
+        config: Event serialization config
 
     Returns:
         The serialized event dictionary.
@@ -348,11 +357,11 @@ def serialize_event(
 
     if "redacted_because" in e.unsigned:
         d["unsigned"]["redacted_because"] = serialize_event(
-            e.unsigned["redacted_because"], time_now_ms, event_format=event_format
+            e.unsigned["redacted_because"], time_now_ms, config=config
         )
 
-    if token_id is not None:
-        if token_id == getattr(e.internal_metadata, "token_id", None):
+    if config.token_id is not None:
+        if config.token_id == getattr(e.internal_metadata, "token_id", None):
             txn_id = getattr(e.internal_metadata, "txn_id", None)
             if txn_id is not None:
                 d["unsigned"]["transaction_id"] = txn_id
@@ -361,13 +370,14 @@ def serialize_event(
     # that are meant to provide metadata about a room to an invitee/knocker. They are
     # intended to only be included in specific circumstances, such as down sync, and
     # should not be included in any other case.
-    if not include_stripped_room_state:
+    if not config.include_stripped_room_state:
         d["unsigned"].pop("invite_room_state", None)
         d["unsigned"].pop("knock_room_state", None)
 
-    if as_client_event:
-        d = event_format(d)
+    if config.as_client_event:
+        d = config.event_format(d)
 
+    only_event_fields = config.only_event_fields
     if only_event_fields:
         if not isinstance(only_event_fields, list) or not all(
             isinstance(f, str) for f in only_event_fields
@@ -390,18 +400,18 @@ class EventClientSerializer:
         event: Union[JsonDict, EventBase],
         time_now: int,
         *,
+        config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
         bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
-        **kwargs: Any,
     ) -> JsonDict:
         """Serializes a single event.
 
         Args:
             event: The event being serialized.
             time_now: The current time in milliseconds
+            config: Event serialization config
             bundle_aggregations: Whether to include the bundled aggregations for this
                 event. Only applies to non-state events. (State events never include
                 bundled aggregations.)
-            **kwargs: Arguments to pass to `serialize_event`
 
         Returns:
             The serialized event
@@ -410,7 +420,7 @@ class EventClientSerializer:
         if not isinstance(event, EventBase):
             return event
 
-        serialized_event = serialize_event(event, time_now, **kwargs)
+        serialized_event = serialize_event(event, time_now, config=config)
 
         # Check if there are any bundled aggregations to include with the event.
         if bundle_aggregations:
@@ -419,6 +429,7 @@ class EventClientSerializer:
                 self._inject_bundled_aggregations(
                     event,
                     time_now,
+                    config,
                     bundle_aggregations[event.event_id],
                     serialized_event,
                 )
@@ -456,6 +467,7 @@ class EventClientSerializer:
         self,
         event: EventBase,
         time_now: int,
+        config: SerializeEventConfig,
         aggregations: "BundledAggregations",
         serialized_event: JsonDict,
     ) -> None:
@@ -466,6 +478,7 @@ class EventClientSerializer:
             time_now: The current time in milliseconds
             aggregations: The bundled aggregation to serialize.
             serialized_event: The serialized event which may be modified.
+            config: Event serialization config
 
         """
         serialized_aggregations = {}
@@ -493,8 +506,8 @@ class EventClientSerializer:
             thread = aggregations.thread
 
             # Don't bundle aggregations as this could recurse forever.
-            serialized_latest_event = self.serialize_event(
-                thread.latest_event, time_now, bundle_aggregations=None
+            serialized_latest_event = serialize_event(
+                thread.latest_event, time_now, config=config
             )
             # Manually apply an edit, if one exists.
             if thread.latest_edit:
@@ -515,20 +528,34 @@ class EventClientSerializer:
             )
 
     def serialize_events(
-        self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
+        self,
+        events: Iterable[Union[JsonDict, EventBase]],
+        time_now: int,
+        *,
+        config: SerializeEventConfig = _DEFAULT_SERIALIZE_EVENT_CONFIG,
+        bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
     ) -> List[JsonDict]:
         """Serializes multiple events.
 
         Args:
             event
             time_now: The current time in milliseconds
-            **kwargs: Arguments to pass to `serialize_event`
+            config: Event serialization config
+            bundle_aggregations: Whether to include the bundled aggregations for this
+                event. Only applies to non-state events. (State events never include
+                bundled aggregations.)
 
         Returns:
             The list of serialized events
         """
         return [
-            self.serialize_event(event, time_now=time_now, **kwargs) for event in events
+            self.serialize_event(
+                event,
+                time_now,
+                config=config,
+                bundle_aggregations=bundle_aggregations,
+            )
+            for event in events
         ]