diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index bc9cc51b92..62e4db23ef 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -896,6 +896,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
relation_type: str,
assertion_callable: Callable[[JsonDict], None],
expected_db_txn_for_event: int,
+ access_token: Optional[str] = None,
) -> None:
"""
Makes requests to various endpoints which should include bundled aggregations
@@ -907,7 +908,9 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
for relation-specific assertions.
expected_db_txn_for_event: The number of database transactions which
are expected for a call to /event/.
+ access_token: The access token to user, defaults to self.user_token.
"""
+ access_token = access_token or self.user_token
def assert_bundle(event_json: JsonDict) -> None:
"""Assert the expected values of the bundled aggregations."""
@@ -921,7 +924,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/event/{self.parent_id}",
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body)
@@ -932,7 +935,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/messages?dir=b",
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(self._find_event_in_chunk(channel.json_body["chunk"]))
@@ -941,7 +944,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
channel = self.make_request(
"GET",
f"/rooms/{self.room}/context/{self.parent_id}",
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
assert_bundle(channel.json_body["event"])
@@ -949,7 +952,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# Request sync.
filter = urllib.parse.quote_plus(b'{"room": {"timeline": {"limit": 4}}}')
channel = self.make_request(
- "GET", f"/sync?filter={filter}", access_token=self.user_token
+ "GET", f"/sync?filter={filter}", access_token=access_token
)
self.assertEqual(200, channel.code, channel.json_body)
room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
@@ -962,7 +965,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"/search",
# Search term matches the parent message.
content={"search_categories": {"room_events": {"search_term": "Hi"}}},
- access_token=self.user_token,
+ access_token=access_token,
)
self.assertEqual(200, channel.code, channel.json_body)
chunk = [
@@ -1037,30 +1040,60 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
"""
Test that threads get correctly bundled.
"""
- self._send_relation(RelationTypes.THREAD, "m.room.test")
- channel = self._send_relation(RelationTypes.THREAD, "m.room.test")
+ # The root message is from "user", send replies as "user2".
+ self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
+ channel = self._send_relation(
+ RelationTypes.THREAD, "m.room.test", access_token=self.user2_token
+ )
thread_2 = channel.json_body["event_id"]
- def assert_thread(bundled_aggregations: JsonDict) -> None:
- self.assertEqual(2, bundled_aggregations.get("count"))
- self.assertTrue(bundled_aggregations.get("current_user_participated"))
- # The latest thread event has some fields that don't matter.
- self.assert_dict(
- {
- "content": {
- "m.relates_to": {
- "event_id": self.parent_id,
- "rel_type": RelationTypes.THREAD,
- }
+ # This needs two assertion functions which are identical except for whether
+ # the current_user_participated flag is True, create a factory for the
+ # two versions.
+ def _gen_assert(participated: bool) -> Callable[[JsonDict], None]:
+ def assert_thread(bundled_aggregations: JsonDict) -> None:
+ self.assertEqual(2, bundled_aggregations.get("count"))
+ self.assertEqual(
+ participated, bundled_aggregations.get("current_user_participated")
+ )
+ # The latest thread event has some fields that don't matter.
+ self.assert_dict(
+ {
+ "content": {
+ "m.relates_to": {
+ "event_id": self.parent_id,
+ "rel_type": RelationTypes.THREAD,
+ }
+ },
+ "event_id": thread_2,
+ "sender": self.user2_id,
+ "type": "m.room.test",
},
- "event_id": thread_2,
- "sender": self.user_id,
- "type": "m.room.test",
- },
- bundled_aggregations.get("latest_event"),
- )
+ bundled_aggregations.get("latest_event"),
+ )
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
+ return assert_thread
+
+ # The "user" sent the root event and is making queries for the bundled
+ # aggregations: they have participated.
+ self._test_bundled_aggregations(RelationTypes.THREAD, _gen_assert(True), 8)
+ # The "user2" sent replies in the thread and is making queries for the
+ # bundled aggregations: they have participated.
+ #
+ # Note that this re-uses some cached values, so the total number of
+ # queries is much smaller.
+ self._test_bundled_aggregations(
+ RelationTypes.THREAD, _gen_assert(True), 2, access_token=self.user2_token
+ )
+
+ # A user with no interactions with the thread: they have not participated.
+ user3_id, user3_token = self._create_user("charlie")
+ self.helper.join(self.room, user=user3_id, tok=user3_token)
+ self._test_bundled_aggregations(
+ RelationTypes.THREAD, _gen_assert(False), 2, access_token=user3_token
+ )
def test_thread_with_bundled_aggregations_for_latest(self) -> None:
"""
@@ -1106,7 +1139,7 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
bundled_aggregations["latest_event"].get("unsigned"),
)
- self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 9)
+ self._test_bundled_aggregations(RelationTypes.THREAD, assert_thread, 8)
def test_nested_thread(self) -> None:
"""
|