summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/11952.bugfix1
-rw-r--r--synapse/rest/client/relations.py57
-rw-r--r--synapse/storage/databases/main/relations.py46
-rw-r--r--synapse/storage/relations.py15
-rw-r--r--tests/rest/client/test_relations.py151
5 files changed, 217 insertions, 53 deletions
diff --git a/changelog.d/11952.bugfix b/changelog.d/11952.bugfix
new file mode 100644
index 0000000000..e38a08f559
--- /dev/null
+++ b/changelog.d/11952.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug where pagination tokens from `/sync` and `/messages` could not be provided to the `/relations` API.
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 8cf5ebaa07..9ec425888a 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -32,14 +32,45 @@ from synapse.storage.relations import (
     PaginationChunk,
     RelationPaginationToken,
 )
-from synapse.types import JsonDict
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
 
+async def _parse_token(
+    store: "DataStore", token: Optional[str]
+) -> Optional[StreamToken]:
+    """
+    For backwards compatibility support RelationPaginationToken, but new pagination
+    tokens are generated as full StreamTokens, to be compatible with /sync and /messages.
+    """
+    if not token:
+        return None
+    # Luckily the format for StreamToken and RelationPaginationToken differ enough
+    # that they can easily be separated. An "_" appears in the serialization of
+    # RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses
+    # "-" only for separators.
+    if "_" in token:
+        return await StreamToken.from_string(store, token)
+    else:
+        relation_token = RelationPaginationToken.from_string(token)
+        return StreamToken(
+            room_key=RoomStreamToken(relation_token.topological, relation_token.stream),
+            presence_key=0,
+            typing_key=0,
+            receipt_key=0,
+            account_data_key=0,
+            push_rules_key=0,
+            to_device_key=0,
+            device_list_key=0,
+            groups_key=0,
+        )
+
+
 class RelationPaginationServlet(RestServlet):
     """API to paginate relations on an event by topological ordering, optionally
     filtered by relation type and event type.
@@ -88,13 +119,8 @@ class RelationPaginationServlet(RestServlet):
             pagination_chunk = PaginationChunk(chunk=[])
         else:
             # Return the relations
-            from_token = None
-            if from_token_str:
-                from_token = RelationPaginationToken.from_string(from_token_str)
-
-            to_token = None
-            if to_token_str:
-                to_token = RelationPaginationToken.from_string(to_token_str)
+            from_token = await _parse_token(self.store, from_token_str)
+            to_token = await _parse_token(self.store, to_token_str)
 
             pagination_chunk = await self.store.get_relations_for_event(
                 event_id=parent_id,
@@ -125,7 +151,7 @@ class RelationPaginationServlet(RestServlet):
             events, now, bundle_aggregations=aggregations
         )
 
-        return_value = pagination_chunk.to_dict()
+        return_value = await pagination_chunk.to_dict(self.store)
         return_value["chunk"] = serialized_events
         return_value["original_event"] = original_event
 
@@ -216,7 +242,7 @@ class RelationAggregationPaginationServlet(RestServlet):
                 to_token=to_token,
             )
 
-        return 200, pagination_chunk.to_dict()
+        return 200, await pagination_chunk.to_dict(self.store)
 
 
 class RelationAggregationGroupPaginationServlet(RestServlet):
@@ -287,13 +313,8 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         from_token_str = parse_string(request, "from")
         to_token_str = parse_string(request, "to")
 
-        from_token = None
-        if from_token_str:
-            from_token = RelationPaginationToken.from_string(from_token_str)
-
-        to_token = None
-        if to_token_str:
-            to_token = RelationPaginationToken.from_string(to_token_str)
+        from_token = await _parse_token(self.store, from_token_str)
+        to_token = await _parse_token(self.store, to_token_str)
 
         result = await self.store.get_relations_for_event(
             event_id=parent_id,
@@ -313,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         now = self.clock.time_msec()
         serialized_events = self._event_serializer.serialize_events(events, now)
 
-        return_value = result.to_dict()
+        return_value = await result.to_dict(self.store)
         return_value["chunk"] = serialized_events
 
         return 200, return_value
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7718acbf1c..ad79cc5610 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -39,16 +39,13 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import (
-    AggregationPaginationToken,
-    PaginationChunk,
-    RelationPaginationToken,
-)
-from synapse.types import JsonDict
+from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
@@ -98,8 +95,8 @@ class RelationsWorkerStore(SQLBaseStore):
         aggregation_key: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
-        from_token: Optional[RelationPaginationToken] = None,
-        to_token: Optional[RelationPaginationToken] = None,
+        from_token: Optional[StreamToken] = None,
+        to_token: Optional[StreamToken] = None,
     ) -> PaginationChunk:
         """Get a list of relations for an event, ordered by topological ordering.
 
@@ -138,8 +135,10 @@ class RelationsWorkerStore(SQLBaseStore):
         pagination_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
-            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
+            from_token=from_token.room_key.as_historical_tuple()
+            if from_token
+            else None,
+            to_token=to_token.room_key.as_historical_tuple() if to_token else None,
             engine=self.database_engine,
         )
 
@@ -177,12 +176,27 @@ class RelationsWorkerStore(SQLBaseStore):
                 last_topo_id = row[1]
                 last_stream_id = row[2]
 
-            next_batch = None
+            # If there are more events, generate the next pagination key.
+            next_token = None
             if len(events) > limit and last_topo_id and last_stream_id:
-                next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+                next_key = RoomStreamToken(last_topo_id, last_stream_id)
+                if from_token:
+                    next_token = from_token.copy_and_replace("room_key", next_key)
+                else:
+                    next_token = StreamToken(
+                        room_key=next_key,
+                        presence_key=0,
+                        typing_key=0,
+                        receipt_key=0,
+                        account_data_key=0,
+                        push_rules_key=0,
+                        to_device_key=0,
+                        device_list_key=0,
+                        groups_key=0,
+                    )
 
             return PaginationChunk(
-                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+                chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
             )
 
         return await self.db_pool.runInteraction(
@@ -676,13 +690,15 @@ class RelationsWorkerStore(SQLBaseStore):
 
         annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
         if annotations.chunk:
-            aggregations.annotations = annotations.to_dict()
+            aggregations.annotations = await annotations.to_dict(
+                cast("DataStore", self)
+            )
 
         references = await self.get_relations_for_event(
             event_id, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
-            aggregations.references = references.to_dict()
+            aggregations.references = await references.to_dict(cast("DataStore", self))
 
         # If this event is the start of a thread, include a summary of the replies.
         if self._msc3440_enabled:
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index b1536c1ca4..36ca2b8273 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -13,13 +13,16 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 import attr
 
 from synapse.api.errors import SynapseError
 from synapse.types import JsonDict
 
+if TYPE_CHECKING:
+    from synapse.storage.databases.main import DataStore
+
 logger = logging.getLogger(__name__)
 
 
@@ -39,14 +42,14 @@ class PaginationChunk:
     next_batch: Optional[Any] = None
     prev_batch: Optional[Any] = None
 
-    def to_dict(self) -> Dict[str, Any]:
+    async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
         d = {"chunk": self.chunk}
 
         if self.next_batch:
-            d["next_batch"] = self.next_batch.to_string()
+            d["next_batch"] = await self.next_batch.to_string(store)
 
         if self.prev_batch:
-            d["prev_batch"] = self.prev_batch.to_string()
+            d["prev_batch"] = await self.prev_batch.to_string(store)
 
         return d
 
@@ -75,7 +78,7 @@ class RelationPaginationToken:
         except ValueError:
             raise SynapseError(400, "Invalid relation pagination token")
 
-    def to_string(self) -> str:
+    async def to_string(self, store: "DataStore") -> str:
         return "%d-%d" % (self.topological, self.stream)
 
     def as_tuple(self) -> Tuple[Any, ...]:
@@ -105,7 +108,7 @@ class AggregationPaginationToken:
         except ValueError:
             raise SynapseError(400, "Invalid aggregation pagination token")
 
-    def to_string(self) -> str:
+    async def to_string(self, store: "DataStore") -> str:
         return "%d-%d" % (self.count, self.stream)
 
     def as_tuple(self) -> Tuple[Any, ...]:
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 06721e67c9..9768fb2971 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -21,7 +21,8 @@ from unittest.mock import patch
 from synapse.api.constants import EventTypes, RelationTypes
 from synapse.rest import admin
 from synapse.rest.client import login, register, relations, room, sync
-from synapse.types import JsonDict
+from synapse.storage.relations import RelationPaginationToken
+from synapse.types import JsonDict, StreamToken
 
 from tests import unittest
 from tests.server import FakeChannel
@@ -200,6 +201,15 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             channel.json_body.get("next_batch"), str, channel.json_body
         )
 
+    def _stream_token_to_relation_token(self, token: str) -> str:
+        """Convert a StreamToken into a legacy token (RelationPaginationToken)."""
+        room_key = self.get_success(StreamToken.from_string(self.store, token)).room_key
+        return self.get_success(
+            RelationPaginationToken(
+                topological=room_key.topological, stream=room_key.stream
+            ).to_string(self.store)
+        )
+
     def test_repeated_paginate_relations(self):
         """Test that if we paginate using a limit and tokens then we get the
         expected events.
@@ -213,7 +223,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
             self.assertEquals(200, channel.code, channel.json_body)
             expected_event_ids.append(channel.json_body["event_id"])
 
-        prev_token: Optional[str] = None
+        prev_token = ""
         found_event_ids: List[str] = []
         for _ in range(20):
             from_token = ""
@@ -222,8 +232,35 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
             channel = self.make_request(
                 "GET",
-                "/_matrix/client/unstable/rooms/%s/relations/%s?limit=1%s"
-                % (self.room, self.parent_id, from_token),
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEquals(200, channel.code, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEquals(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEquals(found_event_ids, expected_event_ids)
+
+        # Reset and try again, but convert the tokens to the legacy format.
+        prev_token = ""
+        found_event_ids = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
                 access_token=self.user_token,
             )
             self.assertEquals(200, channel.code, channel.json_body)
@@ -241,6 +278,65 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         found_event_ids.reverse()
         self.assertEquals(found_event_ids, expected_event_ids)
 
+    def test_pagination_from_sync_and_messages(self):
+        """Pagination tokens from /sync and /messages can be used to paginate /relations."""
+        channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")
+        self.assertEquals(200, channel.code, channel.json_body)
+        annotation_id = channel.json_body["event_id"]
+        # Send an event after the relation events.
+        self.helper.send(self.room, body="Latest event", tok=self.user_token)
+
+        # Request /sync, limiting it such that only the latest event is returned
+        # (and not the relation).
+        filter = urllib.parse.quote_plus(
+            '{"room": {"timeline": {"limit": 1}}}'.encode()
+        )
+        channel = self.make_request(
+            "GET", f"/sync?filter={filter}", access_token=self.user_token
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        room_timeline = channel.json_body["rooms"]["join"][self.room]["timeline"]
+        sync_prev_batch = room_timeline["prev_batch"]
+        self.assertIsNotNone(sync_prev_batch)
+        # Ensure the relation event is not in the batch returned from /sync.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in room_timeline["events"]]
+        )
+
+        # Request /messages, limiting it such that only the latest event is
+        # returned (and not the relation).
+        channel = self.make_request(
+            "GET",
+            f"/rooms/{self.room}/messages?dir=b&limit=1",
+            access_token=self.user_token,
+        )
+        self.assertEquals(200, channel.code, channel.json_body)
+        messages_end = channel.json_body["end"]
+        self.assertIsNotNone(messages_end)
+        # Ensure the relation event is not in the chunk returned from /messages.
+        self.assertNotIn(
+            annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+        )
+
+        # Request /relations with the pagination tokens received from both the
+        # /sync and /messages responses above, in turn.
+        #
+        # This is a tiny bit silly since the client wouldn't know the parent ID
+        # from the requests above; consider the parent ID to be known from a
+        # previous /sync.
+        for from_token in (sync_prev_batch, messages_end):
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}/relations/{self.parent_id}?from={from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEquals(200, channel.code, channel.json_body)
+
+            # The relation should be in the returned chunk.
+            self.assertIn(
+                annotation_id, [ev["event_id"] for ev in channel.json_body["chunk"]]
+            )
+
     def test_aggregation_pagination_groups(self):
         """Test that we can paginate annotation groups correctly."""
 
@@ -337,7 +433,7 @@ class RelationsTestCase(unittest.HomeserverTestCase):
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", key="a")
         self.assertEquals(200, channel.code, channel.json_body)
 
-        prev_token: Optional[str] = None
+        prev_token = ""
         found_event_ids: List[str] = []
         encoded_key = urllib.parse.quote_plus("👍".encode())
         for _ in range(20):
@@ -347,15 +443,42 @@ class RelationsTestCase(unittest.HomeserverTestCase):
 
             channel = self.make_request(
                 "GET",
-                "/_matrix/client/unstable/rooms/%s"
-                "/aggregations/%s/%s/m.reaction/%s?limit=1%s"
-                % (
-                    self.room,
-                    self.parent_id,
-                    RelationTypes.ANNOTATION,
-                    encoded_key,
-                    from_token,
-                ),
+                f"/_matrix/client/unstable/rooms/{self.room}"
+                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+                f"/m.reaction/{encoded_key}?limit=1{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEquals(200, channel.code, channel.json_body)
+
+            self.assertEqual(len(channel.json_body["chunk"]), 1, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEquals(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        # We paginated backwards, so reverse
+        found_event_ids.reverse()
+        self.assertEquals(found_event_ids, expected_event_ids)
+
+        # Reset and try again, but convert the tokens to the legacy format.
+        prev_token = ""
+        found_event_ids = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + self._stream_token_to_relation_token(prev_token)
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/unstable/rooms/{self.room}"
+                f"/aggregations/{self.parent_id}/{RelationTypes.ANNOTATION}"
+                f"/m.reaction/{encoded_key}?limit=1{from_token}",
                 access_token=self.user_token,
             )
             self.assertEquals(200, channel.code, channel.json_body)