diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index f3741b3001..329690f8f7 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,7 +15,7 @@
import itertools
import urllib.parse
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
from unittest.mock import patch
from twisted.test.proto_helpers import MemoryReactor
@@ -155,6 +155,16 @@ class BaseRelationsTestCase(unittest.HomeserverTestCase):
self.assertEqual(200, channel.code, channel.json_body)
return channel.json_body["chunk"]
+ def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
+ """
+ Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
+ """
+ for event in events:
+ if event["event_id"] == self.parent_id:
+ return event
+
+ raise AssertionError(f"Event {self.parent_id} not found in chunk")
+
class RelationsTestCase(BaseRelationsTestCase):
def test_send_relation(self) -> None:
@@ -291,202 +301,6 @@ class RelationsTestCase(BaseRelationsTestCase):
)
self.assertEqual(400, channel.code, channel.json_body)
- @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
- def test_bundled_aggregations(self) -> None:
- """
- Test that annotations, references, and threads get correctly bundled.
-
- Note that this doesn't test against /relations since only thread relations
- get bundled via that API. See test_aggregation_get_event_for_thread.
-
- See test_edit for a similar test for edits.
- """
- # Setup by sending a variety of relations.
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
- )
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- reply_1 = channel.json_body["event_id"]
-
- channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
- reply_2 = channel.json_body["event_id"]
-
- self._send_relation(RelationTypes.THREAD, "m.room.test")
-
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- thread_2 = channel.json_body["event_id"]
-
- def assert_bundle(event_json: JsonDict) -> None:
- """Assert the expected values of the bundled aggregations."""
- relations_dict = event_json["unsigned"].get("m.relations")
-
- # Ensure the fields are as expected.
- self.assertCountEqual(
- relations_dict.keys(),
- (
- RelationTypes.ANNOTATION,
- RelationTypes.REFERENCE,
- RelationTypes.THREAD,
- ),
- )
-
- # Check the values of each field.
- self.assertEqual(
- {
- "chunk": [
- {"type": "m.reaction", "key": "a", "count": 2},
- {"type": "m.reaction", "key": "b", "count": 1},
- ]
- },
- relations_dict[RelationTypes.ANNOTATION],
- )
-
- self.assertEqual(
- {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
- relations_dict[RelationTypes.REFERENCE],
- )
-
- self.assertEqual(
- 2,
- relations_dict[RelationTypes.THREAD].get("count"),
- )
- self.assertTrue(
- relations_dict[RelationTypes.THREAD].get("current_user_participated")
- )
- # The latest thread event has some fields that don't matter.
- self.assert_dict(
- {
- "content": {
- "m.relates_to": {
- "event_id": self.parent_id,
- "rel_type": RelationTypes.THREAD,
- }
- },
- "event_id": thread_2,
- "sender": self.user_id,
- "type": "m.room.test",
- },
- relations_dict[RelationTypes.THREAD].get("latest_event"),
- )
-
- # Request the event directly.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body)
-
- # Request the room messages.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/messages?dir=b",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
-
- # Request the room context.
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- assert_bundle(channel.json_body["event"])
-
- # Request sync.
- channel = self.make_request("GET", "/sync", access_token=self.user_token)
- self.assertEqual(200, channel.code, channel.json_body)
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- self.assertTrue(room_timeline["limited"])
- assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
-
- # Request search.
- channel = self.make_request(
- "POST",
- "/search",
- # Search term matches the parent message.
- content={"search_categories": {"room_events": {"search_term": "Hi"}}},
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- chunk = [
- result["result"]
- for result in channel.json_body["search_categories"]["room_events"][
- "results"
- ]
- ]
- assert_bundle(self._find_event_in_chunk(chunk))
-
- def test_aggregation_get_event_for_annotation(self) -> None:
- """Test that annotations do not get bundled aggregations included
- when directly requested.
- """
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
- annotation_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
- )
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{annotation_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
-
- def test_aggregation_get_event_for_thread(self) -> None:
- """Test that threads get bundled aggregations included when directly requested."""
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
- thread_id = channel.json_body["event_id"]
-
- # Annotate the annotation.
- self._send_relation(
- RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
- )
-
- channel = self.make_request(
- "GET",
- f"/rooms/{self.room}/event/{thread_id}",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(
- channel.json_body["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
- # It should also be included when the entire thread is requested.
- channel = self.make_request(
- "GET",
- f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
- access_token=self.user_token,
- )
- self.assertEqual(200, channel.code, channel.json_body)
- self.assertEqual(len(channel.json_body["chunk"]), 1)
-
- thread_message = channel.json_body["chunk"][0]
- self.assertEqual(
- thread_message["unsigned"].get("m.relations"),
- {
- RelationTypes.ANNOTATION: {
- "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
- },
- },
- )
-
def test_ignore_invalid_room(self) -> None:
"""Test that we ignore invalid relations over federation."""
# Create another room and send a message in it.
@@ -796,7 +610,7 @@ class RelationsTestCase(BaseRelationsTestCase):
threaded_event_id = channel.json_body["event_id"]
new_body = {"msgtype": "m.text", "body": "I've been edited!"}
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={"msgtype": "m.text", "body": "foo", "m.new_content": new_body},
@@ -836,7 +650,7 @@ class RelationsTestCase(BaseRelationsTestCase):
edit_event_id = channel.json_body["event_id"]
# Edit the edit event.
- channel = self._send_relation(
+ self._send_relation(
RelationTypes.REPLACE,
"m.room.message",
content={
@@ -912,16 +726,6 @@ class RelationsTestCase(BaseRelationsTestCase):
self.assertEqual(200, channel.code, channel.json_body)
self.assertEqual(channel.json_body["chunk"], [])
- def _find_event_in_chunk(self, events: List[JsonDict]) -> JsonDict:
- """
- Find the parent event in a chunk of events and assert that it has the proper bundled aggregations.
- """
- for event in events:
- if event["event_id"] == self.parent_id:
- return event
-
- raise AssertionError(f"Event {self.parent_id} not found in chunk")
-
def test_background_update(self) -> None:
"""Test the event_arbitrary_relations background update."""
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="👍")
@@ -981,34 +785,6 @@ class RelationsTestCase(BaseRelationsTestCase):
[annotation_event_id_good, thread_event_id],
)
- def test_bundled_aggregations_with_filter(self) -> None:
- """
- If "unsigned" is an omitted field (due to filtering), adding the bundled
- aggregations should not break.
-
- Note that the spec allows for a server to return additional fields beyond
- what is specified.
- """
- self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
-
- # Note that the sync filter does not include "unsigned" as a field.
- filter = urllib.parse.quote_plus(
- b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}'
- )
- channel = self.make_request(
- "GET", f"/sync?filter={filter}", access_token=self.user_token
- )
- self.assertEqual(200, channel.code, channel.json_body)
-
- # Ensure the timeline is limited, find the parent event.
- room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
- self.assertTrue(room_timeline["limited"])
- parent_event = self._find_event_in_chunk(room_timeline["events"])
-
- # Ensure there's bundled aggregations on it.
- self.assertIn("unsigned", parent_event)
- self.assertIn("m.relations", parent_event["unsigned"])
-
class RelationPaginationTestCase(BaseRelationsTestCase):
def test_basic_paginate_relations(self) -> None:
@@ -1255,7 +1031,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
idx += 1
# Also send a different type of reaction so that we test we don't see it
- channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
prev_token = ""
found_event_ids: List[str] = []
@@ -1291,6 +1067,263 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
self.assertEqual(found_event_ids, expected_event_ids)
+class BundledAggregationsTestCase(BaseRelationsTestCase):
+ """
+ See RelationsTestCase.test_edit for a similar test for edits.
+
+ Note that this doesn't test against /relations since only thread relations
+ get bundled via that API. See test_aggregation_get_event_for_thread.
+ """
+
+ def _test_bundled_aggregations(
+ self,
+ relation_type: str,
+ assertion_callable: Callable[[JsonDict], None],
+ expected_db_txn_for_event: int,
+ ) -> None:
+ """
+ Makes requests to various endpoints which should include bundled aggregations
+ and then calls an assertion function on the bundled aggregations.
+
+ Args:
+ relation_type: The field to search for in the `m.relations` field in unsigned.
+ assertion_callable: Called with the contents of unsigned["m.relations"][relation_type]
+ for relation-specific assertions.
+ expected_db_txn_for_event: The number of database transactions which
+ are expected for a call to /event/.
+ """
+
+ def assert_bundle(event_json: JsonDict) -> None:
+ """Assert the expected values of the bundled aggregations."""
+ relations_dict = event_json["unsigned"].get("m.relations")
+
+ # Ensure the fields are as expected.
+ self.assertCountEqual(relations_dict.keys(), (relation_type,))
+ assertion_callable(relations_dict[relation_type])
+
+ # Request the event directly.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body)
+ assert channel.resource_usage is not None
+ self.assertEqual(channel.resource_usage.db_txn_count, expected_db_txn_for_event)
+
+ # Request the room messages.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
+
+ # Request the room context.
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/context/{self.parent_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ assert_bundle(channel.json_body["event"])
+
+ # Request sync.
+ filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
+
+ # Request search.
+ channel = self.make_request(
+ "POST",
+ "/search",
+ # Search term matches the parent message.
+ content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ chunk = [
+ result["result"]
+ for result in channel.json_body["search_categories"]["room_events"][
+ "results"
+ ]
+ ]
+ assert_bundle(self._find_event_in_chunk(chunk))
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_annotation(self) -> None:
+ """
+ Test that annotations get correctly bundled.
+ """
+ # Setup by sending a variety of relations.
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", access_token=self.user2_token
+ )
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {
+ "chunk": [
+ {"type": "m.reaction", "key": "a", "count": 2},
+ {"type": "m.reaction", "key": "b", "count": 1},
+ ]
+ },
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.ANNOTATION, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_reference(self) -> None:
+ """
+ Test that references get correctly bundled.
+ """
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_1 = channel.json_body["event_id"]
+
+ channel = self._send_relation(RelationTypes.REFERENCE, "m.room.test")
+ reply_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(
+ {"chunk": [{"event_id": reply_1}, {"event_id": reply_2}]},
+ bundled_aggregations,
+ )
+
+ self._test_bundled_aggregations(RelationTypes.REFERENCE, assert_annotations, 7)
+
+ @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
+ def test_thread(self) -> None:
+ """
+ Test that threads get correctly bundled.
+ """
+ self._send_relation(RelationTypes.THREAD, "m.room.test")
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_2 = channel.json_body["event_id"]
+
+ def assert_annotations(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(2, bundled_aggregations.get("count"))
+ self.assertTrue(bundled_aggregations.get("current_user_participated"))
+ # The latest thread event has some fields that don't matter.
+ self.assert_dict(
+ {
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "sender": self.user_id,
+ "type": "m.room.test",
+ },
+ bundled_aggregations.get("latest_event"),
+ )
+
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_annotations, 9)
+
+ def test_aggregation_get_event_for_annotation(self) -> None:
+ """Test that annotations do not get bundled aggregations included
+ when directly requested.
+ """
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+ annotation_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=annotation_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{annotation_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertIsNone(channel.json_body["unsigned"].get("m.relations"))
+
+ def test_aggregation_get_event_for_thread(self) -> None:
+ """Test that threads get bundled aggregations included when directly requested."""
+ channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ thread_id = channel.json_body["event_id"]
+
+ # Annotate the annotation.
+ self._send_relation(
+ RelationTypes.ANNOTATION, "m.reaction", "a", parent_id=thread_id
+ )
+
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/event/{thread_id}",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(
+ channel.json_body["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ # It should also be included when the entire thread is requested.
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+ self.assertEqual(len(channel.json_body["chunk"]), 1)
+
+ thread_message = channel.json_body["chunk"][0]
+ self.assertEqual(
+ thread_message["unsigned"].get("m.relations"),
+ {
+ RelationTypes.ANNOTATION: {
+ "chunk": [{"count": 1, "key": "a", "type": "m.reaction"}]
+ },
+ },
+ )
+
+ def test_bundled_aggregations_with_filter(self) -> None:
+ """
+ If "unsigned" is an omitted field (due to filtering), adding the bundled
+ aggregations should not break.
+
+ Note that the spec allows for a server to return additional fields beyond
+ what is specified.
+ """
+ self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
+
+ # Note that the sync filter does not include "unsigned" as a field.
+ filter = urllib.parse.quote_plus(
+ b'{"event_fields": ["content", "event_id"], "room": {"timeline": {"limit": 3}}}'
+ )
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEqual(200, channel.code, channel.json_body)
+
+ # Ensure the timeline is limited, find the parent event.
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ self.assertTrue(room_timeline["limited"])
+ parent_event = self._find_event_in_chunk(room_timeline["events"])
+
+ # Ensure there's bundled aggregations on it.
+ self.assertIn("unsigned", parent_event)
+ self.assertIn("m.relations", parent_event["unsigned"])
+
+
class RelationRedactionTestCase(BaseRelationsTestCase):
"""
Test the behaviour of relations when the parent or child event is redacted.
|