diff --git a/changelog.d/12519.misc b/changelog.d/12519.misc
new file mode 100644
index 0000000000..9c023d8e3e
--- /dev/null
+++ b/changelog.d/12519.misc
@@ -0,0 +1 @@
+Refactor the relations code for clarity.
diff --git a/synapse/events/utils.py b/synapse/events/utils.py
index 2174b4a094..f8d3ba5456 100644
--- a/synapse/events/utils.py
+++ b/synapse/events/utils.py
@@ -479,9 +479,9 @@ class EventClientSerializer:
Args:
event: The event being serialized.
time_now: The current time in milliseconds
+ config: Event serialization config
aggregations: The bundled aggregation to serialize.
serialized_event: The serialized event which may be modified.
- config: Event serialization config
apply_edits: Whether the content of the event should be modified to reflect
any replacement in `aggregations.replace`.
"""
diff --git a/synapse/handlers/relations.py b/synapse/handlers/relations.py
index 0be2319577..5efb561273 100644
--- a/synapse/handlers/relations.py
+++ b/synapse/handlers/relations.py
@@ -256,64 +256,6 @@ class RelationsHandler:
return filtered_results
- async def _get_bundled_aggregation_for_event(
- self, event: EventBase, ignored_users: FrozenSet[str]
- ) -> Optional[BundledAggregations]:
- """Generate bundled aggregations for an event.
-
- Note that this does not use a cache, but depends on cached methods.
-
- Args:
- event: The event to calculate bundled aggregations for.
- ignored_users: The users ignored by the requesting user.
-
- Returns:
- The bundled aggregations for an event, if bundled aggregations are
- enabled and the event can have bundled aggregations.
- """
-
- # Do not bundle aggregations for an event which represents an edit or an
- # annotation. It does not make sense for them to have related events.
- relates_to = event.content.get("m.relates_to")
- if isinstance(relates_to, (dict, frozendict)):
- relation_type = relates_to.get("rel_type")
- if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
- return None
-
- event_id = event.event_id
- room_id = event.room_id
-
- # The bundled aggregations to include, a mapping of relation type to a
- # type-specific value. Some types include the direct return type here
- # while others need more processing during serialization.
- aggregations = BundledAggregations()
-
- annotations = await self.get_annotations_for_event(
- event_id, room_id, ignored_users=ignored_users
- )
- if annotations:
- aggregations.annotations = {"chunk": annotations}
-
- references, next_token = await self.get_relations_for_event(
- event_id,
- event,
- room_id,
- RelationTypes.REFERENCE,
- ignored_users=ignored_users,
- )
- if references:
- aggregations.references = {
- "chunk": [{"event_id": event.event_id} for event in references]
- }
-
- if next_token:
- aggregations.references["next_batch"] = await next_token.to_string(
- self._main_store
- )
-
- # Store the bundled aggregations in the event metadata for later use.
- return aggregations
-
async def get_threads_for_events(
self, event_ids: Collection[str], user_id: str, ignored_users: FrozenSet[str]
) -> Dict[str, _ThreadAggregation]:
@@ -435,11 +377,39 @@ class RelationsHandler:
# Fetch other relations per event.
for event in events_by_id.values():
- event_result = await self._get_bundled_aggregation_for_event(
- event, ignored_users
+ # Do not bundle aggregations for an event which represents an edit or an
+ # annotation. It does not make sense for them to have related events.
+ relates_to = event.content.get("m.relates_to")
+ if isinstance(relates_to, (dict, frozendict)):
+ relation_type = relates_to.get("rel_type")
+ if relation_type in (RelationTypes.ANNOTATION, RelationTypes.REPLACE):
+ continue
+
+ annotations = await self.get_annotations_for_event(
+ event.event_id, event.room_id, ignored_users=ignored_users
)
- if event_result:
- results[event.event_id] = event_result
+ if annotations:
+ results.setdefault(
+ event.event_id, BundledAggregations()
+ ).annotations = {"chunk": annotations}
+
+ references, next_token = await self.get_relations_for_event(
+ event.event_id,
+ event,
+ event.room_id,
+ RelationTypes.REFERENCE,
+ ignored_users=ignored_users,
+ )
+ if references:
+ aggregations = results.setdefault(event.event_id, BundledAggregations())
+ aggregations.references = {
+ "chunk": [{"event_id": ev.event_id} for ev in references]
+ }
+
+ if next_token:
+ aggregations.references["next_batch"] = await next_token.to_string(
+ self._main_store
+ )
# Fetch any edits (but not for redacted events).
#
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 65743cdf67..39667e3225 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -560,43 +560,6 @@ class RelationsTestCase(BaseRelationsTestCase):
{"event_id": edit_event_id, "sender": self.user_id}, m_replace_dict
)
- def test_edit_thread(self) -> None:
- """Test that editing a thread works."""
-
- # Create a thread and edit the last event.
- channel = self._send_relation(
- RelationTypes.THREAD,
- "m.room.message",
- content={"msgtype": "m.text", "body": "A threaded reply!"},
- )
- threaded_event_id = channel.json_body["event_id"]
-
- new_body = {"msgtype": "m.text", "body": "I've been edited!"}
- self._send_relation(
- RelationTypes.REPLACE,
- "m.room.message",
- content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
- parent_id=threaded_event_id,
- )
-
- # Fetch the thread root, to get the bundled aggregation for the thread.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # We expect that the edit message appears in the thread summary in the
- # unsigned relations section.
- relations_dict = channel.json_body["unsigned"].get("m.relations")
- self.assertIn(RelationTypes.THREAD, relations_dict)
-
- thread_summary = relations_dict[RelationTypes.THREAD]
- self.assertIn("latest_event", thread_summary)
- latest_event_in_thread = thread_summary["latest_event"]
- self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
-
def test_edit_edit(self) -> None:
"""Test that an edit cannot be edited."""
new_body = {"msgtype": "m.text", "body": "Initial edit"}
@@ -1047,7 +1010,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_2 = channel.json_body["event_id"]
- def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ def assert_thread(bundled_aggregations: JsonDict) -> None:
self.assertEqual(2, bundled_aggregations.get("count"))
self.assertTrue(bundled_aggregations.get("current_user_participated"))
# The latest thread event has some fields that don't matter.
@@ -1066,7 +1029,38 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations.get("latest_event"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
+
+ def test_thread_edit_latest_event(self) -> None:
+ """Test that editing the latest event in a thread works."""
+
+ # Create a thread and edit the last event.
+ channel = self._send_relation(
+ RelationTypes.THREAD,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "A threaded reply!"},
+ )
+ threaded_event_id = channel.json_body["event_id"]
+
+ new_body = {"msgtype": "m.text", "body": "I've been edited!"}
+ channel = self._send_relation(
+ RelationTypes.REPLACE,
+ "m.room.message",
+ content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
+ parent_id=threaded_event_id,
+ )
+
+ # Fetch the thread root, to get the bundled aggregation for the thread.
+ relations_dict = self._get_bundled_aggregations()
+
+ # We expect that the edit message appears in the thread summary in the
+ # unsigned relations section.
+ self.assertIn(RelationTypes.THREAD, relations_dict)
+
+ thread_summary = relations_dict[RelationTypes.THREAD]
+ self.assertIn("latest_event", thread_summary)
+ latest_event_in_thread = thread_summary["latest_event"]
+ self.assertEqual(latest_event_in_thread["content"]["body"], "I've been edited!")
def test_aggregation_get_event_for_annotation(self) -> None:
"""Test that annotations do not get bundled aggregations included
@@ -1093,7 +1087,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
thread_id = channel.json_body["event_id"]
- # Annotate the annotation.
+ # Annotate the thread.
self._send_relation(
RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
)
|