diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 06721e67c9..9768fb2971 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -21,7 +21,8 @@ from unittest.mock import patch
from synapse.api.constants import EventTypes, RelationTypes
from synapse.rest import admin
from synapse.rest.client import login, register, relations, room, sync
-from synapse.types import JsonDict
+from synapse.storage.relations import RelationPaginationToken
+from synapse.types import JsonDict, StreamToken
from tests import unittest
from tests.server import FakeChannel
@@ -200,6 +201,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel.json_body.get("next_batch"), str, channel.json_body
)
+ def _stream_token_to_relation_token(self, token: str) -> str:
+ """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
+ room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
+ return self.get_success(
+ RelationPaginationToken(
+ topological=room_key.topological, stream=room_key.stream
+ ).to_string(self.store)
+ )
+
def test_repeated_paginate_relations(self):
"""Test that if we paginate using a limit and tokens then we get the
expected events.
@@ -213,7 +223,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
self.assertEquals(200, channel.code, channel.json_body)
expected_event_ids.append(channel.json_body["event_id"])
- prev_token: Optional[str] = None
+ prev_token = ""
found_event_ids: List[str] = []
for _ in range(20):
from_token = ""
@@ -222,8 +232,35 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
- % (self.room, self.parent_id, from_token),
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEquals(found_event_ids, expected_event_ids)
+
+ # Reset and try again, but convert the tokens to the legacy format.
+ prev_token = ""
+ found_event_ids = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
@@ -241,6 +278,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
found_event_ids.reverse()
self.assertEquals(found_event_ids, expected_event_ids)
+ def test_pagination_from_sync_and_messages(self):
+ """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+ channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+ self.assertEquals(200, channel.code, channel.json_body)
+ annotation_id = channel.json_body["event_id"]
+ # Send an event after the relation events.
+ self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+ # Request /sync, limiting it such that only the latest event is returned
+ # (and not the relation).
+ filter = urllib.parse.quote_plus(
+ '{"room": {"timeline": {"limit": 1}}}'.encode()
+ )
+ channel = self.make_request(
+ "GET", f"/sync?filter={filter}", access_token=self.user_token
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+ sync_prev_batch = room_timeline["prev_batch"]
+ self.assertIsNotNone(sync_prev_batch)
+ # Ensure the relation event is not in the batch returned from /sync.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+ )
+
+ # Request /messages, limiting it such that only the latest event is
+ # returned (and not the relation).
+ channel = self.make_request(
+ "GET",
+ f"/rooms/{self.room}/messages?dir=b&limit=1",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+ messages_end = channel.json_body["end"]
+ self.assertIsNotNone(messages_end)
+ # Ensure the relation event is not in the chunk returned from /messages.
+ self.assertNotIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
+ # Request /relations with the pagination tokens received from both the
+ # /sync and /messages responses above, in turn.
+ #
+ # This is a tiny bit silly since the client wouldn't know the parent ID
+ # from the requests above; consider the parent ID to be known from a
+ # previous /sync.
+ for from_token in (sync_prev_batch, messages_end):
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ # The relation should be in the returned chunk.
+ self.assertIn(
+ annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+ )
+
def test_aggregation_pagination_groups(self):
"""Test that we can paginate annotation groups correctly."""
@@ -337,7 +433,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
self.assertEquals(200, channel.code, channel.json_body)
- prev_token: Optional[str] = None
+ prev_token = ""
found_event_ids: List[str] = []
encoded_key = urllib.parse.quote_plus("👍".encode())
for _ in range(20):
@@ -347,15 +443,42 @@ class RelationsTestCase(unittest.HomeserverTestCase):
channel = self.make_request(
"GET",
- "/_matrix/client/unstable/rooms/%s"
- "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
- % (
- self.room,
- self.parent_id,
- RelationTypes.ANNOTATION,
- encoded_key,
- from_token,
- ),
+ f"/_matrix/client/unstable/rooms/{self.room}"
+ f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+ f"/m.reaction/{encoded_key}?limit=1{from_token}",
+ access_token=self.user_token,
+ )
+ self.assertEquals(200, channel.code, channel.json_body)
+
+ self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+ found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+ next_batch = channel.json_body.get("next_batch")
+
+ self.assertNotEquals(prev_token, next_batch)
+ prev_token = next_batch
+
+ if not prev_token:
+ break
+
+ # We paginated backwards, so reverse
+ found_event_ids.reverse()
+ self.assertEquals(found_event_ids, expected_event_ids)
+
+ # Reset and try again, but convert the tokens to the legacy format.
+ prev_token = ""
+ found_event_ids = []
+ for _ in range(20):
+ from_token = ""
+ if prev_token:
+ from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
+
+ channel = self.make_request(
+ "GET",
+ f"/_matrix/client/unstable/rooms/{self.room}"
+ f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+ f"/m.reaction/{encoded_key}?limit=1{from_token}",
access_token=self.user_token,
)
self.assertEquals(200, channel.code, channel.json_body)
|