summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/12130.bugfix1
-rw-r--r--changelog.d/12189.bugfix1
-rw-r--r--changelog.d/12189.misc1
-rw-r--r--synapse/rest/client/relations.py82
-rw-r--r--synapse/storage/databases/main/cache.py4
-rw-r--r--synapse/storage/databases/main/events.py11
-rw-r--r--synapse/storage/databases/main/relations.py60
-rw-r--r--tests/rest/client/test_relations.py45
8 files changed, 122 insertions, 83 deletions
diff --git a/changelog.d/12130.bugfix b/changelog.d/12130.bugfix
new file mode 100644
index 0000000000..df9b0dc413
--- /dev/null
+++ b/changelog.d/12130.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug when redacting events with relations.
diff --git a/changelog.d/12189.bugfix b/changelog.d/12189.bugfix
new file mode 100644
index 0000000000..df9b0dc413
--- /dev/null
+++ b/changelog.d/12189.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug when redacting events with relations.
diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc
deleted file mode 100644
index 015e808e63..0000000000
--- a/changelog.d/12189.misc
+++ /dev/null
@@ -1 +0,0 @@
-Support skipping some arguments when generating cache keys.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 07fa1cdd4c..d9a6be43f7 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -27,7 +27,7 @@ from synapse.http.server import HttpServer
 from synapse.http.servlet import RestServlet, parse_integer, parse_string
 from synapse.http.site import SynapseRequest
 from synapse.rest.client._base import client_patterns
-from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
+from synapse.storage.relations import AggregationPaginationToken
 from synapse.types import JsonDict, StreamToken
 
 if TYPE_CHECKING:
@@ -82,28 +82,25 @@ class RelationPaginationServlet(RestServlet):
         from_token_str = parse_string(request, "from")
         to_token_str = parse_string(request, "to")
 
-        if event.internal_metadata.is_redacted():
-            # If the event is redacted, return an empty list of relations
-            pagination_chunk = PaginationChunk(chunk=[])
-        else:
-            # Return the relations
-            from_token = None
-            if from_token_str:
-                from_token = await StreamToken.from_string(self.store, from_token_str)
-            to_token = None
-            if to_token_str:
-                to_token = await StreamToken.from_string(self.store, to_token_str)
-
-            pagination_chunk = await self.store.get_relations_for_event(
-                event_id=parent_id,
-                room_id=room_id,
-                relation_type=relation_type,
-                event_type=event_type,
-                limit=limit,
-                direction=direction,
-                from_token=from_token,
-                to_token=to_token,
-            )
+        # Return the relations
+        from_token = None
+        if from_token_str:
+            from_token = await StreamToken.from_string(self.store, from_token_str)
+        to_token = None
+        if to_token_str:
+            to_token = await StreamToken.from_string(self.store, to_token_str)
+
+        pagination_chunk = await self.store.get_relations_for_event(
+            event_id=parent_id,
+            event=event,
+            room_id=room_id,
+            relation_type=relation_type,
+            event_type=event_type,
+            limit=limit,
+            direction=direction,
+            from_token=from_token,
+            to_token=to_token,
+        )
 
         events = await self.store.get_events_as_list(
             [c["event_id"] for c in pagination_chunk.chunk]
@@ -193,27 +190,23 @@ class RelationAggregationPaginationServlet(RestServlet):
         from_token_str = parse_string(request, "from")
         to_token_str = parse_string(request, "to")
 
-        if event.internal_metadata.is_redacted():
-            # If the event is redacted, return an empty list of relations
-            pagination_chunk = PaginationChunk(chunk=[])
-        else:
-            # Return the relations
-            from_token = None
-            if from_token_str:
-                from_token = AggregationPaginationToken.from_string(from_token_str)
-
-            to_token = None
-            if to_token_str:
-                to_token = AggregationPaginationToken.from_string(to_token_str)
-
-            pagination_chunk = await self.store.get_aggregation_groups_for_event(
-                event_id=parent_id,
-                room_id=room_id,
-                event_type=event_type,
-                limit=limit,
-                from_token=from_token,
-                to_token=to_token,
-            )
+        # Return the relations
+        from_token = None
+        if from_token_str:
+            from_token = AggregationPaginationToken.from_string(from_token_str)
+
+        to_token = None
+        if to_token_str:
+            to_token = AggregationPaginationToken.from_string(to_token_str)
+
+        pagination_chunk = await self.store.get_aggregation_groups_for_event(
+            event_id=parent_id,
+            room_id=room_id,
+            event_type=event_type,
+            limit=limit,
+            from_token=from_token,
+            to_token=to_token,
+        )
 
         return 200, await pagination_chunk.to_dict(self.store)
 
@@ -295,6 +288,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
 
         result = await self.store.get_relations_for_event(
             event_id=parent_id,
+            event=event,
             room_id=room_id,
             relation_type=relation_type,
             event_type=event_type,
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index abd54c7dc7..d6a2df1afe 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -191,6 +191,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         if redacts:
             self._invalidate_get_event_cache(redacts)
+            # Caches which might leak edits must be invalidated for the event being
+            # redacted.
+            self.get_relations_for_event.invalidate((redacts,))
+            self.get_applicable_edit.invalidate((redacts,))
 
         if etype == EventTypes.Member:
             self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1dc83aa5e3..1a322882bf 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1619,9 +1619,12 @@ class PersistEventsStore:
 
         txn.call_after(prefill)
 
-    def _store_redaction(self, txn, event):
-        # invalidate the cache for the redacted event
+    def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
+        # Invalidate the caches for the redacted event, note that these caches
+        # are also cleared as part of event replication in _invalidate_caches_for_event.
         txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+        txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
+        txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
 
         self.db_pool.simple_upsert_txn(
             txn,
@@ -1812,9 +1815,7 @@ class PersistEventsStore:
             txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
 
         if rel_type == RelationTypes.THREAD:
-            txn.call_after(
-                self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
-            )
+            txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
             # It should be safe to only invalidate the cache if the user has not
             # previously participated in the thread, but that's difficult (and
             # potentially error-prone) so it is always invalidated.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 36aa1092f6..be1500092b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -91,10 +91,11 @@ class RelationsWorkerStore(SQLBaseStore):
 
         self._msc3440_enabled = hs.config.experimental.msc3440_enabled
 
-    @cached(tree=True)
+    @cached(uncached_args=("event",), tree=True)
     async def get_relations_for_event(
         self,
         event_id: str,
+        event: EventBase,
         room_id: str,
         relation_type: Optional[str] = None,
         event_type: Optional[str] = None,
@@ -108,6 +109,7 @@ class RelationsWorkerStore(SQLBaseStore):
 
         Args:
             event_id: Fetch events that relate to this event ID.
+            event: The matching EventBase to event_id.
             room_id: The room the event belongs to.
             relation_type: Only fetch events with this relation type, if given.
             event_type: Only fetch events with this event type, if given.
@@ -122,9 +124,13 @@ class RelationsWorkerStore(SQLBaseStore):
             List of event IDs that match relations requested. The rows are of
             the form `{"event_id": "..."}`.
         """
+        # We don't use `event_id`, it's there so that we can cache based on
+        # it. The `event_id` must match the `event.event_id`.
+        assert event.event_id == event_id
 
         where_clause = ["relates_to_id = ?", "room_id = ?"]
-        where_args: List[Union[str, int]] = [event_id, room_id]
+        where_args: List[Union[str, int]] = [event.event_id, room_id]
+        is_redacted = event.internal_metadata.is_redacted()
 
         if relation_type is not None:
             where_clause.append("relation_type = ?")
@@ -157,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
             order = "ASC"
 
         sql = """
-            SELECT event_id, topological_ordering, stream_ordering
+            SELECT event_id, relation_type, topological_ordering, stream_ordering
             FROM event_relations
             INNER JOIN events USING (event_id)
             WHERE %s
@@ -178,9 +184,12 @@ class RelationsWorkerStore(SQLBaseStore):
             last_stream_id = None
             events = []
             for row in txn:
-                events.append({"event_id": row[0]})
-                last_topo_id = row[1]
-                last_stream_id = row[2]
+                # Do not include edits for redacted events as they leak event
+                # content.
+                if not is_redacted or row[1] != RelationTypes.REPLACE:
+                    events.append({"event_id": row[0]})
+                last_topo_id = row[2]
+                last_stream_id = row[3]
 
             # If there are more events, generate the next pagination key.
             next_token = None
@@ -776,7 +785,7 @@ class RelationsWorkerStore(SQLBaseStore):
             )
 
         references = await self.get_relations_for_event(
-            event_id, room_id, RelationTypes.REFERENCE, direction="f"
+            event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
             aggregations.references = await references.to_dict(cast("DataStore", self))
@@ -797,41 +806,36 @@ class RelationsWorkerStore(SQLBaseStore):
             A map of event ID to the bundled aggregation for the event. Not all
             events may have bundled aggregations in the results.
         """
-        # The already processed event IDs. Tracked separately from the result
-        # since the result omits events which do not have bundled aggregations.
-        seen_event_ids = set()
-
-        # State events and redacted events do not get bundled aggregations.
-        events = [
-            event
-            for event in events
-            if not event.is_state() and not event.internal_metadata.is_redacted()
-        ]
+        # De-duplicate events by ID to handle the same event requested multiple times.
+        #
+        # State events do not get bundled aggregations.
+        events_by_id = {
+            event.event_id: event for event in events if not event.is_state()
+        }
 
         # event ID -> bundled aggregation in non-serialized form.
         results: Dict[str, BundledAggregations] = {}
 
         # Fetch other relations per event.
-        for event in events:
-            # De-duplicate events by ID to handle the same event requested multiple
-            # times. The caches that _get_bundled_aggregation_for_event use should
-            # capture this, but best to reduce work.
-            if event.event_id in seen_event_ids:
-                continue
-            seen_event_ids.add(event.event_id)
-
+        for event in events_by_id.values():
             event_result = await self._get_bundled_aggregation_for_event(event, user_id)
             if event_result:
                 results[event.event_id] = event_result
 
-        # Fetch any edits.
-        edits = await self._get_applicable_edits(seen_event_ids)
+        # Fetch any edits (but not for redacted events).
+        edits = await self._get_applicable_edits(
+            [
+                event_id
+                for event_id, event in events_by_id.items()
+                if not event.internal_metadata.is_redacted()
+            ]
+        )
         for event_id, edit in edits.items():
             results.setdefault(event_id, BundledAggregations()).replace = edit
 
         # Fetch thread summaries.
         if self._msc3440_enabled:
-            summaries = await self._get_thread_summaries(seen_event_ids)
+            summaries = await self._get_thread_summaries(events_by_id.keys())
             # Only fetch participated for a limited selection based on what had
             # summaries.
             participated = await self._get_threads_participated(
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index a40a5de399..f9ae6e663f 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1475,12 +1475,13 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         self.assertEqual(relations, {})
 
     def test_redact_parent_annotation(self) -> None:
-        """Test that annotations of an event are redacted when the original event
+        """Test that annotations of an event are viewable when the original event
         is redacted.
         """
         # Add a relation
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
         self.assertEqual(200, channel.code, channel.json_body)
+        related_event_id = channel.json_body["event_id"]
 
         # The relations should exist.
         event_ids, relations = self._make_relation_requests()
@@ -1494,11 +1495,45 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
         # Redact the original event.
         self._redact(self.parent_id)
 
-        # The relations are not returned.
+        # The relations are returned.
         event_ids, relations = self._make_relation_requests()
-        self.assertEqual(event_ids, [])
-        self.assertEqual(relations, {})
+        self.assertEquals(event_ids, [related_event_id])
+        self.assertEquals(
+            relations["m.annotation"],
+            {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
+        )
 
         # There's nothing to aggregate.
         chunk = self._get_aggregations()
-        self.assertEqual(chunk, [])
+        self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
+
+    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    def test_redact_parent_thread(self) -> None:
+        """
+        Test that thread replies are still available when the root event is redacted.
+        """
+        channel = self._send_relation(
+            RelationTypes.THREAD,
+            EventTypes.Message,
+            content={"body": "reply 1", "msgtype": "m.text"},
+        )
+        self.assertEqual(200, channel.code, channel.json_body)
+        related_event_id = channel.json_body["event_id"]
+
+        # Redact one of the reactions.
+        self._redact(self.parent_id)
+
+        # The unredacted relation should still exist.
+        event_ids, relations = self._make_relation_requests()
+        self.assertEquals(len(event_ids), 1)
+        self.assertDictContainsSubset(
+            {
+                "count": 1,
+                "current_user_participated": True,
+            },
+            relations[RelationTypes.THREAD],
+        )
+        self.assertEqual(
+            relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+            related_event_id,
+        )