diff --git a/changelog.d/12130.bugfix b/changelog.d/12130.bugfix
new file mode 100644
index 0000000000..df9b0dc413
--- /dev/null
+++ b/changelog.d/12130.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug when redacting events with relations.
diff --git a/changelog.d/12189.bugfix b/changelog.d/12189.bugfix
new file mode 100644
index 0000000000..df9b0dc413
--- /dev/null
+++ b/changelog.d/12189.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug when redacting events with relations.
diff --git a/changelog.d/12189.misc b/changelog.d/12189.misc
deleted file mode 100644
index 015e808e63..0000000000
--- a/changelog.d/12189.misc
+++ /dev/null
@@ -1 +0,0 @@
-Support skipping some arguments when generating cache keys.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 07fa1cdd4c..d9a6be43f7 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -27,7 +27,7 @@ from synapse.http.server import HttpServer
from synapse.http.servlet import RestServlet, parse_integer, parse_string
from synapse.http.site import SynapseRequest
from synapse.rest.client._base import client_patterns
-from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
+from synapse.storage.relations import AggregationPaginationToken
from synapse.types import JsonDict, StreamToken
if TYPE_CHECKING:
@@ -82,28 +82,25 @@ class RelationPaginationServlet(RestServlet):
from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")
- if event.internal_metadata.is_redacted():
- # If the event is redacted, return an empty list of relations
- pagination_chunk = PaginationChunk(chunk=[])
- else:
- # Return the relations
- from_token = None
- if from_token_str:
- from_token = await StreamToken.from_string(self.store, from_token_str)
- to_token = None
- if to_token_str:
- to_token = await StreamToken.from_string(self.store, to_token_str)
-
- pagination_chunk = await self.store.get_relations_for_event(
- event_id=parent_id,
- room_id=room_id,
- relation_type=relation_type,
- event_type=event_type,
- limit=limit,
- direction=direction,
- from_token=from_token,
- to_token=to_token,
- )
+ # Return the relations
+ from_token = None
+ if from_token_str:
+ from_token = await StreamToken.from_string(self.store, from_token_str)
+ to_token = None
+ if to_token_str:
+ to_token = await StreamToken.from_string(self.store, to_token_str)
+
+ pagination_chunk = await self.store.get_relations_for_event(
+ event_id=parent_id,
+ event=event,
+ room_id=room_id,
+ relation_type=relation_type,
+ event_type=event_type,
+ limit=limit,
+ direction=direction,
+ from_token=from_token,
+ to_token=to_token,
+ )
events = await self.store.get_events_as_list(
[c["event_id"] for c in pagination_chunk.chunk]
@@ -193,27 +190,23 @@ class RelationAggregationPaginationServlet(RestServlet):
from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")
- if event.internal_metadata.is_redacted():
- # If the event is redacted, return an empty list of relations
- pagination_chunk = PaginationChunk(chunk=[])
- else:
- # Return the relations
- from_token = None
- if from_token_str:
- from_token = AggregationPaginationToken.from_string(from_token_str)
-
- to_token = None
- if to_token_str:
- to_token = AggregationPaginationToken.from_string(to_token_str)
-
- pagination_chunk = await self.store.get_aggregation_groups_for_event(
- event_id=parent_id,
- room_id=room_id,
- event_type=event_type,
- limit=limit,
- from_token=from_token,
- to_token=to_token,
- )
+ # Return the relations
+ from_token = None
+ if from_token_str:
+ from_token = AggregationPaginationToken.from_string(from_token_str)
+
+ to_token = None
+ if to_token_str:
+ to_token = AggregationPaginationToken.from_string(to_token_str)
+
+ pagination_chunk = await self.store.get_aggregation_groups_for_event(
+ event_id=parent_id,
+ room_id=room_id,
+ event_type=event_type,
+ limit=limit,
+ from_token=from_token,
+ to_token=to_token,
+ )
return 200, await pagination_chunk.to_dict(self.store)
@@ -295,6 +288,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
result = await self.store.get_relations_for_event(
event_id=parent_id,
+ event=event,
room_id=room_id,
relation_type=relation_type,
event_type=event_type,
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index abd54c7dc7..d6a2df1afe 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -191,6 +191,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if redacts:
self._invalidate_get_event_cache(redacts)
+ # Caches which might leak edits must be invalidated for the event being
+ # redacted.
+ self.get_relations_for_event.invalidate((redacts,))
+ self.get_applicable_edit.invalidate((redacts,))
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1dc83aa5e3..1a322882bf 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1619,9 +1619,12 @@ class PersistEventsStore:
txn.call_after(prefill)
- def _store_redaction(self, txn, event):
- # invalidate the cache for the redacted event
+ def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
+ # Invalidate the caches for the redacted event, note that these caches
+ # are also cleared as part of event replication in _invalidate_caches_for_event.
txn.call_after(self.store._invalidate_get_event_cache, event.redacts)
+ txn.call_after(self.store.get_relations_for_event.invalidate, (event.redacts,))
+ txn.call_after(self.store.get_applicable_edit.invalidate, (event.redacts,))
self.db_pool.simple_upsert_txn(
txn,
@@ -1812,9 +1815,7 @@ class PersistEventsStore:
txn.call_after(self.store.get_applicable_edit.invalidate, (parent_id,))
if rel_type == RelationTypes.THREAD:
- txn.call_after(
- self.store.get_thread_summary.invalidate, (parent_id, event.room_id)
- )
+ txn.call_after(self.store.get_thread_summary.invalidate, (parent_id,))
# It should be safe to only invalidate the cache if the user has not
# previously participated in the thread, but that's difficult (and
# potentially error-prone) so it is always invalidated.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 36aa1092f6..be1500092b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -91,10 +91,11 @@ class RelationsWorkerStore(SQLBaseStore):
self._msc3440_enabled = hs.config.experimental.msc3440_enabled
- @cached(tree=True)
+ @cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
event_id: str,
+ event: EventBase,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
@@ -108,6 +109,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
+ event: The matching EventBase to event_id.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
@@ -122,9 +124,13 @@ class RelationsWorkerStore(SQLBaseStore):
List of event IDs that match relations requested. The rows are of
the form `{"event_id": "..."}`.
"""
+ # We don't use `event_id`, it's there so that we can cache based on
+ # it. The `event_id` must match the `event.event_id`.
+ assert event.event_id == event_id
where_clause = ["relates_to_id = ?", "room_id = ?"]
- where_args: List[Union[str, int]] = [event_id, room_id]
+ where_args: List[Union[str, int]] = [event.event_id, room_id]
+ is_redacted = event.internal_metadata.is_redacted()
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -157,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
order = "ASC"
sql = """
- SELECT event_id, topological_ordering, stream_ordering
+ SELECT event_id, relation_type, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
@@ -178,9 +184,12 @@ class RelationsWorkerStore(SQLBaseStore):
last_stream_id = None
events = []
for row in txn:
- events.append({"event_id": row[0]})
- last_topo_id = row[1]
- last_stream_id = row[2]
+ # Do not include edits for redacted events as they leak event
+ # content.
+ if not is_redacted or row[1] != RelationTypes.REPLACE:
+ events.append({"event_id": row[0]})
+ last_topo_id = row[2]
+ last_stream_id = row[3]
# If there are more events, generate the next pagination key.
next_token = None
@@ -776,7 +785,7 @@ class RelationsWorkerStore(SQLBaseStore):
)
references = await self.get_relations_for_event(
- event_id, room_id, RelationTypes.REFERENCE, direction="f"
+ event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
@@ -797,41 +806,36 @@ class RelationsWorkerStore(SQLBaseStore):
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
- # The already processed event IDs. Tracked separately from the result
- # since the result omits events which do not have bundled aggregations.
- seen_event_ids = set()
-
- # State events and redacted events do not get bundled aggregations.
- events = [
- event
- for event in events
- if not event.is_state() and not event.internal_metadata.is_redacted()
- ]
+ # De-duplicate events by ID to handle the same event requested multiple times.
+ #
+ # State events do not get bundled aggregations.
+ events_by_id = {
+ event.event_id: event for event in events if not event.is_state()
+ }
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch other relations per event.
- for event in events:
- # De-duplicate events by ID to handle the same event requested multiple
- # times. The caches that _get_bundled_aggregation_for_event use should
- # capture this, but best to reduce work.
- if event.event_id in seen_event_ids:
- continue
- seen_event_ids.add(event.event_id)
-
+ for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
- # Fetch any edits.
- edits = await self._get_applicable_edits(seen_event_ids)
+ # Fetch any edits (but not for redacted events).
+ edits = await self._get_applicable_edits(
+ [
+ event_id
+ for event_id, event in events_by_id.items()
+ if not event.internal_metadata.is_redacted()
+ ]
+ )
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
if self._msc3440_enabled:
- summaries = await self._get_thread_summaries(seen_event_ids)
+ summaries = await self._get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index a40a5de399..f9ae6e663f 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -1475,12 +1475,13 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
self.assertEqual(relations, {})
def test_redact_parent_annotation(self) -> None:
- """Test that annotations of an event are redacted when the original event
+ """Test that annotations of an event are viewable when the original event
is redacted.
"""
# Add a relation
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
self.assertEqual(200, channel.code, channel.json_body)
+ related_event_id = channel.json_body["event_id"]
# The relations should exist.
event_ids, relations = self._make_relation_requests()
@@ -1494,11 +1495,45 @@ class RelationRedactionTestCase(BaseRelationsTestCase):
# Redact the original event.
self._redact(self.parent_id)
- # The relations are not returned.
+ # The relations are returned.
event_ids, relations = self._make_relation_requests()
- self.assertEqual(event_ids, [])
- self.assertEqual(relations, {})
+ self.assertEquals(event_ids, [related_event_id])
+ self.assertEquals(
+ relations["m.annotation"],
+ {"chunk": [{"type": "m.reaction", "key": "👍", "count": 1}]},
+ )
# There's nothing to aggregate.
chunk = self._get_aggregations()
- self.assertEqual(chunk, [])
+ self.assertEqual(chunk, [{"count": 1, "key": "👍", "type": "m.reaction"}])
+
+ @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+ def test_redact_parent_thread(self) -> None:
+ """
+ Test that thread replies are still available when the root event is redacted.
+ """
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ EventTypes.Message,
+ content={"body": "reply 1", "msgtype": "m.text"},
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ related_event_id = channel.json_body["event_id"]
+
+ # Redact one of the reactions.
+ self._redact(self.parent_id)
+
+ # The unredacted relation should still exist.
+ event_ids, relations = self._make_relation_requests()
+ self.assertEquals(len(event_ids), 1)
+ self.assertDictContainsSubset(
+ {
+ "count": 1,
+ "current_user_participated": True,
+ },
+ relations[RelationTypes.THREAD],
+ )
+ self.assertEqual(
+ relations[RelationTypes.THREAD]["latest_event"]["event_id"],
+ related_event_id,
+ )
|