diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 36aa1092f6..be1500092b 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -91,10 +91,11 @@ class RelationsWorkerStore(SQLBaseStore):
self._msc3440_enabled = hs.config.experimental.msc3440_enabled
- @cached(tree=True)
+ @cached(uncached_args=("event",), tree=True)
async def get_relations_for_event(
self,
event_id: str,
+ event: EventBase,
room_id: str,
relation_type: Optional[str] = None,
event_type: Optional[str] = None,
@@ -108,6 +109,7 @@ class RelationsWorkerStore(SQLBaseStore):
Args:
event_id: Fetch events that relate to this event ID.
+ event: The matching EventBase to event_id.
room_id: The room the event belongs to.
relation_type: Only fetch events with this relation type, if given.
event_type: Only fetch events with this event type, if given.
@@ -122,9 +124,13 @@ class RelationsWorkerStore(SQLBaseStore):
List of event IDs that match relations requested. The rows are of
the form `{"event_id": "..."}`.
"""
+ # We don't use `event_id`, it's there so that we can cache based on
+ # it. The `event_id` must match the `event.event_id`.
+ assert event.event_id == event_id
where_clause = ["relates_to_id = ?", "room_id = ?"]
- where_args: List[Union[str, int]] = [event_id, room_id]
+ where_args: List[Union[str, int]] = [event.event_id, room_id]
+ is_redacted = event.internal_metadata.is_redacted()
if relation_type is not None:
where_clause.append("relation_type = ?")
@@ -157,7 +163,7 @@ class RelationsWorkerStore(SQLBaseStore):
order = "ASC"
sql = """
- SELECT event_id, topological_ordering, stream_ordering
+ SELECT event_id, relation_type, topological_ordering, stream_ordering
FROM event_relations
INNER JOIN events USING (event_id)
WHERE %s
@@ -178,9 +184,12 @@ class RelationsWorkerStore(SQLBaseStore):
last_stream_id = None
events = []
for row in txn:
- events.append({"event_id": row[0]})
- last_topo_id = row[1]
- last_stream_id = row[2]
+ # Do not include edits for redacted events as they leak event
+ # content.
+ if not is_redacted or row[1] != RelationTypes.REPLACE:
+ events.append({"event_id": row[0]})
+ last_topo_id = row[2]
+ last_stream_id = row[3]
# If there are more events, generate the next pagination key.
next_token = None
@@ -776,7 +785,7 @@ class RelationsWorkerStore(SQLBaseStore):
)
references = await self.get_relations_for_event(
- event_id, room_id, RelationTypes.REFERENCE, direction="f"
+ event_id, event, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
aggregations.references = await references.to_dict(cast("DataStore", self))
@@ -797,41 +806,36 @@ class RelationsWorkerStore(SQLBaseStore):
A map of event ID to the bundled aggregation for the event. Not all
events may have bundled aggregations in the results.
"""
- # The already processed event IDs. Tracked separately from the result
- # since the result omits events which do not have bundled aggregations.
- seen_event_ids = set()
-
- # State events and redacted events do not get bundled aggregations.
- events = [
- event
- for event in events
- if not event.is_state() and not event.internal_metadata.is_redacted()
- ]
+ # De-duplicate events by ID to handle the same event requested multiple times.
+ #
+ # State events do not get bundled aggregations.
+ events_by_id = {
+ event.event_id: event for event in events if not event.is_state()
+ }
# event ID -> bundled aggregation in non-serialized form.
results: Dict[str, BundledAggregations] = {}
# Fetch other relations per event.
- for event in events:
- # De-duplicate events by ID to handle the same event requested multiple
- # times. The caches that _get_bundled_aggregation_for_event use should
- # capture this, but best to reduce work.
- if event.event_id in seen_event_ids:
- continue
- seen_event_ids.add(event.event_id)
-
+ for event in events_by_id.values():
event_result = await self._get_bundled_aggregation_for_event(event, user_id)
if event_result:
results[event.event_id] = event_result
- # Fetch any edits.
- edits = await self._get_applicable_edits(seen_event_ids)
+ # Fetch any edits (but not for redacted events).
+ edits = await self._get_applicable_edits(
+ [
+ event_id
+ for event_id, event in events_by_id.items()
+ if not event.internal_metadata.is_redacted()
+ ]
+ )
for event_id, edit in edits.items():
results.setdefault(event_id, BundledAggregations()).replace = edit
# Fetch thread summaries.
if self._msc3440_enabled:
- summaries = await self._get_thread_summaries(seen_event_ids)
+ summaries = await self._get_thread_summaries(events_by_id.keys())
# Only fetch participated for a limited selection based on what had
# summaries.
participated = await self._get_threads_participated(
|