summary refs log tree commit diff
path: root/synapse/handlers/room.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-01-26 08:27:04 -0500
committerGitHub <noreply@github.com>2022-01-26 08:27:04 -0500
commit2897fb6b4fb8bdaea0e919233d5ccaf5dea12742 (patch)
tree86973f5af5bee99ca612fe553372eac4fa7f1080 /synapse/handlers/room.py
parentDon't print HTTPStatus.* in "Processed..." logs (#11827) (diff)
downloadsynapse-2897fb6b4fb8bdaea0e919233d5ccaf5dea12742.tar.xz
Improvements to bundling aggregations. (#11815)
This is some odds and ends found during the review of #11791
and while continuing to work in this code:

* Return attrs classes instead of dictionaries from some methods
  to improve type safety.
* Call `get_bundled_aggregations` fewer times.
* Adds a missing assertion in the tests.
* Do not return empty bundled aggregations for an event (preferring
  to not include the bundle at all, as the docstring states).
Diffstat (limited to 'synapse/handlers/room.py')
-rw-r--r--synapse/handlers/room.py77
1 files changed, 41 insertions, 36 deletions
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index f963078e59..1420d67729 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -30,6 +30,7 @@ from typing import (
     Tuple,
 )
 
+import attr
 from typing_extensions import TypedDict
 
 from synapse.api.constants import (
@@ -60,6 +61,7 @@ from synapse.events.utils import copy_power_levels_contents
 from synapse.federation.federation_client import InvalidResponseError
 from synapse.handlers.federation import get_domains_from_state
 from synapse.rest.admin._base import assert_user_is_admin
+from synapse.storage.databases.main.relations import BundledAggregations
 from synapse.storage.state import StateFilter
 from synapse.streams import EventSource
 from synapse.types import (
@@ -90,6 +92,17 @@ id_server_scheme = "https://"
 FIVE_MINUTES_IN_MS = 5 * 60 * 1000
 
 
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class EventContext:
+    events_before: List[EventBase]
+    event: EventBase
+    events_after: List[EventBase]
+    state: List[EventBase]
+    aggregations: Dict[str, BundledAggregations]
+    start: str
+    end: str
+
+
 class RoomCreationHandler:
     def __init__(self, hs: "HomeServer"):
         self.store = hs.get_datastore()
@@ -1119,7 +1132,7 @@ class RoomContextHandler:
         limit: int,
         event_filter: Optional[Filter],
         use_admin_priviledge: bool = False,
-    ) -> Optional[JsonDict]:
+    ) -> Optional[EventContext]:
         """Retrieves events, pagination tokens and state around a given event
         in a room.
 
@@ -1167,38 +1180,28 @@ class RoomContextHandler:
         results = await self.store.get_events_around(
             room_id, event_id, before_limit, after_limit, event_filter
         )
+        events_before = results.events_before
+        events_after = results.events_after
 
         if event_filter:
-            results["events_before"] = await event_filter.filter(
-                results["events_before"]
-            )
-            results["events_after"] = await event_filter.filter(results["events_after"])
+            events_before = await event_filter.filter(events_before)
+            events_after = await event_filter.filter(events_after)
 
-        results["events_before"] = await filter_evts(results["events_before"])
-        results["events_after"] = await filter_evts(results["events_after"])
+        events_before = await filter_evts(events_before)
+        events_after = await filter_evts(events_after)
         # filter_evts can return a pruned event in case the user is allowed to see that
         # there's something there but not see the content, so use the event that's in
         # `filtered` rather than the event we retrieved from the datastore.
-        results["event"] = filtered[0]
+        event = filtered[0]
 
         # Fetch the aggregations.
         aggregations = await self.store.get_bundled_aggregations(
-            [results["event"]], user.to_string()
+            itertools.chain(events_before, (event,), events_after),
+            user.to_string(),
         )
-        aggregations.update(
-            await self.store.get_bundled_aggregations(
-                results["events_before"], user.to_string()
-            )
-        )
-        aggregations.update(
-            await self.store.get_bundled_aggregations(
-                results["events_after"], user.to_string()
-            )
-        )
-        results["aggregations"] = aggregations
 
-        if results["events_after"]:
-            last_event_id = results["events_after"][-1].event_id
+        if events_after:
+            last_event_id = events_after[-1].event_id
         else:
             last_event_id = event_id
 
@@ -1206,9 +1209,9 @@ class RoomContextHandler:
             state_filter = StateFilter.from_lazy_load_member_list(
                 ev.sender
                 for ev in itertools.chain(
-                    results["events_before"],
-                    (results["event"],),
-                    results["events_after"],
+                    events_before,
+                    (event,),
+                    events_after,
                 )
             )
         else:
@@ -1226,21 +1229,23 @@ class RoomContextHandler:
         if event_filter:
             state_events = await event_filter.filter(state_events)
 
-        results["state"] = await filter_evts(state_events)
-
         # We use a dummy token here as we only care about the room portion of
         # the token, which we replace.
         token = StreamToken.START
 
-        results["start"] = await token.copy_and_replace(
-            "room_key", results["start"]
-        ).to_string(self.store)
-
-        results["end"] = await token.copy_and_replace(
-            "room_key", results["end"]
-        ).to_string(self.store)
-
-        return results
+        return EventContext(
+            events_before=events_before,
+            event=event,
+            events_after=events_after,
+            state=await filter_evts(state_events),
+            aggregations=aggregations,
+            start=await token.copy_and_replace("room_key", results.start).to_string(
+                self.store
+            ),
+            end=await token.copy_and_replace("room_key", results.end).to_string(
+                self.store
+            ),
+        )
 
 
 class TimestampLookupHandler: