summary refs log tree commit diff
path: root/synapse/rest/client/relations.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/relations.py')
-rw-r--r--synapse/rest/client/relations.py61
1 files changed, 43 insertions, 18 deletions
diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 8cf5ebaa07..2cab83c4e6 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -32,14 +32,45 @@ from synapse.storage.relations import (
     PaginationChunk,
     RelationPaginationToken,
 )
-from synapse.types import JsonDict
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
 
+async def _parse_token(
+    store: "DataStore", token: Optional[str]
+) -> Optional[StreamToken]:
+    """
+    For backwards compatibility support RelationPaginationToken, but new pagination
+    tokens are generated as full StreamTokens, to be compatible with /sync and /messages.
+    """
+    if not token:
+        return None
+    # Luckily the format for StreamToken and RelationPaginationToken differ enough
+    # that they can easily be separated. An "_" appears in the serialization of
+    # RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses
+    # "-" only for separators.
+    if "_" in token:
+        return await StreamToken.from_string(store, token)
+    else:
+        relation_token = RelationPaginationToken.from_string(token)
+        return StreamToken(
+            room_key=RoomStreamToken(relation_token.topological, relation_token.stream),
+            presence_key=0,
+            typing_key=0,
+            receipt_key=0,
+            account_data_key=0,
+            push_rules_key=0,
+            to_device_key=0,
+            device_list_key=0,
+            groups_key=0,
+        )
+
+
 class RelationPaginationServlet(RestServlet):
     """API to paginate relations on an event by topological ordering, optionally
     filtered by relation type and event type.
@@ -80,6 +111,9 @@ class RelationPaginationServlet(RestServlet):
             raise SynapseError(404, "Unknown parent event.")
 
         limit = parse_integer(request, "limit", default=5)
+        direction = parse_string(
+            request, "org.matrix.msc3715.dir", default="b", allowed_values=["f", "b"]
+        )
         from_token_str = parse_string(request, "from")
         to_token_str = parse_string(request, "to")
 
@@ -88,13 +122,8 @@ class RelationPaginationServlet(RestServlet):
             pagination_chunk = PaginationChunk(chunk=[])
         else:
             # Return the relations
-            from_token = None
-            if from_token_str:
-                from_token = RelationPaginationToken.from_string(from_token_str)
-
-            to_token = None
-            if to_token_str:
-                to_token = RelationPaginationToken.from_string(to_token_str)
+            from_token = await _parse_token(self.store, from_token_str)
+            to_token = await _parse_token(self.store, to_token_str)
 
             pagination_chunk = await self.store.get_relations_for_event(
                 event_id=parent_id,
@@ -102,6 +131,7 @@ class RelationPaginationServlet(RestServlet):
                 relation_type=relation_type,
                 event_type=event_type,
                 limit=limit,
+                direction=direction,
                 from_token=from_token,
                 to_token=to_token,
             )
@@ -125,7 +155,7 @@ class RelationPaginationServlet(RestServlet):
             events, now, bundle_aggregations=aggregations
         )
 
-        return_value = pagination_chunk.to_dict()
+        return_value = await pagination_chunk.to_dict(self.store)
         return_value["chunk"] = serialized_events
         return_value["original_event"] = original_event
 
@@ -216,7 +246,7 @@ class RelationAggregationPaginationServlet(RestServlet):
                 to_token=to_token,
             )
 
-        return 200, pagination_chunk.to_dict()
+        return 200, await pagination_chunk.to_dict(self.store)
 
 
 class RelationAggregationGroupPaginationServlet(RestServlet):
@@ -287,13 +317,8 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         from_token_str = parse_string(request, "from")
         to_token_str = parse_string(request, "to")
 
-        from_token = None
-        if from_token_str:
-            from_token = RelationPaginationToken.from_string(from_token_str)
-
-        to_token = None
-        if to_token_str:
-            to_token = RelationPaginationToken.from_string(to_token_str)
+        from_token = await _parse_token(self.store, from_token_str)
+        to_token = await _parse_token(self.store, to_token_str)
 
         result = await self.store.get_relations_for_event(
             event_id=parent_id,
@@ -313,7 +338,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
         now = self.clock.time_msec()
         serialized_events = self._event_serializer.serialize_events(events, now)
 
-        return_value = result.to_dict()
+        return_value = await result.to_dict(self.store)
         return_value["chunk"] = serialized_events
 
         return 200, return_value