summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11992.bugfix1
-rw-r--r--synapse/events/utils.py69
-rw-r--r--synapse/storage/databases/main/events_worker.py2
-rw-r--r--synapse/storage/databases/main/relations.py24
-rw-r--r--tests/rest/client/test_relations.py42
5 files changed, 107 insertions, 31 deletions
diff --git a/changelog.d/11992.bugfix b/changelog.d/11992.bugfix
new file mode 100644
index 0000000000..f73c86bb25
--- /dev/null
+++ b/changelog.d/11992.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse v1.48.0 where an edit of the latest event in a thread would not be properly applied to the thread summary.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 243696b357..9386fa29dd 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -425,6 +425,33 @@ class EventClientSerializer:
 
         return serialized_event
 
+    def _apply_edit(
+        self, orig_event: EventBase, serialized_event: JsonDict, edit: EventBase
+    ) -> None:
+        """Replace the content, preserving existing relations of the serialized event.
+
+        Args:
+            orig_event: The original event.
+            serialized_event: The original event, serialized. This is modified.
+            edit: The event which edits the above.
+        """
+
+        # Ensure we take copies of the edit content, otherwise we risk modifying
+        # the original event.
+        edit_content = edit.content.copy()
+
+        # Unfreeze the event content if necessary, so that we may modify it below
+        edit_content = unfreeze(edit_content)
+        serialized_event["content"] = edit_content.get("m.new_content", {})
+
+        # Check for existing relations
+        relates_to = orig_event.content.get("m.relates_to")
+        if relates_to:
+            # Keep the relations, ensuring we use a dict copy of the original
+            serialized_event["content"]["m.relates_to"] = relates_to.copy()
+        else:
+            serialized_event["content"].pop("m.relates_to", None)
+
     def _inject_bundled_aggregations(
         self,
         event: EventBase,
@@ -450,26 +477,11 @@ class EventClientSerializer:
             serialized_aggregations[RelationTypes.REFERENCE] = aggregations.references
 
         if aggregations.replace:
-            # If there is an edit replace the content, preserving existing
-            # relations.
+            # If there is an edit, apply it to the event.
             edit = aggregations.replace
+            self._apply_edit(event, serialized_event, edit)
 
-            # Ensure we take copies of the edit content, otherwise we risk modifying
-            # the original event.
-            edit_content = edit.content.copy()
-
-            # Unfreeze the event content if necessary, so that we may modify it below
-            edit_content = unfreeze(edit_content)
-            serialized_event["content"] = edit_content.get("m.new_content", {})
-
-            # Check for existing relations
-            relates_to = event.content.get("m.relates_to")
-            if relates_to:
-                # Keep the relations, ensuring we use a dict copy of the original
-                serialized_event["content"]["m.relates_to"] = relates_to.copy()
-            else:
-                serialized_event["content"].pop("m.relates_to", None)
-
+            # Include information about it in the relations dict.
             serialized_aggregations[RelationTypes.REPLACE] = {
                 "event_id": edit.event_id,
                 "origin_server_ts": edit.origin_server_ts,
@@ -478,13 +490,22 @@ class EventClientSerializer:
 
         # If this event is the start of a thread, include a summary of the replies.
         if aggregations.thread:
+            thread = aggregations.thread
+
+            # Don't bundle aggregations as this could recurse forever.
+            serialized_latest_event = self.serialize_event(
+                thread.latest_event, time_now, bundle_aggregations=None
+            )
+            # Manually apply an edit, if one exists.
+            if thread.latest_edit:
+                self._apply_edit(
+                    thread.latest_event, serialized_latest_event, thread.latest_edit
+                )
+
             serialized_aggregations[RelationTypes.THREAD] = {
-                # Don't bundle aggregations as this could recurse forever.
-                "latest_event": self.serialize_event(
-                    aggregations.thread.latest_event, time_now, bundle_aggregations=None
-                ),
-                "count": aggregations.thread.count,
-                "current_user_participated": aggregations.thread.current_user_participated,
+                "latest_event": serialized_latest_event,
+                "count": thread.count,
+                "current_user_participated": thread.current_user_participated,
             }
 
         # Include the bundled aggregations in the event.
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index 8d4287045a..712b8ce204 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -408,7 +408,7 @@ class EventsWorkerStore(SQLBaseStore):
                 include the previous states content in the unsigned field.
 
             allow_rejected: If True, return rejected events. Otherwise,
-                omits rejeted events from the response.
+                omits rejected events from the response.
 
         Returns:
             A mapping from event_id to event.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index e2c27e594b..5582029f9f 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -53,8 +53,13 @@ logger = logging.getLogger(__name__)
 
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _ThreadAggregation:
+    # The latest event in the thread.
     latest_event: EventBase
+    # The latest edit to the latest event in the thread.
+    latest_edit: Optional[EventBase]
+    # The total number of events in the thread.
     count: int
+    # True if the current user has sent an event to the thread.
     current_user_participated: bool
 
 
@@ -461,8 +466,8 @@ class RelationsWorkerStore(SQLBaseStore):
     @cachedList(cached_method_name="get_thread_summary", list_name="event_ids")
     async def _get_thread_summaries(
         self, event_ids: Collection[str]
-    ) -> Dict[str, Optional[Tuple[int, EventBase]]]:
-        """Get the number of threaded replies and the latest reply (if any) for the given event.
+    ) -> Dict[str, Optional[Tuple[int, EventBase, Optional[EventBase]]]]:
+        """Get the number of threaded replies, the latest reply (if any), and the latest edit for that reply for the given event.
 
         Args:
             event_ids: Summarize the thread related to this event ID.
@@ -471,8 +476,10 @@ class RelationsWorkerStore(SQLBaseStore):
             A map of the thread summary each event. A missing event implies there
             are no threaded replies.
 
-            Each summary includes the number of items in the thread and the most
-            recent response.
+            Each summary is a tuple of:
+                The number of events in the thread.
+                The most recent event in the thread.
+                The most recent edit to the most recent event in the thread, if applicable.
         """
 
         def _get_thread_summaries_txn(
@@ -558,6 +565,9 @@ class RelationsWorkerStore(SQLBaseStore):
 
         latest_events = await self.get_events(latest_event_ids.values())  # type: ignore[attr-defined]
 
+        # Check to see if any of those events are edited.
+        latest_edits = await self._get_applicable_edits(latest_event_ids.values())
+
         # Map to the event IDs to the thread summary.
         #
         # There might not be a summary due to there not being a thread or
@@ -568,7 +578,8 @@ class RelationsWorkerStore(SQLBaseStore):
 
             summary = None
             if latest_event:
-                summary = (counts[parent_event_id], latest_event)
+                latest_edit = latest_edits.get(latest_event_id)
+                summary = (counts[parent_event_id], latest_event, latest_edit)
             summaries[parent_event_id] = summary
 
         return summaries
@@ -828,11 +839,12 @@ class RelationsWorkerStore(SQLBaseStore):
             )
             for event_id, summary in summaries.items():
                 if summary:
-                    thread_count, latest_thread_event = summary
+                    thread_count, latest_thread_event, edit = summary
                     results.setdefault(
                         event_id, BundledAggregations()
                     ).thread = _ThreadAggregation(
                         latest_event=latest_thread_event,
+                        latest_edit=edit,
                         count=thread_count,
                         # If there's a thread summary it must also exist in the
                         # participated dictionary.
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index de80aca037..dfd9ffcb93 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1123,6 +1123,48 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             {"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
         )
 
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    def test_edit_thread(self):
+        """Test that editing a thread works."""
+
+        # Create a thread and edit the last event.
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "A threaded reply!"},
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        threaded_event_id = channel.json_body["event_id"]
+
+        new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+        channel = self._send_relation(
+            RelationTypes.REPLACE,
+            "m.room.message",
+            content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+            parent_id=threaded_event_id,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # Fetch the thread root, to get the bundled aggregation for the thread.
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/event/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # We expect that the edit message appears in the thread summary in the
+        # unsigned relations section.
+        relations_dict = channel.json_body["unsigned"].get("m.relations")
+        self.assertIn(RelationTypes.THREAD, relations_dict)
+
+        thread_summary = relations_dict[RelationTypes.THREAD]
+        self.assertIn("latest_event", thread_summary)
+        latest_event_in_thread = thread_summary["latest_event"]
+        self.assertEquals(
+            latest_event_in_thread["content"]["body"], "I've been edited!"
+        )
+
     def test_edit_edit(self):
         """Test that an edit cannot be edited."""
         new_body = {"msgtype": "m.text", "body": "Initial edit"}