summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12113.bugfix1
-rw-r--r--changelog.d/12113.misc1
-rw-r--r--changelog.d/12121.bugfix1
-rw-r--r--synapse/storage/databases/main/cache.py2
-rw-r--r--synapse/storage/databases/main/events.py38
-rw-r--r--tests/rest/client/test_relations.py207
6 files changed, 202 insertions, 48 deletions
diff --git a/changelog.d/12113.bugfix b/changelog.d/12113.bugfix
new file mode 100644
index 0000000000..df9b0dc413
--- /dev/null
+++ b/changelog.d/12113.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug when redacting events with relations.
diff --git a/changelog.d/12113.misc b/changelog.d/12113.misc
deleted file mode 100644
index 102e064053..0000000000
--- a/changelog.d/12113.misc
+++ /dev/null
@@ -1 +0,0 @@
-Refactor the tests for event relations.
diff --git a/changelog.d/12121.bugfix b/changelog.d/12121.bugfix
new file mode 100644
index 0000000000..df9b0dc413
--- /dev/null
+++ b/changelog.d/12121.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug when redacting events with relations.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index c428dd5596..abd54c7dc7 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -200,6 +200,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self.get_relations_for_event.invalidate((relates_to,))
             self.get_aggregation_groups_for_event.invalidate((relates_to,))
             self.get_applicable_edit.invalidate((relates_to,))
+            self.get_thread_summary.invalidate((relates_to,))
+            self.get_thread_participated.invalidate((relates_to,))
 
     async def invalidate_cache_and_stream(self, cache_name: str, keys: Tuple[Any, ...]):
         """Invalidates the cache and adds it to the cache stream so slaves
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index ca2a9ba9d1..1dc83aa5e3 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1518,7 +1518,7 @@ class PersistEventsStore:
                 )
 
                 # Remove from relations table.
-                self._handle_redaction(txn, event.redacts)
+                self._handle_redact_relations(txn, event.redacts)
 
         # Update the event_forward_extremities, event_backward_extremities and
         # event_edges tables.
@@ -1943,15 +1943,43 @@ class PersistEventsStore:
 
         txn.execute(sql, (batch_id,))
 
-    def _handle_redaction(self, txn, redacted_event_id):
-        """Handles receiving a redaction and checking whether we need to remove
-        any redacted relations from the database.
+    def _handle_redact_relations(
+        self, txn: LoggingTransaction, redacted_event_id: str
+    ) -> None:
+        """Handles receiving a redaction and checking whether the redacted event
+        has any relations which must be removed from the database.
 
         Args:
             txn
-            redacted_event_id (str): The event that was redacted.
+            redacted_event_id: The event that was redacted.
         """
 
+        # Fetch the current relation of the event being redacted.
+        redacted_relates_to = self.db_pool.simple_select_one_onecol_txn(
+            txn,
+            table="event_relations",
+            keyvalues={"event_id": redacted_event_id},
+            retcol="relates_to_id",
+            allow_none=True,
+        )
+        # Any relation information for the related event must be cleared.
+        if redacted_relates_to is not None:
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_relations_for_event, (redacted_relates_to,)
+            )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_aggregation_groups_for_event, (redacted_relates_to,)
+            )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_applicable_edit, (redacted_relates_to,)
+            )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_thread_summary, (redacted_relates_to,)
+            )
+            self.store._invalidate_cache_and_stream(
+                txn, self.store.get_thread_participated, (redacted_relates_to,)
+            )
+
         self.db_pool.simple_delete_txn(
             txn, table="event_relations", keyvalues={"event_id": redacted_event_id}
         )
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 274f9c44c1..a40a5de399 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1273,7 +1273,21 @@ class RelationsTestCase(BaseRelationsTestCase):
 
 
 class RelationRedactionTestCase(BaseRelationsTestCase):
-    """Test the behaviour of relations when the parent or child event is redacted."""
+    """
+    Test the behaviour of relations when the parent or child event is redacted.
+
+    The behaviour of each relation type is subtly different which causes the tests
+    to be a bit repetitive, they follow a naming scheme of:
+
+        test_redact_(relation|parent)_{relation_type}
+
+    The first bit of "relation" means that the event with the relation defined
+    on it (the child event) is to be redacted. A "parent" means that the target
+    of the relation (the parent event) is to be redacted.
+
+    The relation_type describes which type of relation is under test (i.e. it is
+    related to the value of rel_type in the event content).
+    """
 
     def _redact(self, event_id: str) -> None:
         channel = self.make_request(
@@ -1284,9 +1298,53 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         )
         self.assertEqual(200, channel.code, channel.json_body)
 
+    def _make_relation_requests(self) -> Tuple[List[str], JsonDict]:
+        """
+        Makes requests and ensures they result in a 200 response, returns a
+        tuple of results:
+
+        1. `/relations` -> Returns a list of event IDs.
+        2. `/event` -> Returns the response's m.relations field (from unsigned),
+            if it exists.
+        """
+
+        # Request the relations of the event.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        event_ids = [ev["event_id"] for ev in channel.json_body["chunk"]]
+
+        # Fetch the bundled aggregations of the event.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/event/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        bundled_relations = channel.json_body["unsigned"].get("m.relations", {})
+
+        return event_ids, bundled_relations
+
+    def _get_aggregations(self) -> List[JsonDict]:
+        """Request /aggregations on the parent ID and includes the returned chunk."""
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
+            access_token=self.user_token,
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        return channel.json_body["chunk"]
+
     def test_redact_relation_annotation(self) -> None:
-        """Test that annotations of an event are properly handled after the
+        """
+        Test that annotations of an event are properly handled after the
         annotation is redacted.
+
+        The redacted relation should not be included in bundled aggregations or
+        the response to relations.
         """
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEqual(200, channel.code, channel.json_body)
@@ -1296,24 +1354,97 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
             RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
         )
         self.assertEqual(200, channel.code, channel.json_body)
+        unredacted_event_id = channel.json_body["event_id"]
+
+        # Both relations should exist.
+        event_ids, relations = self._make_relation_requests()
+        self.assertCountEqual(event_ids, [to_redact_event_id, unredacted_event_id])
+        self.assertEquals(
+            relations["m.annotation"],
+            {"chunk": [{"type": "m.reaction", "key": "a", "count": 2}]},
+        )
+
+        # Both relations appear in the aggregation.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 2}])
 
         # Redact one of the reactions.
         self._redact(to_redact_event_id)
 
-        # Ensure that the aggregations are correct.
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}",
-            access_token=self.user_token,
+        # The unredacted relation should still exist.
+        event_ids, relations = self._make_relation_requests()
+        self.assertEquals(event_ids, [unredacted_event_id])
+        self.assertEquals(
+            relations["m.annotation"],
+            {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+        )
+
+        # The unredacted aggregation should still exist.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"type": "m.reaction", "key": "a", "count": 1}])
+
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    def test_redact_relation_thread(self) -> None:
+        """
+        Test that thread replies are properly handled after the thread reply redacted.
+
+        The redacted event should not be included in bundled aggregations or
+        the response to relations.
+        """
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            content={"body": "reply 1", "msgtype": "m.text"},
         )
         self.assertEqual(200, channel.code, channel.json_body)
+        unredacted_event_id = channel.json_body["event_id"]
 
+        # Note that the *last* event in the thread is redacted, as that gets
+        # included in the bundled aggregation.
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            content={"body": "reply 2", "msgtype": "m.text"},
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        to_redact_event_id = channel.json_body["event_id"]
+
+        # Both relations exist.
+        event_ids, relations = self._make_relation_requests()
+        self.assertEquals(event_ids, [to_redact_event_id, unredacted_event_id])
+        self.assertDictContainsSubset(
+            {
+                "count": 2,
+                "current_user_participated": True,
+            },
+            relations[RelationTypes.THREAD],
+        )
+        # And the latest event returned is the event that will be redacted.
         self.assertEqual(
-            channel.json_body,
-            {"chunk": [{"type": "m.reaction", "key": "a", "count": 1}]},
+            relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+            to_redact_event_id,
         )
 
-    def test_redact_relation_edit(self) -> None:
+        # Redact one of the reactions.
+        self._redact(to_redact_event_id)
+
+        # The unredacted relation should still exist.
+        event_ids, relations = self._make_relation_requests()
+        self.assertEquals(event_ids, [unredacted_event_id])
+        self.assertDictContainsSubset(
+            {
+                "count": 1,
+                "current_user_participated": True,
+            },
+            relations[RelationTypes.THREAD],
+        )
+        # And the latest event is now the unredacted event.
+        self.assertEqual(
+            relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+            unredacted_event_id,
+        )
+
+    def test_redact_parent_edit(self) -> None:
         """Test that edits of an event are redacted when the original event
         is redacted.
         """
@@ -1331,34 +1462,19 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self.assertEqual(200, channel.code, channel.json_body)
 
         # Check the relation is returned
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations"
-            f"/{self.parent_id}/m.replace/m.room.message",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        self.assertIn("chunk", channel.json_body)
-        self.assertEqual(len(channel.json_body["chunk"]), 1)
+        event_ids, relations = self._make_relation_requests()
+        self.assertEqual(len(event_ids), 1)
+        self.assertIn(RelationTypes.REPLACE, relations)
 
         # Redact the original event
         self._redact(self.parent_id)
 
-        # Try to check for remaining m.replace relations
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/relations"
-            f"/{self.parent_id}/m.replace/m.room.message",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
-
-        # Check that no relations are returned
-        self.assertIn("chunk", channel.json_body)
-        self.assertEqual(channel.json_body["chunk"], [])
+        # The relations are not returned.
+        event_ids, relations = self._make_relation_requests()
+        self.assertEqual(len(event_ids), 0)
+        self.assertEqual(relations, {})
 
-    def test_redact_parent(self) -> None:
+    def test_redact_parent_annotation(self) -> None:
         """Test that annotations of an event are redacted when the original event
         is redacted.
         """
@@ -1366,16 +1482,23 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
         self.assertEqual(200, channel.code, channel.json_body)
 
+        # The relations should exist.
+        event_ids, relations = self._make_relation_requests()
+        self.assertEqual(len(event_ids), 1)
+        self.assertIn(RelationTypes.ANNOTATION, relations)
+
+        # The aggregation should exist.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [{"type": "m.reaction", "key": "👍", "count": 1}])
+
         # Redact the original event.
         self._redact(self.parent_id)
 
-        # Check that aggregations returns zero
-        channel = self.make_request(
-            "GET",
-            f"/_matrix/client/unstable/rooms/{self.room}/aggregations/{self.parent_id}/m.annotation/m.reaction",
-            access_token=self.user_token,
-        )
-        self.assertEqual(200, channel.code, channel.json_body)
+        # The relations are not returned.
+        event_ids, relations = self._make_relation_requests()
+        self.assertEqual(event_ids, [])
+        self.assertEqual(relations, {})
 
-        self.assertIn("chunk", channel.json_body)
-        self.assertEqual(channel.json_body["chunk"], [])
+        # There's nothing to aggregate.
+        chunk = self._get_aggregations()
+        self.assertEqual(chunk, [])