summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11815.misc1
-rw-r--r--synapse/events/utils.py57
-rw-r--r--synapse/handlers/room.py77
-rw-r--r--synapse/handlers/search.py45
-rw-r--r--synapse/handlers/sync.py3
-rw-r--r--synapse/push/mailer.py2
-rw-r--r--synapse/rest/admin/rooms.py39
-rw-r--r--synapse/rest/client/room.py39
-rw-r--r--synapse/rest/client/sync.py3
-rw-r--r--synapse/storage/databases/main/relations.py61
-rw-r--r--synapse/storage/databases/main/stream.py22
-rw-r--r--tests/rest/client/test_relations.py2
12 files changed, 212 insertions, 139 deletions
diff --git a/changelog.d/11815.misc b/changelog.d/11815.misc
new file mode 100644
index 0000000000..83aa6d6eb0
--- /dev/null
+++ b/changelog.d/11815.misc
@@ -0,0 +1 @@
+Improve type safety of bundled aggregations code.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 918adeecf8..243696b357 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -14,7 +14,17 @@
 # limitations under the License.
 import collections.abc
 import re
-from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Mapping,
+    Optional,
+    Union,
+)
 
 from frozendict import frozendict
 
@@ -26,6 +36,10 @@ from synapse.util.frozenutils import unfreeze
 
 from . import EventBase
 
+if TYPE_CHECKING:
+    from synapse.storage.databases.main.relations import BundledAggregations
+
+
 # Split strings on "." but not "\." This uses a negative lookbehind assertion for '\'
 # (?<!stuff) matches if the current position in the string is not preceded
 # by a match for 'stuff'.
@@ -376,7 +390,7 @@ class EventClientSerializer:
         event: Union[JsonDict, EventBase],
         time_now: int,
         *,
-        bundle_aggregations: Optional[Dict[str, JsonDict]] = None,
+        bundle_aggregations: Optional[Dict[str, "BundledAggregations"]] = None,
         **kwargs: Any,
     ) -> JsonDict:
         """Serializes a single event.
@@ -415,7 +429,7 @@ class EventClientSerializer:
         self,
         event: EventBase,
         time_now: int,
-        aggregations: JsonDict,
+        aggregations: "BundledAggregations",
         serialized_event: JsonDict,
     ) -> None:
         """Potentially injects bundled aggregations into the unsigned portion of the serialized event.
@@ -427,13 +441,18 @@ class EventClientSerializer:
             serialized_event: The serialized event which may be modified.
 
         """
-        # Make a copy in-case the object is cached.
-        aggregations = aggregations.copy()
+        serialized_aggregations = {}
+
+        if aggregations.annotations:
+            serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations
+
+        if aggregations.references:
+            serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
 
-        if RelationTypes.REPLACE in aggregations:
+        if aggregations.replace:
             # If there is an edit replace the content, preserving existing
             # relations.
-            edit = aggregations[RelationTypes.REPLACE]
+            edit = aggregations.replace
 
             # Ensure we take copies of the edit content, otherwise we risk modifying
             # the original event.
@@ -451,24 +470,28 @@ class EventClientSerializer:
             else:
                 serialized_event["content"].pop("m.relates_to", None)
 
-            aggregations[RelationTypes.REPLACE] = {
+            serialized_aggregations[RelationTypes.REPLACE] = {
                 "event_id": edit.event_id,
                 "origin_server_ts": edit.origin_server_ts,
                 "sender": edit.sender,
             }
 
         # If this event is the start of a thread, include a summary of the replies.
-        if RelationTypes.THREAD in aggregations:
-            # Serialize the latest thread event.
-            latest_thread_event = aggregations[RelationTypes.THREAD]["latest_event"]
-
-            # Don't bundle aggregations as this could recurse forever.
-            aggregations[RelationTypes.THREAD]["latest_event"] = self.serialize_event(
-                latest_thread_event, time_now, bundle_aggregations=None
-            )
+        if aggregations.thread:
+            serialized_aggregations[RelationTypes.THREAD] = {
+                # Don't bundle aggregations as this could recurse forever.
+                "latest_event": self.serialize_event(
+                    aggregations.thread.latest_event, time_now, bundle_aggregations=None
+                ),
+                "count": aggregations.thread.count,
+                "current_user_participated": aggregations.thread.current_user_participated,
+            }
 
         # Include the bundled aggregations in the event.
-        serialized_event["unsigned"].setdefault("m.relations", {}).update(aggregations)
+        if serialized_aggregations:
+            serialized_event["unsigned"].setdefault("m.relations", {}).update(
+                serialized_aggregations
+            )
 
     def serialize_events(
         self, events: Iterable[Union[JsonDict, EventBase]], time_now: int, **kwargs: Any
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f963078e59..1420d67729 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -30,6 +30,7 @@ from typing import (
     Tuple,
 )
 
+import attr
 from typing_extensions import TypedDict
 
 from synapse.api.constants import (
@@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.handlers.federation import get_domains_from_state
 from synapse.rest.admin._base import assert_user_is_admin
+from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.state import StateFilter
 from synapse.streams import EventSource
 from synapse.types import (
@@ -90,6 +92,17 @@ id_server_scheme = "https://"
 FIVE_MINUTES_IN_MS = 5 * 60 * 1000
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventContext:
+    events_before: List[EventBase]
+    event: EventBase
+    events_after: List[EventBase]
+    state: List[EventBase]
+    aggregations: Dict[str, BundledAggregations]
+    start: str
+    end: str
+
+
 class RoomCreationHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -1119,7 +1132,7 @@ class RoomContextHandler:
         limit: int,
         event_filter: Optional[Filter],
         use_admin_priviledge: bool = False,
-    ) -> Optional[JsonDict]:
+    ) -> Optional[EventContext]:
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
@@ -1167,38 +1180,28 @@ class RoomContextHandler:
         results = await self.store.get_events_around(
             room_id, event_id, before_limit, after_limit, event_filter
         )
+        events_before = results.events_before
+        events_after = results.events_after
 
         if event_filter:
-            results["events_before"] = await event_filter.filter(
-                results["events_before"]
-            )
-            results["events_after"] = await event_filter.filter(results["events_after"])
+            events_before = await event_filter.filter(events_before)
+            events_after = await event_filter.filter(events_after)
 
-        results["events_before"] = await filter_evts(results["events_before"])
-        results["events_after"] = await filter_evts(results["events_after"])
+        events_before = await filter_evts(events_before)
+        events_after = await filter_evts(events_after)
         # filter_evts can return a pruned event in case the user is allowed to see that
         # there's something there but not see the content, so use the event that's in
         # `filtered` rather than the event we retrieved from the datastore.
-        results["event"] = filtered[0]
+        event = filtered[0]
 
         # Fetch the aggregations.
         aggregations = await self.store.get_bundled_aggregations(
-            [results["event"]], user.to_string()
+            itertools.chain(events_before, (event,), events_after),
+            user.to_string(),
         )
-        aggregations.update(
-            await self.store.get_bundled_aggregations(
-                results["events_before"], user.to_string()
-            )
-        )
-        aggregations.update(
-            await self.store.get_bundled_aggregations(
-                results["events_after"], user.to_string()
-            )
-        )
-        results["aggregations"] = aggregations
 
-        if results["events_after"]:
-            last_event_id = results["events_after"][-1].event_id
+        if events_after:
+            last_event_id = events_after[-1].event_id
         else:
             last_event_id = event_id
 
@@ -1206,9 +1209,9 @@ class RoomContextHandler:
             state_filter = StateFilter.from_lazy_load_member_list(
                 ev.sender
                 for ev in itertools.chain(
-                    results["events_before"],
-                    (results["event"],),
-                    results["events_after"],
+                    events_before,
+                    (event,),
+                    events_after,
                 )
             )
         else:
@@ -1226,21 +1229,23 @@ class RoomContextHandler:
         if event_filter:
             state_events = await event_filter.filter(state_events)
 
-        results["state"] = await filter_evts(state_events)
-
         # We use a dummy token here as we only care about the room portion of
         # the token, which we replace.
         token = StreamToken.START
 
-        results["start"] = await token.copy_and_replace(
-            "room_key", results["start"]
-        ).to_string(self.store)
-
-        results["end"] = await token.copy_and_replace(
-            "room_key", results["end"]
-        ).to_string(self.store)
-
-        return results
+        return EventContext(
+            events_before=events_before,
+            event=event,
+            events_after=events_after,
+            state=await filter_evts(state_events),
+            aggregations=aggregations,
+            start=await token.copy_and_replace("room_key", results.start).to_string(
+                self.store
+            ),
+            end=await token.copy_and_replace("room_key", results.end).to_string(
+                self.store
+            ),
+        )
 
 
 class TimestampLookupHandler:
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 0b153a6822..02bb5ae72f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -361,36 +361,37 @@ class SearchHandler:
 
                 logger.info(
                     "Context for search returned %d and %d events",
-                    len(res["events_before"]),
-                    len(res["events_after"]),
+                    len(res.events_before),
+                    len(res.events_after),
                 )
 
-                res["events_before"] = await filter_events_for_client(
-                    self.storage, user.to_string(), res["events_before"]
+                events_before = await filter_events_for_client(
+                    self.storage, user.to_string(), res.events_before
                 )
 
-                res["events_after"] = await filter_events_for_client(
-                    self.storage, user.to_string(), res["events_after"]
+                events_after = await filter_events_for_client(
+                    self.storage, user.to_string(), res.events_after
                 )
 
-                res["start"] = await now_token.copy_and_replace(
-                    "room_key", res["start"]
-                ).to_string(self.store)
-
-                res["end"] = await now_token.copy_and_replace(
-                    "room_key", res["end"]
-                ).to_string(self.store)
+                context = {
+                    "events_before": events_before,
+                    "events_after": events_after,
+                    "start": await now_token.copy_and_replace(
+                        "room_key", res.start
+                    ).to_string(self.store),
+                    "end": await now_token.copy_and_replace(
+                        "room_key", res.end
+                    ).to_string(self.store),
+                }
 
                 if include_profile:
                     senders = {
                         ev.sender
-                        for ev in itertools.chain(
-                            res["events_before"], [event], res["events_after"]
-                        )
+                        for ev in itertools.chain(events_before, [event], events_after)
                     }
 
-                    if res["events_after"]:
-                        last_event_id = res["events_after"][-1].event_id
+                    if events_after:
+                        last_event_id = events_after[-1].event_id
                     else:
                         last_event_id = event.event_id
 
@@ -402,7 +403,7 @@ class SearchHandler:
                         last_event_id, state_filter
                     )
 
-                    res["profile_info"] = {
+                    context["profile_info"] = {
                         s.state_key: {
                             "displayname": s.content.get("displayname", None),
                             "avatar_url": s.content.get("avatar_url", None),
@@ -411,7 +412,7 @@ class SearchHandler:
                         if s.type == EventTypes.Member and s.state_key in senders
                     }
 
-                contexts[event.event_id] = res
+                contexts[event.event_id] = context
         else:
             contexts = {}
 
@@ -421,10 +422,10 @@ class SearchHandler:
 
         for context in contexts.values():
             context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now
+                context["events_before"], time_now  # type: ignore[arg-type]
             )
             context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now
+                context["events_after"], time_now  # type: ignore[arg-type]
             )
 
         state_results = {}
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7e2a892b63..c72ed7c290 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -37,6 +37,7 @@ from synapse.logging.context import current_context
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -100,7 +101,7 @@ class TimelineBatch:
     limited: bool
     # A mapping of event ID to the bundled aggregations for the above events.
     # This is only calculated if limited is true.
-    bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
+    bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used
diff --git a/synapse/push/mailer.py b/synapse/push/mailer.py
index dadfc57413..3df8452eec 100644
--- a/synapse/push/mailer.py
+++ b/synapse/push/mailer.py
@@ -455,7 +455,7 @@ class Mailer:
         }
 
         the_events = await filter_events_for_client(
-            self.storage, user_id, results["events_before"]
+            self.storage, user_id, results.events_before
         )
         the_events.append(notif_event)
 
diff --git a/synapse/rest/admin/rooms.py b/synapse/rest/admin/rooms.py
index efe25fe7eb..5b706efbcf 100644
--- a/synapse/rest/admin/rooms.py
+++ b/synapse/rest/admin/rooms.py
@@ -729,7 +729,7 @@ class RoomEventContextServlet(RestServlet):
         else:
             event_filter = None
 
-        results = await self.room_context_handler.get_event_context(
+        event_context = await self.room_context_handler.get_event_context(
             requester,
             room_id,
             event_id,
@@ -738,25 +738,34 @@ class RoomEventContextServlet(RestServlet):
             use_admin_priviledge=True,
         )
 
-        if not results:
+        if not event_context:
             raise SynapseError(
                 HTTPStatus.NOT_FOUND, "Event not found.", errcode=Codes.NOT_FOUND
             )
 
         time_now = self.clock.time_msec()
-        aggregations = results.pop("aggregations", None)
-        results["events_before"] = self._event_serializer.serialize_events(
-            results["events_before"], time_now, bundle_aggregations=aggregations
-        )
-        results["event"] = self._event_serializer.serialize_event(
-            results["event"], time_now, bundle_aggregations=aggregations
-        )
-        results["events_after"] = self._event_serializer.serialize_events(
-            results["events_after"], time_now, bundle_aggregations=aggregations
-        )
-        results["state"] = self._event_serializer.serialize_events(
-            results["state"], time_now
-        )
+        results = {
+            "events_before": self._event_serializer.serialize_events(
+                event_context.events_before,
+                time_now,
+                bundle_aggregations=event_context.aggregations,
+            ),
+            "event": self._event_serializer.serialize_event(
+                event_context.event,
+                time_now,
+                bundle_aggregations=event_context.aggregations,
+            ),
+            "events_after": self._event_serializer.serialize_events(
+                event_context.events_after,
+                time_now,
+                bundle_aggregations=event_context.aggregations,
+            ),
+            "state": self._event_serializer.serialize_events(
+                event_context.state, time_now
+            ),
+            "start": event_context.start,
+            "end": event_context.end,
+        }
 
         return HTTPStatus.OK, results
 
diff --git a/synapse/rest/client/room.py b/synapse/rest/client/room.py
index 90bb9142a0..90355e44b2 100644
--- a/synapse/rest/client/room.py
+++ b/synapse/rest/client/room.py
@@ -706,27 +706,36 @@ class RoomEventContextServlet(RestServlet):
         else:
             event_filter = None
 
-        results = await self.room_context_handler.get_event_context(
+        event_context = await self.room_context_handler.get_event_context(
             requester, room_id, event_id, limit, event_filter
         )
 
-        if not results:
+        if not event_context:
             raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
 
         time_now = self.clock.time_msec()
-        aggregations = results.pop("aggregations", None)
-        results["events_before"] = self._event_serializer.serialize_events(
-            results["events_before"], time_now, bundle_aggregations=aggregations
-        )
-        results["event"] = self._event_serializer.serialize_event(
-            results["event"], time_now, bundle_aggregations=aggregations
-        )
-        results["events_after"] = self._event_serializer.serialize_events(
-            results["events_after"], time_now, bundle_aggregations=aggregations
-        )
-        results["state"] = self._event_serializer.serialize_events(
-            results["state"], time_now
-        )
+        results = {
+            "events_before": self._event_serializer.serialize_events(
+                event_context.events_before,
+                time_now,
+                bundle_aggregations=event_context.aggregations,
+            ),
+            "event": self._event_serializer.serialize_event(
+                event_context.event,
+                time_now,
+                bundle_aggregations=event_context.aggregations,
+            ),
+            "events_after": self._event_serializer.serialize_events(
+                event_context.events_after,
+                time_now,
+                bundle_aggregations=event_context.aggregations,
+            ),
+            "state": self._event_serializer.serialize_events(
+                event_context.state, time_now
+            ),
+            "start": event_context.start,
+            "end": event_context.end,
+        }
 
         return 200, results
 
diff --git a/synapse/rest/client/sync.py b/synapse/rest/client/sync.py
index d20ae1421e..f9615da525 100644
--- a/synapse/rest/client/sync.py
+++ b/synapse/rest/client/sync.py
@@ -48,6 +48,7 @@ from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_boolean, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.logging.opentracing import trace
+from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.types import JsonDict, StreamToken
 from synapse.util import json_decoder
 
@@ -526,7 +527,7 @@ class SyncRestServlet(RestServlet):
 
         def serialize(
             events: Iterable[EventBase],
-            aggregations: Optional[Dict[str, Dict[str, Any]]] = None,
+            aggregations: Optional[Dict[str, BundledAggregations]] = None,
         ) -> List[JsonDict]:
             return self._event_serializer.serialize_events(
                 events,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 2cb5d06c13..a9a5dd5f03 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -13,17 +13,7 @@
 # limitations under the License.
 
 import logging
-from typing import (
-    TYPE_CHECKING,
-    Any,
-    Dict,
-    Iterable,
-    List,
-    Optional,
-    Tuple,
-    Union,
-    cast,
-)
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union, cast
 
 import attr
 from frozendict import frozendict
@@ -43,6 +33,7 @@ from synapse.storage.relations import (
     PaginationChunk,
     RelationPaginationToken,
 )
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -51,6 +42,30 @@ if TYPE_CHECKING:
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _ThreadAggregation:
+    latest_event: EventBase
+    count: int
+    current_user_participated: bool
+
+
+@attr.s(slots=True, auto_attribs=True)
+class BundledAggregations:
+    """
+    The bundled aggregations for an event.
+
+    Some values require additional processing during serialization.
+    """
+
+    annotations: Optional[JsonDict] = None
+    references: Optional[JsonDict] = None
+    replace: Optional[EventBase] = None
+    thread: Optional[_ThreadAggregation] = None
+
+    def __bool__(self) -> bool:
+        return bool(self.annotations or self.references or self.replace or self.thread)
+
+
 class RelationsWorkerStore(SQLBaseStore):
     def __init__(
         self,
@@ -585,7 +600,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
     async def _get_bundled_aggregation_for_event(
         self, event: EventBase, user_id: str
-    ) -> Optional[Dict[str, Any]]:
+    ) -> Optional[BundledAggregations]:
         """Generate bundled aggregations for an event.
 
         Note that this does not use a cache, but depends on cached methods.
@@ -616,24 +631,24 @@ class RelationsWorkerStore(SQLBaseStore):
         # The bundled aggregations to include, a mapping of relation type to a
         # type-specific value. Some types include the direct return type here
         # while others need more processing during serialization.
-        aggregations: Dict[str, Any] = {}
+        aggregations = BundledAggregations()
 
         annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
         if annotations.chunk:
-            aggregations[RelationTypes.ANNOTATION] = annotations.to_dict()
+            aggregations.annotations = annotations.to_dict()
 
         references = await self.get_relations_for_event(
             event_id, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
-            aggregations[RelationTypes.REFERENCE] = references.to_dict()
+            aggregations.references = references.to_dict()
 
         edit = None
         if event.type == EventTypes.Message:
             edit = await self.get_applicable_edit(event_id, room_id)
 
         if edit:
-            aggregations[RelationTypes.REPLACE] = edit
+            aggregations.replace = edit
 
         # If this event is the start of a thread, include a summary of the replies.
         if self._msc3440_enabled:
@@ -644,11 +659,11 @@ class RelationsWorkerStore(SQLBaseStore):
                 event_id, room_id, user_id
             )
             if latest_thread_event:
-                aggregations[RelationTypes.THREAD] = {
-                    "latest_event": latest_thread_event,
-                    "count": thread_count,
-                    "current_user_participated": participated,
-                }
+                aggregations.thread = _ThreadAggregation(
+                    latest_event=latest_thread_event,
+                    count=thread_count,
+                    current_user_participated=participated,
+                )
 
         # Store the bundled aggregations in the event metadata for later use.
         return aggregations
@@ -657,7 +672,7 @@ class RelationsWorkerStore(SQLBaseStore):
         self,
         events: Iterable[EventBase],
         user_id: str,
-    ) -> Dict[str, Dict[str, Any]]:
+    ) -> Dict[str, BundledAggregations]:
         """Generate bundled aggregations for events.
 
         Args:
@@ -676,7 +691,7 @@ class RelationsWorkerStore(SQLBaseStore):
         results = {}
         for event in events:
             event_result = await self._get_bundled_aggregation_for_event(event, user_id)
-            if event_result is not None:
+            if event_result:
                 results[event.event_id] = event_result
 
         return results
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 319464b1fa..a898f847e7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -81,6 +81,14 @@ class _EventDictReturn:
     stream_ordering: int
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class _EventsAround:
+    events_before: List[EventBase]
+    events_after: List[EventBase]
+    start: RoomStreamToken
+    end: RoomStreamToken
+
+
 def generate_pagination_where_clause(
     direction: str,
     column_names: Tuple[str, str],
@@ -846,7 +854,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         before_limit: int,
         after_limit: int,
         event_filter: Optional[Filter] = None,
-    ) -> dict:
+    ) -> _EventsAround:
         """Retrieve events and pagination tokens around a given event in a
         room.
         """
@@ -869,12 +877,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             list(results["after"]["event_ids"]), get_prev_content=True
         )
 
-        return {
-            "events_before": events_before,
-            "events_after": events_after,
-            "start": results["before"]["token"],
-            "end": results["after"]["token"],
-        }
+        return _EventsAround(
+            events_before=events_before,
+            events_after=events_after,
+            start=results["before"]["token"],
+            end=results["after"]["token"],
+        )
 
     def _get_events_around_txn(
         self,
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index c9b220e73d..96ae7790bb 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -577,7 +577,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
         self.assertTrue(room_timeline["limited"])
-        self._find_event_in_chunk(room_timeline["events"])
+        assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
 
     def test_aggregation_get_event_for_annotation(self):
         """Test that annotations do not get bundled aggregations included