summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11837.feature1
-rw-r--r--synapse/config/experimental.py2
-rw-r--r--synapse/handlers/search.py29
-rw-r--r--synapse/storage/databases/main/relations.py13
-rw-r--r--tests/rest/client/test_relations.py39
5 files changed, 76 insertions, 8 deletions
diff --git a/changelog.d/11837.feature b/changelog.d/11837.feature
new file mode 100644
index 0000000000..62ef707123
--- /dev/null
+++ b/changelog.d/11837.feature
@@ -0,0 +1 @@
+Experimental support for [MSC3666](https://github.com/matrix-org/matrix-doc/pull/3666): including bundled aggregations in server side search results.
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index e4719d19b8..f05a803a71 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -26,6 +26,8 @@ class ExperimentalConfig(Config):
 
         # MSC3440 (thread relation)
         self.msc3440_enabled: bool = experimental.get("msc3440_enabled", False)
+        # MSC3666: including bundled relations in /search.
+        self.msc3666_enabled: bool = experimental.get("msc3666_enabled", False)
 
         # MSC3026 (busy presence state)
         self.msc3026_enabled: bool = experimental.get("msc3026_enabled", False)
diff --git a/synapse/handlers/search.py b/synapse/handlers/search.py
index 02bb5ae72f..41cb809078 100644
--- a/synapse/handlers/search.py
+++ b/synapse/handlers/search.py
@@ -43,6 +43,8 @@ class SearchHandler:
         self.state_store = self.storage.state
         self.auth = hs.get_auth()
 
+        self._msc3666_enabled = hs.config.experimental.msc3666_enabled
+
     async def get_old_rooms_from_upgraded_room(self, room_id: str) -> Iterable[str]:
         """Retrieves room IDs of old rooms in the history of an upgraded room.
 
@@ -238,8 +240,6 @@ class SearchHandler:
 
             results = search_result["results"]
 
-            results_map = {r["event"].event_id: r for r in results}
-
             rank_map.update({r["event"].event_id: r["rank"] for r in results})
 
             filtered_events = await search_filter.filter([r["event"] for r in results])
@@ -420,12 +420,29 @@ class SearchHandler:
 
         time_now = self.clock.time_msec()
 
+        aggregations = None
+        if self._msc3666_enabled:
+            aggregations = await self.store.get_bundled_aggregations(
+                # Generate an iterable of EventBase for all the events that will be
+                # returned, including contextual events.
+                itertools.chain(
+                    # The events_before and events_after for each context.
+                    itertools.chain.from_iterable(
+                        itertools.chain(context["events_before"], context["events_after"])  # type: ignore[arg-type]
+                        for context in contexts.values()
+                    ),
+                    # The returned events.
+                    allowed_events,
+                ),
+                user.to_string(),
+            )
+
         for context in contexts.values():
             context["events_before"] = self._event_serializer.serialize_events(
-                context["events_before"], time_now  # type: ignore[arg-type]
+                context["events_before"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
             )
             context["events_after"] = self._event_serializer.serialize_events(
-                context["events_after"], time_now  # type: ignore[arg-type]
+                context["events_after"], time_now, bundle_aggregations=aggregations  # type: ignore[arg-type]
             )
 
         state_results = {}
@@ -442,7 +459,9 @@ class SearchHandler:
             results.append(
                 {
                     "rank": rank_map[e.event_id],
-                    "result": self._event_serializer.serialize_event(e, time_now),
+                    "result": self._event_serializer.serialize_event(
+                        e, time_now, bundle_aggregations=aggregations
+                    ),
                     "context": contexts.get(e.event_id, {}),
                 }
             )
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 6180b17296..7718acbf1c 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -715,6 +715,9 @@ class RelationsWorkerStore(SQLBaseStore):
             A map of event ID to the bundled aggregation for the event. Not all
             events may have bundled aggregations in the results.
         """
+        # The already processed event IDs. Tracked separately from the result
+        # since the result omits events which do not have bundled aggregations.
+        seen_event_ids = set()
 
         # State events and redacted events do not get bundled aggregations.
         events = [
@@ -728,13 +731,19 @@ class RelationsWorkerStore(SQLBaseStore):
 
         # Fetch other relations per event.
         for event in events:
+            # De-duplicate events by ID to handle the same event requested multiple
+            # times. The caches that _get_bundled_aggregation_for_event use should
+            # capture this, but best to reduce work.
+            if event.event_id in seen_event_ids:
+                continue
+            seen_event_ids.add(event.event_id)
+
             event_result = await self._get_bundled_aggregation_for_event(event, user_id)
             if event_result:
                 results[event.event_id] = event_result
 
         # Fetch any edits.
-        event_ids = [event.event_id for event in events]
-        edits = await self._get_applicable_edits(event_ids)
+        edits = await self._get_applicable_edits(seen_event_ids)
         for event_id, edit in edits.items():
             results.setdefault(event_id, BundledAggregations()).replace = edit
 
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 96ae7790bb..06721e67c9 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -453,7 +453,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEquals(400, channel.code, channel.json_body)
 
-    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    @unittest.override_config(
+        {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
+    )
     def test_bundled_aggregations(self):
         """
         Test that annotations, references, and threads get correctly bundled.
@@ -579,6 +581,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertTrue(room_timeline["limited"])
         assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
 
+        # Request search.
+        channel = self.make_request(
+            "POST",
+            "/search",
+            # Search term matches the parent message.
+            content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        chunk = [
+            result["result"]
+            for result in channel.json_body["search_categories"]["room_events"][
+                "results"
+            ]
+        ]
+        assert_bundle(self._find_event_in_chunk(chunk))
+
     def test_aggregation_get_event_for_annotation(self):
         """Test that annotations do not get bundled aggregations included
         when directly requested.
@@ -759,6 +778,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         self.assertNotIn("m.relations", channel.json_body["unsigned"])
 
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
     def test_edit(self):
         """Test that a simple edit works."""
 
@@ -825,6 +845,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertTrue(room_timeline["limited"])
         assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
 
+        # Request search.
+        channel = self.make_request(
+            "POST",
+            "/search",
+            # Search term matches the parent message.
+            content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        chunk = [
+            result["result"]
+            for result in channel.json_body["search_categories"]["room_events"][
+                "results"
+            ]
+        ]
+        assert_bundle(self._find_event_in_chunk(chunk))
+
     def test_multi_edit(self):
         """Test that multiple edits, including attempts by people who
         shouldn't be allowed, are correctly handled.