summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12273.bugfix1
-rw-r--r--synapse/events/utils.py55
-rw-r--r--synapse/handlers/relations.py74
-rw-r--r--synapse/storage/databases/main/relations.py11
-rw-r--r--tests/rest/client/test_relations.py108
5 files changed, 198 insertions, 51 deletions
diff --git a/changelog.d/12273.bugfix b/changelog.d/12273.bugfix
new file mode 100644
index 0000000000..f8d7b6c889
--- /dev/null
+++ b/changelog.d/12273.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse v1.48.0 where latest thread reply provided failed to include the proper bundled aggregations.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index a6c48308b3..026dcde8d8 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -426,13 +426,12 @@ class EventClientSerializer:
 
         # Check if there are any bundled aggregations to include with the event.
         if bundle_aggregations:
-            event_aggregations = bundle_aggregations.get(event.event_id)
-            if event_aggregations:
+            if event.event_id in bundle_aggregations:
                 self._inject_bundled_aggregations(
                     event,
                     time_now,
                     config,
-                    event_aggregations,
+                    bundle_aggregations,
                     serialized_event,
                     apply_edits=apply_edits,
                 )
@@ -471,7 +470,7 @@ class EventClientSerializer:
         event: EventBase,
         time_now: int,
         config: SerializeEventConfig,
-        aggregations: "BundledAggregations",
+        bundled_aggregations: Dict[str, "BundledAggregations"],
         serialized_event: JsonDict,
         apply_edits: bool,
     ) -> None:
@@ -481,22 +480,37 @@ class EventClientSerializer:
             event: The event being serialized.
             time_now: The current time in milliseconds
             config: Event serialization config
-            aggregations: The bundled aggregation to serialize.
+            bundled_aggregations: Bundled aggregations to be injected.
+                A map from event_id to aggregation data. Must contain at least an
+                entry for `event`.
+
+                While serializing the bundled aggregations this map may be searched
+                again for additional events in a recursive manner.
             serialized_event: The serialized event which may be modified.
             apply_edits: Whether the content of the event should be modified to reflect
                any replacement in `aggregations.replace`.
         """
+
+        # We have already checked that aggregations exist for this event.
+        event_aggregations = bundled_aggregations[event.event_id]
+
+        # The JSON dictionary to be added under the unsigned property of the event
+        # being serialized.
         serialized_aggregations = {}
 
-        if aggregations.annotations:
-            serialized_aggregations[RelationTypes.ANNOTATION] = aggregations.annotations
+        if event_aggregations.annotations:
+            serialized_aggregations[
+                RelationTypes.ANNOTATION
+            ] = event_aggregations.annotations
 
-        if aggregations.references:
-            serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
+        if event_aggregations.references:
+            serialized_aggregations[
+                RelationTypes.REFERENCE
+            ] = event_aggregations.references
 
-        if aggregations.replace:
+        if event_aggregations.replace:
             # If there is an edit, optionally apply it to the event.
-            edit = aggregations.replace
+            edit = event_aggregations.replace
             if apply_edits:
                 self._apply_edit(event, serialized_event, edit)
 
@@ -507,19 +521,16 @@ class EventClientSerializer:
                 "sender": edit.sender,
             }
 
-        # If this event is the start of a thread, include a summary of the replies.
-        if aggregations.thread:
-            thread = aggregations.thread
+        # Include any threaded replies to this event.
+        if event_aggregations.thread:
+            thread = event_aggregations.thread
 
-            # Don't bundle aggregations as this could recurse forever.
-            serialized_latest_event = serialize_event(
-                thread.latest_event, time_now, config=config
+            serialized_latest_event = self.serialize_event(
+                thread.latest_event,
+                time_now,
+                config=config,
+                bundle_aggregations=bundled_aggregations,
             )
-            # Manually apply an edit, if one exists.
-            if thread.latest_edit:
-                self._apply_edit(
-                    thread.latest_event, serialized_latest_event, thread.latest_edit
-                )
 
             thread_summary = {
                 "latest_event": serialized_latest_event,
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index b5dc9f74b3..cec5740fbd 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -44,8 +44,6 @@ logger = logging.getLogger(__name__)
 class _ThreadAggregation:
     # The latest event in the thread.
     latest_event: EventBase
-    # The latest edit to the latest event in the thread.
-    latest_edit: Optional[EventBase]
     # The total number of events in the thread.
     count: int
     # True if the current user has sent an event to the thread.
@@ -295,7 +293,7 @@ class RelationsHandler:
 
         for event_id, summary in summaries.items():
             if summary:
-                thread_count, latest_thread_event, edit = summary
+                thread_count, latest_thread_event = summary
 
                 # Subtract off the count of any ignored users.
                 for ignored_user in ignored_users:
@@ -340,7 +338,6 @@ class RelationsHandler:
 
                 results[event_id] = _ThreadAggregation(
                     latest_event=latest_thread_event,
-                    latest_edit=edit,
                     count=thread_count,
                     # If there's a thread summary it must also exist in the
                     # participated dictionary.
@@ -359,8 +356,13 @@ class RelationsHandler:
             user_id: The user requesting the bundled aggregations.
 
         Returns:
-            A map of event ID to the bundled aggregation for the event. Not all
-            events may have bundled aggregations in the results.
+            A map of event ID to the bundled aggregations for the event.
+
+            Not all requested events may exist in the results (if they don't have
+            bundled aggregations).
+
+            The results may include additional events which are related to the
+            requested events.
         """
         # De-duplicate events by ID to handle the same event requested multiple times.
         #
@@ -369,22 +371,59 @@ class RelationsHandler:
             event.event_id: event for event in events if not event.is_state()
         }
 
+        # A map of event ID to the relation in that event, if there is one.
+        relations_by_id: Dict[str, str] = {}
+        for event_id, event in events_by_id.items():
+            relates_to = event.content.get("m.relates_to")
+            if isinstance(relates_to, collections.abc.Mapping):
+                relation_type = relates_to.get("rel_type")
+                if isinstance(relation_type, str):
+                    relations_by_id[event_id] = relation_type
+
         # event ID -> bundled aggregation in non-serialized form.
         results: Dict[str, BundledAggregations] = {}
 
         # Fetch any ignored users of the requesting user.
         ignored_users = await self._main_store.ignored_users(user_id)
 
+        # Threads are special as the latest event of a thread might cause additional
+        # events to be fetched. Thus, we check those first!
+
+        # Fetch thread summaries (but only for the directly requested events).
+        threads = await self.get_threads_for_events(
+            # It is not valid to start a thread on an event which itself relates to another event.
+            [eid for eid in events_by_id.keys() if eid not in relations_by_id],
+            user_id,
+            ignored_users,
+        )
+        for event_id, thread in threads.items():
+            results.setdefault(event_id, BundledAggregations()).thread = thread
+
+            # If the latest event in a thread is not already being fetched,
+            # add it. This ensures that the bundled aggregations for the
+            # latest thread event is correct.
+            latest_thread_event = thread.latest_event
+            if latest_thread_event and latest_thread_event.event_id not in events_by_id:
+                events_by_id[latest_thread_event.event_id] = latest_thread_event
+                # Keep relations_by_id in sync with events_by_id:
+                #
+                # We know that the latest event in a thread has a thread relation
+                # (as that is what makes it part of the thread).
+                relations_by_id[latest_thread_event.event_id] = RelationTypes.THREAD
+
         # Fetch other relations per event.
         for event in events_by_id.values():
-            # Do not bundle aggregations for an event which represents an edit or an
-            # annotation. It does not make sense for them to have related events.
-            relates_to = event.content.get("m.relates_to")
-            if isinstance(relates_to, collections.abc.Mapping):
-                relation_type = relates_to.get("rel_type")
-                if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
-                    continue
-
+            # An event which is a replacement (ie edit) or annotation (ie, reaction)
+            # may not have any other event related to it.
+            #
+            # XXX This is buggy, see https://github.com/matrix-org/synapse/issues/12566
+            if relations_by_id.get(event.event_id) in (
+                RelationTypes.ANNOTATION,
+                RelationTypes.REPLACE,
+            ):
+                continue
+
+            # Fetch any annotations (ie, reactions) to bundle with this event.
             annotations = await self.get_annotations_for_event(
                 event.event_id, event.room_id, ignored_users=ignored_users
             )
@@ -393,6 +432,7 @@ class RelationsHandler:
                     event.event_id, BundledAggregations()
                 ).annotations = {"chunk": annotations}
 
+            # Fetch any references to bundle with this event.
             references, next_token = await self.get_relations_for_event(
                 event.event_id,
                 event,
@@ -425,10 +465,4 @@ class RelationsHandler:
         for event_id, edit in edits.items():
             results.setdefault(event_id, BundledAggregations()).replace = edit
 
-        threads = await self.get_threads_for_events(
-            events_by_id.keys(), user_id, ignored_users
-        )
-        for event_id, thread in threads.items():
-            results.setdefault(event_id, BundledAggregations()).thread = thread
-
         return results
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index a5c31f6787..484976ca6b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -445,8 +445,8 @@ class RelationsWorkerStore(SQLBaseStore):
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
     async def get_thread_summaries(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
-        """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
+    ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
+        """Get the number of threaded replies and the latest reply (if any) for the given events.
 
         Args:
             event_ids: Summarize the thread related to this event ID.
@@ -458,7 +458,6 @@ class RelationsWorkerStore(SQLBaseStore):
             Each summary is a tuple of:
                 The number of events in the thread.
                 The most recent event in the thread.
-                The most recent edit to the most recent event in the thread, if applicable.
         """
 
         def _get_thread_summaries_txn(
@@ -544,9 +543,6 @@ class RelationsWorkerStore(SQLBaseStore):
 
         latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
-        # Check to see if any of those events are edited.
-        latest_edits = await self.get_applicable_edits(latest_event_ids.values())
-
         # Map to the event IDs to the thread summary.
         #
         # There might not be a summary due to there not being a thread or
@@ -557,8 +553,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
             summary = None
             if latest_event:
-                latest_edit = latest_edits.get(latest_event_id)
-                summary = (counts[parent_event_id], latest_event, latest_edit)
+                summary = (counts[parent_event_id], latest_event)
             summaries[parent_event_id] = summary
 
         return summaries
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 39667e3225..33ce9471b3 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1029,7 +1029,106 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
                 bundled_aggregations.get("latest_event"),
             )
 
-        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10)
+
+    def test_thread_with_bundled_aggregations_for_latest(self) -> None:
+        """
+        Bundled aggregations should get applied to the latest thread event.
+        """
+        self._send_relation(RelationTypes.THREAD, "m.room.test")
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        thread_2 = channel.json_body["event_id"]
+
+        self._send_relation(
+            RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_2
+        )
+
+        def assert_thread(bundled_aggregations: JsonDict) -> None:
+            self.assertEqual(2, bundled_aggregations.get("count"))
+            self.assertTrue(bundled_aggregations.get("current_user_participated"))
+            # The latest thread event has some fields that don't matter.
+            self.assert_dict(
+                {
+                    "content": {
+                        "m.relates_to": {
+                            "event_id": self.parent_id,
+                            "rel_type": RelationTypes.THREAD,
+                        }
+                    },
+                    "event_id": thread_2,
+                    "sender": self.user_id,
+                    "type": "m.room.test",
+                },
+                bundled_aggregations.get("latest_event"),
+            )
+            # Check the unsigned field on the latest event.
+            self.assert_dict(
+                {
+                    "m.relations": {
+                        RelationTypes.ANNOTATION: {
+                            "chunk": [
+                                {"type": "m.reaction", "key": "a", "count": 1},
+                            ]
+                        },
+                    }
+                },
+                bundled_aggregations["latest_event"].get("unsigned"),
+            )
+
+        self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 10)
+
+    def test_nested_thread(self) -> None:
+        """
+        Ensure that a nested thread gets ignored by bundled aggregations, as
+        those are forbidden.
+        """
+
+        # Start a thread.
+        channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+        reply_event_id = channel.json_body["event_id"]
+
+        # Disable the validation to pretend this came over federation, since it is
+        # not an event the Client-Server API will allow..
+        with patch(
+            "synapse.handlers.message.EventCreationHandler._validate_event_relation",
+            new=lambda self, event: make_awaitable(None),
+        ):
+            # Create a sub-thread off the thread, which is not allowed.
+            self._send_relation(
+                RelationTypes.THREAD, "m.room.test", parent_id=reply_event_id
+            )
+
+        # Fetch the thread root, to get the bundled aggregation for the thread.
+        relations_from_event = self._get_bundled_aggregations()
+
+        # Ensure that requesting the room messages also does not return the sub-thread.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        event = self._find_event_in_chunk(channel.json_body["chunk"])
+        relations_from_messages = event["unsigned"]["m.relations"]
+
+        # Check the bundled aggregations from each point.
+        for aggregations, desc in (
+            (relations_from_event, "/event"),
+            (relations_from_messages, "/messages"),
+        ):
+            # The latest event should have bundled aggregations.
+            self.assertIn(RelationTypes.THREAD, aggregations, desc)
+            thread_summary = aggregations[RelationTypes.THREAD]
+            self.assertIn("latest_event", thread_summary, desc)
+            self.assertEqual(
+                thread_summary["latest_event"]["event_id"], reply_event_id, desc
+            )
+
+            # The latest event should not have any bundled aggregations (since the
+            # only relation to it is another thread, which is invalid).
+            self.assertNotIn(
+                "m.relations", thread_summary["latest_event"]["unsigned"], desc
+            )
 
     def test_thread_edit_latest_event(self) -> None:
         """Test that editing the latest event in a thread works."""
@@ -1049,6 +1148,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
             content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
             parent_id=threaded_event_id,
         )
+        edit_event_id = channel.json_body["event_id"]
 
         # Fetch the thread root, to get the bundled aggregation for the thread.
         relations_dict = self._get_bundled_aggregations()
@@ -1061,6 +1161,12 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         self.assertIn("latest_event", thread_summary)
         latest_event_in_thread = thread_summary["latest_event"]
         self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
+        # The latest event in the thread should have the edit appear under the
+        # bundled aggregations.
+        self.assertDictContainsSubset(
+            {"event_id": edit_event_id, "sender": "@alice:test"},
+            latest_event_in_thread["unsigned"]["m.relations"][RelationTypes.REPLACE],
+        )
 
     def test_aggregation_get_event_for_annotation(self) -> None:
         """Test that annotations do not get bundled aggregations included