summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/room.py77
-rw-r--r--synapse/handlers/search.py45
-rw-r--r--synapse/handlers/sync.py3
3 files changed, 66 insertions, 59 deletions
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f963078e59..1420d67729 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -30,6 +30,7 @@ from typing import (
     Tuple,
 )
 
+import attr
 from typing_extensions import TypedDict
 
 from synapse.api.constants import (
@@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.handlers.federation import get_domains_from_state
 from synapse.rest.admin._base import assert_user_is_admin
+from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.state import StateFilter
 from synapse.streams import EventSource
 from synapse.types import (
@@ -90,6 +92,17 @@ id_server_scheme = "https://"
 FIVE_MINUTES_IN_MS = 5 * 60 * 1000
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventContext:
+    events_before: List[EventBase]
+    event: EventBase
+    events_after: List[EventBase]
+    state: List[EventBase]
+    aggregations: Dict[str, BundledAggregations]
+    start: str
+    end: str
+
+
 class RoomCreationHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -1119,7 +1132,7 @@ class RoomContextHandler:
         limit: int,
         event_filter: Optional[Filter],
         use_admin_priviledge: bool = False,
-    ) -> Optional[JsonDict]:
+    ) -> Optional[EventContext]:
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
@@ -1167,38 +1180,28 @@ class RoomContextHandler:
         results = await self.store.get_events_around(
             room_id, event_id, before_limit, after_limit, event_filter
         )
+        events_before = results.events_before
+        events_after = results.events_after
 
         if event_filter:
-            results["events_before"] = await event_filter.filter(
-                results["events_before"]
-            )
-            results["events_after"] = await event_filter.filter(results["events_after"])
+            events_before = await event_filter.filter(events_before)
+            events_after = await event_filter.filter(events_after)
 
-        results["events_before"] = await filter_evts(results["events_before"])
-        results["events_after"] = await filter_evts(results["events_after"])
+        events_before = await filter_evts(events_before)
+        events_after = await filter_evts(events_after)
         # filter_evts can return a pruned event in case the user is allowed to see that
         # there's something there but not see the content, so use the event that's in
         # `filtered` rather than the event we retrieved from the datastore.
-        results["event"] = filtered[0]
+        event = filtered[0]
 
         # Fetch the aggregations.
         aggregations = await self.store.get_bundled_aggregations(
-            [results["event"]], user.to_string()
+            itertools.chain(events_before, (event,), events_after),
+            user.to_string(),
         )
-        aggregations.update(
-            await self.store.get_bundled_aggregations(
-                results["events_before"], user.to_string()
-            )
-        )
-        aggregations.update(
-            await self.store.get_bundled_aggregations(
-                results["events_after"], user.to_string()
-            )
-        )
-        results["aggregations"] = aggregations
 
-        if results["events_after"]:
-            last_event_id = results["events_after"][-1].event_id
+        if events_after:
+            last_event_id = events_after[-1].event_id
         else:
             last_event_id = event_id
 
@@ -1206,9 +1209,9 @@ class RoomContextHandler:
             state_filter = StateFilter.from_lazy_load_member_list(
                 ev.sender
                 for ev in itertools.chain(
-                    results["events_before"],
-                    (results["event"],),
-                    results["events_after"],
+                    events_before,
+                    (event,),
+                    events_after,
                 )
             )
         else:
@@ -1226,21 +1229,23 @@ class RoomContextHandler:
         if event_filter:
             state_events = await event_filter.filter(state_events)
 
-        results["state"] = await filter_evts(state_events)
-
         # We use a dummy token here as we only care about the room portion of
         # the token, which we replace.
         token = StreamToken.START
 
-        results["start"] = await token.copy_and_replace(
-            "room_key", results["start"]
-        ).to_string(self.store)
-
-        results["end"] = await token.copy_and_replace(
-            "room_key", results["end"]
-        ).to_string(self.store)
-
-        return results
+        return EventContext(
+            events_before=events_before,
+            event=event,
+            events_after=events_after,
+            state=await filter_evts(state_events),
+            aggregations=aggregations,
+            start=await token.copy_and_replace("room_key", results.start).to_string(
+                self.store
+            ),
+            end=await token.copy_and_replace("room_key", results.end).to_string(
+                self.store
+            ),
+        )
 
 
 class TimestampLookupHandler:
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 0b153a6822..02bb5ae72f 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -361,36 +361,37 @@ class SearchHandler:
 
                 logger.info(
                     "Context for search returned %d and %d events",
-                    len(res["events_before"]),
-                    len(res["events_after"]),
+                    len(res.events_before),
+                    len(res.events_after),
                 )
 
-                res["events_before"] = await filter_events_for_client(
-                    self.storage, user.to_string(), res["events_before"]
+                events_before = await filter_events_for_client(
+                    self.storage, user.to_string(), res.events_before
                 )
 
-                res["events_after"] = await filter_events_for_client(
-                    self.storage, user.to_string(), res["events_after"]
+                events_after = await filter_events_for_client(
+                    self.storage, user.to_string(), res.events_after
                 )
 
-                res["start"] = await now_token.copy_and_replace(
-                    "room_key", res["start"]
-                ).to_string(self.store)
-
-                res["end"] = await now_token.copy_and_replace(
-                    "room_key", res["end"]
-                ).to_string(self.store)
+                context = {
+                    "events_before": events_before,
+                    "events_after": events_after,
+                    "start": await now_token.copy_and_replace(
+                        "room_key", res.start
+                    ).to_string(self.store),
+                    "end": await now_token.copy_and_replace(
+                        "room_key", res.end
+                    ).to_string(self.store),
+                }
 
                 if include_profile:
                     senders = {
                         ev.sender
-                        for ev in itertools.chain(
-                            res["events_before"], [event], res["events_after"]
-                        )
+                        for ev in itertools.chain(events_before, [event], events_after)
                     }
 
-                    if res["events_after"]:
-                        last_event_id = res["events_after"][-1].event_id
+                    if events_after:
+                        last_event_id = events_after[-1].event_id
                     else:
                         last_event_id = event.event_id
 
@@ -402,7 +403,7 @@ class SearchHandler:
                         last_event_id, state_filter
                     )
 
-                    res["profile_info"] = {
+                    context["profile_info"] = {
                         s.state_key: {
                             "displayname": s.content.get("displayname", None),
                             "avatar_url": s.content.get("avatar_url", None),
@@ -411,7 +412,7 @@ class SearchHandler:
                         if s.type == EventTypes.Member and s.state_key in senders
                     }
 
-                contexts[event.event_id] = res
+                contexts[event.event_id] = context
         else:
             contexts = {}
 
@@ -421,10 +422,10 @@ class SearchHandler:
 
         for context in contexts.values():
             context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now
+                context["events_before"], time_now  # type: ignore[arg-type]
             )
             context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now
+                context["events_after"], time_now  # type: ignore[arg-type]
             )
 
         state_results = {}
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index 7e2a892b63..c72ed7c290 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -37,6 +37,7 @@ from synapse.logging.context import current_context
 from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, start_active_span
 from synapse.push.clientformat import format_push_rules_for_user
 from synapse.storage.databases.main.event_push_actions import NotifCounts
+from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.roommember import MemberSummary
 from synapse.storage.state import StateFilter
 from synapse.types import (
@@ -100,7 +101,7 @@ class TimelineBatch:
     limited: bool
     # A mapping of event ID to the bundled aggregations for the above events.
     # This is only calculated if limited is true.
-    bundled_aggregations: Optional[Dict[str, Dict[str, Any]]] = None
+    bundled_aggregations: Optional[Dict[str, BundledAggregations]] = None
 
     def __bool__(self) -> bool:
         """Make the result appear empty if there are no updates. This is used