diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py
index 8cf5ebaa07..9ec425888a 100644
--- a/synapse/rest/client/relations.py
+++ b/synapse/rest/client/relations.py
@@ -32,14 +32,45 @@ from synapse.storage.relations import (
PaginationChunk,
RelationPaginationToken,
)
-from synapse.types import JsonDict
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
+async def _parse_token(
+ store: "DataStore", token: Optional[str]
+) -> Optional[StreamToken]:
+ """
+ For backwards compatibility support RelationPaginationToken, but new pagination
+ tokens are generated as full StreamTokens, to be compatible with /sync and /messages.
+ """
+ if not token:
+ return None
+ # Luckily the format for StreamToken and RelationPaginationToken differ enough
+ # that they can easily be separated. An "_" appears in the serialization of
+ # RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses
+ # "-" only for separators.
+ if "_" in token:
+ return await StreamToken.from_string(store, token)
+ else:
+ relation_token = RelationPaginationToken.from_string(token)
+ return StreamToken(
+ room_key=RoomStreamToken(relation_token.topological, relation_token.stream),
+ presence_key=0,
+ typing_key=0,
+ receipt_key=0,
+ account_data_key=0,
+ push_rules_key=0,
+ to_device_key=0,
+ device_list_key=0,
+ groups_key=0,
+ )
+
+
class RelationPaginationServlet(RestServlet):
"""API to paginate relations on an event by topological ordering, optionally
filtered by relation type and event type.
@@ -88,13 +119,8 @@ class RelationPaginationServlet(RestServlet):
pagination_chunk = PaginationChunk(chunk=[])
else:
# Return the relations
- from_token = None
- if from_token_str:
- from_token = RelationPaginationToken.from_string(from_token_str)
-
- to_token = None
- if to_token_str:
- to_token = RelationPaginationToken.from_string(to_token_str)
+ from_token = await _parse_token(self.store, from_token_str)
+ to_token = await _parse_token(self.store, to_token_str)
pagination_chunk = await self.store.get_relations_for_event(
event_id=parent_id,
@@ -125,7 +151,7 @@ class RelationPaginationServlet(RestServlet):
events, now, bundle_aggregations=aggregations
)
- return_value = pagination_chunk.to_dict()
+ return_value = await pagination_chunk.to_dict(self.store)
return_value["chunk"] = serialized_events
return_value["original_event"] = original_event
@@ -216,7 +242,7 @@ class RelationAggregationPaginationServlet(RestServlet):
to_token=to_token,
)
- return 200, pagination_chunk.to_dict()
+ return 200, await pagination_chunk.to_dict(self.store)
class RelationAggregationGroupPaginationServlet(RestServlet):
@@ -287,13 +313,8 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
from_token_str = parse_string(request, "from")
to_token_str = parse_string(request, "to")
- from_token = None
- if from_token_str:
- from_token = RelationPaginationToken.from_string(from_token_str)
-
- to_token = None
- if to_token_str:
- to_token = RelationPaginationToken.from_string(to_token_str)
+ from_token = await _parse_token(self.store, from_token_str)
+ to_token = await _parse_token(self.store, to_token_str)
result = await self.store.get_relations_for_event(
event_id=parent_id,
@@ -313,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet):
now = self.clock.time_msec()
serialized_events = self._event_serializer.serialize_events(events, now)
- return_value = result.to_dict()
+ return_value = await result.to_dict(self.store)
return_value["chunk"] = serialized_events
return 200, return_value
|