summary refs log tree commit diff
path: root/synapse/handlers/search.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2022-02-08 09:21:20 -0500
committerGitHub <noreply@github.com>2022-02-08 09:21:20 -0500
commit8c94b3abe93fe8c3e2ddd29fa350f54f69714151 (patch)
tree8c60d6bb03ce9fc0f9dd27e2960ffc20f5d34a3a /synapse/handlers/search.py
parentRemove unnecessary ignores due to Twisted upgrade. (#11939) (diff)
downloadsynapse-8c94b3abe93fe8c3e2ddd29fa350f54f69714151.tar.xz
Experimental support to include bundled aggregations in search results (MSC3666) (#11837)
Diffstat (limited to 'synapse/handlers/search.py')
-rw-r--r--synapse/handlers/search.py29
1 files changed, 24 insertions, 5 deletions
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 02bb5ae72f..41cb809078 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -43,6 +43,8 @@ class SearchHandler:
         self.state_store = self.storage.state
         self.auth = hs.get_auth()
 
+        self._msc3666_enabled = hs.config.experimental.msc3666_enabled
+
     async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
         """Retrieves room IDs of old rooms in the history of an upgraded room.
 
@@ -238,8 +240,6 @@ class SearchHandler:
 
             results = search_result["results"]
 
-            results_map = {r["event"].event_id: r for r in results}
-
             rank_map.update({r["event"].event_id: r["rank"] for r in results})
 
             filtered_events = await search_filter.filter([r["event"] for r in results])
@@ -420,12 +420,29 @@ class SearchHandler:
 
         time_now = self.clock.time_msec()
 
+        aggregations = None
+        if self._msc3666_enabled:
+            aggregations = await self.store.get_bundled_aggregations(
+                # Generate an iterable of EventBase for all the events that will be
+                # returned, including contextual events.
+                itertools.chain(
+                    # The events_before and events_after for each context.
+                    itertools.chain.from_iterable(
+                        itertools.chain(context["events_before"], context["events_after"])  # type: ignore[arg-type]
+                        for context in contexts.values()
+                    ),
+                    # The returned events.
+                    allowed_events,
+                ),
+                user.to_string(),
+            )
+
         for context in contexts.values():
             context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now  # type: ignore[arg-type]
+                context["events_before"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
             )
             context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now  # type: ignore[arg-type]
+                context["events_after"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
             )
 
         state_results = {}
@@ -442,7 +459,9 @@ class SearchHandler:
             results.append(
                 {
                     "rank": rank_map[e.event_id],
-                    "result": self._event_serializer.serialize_event(e, time_now),
+                    "result": self._event_serializer.serialize_event(
+                        e, time_now, bundle_aggregations=aggregations
+                    ),
                     "context": contexts.get(e.event_id, {}),
                 }
             )