summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16549.feature1
-rw-r--r--docs/SUMMARY.md3
-rw-r--r--docs/modules/add_extra_fields_to_client_events_unsigned.md32
-rw-r--r--synapse/events/utils.py48
-rw-r--r--synapse/handlers/events.py2
-rw-r--r--synapse/handlers/initial_sync.py14
-rw-r--r--synapse/handlers/message.py2
-rw-r--r--synapse/handlers/pagination.py4
-rw-r--r--synapse/handlers/relations.py8
-rw-r--r--synapse/handlers/search.py8
-rw-r--r--synapse/module_api/__init__.py21
-rw-r--r--synapse/rest/admin/rooms.py10
-rw-r--r--synapse/rest/client/events.py2
-rw-r--r--synapse/rest/client/notifications.py2
-rw-r--r--synapse/rest/client/room.py10
-rw-r--r--synapse/rest/client/sync.py8
-rw-r--r--synapse/server.py2
-rw-r--r--tests/module_api/test_event_unsigned_addition.py59
-rw-r--r--tests/rest/client/test_retention.py2
19 files changed, 194 insertions, 44 deletions
diff --git a/changelog.d/16549.feature b/changelog.d/16549.feature
new file mode 100644
index 0000000000..51129200f3
--- /dev/null
+++ b/changelog.d/16549.feature
@@ -0,0 +1 @@
+Add a new module API callback that allows adding extra fields to events' unsigned section when sent down to clients.
diff --git a/docs/SUMMARY.md b/docs/SUMMARY.md
index 31b3032029..c50121d5f7 100644
--- a/docs/SUMMARY.md
+++ b/docs/SUMMARY.md
@@ -19,7 +19,7 @@
 # Usage
   - [Federation](federate.md)
   - [Configuration](usage/configuration/README.md)
-    - [Configuration Manual](usage/configuration/config_documentation.md) 
+    - [Configuration Manual](usage/configuration/config_documentation.md)
     - [Homeserver Sample Config File](usage/configuration/homeserver_sample_config.md)
     - [Logging Sample Config File](usage/configuration/logging_sample_config.md)
     - [Structured Logging](structured_logging.md)
@@ -48,6 +48,7 @@
         - [Password auth provider callbacks](modules/password_auth_provider_callbacks.md)
         - [Background update controller callbacks](modules/background_update_controller_callbacks.md)
         - [Account data callbacks](modules/account_data_callbacks.md)
+        - [Add extra fields to client events unsigned section callbacks](modules/add_extra_fields_to_client_events_unsigned.md)
         - [Porting a legacy module to the new interface](modules/porting_legacy_module.md)
     - [Workers](workers.md)
       - [Using `synctl` with Workers](synctl_workers.md)
diff --git a/docs/modules/add_extra_fields_to_client_events_unsigned.md b/docs/modules/add_extra_fields_to_client_events_unsigned.md
new file mode 100644
index 0000000000..c4fd19bde0
--- /dev/null
+++ b/docs/modules/add_extra_fields_to_client_events_unsigned.md
@@ -0,0 +1,32 @@
+# Add extra fields to client events unsigned section callbacks
+
+_First introduced in Synapse v1.96.0_
+
+This callback allows modules to add extra fields to the unsigned section of
+events when they get sent down to clients.
+
+These get called *every* time an event is to be sent to clients, so care should
+be taken to ensure with respect to performance.
+
+### API
+
+To register the callback, use
+`register_add_extra_fields_to_unsigned_client_event_callbacks` on the
+`ModuleApi`.
+
+The callback should be of the form
+
+```python
+async def add_field_to_unsigned(
+    event: EventBase,
+) -> JsonDict:
+```
+
+where the extra fields to add to the event's unsigned section is returned.
+(Modules must not attempt to modify the `event` directly).
+
+This cannot be used to alter the "core" fields in the unsigned section emitted
+by Synapse itself.
+
+If multiple such callbacks try to add the same field to an event's unsigned
+section, the last-registered callback wins.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 53af423a5a..ac2cf83d9f 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -17,6 +17,7 @@ import re
 from typing import (
     TYPE_CHECKING,
     Any,
+    Awaitable,
     Callable,
     Dict,
     Iterable,
@@ -45,6 +46,7 @@ from . import EventBase
 
 if TYPE_CHECKING:
     from synapse.handlers.relations import BundledAggregations
+    from synapse.server import HomeServer
 
 
 # Split strings on "." but not "\." (or "\\\.").
@@ -56,6 +58,13 @@ CANONICALJSON_MAX_INT = (2**53) - 1
 CANONICALJSON_MIN_INT = -CANONICALJSON_MAX_INT
 
 
+# Module API callback that allows adding fields to the unsigned section of
+# events that are sent to clients.
+ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK = Callable[
+    [EventBase], Awaitable[JsonDict]
+]
+
+
 def prune_event(event: EventBase) -> EventBase:
     """Returns a pruned version of the given event, which removes all keys we
     don't know about or think could potentially be dodgy.
@@ -509,7 +518,13 @@ class EventClientSerializer:
     clients.
     """
 
-    def serialize_event(
+    def __init__(self, hs: "HomeServer") -> None:
+        self._store = hs.get_datastores().main
+        self._add_extra_fields_to_unsigned_client_event_callbacks: List[
+            ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
+        ] = []
+
+    async def serialize_event(
         self,
         event: Union[JsonDict, EventBase],
         time_now: int,
@@ -535,10 +550,21 @@ class EventClientSerializer:
 
         serialized_event = serialize_event(event, time_now, config=config)
 
+        new_unsigned = {}
+        for callback in self._add_extra_fields_to_unsigned_client_event_callbacks:
+            u = await callback(event)
+            new_unsigned.update(u)
+
+        if new_unsigned:
+            # We do the `update` this way round so that modules can't clobber
+            # existing fields.
+            new_unsigned.update(serialized_event["unsigned"])
+            serialized_event["unsigned"] = new_unsigned
+
         # Check if there are any bundled aggregations to include with the event.
         if bundle_aggregations:
             if event.event_id in bundle_aggregations:
-                self._inject_bundled_aggregations(
+                await self._inject_bundled_aggregations(
                     event,
                     time_now,
                     config,
@@ -548,7 +574,7 @@ class EventClientSerializer:
 
         return serialized_event
 
-    def _inject_bundled_aggregations(
+    async def _inject_bundled_aggregations(
         self,
         event: EventBase,
         time_now: int,
@@ -590,7 +616,7 @@ class EventClientSerializer:
             # said that we should only include the `event_id`, `origin_server_ts` and
             # `sender` of the edit; however MSC3925 proposes extending it to the whole
             # of the edit, which is what we do here.
-            serialized_aggregations[RelationTypes.REPLACE] = self.serialize_event(
+            serialized_aggregations[RelationTypes.REPLACE] = await self.serialize_event(
                 event_aggregations.replace,
                 time_now,
                 config=config,
@@ -600,7 +626,7 @@ class EventClientSerializer:
         if event_aggregations.thread:
             thread = event_aggregations.thread
 
-            serialized_latest_event = self.serialize_event(
+            serialized_latest_event = await self.serialize_event(
                 thread.latest_event,
                 time_now,
                 config=config,
@@ -623,7 +649,7 @@ class EventClientSerializer:
                 "m.relations", {}
             ).update(serialized_aggregations)
 
-    def serialize_events(
+    async def serialize_events(
         self,
         events: Iterable[Union[JsonDict, EventBase]],
         time_now: int,
@@ -645,7 +671,7 @@ class EventClientSerializer:
             The list of serialized events
         """
         return [
-            self.serialize_event(
+            await self.serialize_event(
                 event,
                 time_now,
                 config=config,
@@ -654,6 +680,14 @@ class EventClientSerializer:
             for event in events
         ]
 
+    def register_add_extra_fields_to_unsigned_client_event_callback(
+        self, callback: ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
+    ) -> None:
+        """Register a callback that returns additions to the unsigned section of
+        serialized events.
+        """
+        self._add_extra_fields_to_unsigned_client_event_callbacks.append(callback)
+
 
 _PowerLevel = Union[str, int]
 PowerLevelsContent = Mapping[str, Union[_PowerLevel, Mapping[str, _PowerLevel]]]
diff --git a/synapse/handlers/events.py b/synapse/handlers/events.py
index d12803bf0f..756825061c 100644
--- a/synapse/handlers/events.py
+++ b/synapse/handlers/events.py
@@ -120,7 +120,7 @@ class EventStreamHandler:
 
             events.extend(to_add)
 
-            chunks = self._event_serializer.serialize_events(
+            chunks = await self._event_serializer.serialize_events(
                 events,
                 time_now,
                 config=SerializeEventConfig(
diff --git a/synapse/handlers/initial_sync.py b/synapse/handlers/initial_sync.py
index 4727efcdba..c4bec955fe 100644
--- a/synapse/handlers/initial_sync.py
+++ b/synapse/handlers/initial_sync.py
@@ -173,7 +173,7 @@ class InitialSyncHandler:
                 d["inviter"] = event.sender
 
                 invite_event = await self.store.get_event(event.event_id)
-                d["invite"] = self._event_serializer.serialize_event(
+                d["invite"] = await self._event_serializer.serialize_event(
                     invite_event,
                     time_now,
                     config=serializer_options,
@@ -225,7 +225,7 @@ class InitialSyncHandler:
 
                 d["messages"] = {
                     "chunk": (
-                        self._event_serializer.serialize_events(
+                        await self._event_serializer.serialize_events(
                             messages,
                             time_now=time_now,
                             config=serializer_options,
@@ -235,7 +235,7 @@ class InitialSyncHandler:
                     "end": await end_token.to_string(self.store),
                 }
 
-                d["state"] = self._event_serializer.serialize_events(
+                d["state"] = await self._event_serializer.serialize_events(
                     current_state.values(),
                     time_now=time_now,
                     config=serializer_options,
@@ -387,7 +387,7 @@ class InitialSyncHandler:
             "messages": {
                 "chunk": (
                     # Don't bundle aggregations as this is a deprecated API.
-                    self._event_serializer.serialize_events(
+                    await self._event_serializer.serialize_events(
                         messages, time_now, config=serialize_options
                     )
                 ),
@@ -396,7 +396,7 @@ class InitialSyncHandler:
             },
             "state": (
                 # Don't bundle aggregations as this is a deprecated API.
-                self._event_serializer.serialize_events(
+                await self._event_serializer.serialize_events(
                     room_state.values(), time_now, config=serialize_options
                 )
             ),
@@ -420,7 +420,7 @@ class InitialSyncHandler:
         time_now = self.clock.time_msec()
         serialize_options = SerializeEventConfig(requester=requester)
         # Don't bundle aggregations as this is a deprecated API.
-        state = self._event_serializer.serialize_events(
+        state = await self._event_serializer.serialize_events(
             current_state.values(),
             time_now,
             config=serialize_options,
@@ -497,7 +497,7 @@ class InitialSyncHandler:
             "messages": {
                 "chunk": (
                     # Don't bundle aggregations as this is a deprecated API.
-                    self._event_serializer.serialize_events(
+                    await self._event_serializer.serialize_events(
                         messages, time_now, config=serialize_options
                     )
                 ),
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 41a35ce510..a0b4a93ae8 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -244,7 +244,7 @@ class MessageHandler:
                 )
                 room_state = room_state_events[membership_event_id]
 
-        events = self._event_serializer.serialize_events(
+        events = await self._event_serializer.serialize_events(
             room_state.values(),
             self.clock.time_msec(),
             config=SerializeEventConfig(requester=requester),
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 878f267a4e..87e51bca48 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -657,7 +657,7 @@ class PaginationHandler:
 
         chunk = {
             "chunk": (
-                self._event_serializer.serialize_events(
+                await self._event_serializer.serialize_events(
                     events,
                     time_now,
                     config=serialize_options,
@@ -669,7 +669,7 @@ class PaginationHandler:
         }
 
         if state:
-            chunk["state"] = self._event_serializer.serialize_events(
+            chunk["state"] = await self._event_serializer.serialize_events(
                 state, time_now, config=serialize_options
             )
 
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 9b13448cdd..a15983afae 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -167,7 +167,7 @@ class RelationsHandler:
         now = self._clock.time_msec()
         serialize_options = SerializeEventConfig(requester=requester)
         return_value: JsonDict = {
-            "chunk": self._event_serializer.serialize_events(
+            "chunk": await self._event_serializer.serialize_events(
                 events,
                 now,
                 bundle_aggregations=aggregations,
@@ -177,7 +177,9 @@ class RelationsHandler:
         if include_original_event:
             # Do not bundle aggregations when retrieving the original event because
             # we want the content before relations are applied to it.
-            return_value["original_event"] = self._event_serializer.serialize_event(
+            return_value[
+                "original_event"
+            ] = await self._event_serializer.serialize_event(
                 event,
                 now,
                 bundle_aggregations=None,
@@ -602,7 +604,7 @@ class RelationsHandler:
         )
 
         now = self._clock.time_msec()
-        serialized_events = self._event_serializer.serialize_events(
+        serialized_events = await self._event_serializer.serialize_events(
             events, now, bundle_aggregations=aggregations
         )
 
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index aad4706f14..f51ed9d5bb 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -374,13 +374,13 @@ class SearchHandler:
         serialize_options = SerializeEventConfig(requester=requester)
 
         for context in contexts.values():
-            context["events_before"] = self._event_serializer.serialize_events(
+            context["events_before"] = await self._event_serializer.serialize_events(
                 context["events_before"],
                 time_now,
                 bundle_aggregations=aggregations,
                 config=serialize_options,
             )
-            context["events_after"] = self._event_serializer.serialize_events(
+            context["events_after"] = await self._event_serializer.serialize_events(
                 context["events_after"],
                 time_now,
                 bundle_aggregations=aggregations,
@@ -390,7 +390,7 @@ class SearchHandler:
         results = [
             {
                 "rank": search_result.rank_map[e.event_id],
-                "result": self._event_serializer.serialize_event(
+                "result": await self._event_serializer.serialize_event(
                     e,
                     time_now,
                     bundle_aggregations=aggregations,
@@ -409,7 +409,7 @@ class SearchHandler:
 
         if state_results:
             rooms_cat_res["state"] = {
-                room_id: self._event_serializer.serialize_events(
+                room_id: await self._event_serializer.serialize_events(
                     state_events, time_now, config=serialize_options
                 )
                 for room_id, state_events in state_results.items()
diff --git a/synapse/module_api/__init__.py b/synapse/module_api/__init__.py
index 09ea6bdecb..755c59274c 100644
--- a/synapse/module_api/__init__.py
+++ b/synapse/module_api/__init__.py
@@ -48,6 +48,7 @@ from synapse.events.presence_router import (
     GET_USERS_FOR_STATES_CALLBACK,
     PresenceRouter,
 )
+from synapse.events.utils import ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
 from synapse.handlers.account_data import ON_ACCOUNT_DATA_UPDATED_CALLBACK
 from synapse.handlers.auth import (
     CHECK_3PID_AUTH_CALLBACK,
@@ -259,6 +260,7 @@ class ModuleApi:
         self.custom_template_dir = hs.config.server.custom_template_directory
         self._callbacks = hs.get_module_api_callbacks()
         self.msc3861_oauth_delegation_enabled = hs.config.experimental.msc3861.enabled
+        self._event_serializer = hs.get_event_client_serializer()
 
         try:
             app_name = self._hs.config.email.email_app_name
@@ -490,6 +492,25 @@ class ModuleApi:
         """
         self._hs.register_module_web_resource(path, resource)
 
+    def register_add_extra_fields_to_unsigned_client_event_callbacks(
+        self,
+        *,
+        add_field_to_unsigned_callback: Optional[
+            ADD_EXTRA_FIELDS_TO_UNSIGNED_CLIENT_EVENT_CALLBACK
+        ] = None,
+    ) -> None:
+        """Registers a callback that can be used to add fields to the unsigned
+        section of events.
+
+        The callback is called every time an event is sent down to a client.
+
+        Added in Synapse 1.96.0
+        """
+        if add_field_to_unsigned_callback is not None:
+            self._event_serializer.register_add_extra_fields_to_unsigned_client_event_callback(
+                add_field_to_unsigned_callback
+            )
+
     #########################################################################
     # The following methods can be called by the module at any point in time.
 
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index 2d4da38db9..0659f22a89 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -444,7 +444,7 @@ class RoomStateRestServlet(RestServlet):
         event_ids = await self._storage_controllers.state.get_current_state_ids(room_id)
         events = await self.store.get_events(event_ids.values())
         now = self.clock.time_msec()
-        room_state = self._event_serializer.serialize_events(events.values(), now)
+        room_state = await self._event_serializer.serialize_events(events.values(), now)
         ret = {"state": room_state}
 
         return HTTPStatus.OK, ret
@@ -789,22 +789,22 @@ class RoomEventContextServlet(RestServlet):
 
         time_now = self.clock.time_msec()
         results = {
-            "events_before": self._event_serializer.serialize_events(
+            "events_before": await self._event_serializer.serialize_events(
                 event_context.events_before,
                 time_now,
                 bundle_aggregations=event_context.aggregations,
             ),
-            "event": self._event_serializer.serialize_event(
+            "event": await self._event_serializer.serialize_event(
                 event_context.event,
                 time_now,
                 bundle_aggregations=event_context.aggregations,
             ),
-            "events_after": self._event_serializer.serialize_events(
+            "events_after": await self._event_serializer.serialize_events(
                 event_context.events_after,
                 time_now,
                 bundle_aggregations=event_context.aggregations,
             ),
-            "state": self._event_serializer.serialize_events(
+            "state": await self._event_serializer.serialize_events(
                 event_context.state, time_now
             ),
             "start": event_context.start,
diff --git a/synapse/rest/client/events.py b/synapse/rest/client/events.py
index 3eca4fe21f..5705f812a5 100644
--- a/synapse/rest/client/events.py
+++ b/synapse/rest/client/events.py
@@ -93,7 +93,7 @@ class EventRestServlet(RestServlet):
         event = await self.event_handler.get_event(requester.user, None, event_id)
 
         if event:
-            result = self._event_serializer.serialize_event(
+            result = await self._event_serializer.serialize_event(
                 event,
                 self.clock.time_msec(),
                 config=SerializeEventConfig(requester=requester),
diff --git a/synapse/rest/client/notifications.py b/synapse/rest/client/notifications.py
index e7fe1332e7..5688d8593d 100644
--- a/synapse/rest/client/notifications.py
+++ b/synapse/rest/client/notifications.py
@@ -87,7 +87,7 @@ class NotificationsServlet(RestServlet):
                 "actions": pa.actions,
                 "ts": pa.received_ts,
                 "event": (
-                    self._event_serializer.serialize_event(
+                    await self._event_serializer.serialize_event(
                         notif_events[pa.event_id],
                         now,
                         config=serialize_options,
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 553938ce9d..96f5726911 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -859,7 +859,7 @@ class RoomEventServlet(RestServlet):
 
             # per MSC2676, /rooms/{roomId}/event/{eventId}, should return the
             # *original* event, rather than the edited version
-            event_dict = self._event_serializer.serialize_event(
+            event_dict = await self._event_serializer.serialize_event(
                 event,
                 self.clock.time_msec(),
                 bundle_aggregations=aggregations,
@@ -911,25 +911,25 @@ class RoomEventContextServlet(RestServlet):
         time_now = self.clock.time_msec()
         serializer_options = SerializeEventConfig(requester=requester)
         results = {
-            "events_before": self._event_serializer.serialize_events(
+            "events_before": await self._event_serializer.serialize_events(
                 event_context.events_before,
                 time_now,
                 bundle_aggregations=event_context.aggregations,
                 config=serializer_options,
             ),
-            "event": self._event_serializer.serialize_event(
+            "event": await self._event_serializer.serialize_event(
                 event_context.event,
                 time_now,
                 bundle_aggregations=event_context.aggregations,
                 config=serializer_options,
             ),
-            "events_after": self._event_serializer.serialize_events(
+            "events_after": await self._event_serializer.serialize_events(
                 event_context.events_after,
                 time_now,
                 bundle_aggregations=event_context.aggregations,
                 config=serializer_options,
             ),
-            "state": self._event_serializer.serialize_events(
+            "state": await self._event_serializer.serialize_events(
                 event_context.state,
                 time_now,
                 config=serializer_options,
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index 42bdd3bb10..33fde6c6f8 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -384,7 +384,7 @@ class SyncRestServlet(RestServlet):
         """
         invited = {}
         for room in rooms:
-            invite = self._event_serializer.serialize_event(
+            invite = await self._event_serializer.serialize_event(
                 room.invite, time_now, config=serialize_options
             )
             unsigned = dict(invite.get("unsigned", {}))
@@ -415,7 +415,7 @@ class SyncRestServlet(RestServlet):
         """
         knocked = {}
         for room in rooms:
-            knock = self._event_serializer.serialize_event(
+            knock = await self._event_serializer.serialize_event(
                 room.knock, time_now, config=serialize_options
             )
 
@@ -506,10 +506,10 @@ class SyncRestServlet(RestServlet):
                     event.room_id,
                 )
 
-        serialized_state = self._event_serializer.serialize_events(
+        serialized_state = await self._event_serializer.serialize_events(
             state_events, time_now, config=serialize_options
         )
-        serialized_timeline = self._event_serializer.serialize_events(
+        serialized_timeline = await self._event_serializer.serialize_events(
             timeline_events,
             time_now,
             config=serialize_options,
diff --git a/synapse/server.py b/synapse/server.py
index 71ead524d6..5bfb4ba4eb 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -786,7 +786,7 @@ class HomeServer(metaclass=abc.ABCMeta):
 
     @cache_in_self
     def get_event_client_serializer(self) -> EventClientSerializer:
-        return EventClientSerializer()
+        return EventClientSerializer(self)
 
     @cache_in_self
     def get_password_policy_handler(self) -> PasswordPolicyHandler:
diff --git a/tests/module_api/test_event_unsigned_addition.py b/tests/module_api/test_event_unsigned_addition.py
new file mode 100644
index 0000000000..b64426b1ac
--- /dev/null
+++ b/tests/module_api/test_event_unsigned_addition.py
@@ -0,0 +1,59 @@
+# Copyright 2023 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from twisted.test.proto_helpers import MemoryReactor
+
+from synapse.events import EventBase
+from synapse.rest import admin, login, room
+from synapse.server import HomeServer
+from synapse.types import JsonDict
+from synapse.util import Clock
+
+from tests.unittest import HomeserverTestCase
+
+
+class EventUnsignedAdditionTestCase(HomeserverTestCase):
+    servlets = [
+        room.register_servlets,
+        admin.register_servlets,
+        login.register_servlets,
+    ]
+
+    def prepare(
+        self, reactor: MemoryReactor, clock: Clock, homeserver: HomeServer
+    ) -> None:
+        self._store = homeserver.get_datastores().main
+        self._module_api = homeserver.get_module_api()
+        self._account_data_mgr = self._module_api.account_data_manager
+
+    def test_annotate_event(self) -> None:
+        """Test that we can annotate an event when we request it from the
+        server.
+        """
+
+        async def add_unsigned_event(event: EventBase) -> JsonDict:
+            return {"test_key": event.event_id}
+
+        self._module_api.register_add_extra_fields_to_unsigned_client_event_callbacks(
+            add_field_to_unsigned_callback=add_unsigned_event
+        )
+
+        user_id = self.register_user("user", "password")
+        token = self.login("user", "password")
+
+        room_id = self.helper.create_room_as(user_id, tok=token)
+        result = self.helper.send(room_id, "Hello!", tok=token)
+        event_id = result["event_id"]
+
+        event_json = self.helper.get_event(room_id, event_id, tok=token)
+        self.assertEqual(event_json["unsigned"].get("test_key"), event_id)
diff --git a/tests/rest/client/test_retention.py b/tests/rest/client/test_retention.py
index d3e06bf6b3..534dc339f3 100644
--- a/tests/rest/client/test_retention.py
+++ b/tests/rest/client/test_retention.py
@@ -243,7 +243,7 @@ class RetentionTestCase(unittest.HomeserverTestCase):
         assert event is not None
 
         time_now = self.clock.time_msec()
-        serialized = self.serializer.serialize_event(event, time_now)
+        serialized = self.get_success(self.serializer.serialize_event(event, time_now))
 
         return serialized