summary refs log tree commit diff
path: root/tests/rest/client/test_relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rest/client/test_relations.py')
-rw-r--r--tests/rest/client/test_relations.py227
1 files changed, 206 insertions, 21 deletions
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 96ae7790bb..de80aca037 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -21,7 +21,8 @@ from unittest.mock import patch
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, register, relations, room, sync
-from synapse.types import JsonDict
+from synapse.storage.relations import RelationPaginationToken
+from synapse.types import JsonDict, StreamToken
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -168,24 +169,28 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         """Tests that calling pagination API correctly the latest relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "a")
         self.assertEquals(200, channel.code, channel.json_body)
+        first_annotation_id = channel.json_body["event_id"]
 
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "b")
         self.assertEquals(200, channel.code, channel.json_body)
-        annotation_id = channel.json_body["event_id"]
+        second_annotation_id = channel.json_body["event_id"]
 
         channel = self.make_request(
             "GET",
-            "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1"
-            % (self.room, self.parent_id),
+            f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1",
             access_token=self.user_token,
         )
         self.assertEquals(200, channel.code, channel.json_body)
 
-        # We expect to get back a single pagination result, which is the full
-        # relation event we sent above.
+        # We expect to get back a single pagination result, which is the latest
+        # full relation event we sent above.
         self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
         self.assert_dict(
-            {"event_id": annotation_id, "sender": self.user_id, "type": "m.reaction"},
+            {
+                "event_id": second_annotation_id,
+                "sender": self.user_id,
+                "type": "m.reaction",
+            },
             channel.json_body["chunk"][0],
         )
 
@@ -200,6 +205,36 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             channel.json_body.get("next_batch"), str, channel.json_body
         )
 
+        # Request the relations again, but with a different direction.
+        channel = self.make_request(
+            "GET",
+            f"/_matrix/client/unstable/rooms/{self.room}/relations"
+            f"/{self.parent_id}?limit=1&org.matrix.msc3715.dir=f",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+
+        # We expect to get back a single pagination result, which is the earliest
+        # full relation event we sent above.
+        self.assertEquals(len(channel.json_body["chunk"]), 1, channel.json_body)
+        self.assert_dict(
+            {
+                "event_id": first_annotation_id,
+                "sender": self.user_id,
+                "type": "m.reaction",
+            },
+            channel.json_body["chunk"][0],
+        )
+
+    def _stream_token_to_relation_token(self, token: str) -> str:
+        """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
+        room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
+        return self.get_success(
+            RelationPaginationToken(
+                topological=room_key.topological, stream=room_key.stream
+            ).to_string(self.store)
+        )
+
     def test_repeated_paginate_relations(self):
         """Test that if we paginate using a limit and tokens then we get the
         expected events.
@@ -213,7 +248,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             self.assertEquals(200, channel.code, channel.json_body)
             expected_event_ids.append(channel.json_body["event_id"])
 
-        prev_token: Optional[str] = None
+        prev_token = ""
         found_event_ids: List[str] = []
         for _ in range(20):
             from_token = ""
@@ -222,8 +257,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
             channel = self.make_request(
                 "GET",
-                "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
-                % (self.room, self.parent_id, from_token),
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
                 access_token=self.user_token,
             )
             self.assertEquals(200, channel.code, channel.json_body)
@@ -241,6 +275,93 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         found_event_ids.reverse()
         self.assertEquals(found_event_ids, expected_event_ids)
 
+        # Reset and try again, but convert the tokens to the legacy format.
+        prev_token = ""
+        found_event_ids = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEquals(200, channel.code, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEquals(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEquals(found_event_ids, expected_event_ids)
+
+    def test_pagination_from_sync_and_messages(self):
+        """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_id = channel.json_body["event_id"]
+        # Send an event after the relation events.
+        self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+        # Request /sync, limiting it such that only the latest event is returned
+        # (and not the relation).
+        filter = urllib.parse.quote_plus(
+            '{"room": {"timeline": {"limit": 1}}}'.encode()
+        )
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        sync_prev_batch = room_timeline["prev_batch"]
+        self.assertIsNotNone(sync_prev_batch)
+        # Ensure the relation event is not in the batch returned from /sync.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+        )
+
+        # Request /messages, limiting it such that only the latest event is
+        # returned (and not the relation).
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b&limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        messages_end = channel.json_body["end"]
+        self.assertIsNotNone(messages_end)
+        # Ensure the relation event is not in the chunk returned from /messages.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+        )
+
+        # Request /relations with the pagination tokens received from both the
+        # /sync and /messages responses above, in turn.
+        #
+        # This is a tiny bit silly since the client wouldn't know the parent ID
+        # from the requests above; consider the parent ID to be known from a
+        # previous /sync.
+        for from_token in (sync_prev_batch, messages_end):
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEquals(200, channel.code, channel.json_body)
+
+            # The relation should be in the returned chunk.
+            self.assertIn(
+                annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+            )
+
     def test_aggregation_pagination_groups(self):
         """Test that we can paginate annotation groups correctly."""
 
@@ -337,7 +458,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
         self.assertEquals(200, channel.code, channel.json_body)
 
-        prev_token: Optional[str] = None
+        prev_token = ""
         found_event_ids: List[str] = []
         encoded_key = urllib.parse.quote_plus("👍".encode())
         for _ in range(20):
@@ -347,15 +468,42 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
             channel = self.make_request(
                 "GET",
-                "/_matrix/client/unstable/rooms/%s"
-                "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
-                % (
-                    self.room,
-                    self.parent_id,
-                    RelationTypes.ANNOTATION,
-                    encoded_key,
-                    from_token,
-                ),
+                f"/_matrix/client/unstable/rooms/{self.room}"
+                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+                f"/m.reaction/{encoded_key}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEquals(200, channel.code, channel.json_body)
+
+            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEquals(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEquals(found_event_ids, expected_event_ids)
+
+        # Reset and try again, but convert the tokens to the legacy format.
+        prev_token = ""
+        found_event_ids = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}"
+                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+                f"/m.reaction/{encoded_key}?limit=1{from_token}",
                 access_token=self.user_token,
             )
             self.assertEquals(200, channel.code, channel.json_body)
@@ -453,7 +601,9 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         )
         self.assertEquals(400, channel.code, channel.json_body)
 
-    @unittest.override_config({"experimental_features": {"msc3440_enabled": True}})
+    @unittest.override_config(
+        {"experimental_features": {"msc3440_enabled": True, "msc3666_enabled": True}}
+    )
     def test_bundled_aggregations(self):
         """
         Test that annotations, references, and threads get correctly bundled.
@@ -579,6 +729,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertTrue(room_timeline["limited"])
         assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
 
+        # Request search.
+        channel = self.make_request(
+            "POST",
+            "/search",
+            # Search term matches the parent message.
+            content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        chunk = [
+            result["result"]
+            for result in channel.json_body["search_categories"]["room_events"][
+                "results"
+            ]
+        ]
+        assert_bundle(self._find_event_in_chunk(chunk))
+
     def test_aggregation_get_event_for_annotation(self):
         """Test that annotations do not get bundled aggregations included
         when directly requested.
@@ -759,6 +926,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertEquals(200, channel.code, channel.json_body)
         self.assertNotIn("m.relations", channel.json_body["unsigned"])
 
+    @unittest.override_config({"experimental_features": {"msc3666_enabled": True}})
     def test_edit(self):
         """Test that a simple edit works."""
 
@@ -825,6 +993,23 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         self.assertTrue(room_timeline["limited"])
         assert_bundle(self._find_event_in_chunk(room_timeline["events"]))
 
+        # Request search.
+        channel = self.make_request(
+            "POST",
+            "/search",
+            # Search term matches the parent message.
+            content={"search_categories": {"room_events": {"search_term": "Hi"}}},
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        chunk = [
+            result["result"]
+            for result in channel.json_body["search_categories"]["room_events"][
+                "results"
+            ]
+        ]
+        assert_bundle(self._find_event_in_chunk(chunk))
+
     def test_multi_edit(self):
         """Test that multiple edits, including attempts by people who
         shouldn't be allowed, are correctly handled.