summary refs log tree commit diff
path: root/synapse/events
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/events')
-rw-r--r--synapse/events/utils.py48
1 files changed, 41 insertions, 7 deletions
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 53af423a5a..ac2cf83d9f 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -17,6 +17,7 @@ import re
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Dict,
     Iterable,
@@ -45,6 +46,7 @@ from . import EventBase
 
 if TYPE_CHECKING:
     from synapse.handlers.relations import BundledAggregations
+    from synapse.server import HomeServer
 
 
 # Split strings on "." but not "\." (or "\\\.").
@@ -56,6 +58,13 @@ CANONICALJSON_MAX_INT = (2**53) - 1
 CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
 
 
+# Module API callback that allows adding fields to the unsigned section of
+# events that are sent to clients.
+ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK = Callable[
+    [EventBase], Awaitable[JsonDict]
+]
+
+
 def prune_event(event: EventBase) -> EventBase:
     """Returns a pruned version of the given event, which removes all keys we
     don't know about or think could potentially be dodgy.
@@ -509,7 +518,13 @@ class EventClientSerializer:
     clients.
     """
 
-    def serialize_event(
+    def __init__(self, hs: "HomeServer") -> None:
+        self._store = hs.get_datastores().main
+        self._add_extra_fields_to_unsigned_client_event_callbacks: List[
+            ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
+        ] = []
+
+    async def serialize_event(
         self,
         event: Union[JsonDict, EventBase],
         time_now: int,
@@ -535,10 +550,21 @@ class EventClientSerializer:
 
         serialized_event = serialize_event(event, time_now, config=config)
 
+        new_unsigned = {}
+        for callback in self._add_extra_fields_to_unsigned_client_event_callbacks:
+            u = await callback(event)
+            new_unsigned.update(u)
+
+        if new_unsigned:
+            # We do the `update` this way round so that modules can't clobber
+            # existing fields.
+            new_unsigned.update(serialized_event["unsigned"])
+            serialized_event["unsigned"] = new_unsigned
+
         # Check if there are any bundled aggregations to include with the event.
         if bundle_aggregations:
             if event.event_id in bundle_aggregations:
-                self._inject_bundled_aggregations(
+                await self._inject_bundled_aggregations(
                     event,
                     time_now,
                     config,
@@ -548,7 +574,7 @@ class EventClientSerializer:
 
         return serialized_event
 
-    def _inject_bundled_aggregations(
+    async def _inject_bundled_aggregations(
         self,
         event: EventBase,
         time_now: int,
@@ -590,7 +616,7 @@ class EventClientSerializer:
             # said that we should only include the `event_id`, `origin_server_ts` and
             # `sender` of the edit; however MSC3925 proposes extending it to the whole
             # of the edit, which is what we do here.
-            serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event(
+            serialized_aggregations[RelationTypes.REPLACE] = await self.serialize_event(
                 event_aggregations.replace,
                 time_now,
                 config=config,
@@ -600,7 +626,7 @@ class EventClientSerializer:
         if event_aggregations.thread:
             thread = event_aggregations.thread
 
-            serialized_latest_event = self.serialize_event(
+            serialized_latest_event = await self.serialize_event(
                 thread.latest_event,
                 time_now,
                 config=config,
@@ -623,7 +649,7 @@ class EventClientSerializer:
                 "m.relations", {}
             ).update(serialized_aggregations)
 
-    def serialize_events(
+    async def serialize_events(
         self,
         events: Iterable[Union[JsonDict, EventBase]],
         time_now: int,
@@ -645,7 +671,7 @@ class EventClientSerializer:
             The list of serialized events
         """
         return [
-            self.serialize_event(
+            await self.serialize_event(
                 event,
                 time_now,
                 config=config,
@@ -654,6 +680,14 @@ class EventClientSerializer:
             for event in events
         ]
 
+    def register_add_extra_fields_to_unsigned_client_event_callback(
+        self, callback: ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
+    ) -> None:
+        """Register a callback that returns additions to the unsigned section of
+        serialized events.
+        """
+        self._add_extra_fields_to_unsigned_client_event_callbacks.append(callback)
+
 
 _PowerLevel = Union[str, int]
 PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]