summary refs log tree commit diff
path: root/tests/rest/client/test_relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_relations.py')
-rw-r--r--tests/rest/client/test_relations.py151
1 files changed, 137 insertions, 14 deletions
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 06721e67c9..9768fb2971 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -21,7 +21,8 @@ from unittest.mock import patch
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, register, relations, room, sync
-from synapse.types import JsonDict
+from synapse.storage.relations import RelationPaginationToken
+from synapse.types import JsonDict, StreamToken
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -200,6 +201,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             channel.json_body.get("next_batch"), str, channel.json_body
         )
 
+    def _stream_token_to_relation_token(self, token: str) -> str:
+        """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
+        room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
+        return self.get_success(
+            RelationPaginationToken(
+                topological=room_key.topological, stream=room_key.stream
+            ).to_string(self.store)
+        )
+
     def test_repeated_paginate_relations(self):
         """Test that if we paginate using a limit and tokens then we get the
         expected events.
@@ -213,7 +223,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             self.assertEquals(200, channel.code, channel.json_body)
             expected_event_ids.append(channel.json_body["event_id"])
 
-        prev_token: Optional[str] = None
+        prev_token = ""
         found_event_ids: List[str] = []
         for _ in range(20):
             from_token = ""
@@ -222,8 +232,35 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
             channel = self.make_request(
                 "GET",
-                "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
-                % (self.room, self.parent_id, from_token),
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEquals(200, channel.code, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEquals(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEquals(found_event_ids, expected_event_ids)
+
+        # Reset and try again, but convert the tokens to the legacy format.
+        prev_token = ""
+        found_event_ids = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
                 access_token=self.user_token,
             )
             self.assertEquals(200, channel.code, channel.json_body)
@@ -241,6 +278,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         found_event_ids.reverse()
         self.assertEquals(found_event_ids, expected_event_ids)
 
+    def test_pagination_from_sync_and_messages(self):
+        """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_id = channel.json_body["event_id"]
+        # Send an event after the relation events.
+        self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+        # Request /sync, limiting it such that only the latest event is returned
+        # (and not the relation).
+        filter = urllib.parse.quote_plus(
+            '{"room": {"timeline": {"limit": 1}}}'.encode()
+        )
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        sync_prev_batch = room_timeline["prev_batch"]
+        self.assertIsNotNone(sync_prev_batch)
+        # Ensure the relation event is not in the batch returned from /sync.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+        )
+
+        # Request /messages, limiting it such that only the latest event is
+        # returned (and not the relation).
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b&limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        messages_end = channel.json_body["end"]
+        self.assertIsNotNone(messages_end)
+        # Ensure the relation event is not in the chunk returned from /messages.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+        )
+
+        # Request /relations with the pagination tokens received from both the
+        # /sync and /messages responses above, in turn.
+        #
+        # This is a tiny bit silly since the client wouldn't know the parent ID
+        # from the requests above; consider the parent ID to be known from a
+        # previous /sync.
+        for from_token in (sync_prev_batch, messages_end):
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEquals(200, channel.code, channel.json_body)
+
+            # The relation should be in the returned chunk.
+            self.assertIn(
+                annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+            )
+
     def test_aggregation_pagination_groups(self):
         """Test that we can paginate annotation groups correctly."""
 
@@ -337,7 +433,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
         self.assertEquals(200, channel.code, channel.json_body)
 
-        prev_token: Optional[str] = None
+        prev_token = ""
         found_event_ids: List[str] = []
         encoded_key = urllib.parse.quote_plus("👍".encode())
         for _ in range(20):
@@ -347,15 +443,42 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
             channel = self.make_request(
                 "GET",
-                "/_matrix/client/unstable/rooms/%s"
-                "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
-                % (
-                    self.room,
-                    self.parent_id,
-                    RelationTypes.ANNOTATION,
-                    encoded_key,
-                    from_token,
-                ),
+                f"/_matrix/client/unstable/rooms/{self.room}"
+                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+                f"/m.reaction/{encoded_key}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEquals(200, channel.code, channel.json_body)
+
+            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEquals(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEquals(found_event_ids, expected_event_ids)
+
+        # Reset and try again, but convert the tokens to the legacy format.
+        prev_token = ""
+        found_event_ids = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}"
+                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+                f"/m.reaction/{encoded_key}?limit=1{from_token}",
                 access_token=self.user_token,
             )
             self.assertEquals(200, channel.code, channel.json_body)