summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13840.bugfix1
-rw-r--r--synapse/storage/databases/main/relations.py38
-rw-r--r--synapse/storage/databases/main/stream.py6
-rw-r--r--tests/rest/client/test_relations.py29
4 files changed, 60 insertions, 14 deletions
diff --git a/changelog.d/13840.bugfix b/changelog.d/13840.bugfix
new file mode 100644
index 0000000000..0f014439a8
--- /dev/null
+++ b/changelog.d/13840.bugfix
@@ -0,0 +1 @@
+Fix a bug introduced in Synapse v1.53.0 where the experimental implementation of [MSC3715](https://github.com/matrix-org/matrix-spec-proposals/pull/3715) would give incorrect results when paginating forward.
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7bd27790eb..898947af95 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -51,6 +51,8 @@ class _RelatedEvent:
     event_id: str
     # The sender of the related event.
     sender: str
+    topological_ordering: Optional[int]
+    stream_ordering: int
 
 
 class RelationsWorkerStore(SQLBaseStore):
@@ -91,6 +93,9 @@ class RelationsWorkerStore(SQLBaseStore):
         # it. The `event_id` must match the `event.event_id`.
         assert event.event_id == event_id
 
+        # Ensure bad limits aren't being passed in.
+        assert limit >= 0
+
         where_clause = ["relates_to_id = ?", "room_id = ?"]
         where_args: List[Union[str, int]] = [event.event_id, room_id]
         is_redacted = event.internal_metadata.is_redacted()
@@ -139,21 +144,34 @@ class RelationsWorkerStore(SQLBaseStore):
         ) -> Tuple[List[_RelatedEvent], Optional[StreamToken]]:
             txn.execute(sql, where_args + [limit + 1])
 
-            last_topo_id = None
-            last_stream_id = None
             events = []
-            for row in txn:
+            for event_id, relation_type, sender, topo_ordering, stream_ordering in txn:
                 # Do not include edits for redacted events as they leak event
                 # content.
-                if not is_redacted or row[1] != RelationTypes.REPLACE:
-                    events.append(_RelatedEvent(row[0], row[2]))
-                last_topo_id = row[3]
-                last_stream_id = row[4]
+                if not is_redacted or relation_type != RelationTypes.REPLACE:
+                    events.append(
+                        _RelatedEvent(event_id, sender, topo_ordering, stream_ordering)
+                    )
 
-            # If there are more events, generate the next pagination key.
+            # If there are more events, generate the next pagination key from the
+            # last event returned.
             next_token = None
-            if len(events) > limit and last_topo_id and last_stream_id:
-                next_key = RoomStreamToken(last_topo_id, last_stream_id)
+            if len(events) > limit:
+                # Instead of using the last row (which tells us there is more
+                # data), use the last row to be returned.
+                events = events[:limit]
+
+                topo = events[-1].topological_ordering
+                token = events[-1].stream_ordering
+                if direction == "b":
+                    # Tokens are positions between events.
+                    # This token points *after* the last event in the chunk.
+                    # We need it to point to the event before it in the chunk
+                    # when we are going backwards so we subtract one from the
+                    # stream part.
+                    token -= 1
+                next_key = RoomStreamToken(topo, token)
+
                 if from_token:
                     next_token = from_token.copy_and_replace(
                         StreamKeyType.ROOM, next_key
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 3f9bfaeac5..530f04e149 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -1334,15 +1334,15 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         if rows:
             topo = rows[-1].topological_ordering
-            toke = rows[-1].stream_ordering
+            token = rows[-1].stream_ordering
             if direction == "b":
                 # Tokens are positions between events.
                 # This token points *after* the last event in the chunk.
                 # We need it to point to the event before it in the chunk
                 # when we are going backwards so we subtract one from the
                 # stream part.
-                toke -= 1
-            next_token = RoomStreamToken(topo, toke)
+                token -= 1
+            next_token = RoomStreamToken(topo, token)
         else:
             # TODO (erikj): We should work out what to do here instead.
             next_token = to_token if to_token else from_token
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 651f4f415d..d33e34d829 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -788,6 +788,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
             channel.json_body["chunk"][0],
         )
 
+    @unittest.override_config({"experimental_features": {"msc3715_enabled": True}})
     def test_repeated_paginate_relations(self) -> None:
         """Test that if we paginate using a limit and tokens then we get the
         expected events.
@@ -809,7 +810,7 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
 
             channel = self.make_request(
                 "GET",
-                f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=1{from_token}",
+                f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?limit=3{from_token}",
                 access_token=self.user_token,
             )
             self.assertEqual(200, channel.code, channel.json_body)
@@ -827,6 +828,32 @@ class RelationPaginationTestCase(BaseRelationsTestCase):
         found_event_ids.reverse()
         self.assertEqual(found_event_ids, expected_event_ids)
 
+        # Test forward pagination.
+        prev_token = ""
+        found_event_ids = []
+        for _ in range(20):
+            from_token = ""
+            if prev_token:
+                from_token = "&from=" + prev_token
+
+            channel = self.make_request(
+                "GET",
+                f"/_matrix/client/v1/rooms/{self.room}/relations/{self.parent_id}?org.matrix.msc3715.dir=f&limit=3{from_token}",
+                access_token=self.user_token,
+            )
+            self.assertEqual(200, channel.code, channel.json_body)
+
+            found_event_ids.extend(e["event_id"] for e in channel.json_body["chunk"])
+            next_batch = channel.json_body.get("next_batch")
+
+            self.assertNotEqual(prev_token, next_batch)
+            prev_token = next_batch
+
+            if not prev_token:
+                break
+
+        self.assertEqual(found_event_ids, expected_event_ids)
+
     def test_pagination_from_sync_and_messages(self) -> None:
         """Pagination tokens from /sync and /messages can be used to paginate /relations."""
         channel = self._send_relation(RelationTypes.ANNOTATION, "m.reaction", "A")