From df36945ff0e4a293a9dac0da07e2c94256835b32 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 10 Feb 2022 10:52:48 -0500 Subject: Support pagination tokens from /sync and /messages in the relations API. (#11952) --- synapse/rest/client/relations.py | 57 ++++++++++++++++++++--------- synapse/storage/databases/main/relations.py | 46 +++++++++++++++-------- synapse/storage/relations.py | 15 +++++--- 3 files changed, 79 insertions(+), 39 deletions(-) (limited to 'synapse') diff --git a/synapse/rest/client/relations.py b/synapse/rest/client/relations.py index 8cf5ebaa07..9ec425888a 100644 --- a/synapse/rest/client/relations.py +++ b/synapse/rest/client/relations.py @@ -32,14 +32,45 @@ from synapse.storage.relations import ( PaginationChunk, RelationPaginationToken, ) -from synapse.types import JsonDict +from synapse.types import JsonDict, RoomStreamToken, StreamToken if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) +async def _parse_token( + store: "DataStore", token: Optional[str] +) -> Optional[StreamToken]: + """ + For backwards compatibility support RelationPaginationToken, but new pagination + tokens are generated as full StreamTokens, to be compatible with /sync and /messages. + """ + if not token: + return None + # Luckily the format for StreamToken and RelationPaginationToken differ enough + # that they can easily be separated. An "_" appears in the serialization of + # RoomStreamToken (as part of StreamToken), but RelationPaginationToken uses + # "-" only for separators. + if "_" in token: + return await StreamToken.from_string(store, token) + else: + relation_token = RelationPaginationToken.from_string(token) + return StreamToken( + room_key=RoomStreamToken(relation_token.topological, relation_token.stream), + presence_key=0, + typing_key=0, + receipt_key=0, + account_data_key=0, + push_rules_key=0, + to_device_key=0, + device_list_key=0, + groups_key=0, + ) + + class RelationPaginationServlet(RestServlet): """API to paginate relations on an event by topological ordering, optionally filtered by relation type and event type. @@ -88,13 +119,8 @@ class RelationPaginationServlet(RestServlet): pagination_chunk = PaginationChunk(chunk=[]) else: # Return the relations - from_token = None - if from_token_str: - from_token = RelationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = RelationPaginationToken.from_string(to_token_str) + from_token = await _parse_token(self.store, from_token_str) + to_token = await _parse_token(self.store, to_token_str) pagination_chunk = await self.store.get_relations_for_event( event_id=parent_id, @@ -125,7 +151,7 @@ class RelationPaginationServlet(RestServlet): events, now, bundle_aggregations=aggregations ) - return_value = pagination_chunk.to_dict() + return_value = await pagination_chunk.to_dict(self.store) return_value["chunk"] = serialized_events return_value["original_event"] = original_event @@ -216,7 +242,7 @@ class RelationAggregationPaginationServlet(RestServlet): to_token=to_token, ) - return 200, pagination_chunk.to_dict() + return 200, await pagination_chunk.to_dict(self.store) class RelationAggregationGroupPaginationServlet(RestServlet): @@ -287,13 +313,8 @@ class RelationAggregationGroupPaginationServlet(RestServlet): from_token_str = parse_string(request, "from") to_token_str = parse_string(request, "to") - from_token = None - if from_token_str: - from_token = RelationPaginationToken.from_string(from_token_str) - - to_token = None - if to_token_str: - to_token = RelationPaginationToken.from_string(to_token_str) + from_token = await _parse_token(self.store, from_token_str) + to_token = await _parse_token(self.store, to_token_str) result = await self.store.get_relations_for_event( event_id=parent_id, @@ -313,7 +334,7 @@ class RelationAggregationGroupPaginationServlet(RestServlet): now = self.clock.time_msec() serialized_events = self._event_serializer.serialize_events(events, now) - return_value = result.to_dict() + return_value = await result.to_dict(self.store) return_value["chunk"] = serialized_events return 200, return_value diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py index 7718acbf1c..ad79cc5610 100644 --- a/synapse/storage/databases/main/relations.py +++ b/synapse/storage/databases/main/relations.py @@ -39,16 +39,13 @@ from synapse.storage.database import ( ) from synapse.storage.databases.main.stream import generate_pagination_where_clause from synapse.storage.engines import PostgresEngine -from synapse.storage.relations import ( - AggregationPaginationToken, - PaginationChunk, - RelationPaginationToken, -) -from synapse.types import JsonDict +from synapse.storage.relations import AggregationPaginationToken, PaginationChunk +from synapse.types import JsonDict, RoomStreamToken, StreamToken from synapse.util.caches.descriptors import cached, cachedList if TYPE_CHECKING: from synapse.server import HomeServer + from synapse.storage.databases.main import DataStore logger = logging.getLogger(__name__) @@ -98,8 +95,8 @@ class RelationsWorkerStore(SQLBaseStore): aggregation_key: Optional[str] = None, limit: int = 5, direction: str = "b", - from_token: Optional[RelationPaginationToken] = None, - to_token: Optional[RelationPaginationToken] = None, + from_token: Optional[StreamToken] = None, + to_token: Optional[StreamToken] = None, ) -> PaginationChunk: """Get a list of relations for an event, ordered by topological ordering. @@ -138,8 +135,10 @@ class RelationsWorkerStore(SQLBaseStore): pagination_clause = generate_pagination_where_clause( direction=direction, column_names=("topological_ordering", "stream_ordering"), - from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type] - to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type] + from_token=from_token.room_key.as_historical_tuple() + if from_token + else None, + to_token=to_token.room_key.as_historical_tuple() if to_token else None, engine=self.database_engine, ) @@ -177,12 +176,27 @@ class RelationsWorkerStore(SQLBaseStore): last_topo_id = row[1] last_stream_id = row[2] - next_batch = None + # If there are more events, generate the next pagination key. + next_token = None if len(events) > limit and last_topo_id and last_stream_id: - next_batch = RelationPaginationToken(last_topo_id, last_stream_id) + next_key = RoomStreamToken(last_topo_id, last_stream_id) + if from_token: + next_token = from_token.copy_and_replace("room_key", next_key) + else: + next_token = StreamToken( + room_key=next_key, + presence_key=0, + typing_key=0, + receipt_key=0, + account_data_key=0, + push_rules_key=0, + to_device_key=0, + device_list_key=0, + groups_key=0, + ) return PaginationChunk( - chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token + chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token ) return await self.db_pool.runInteraction( @@ -676,13 +690,15 @@ class RelationsWorkerStore(SQLBaseStore): annotations = await self.get_aggregation_groups_for_event(event_id, room_id) if annotations.chunk: - aggregations.annotations = annotations.to_dict() + aggregations.annotations = await annotations.to_dict( + cast("DataStore", self) + ) references = await self.get_relations_for_event( event_id, room_id, RelationTypes.REFERENCE, direction="f" ) if references.chunk: - aggregations.references = references.to_dict() + aggregations.references = await references.to_dict(cast("DataStore", self)) # If this event is the start of a thread, include a summary of the replies. if self._msc3440_enabled: diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py index b1536c1ca4..36ca2b8273 100644 --- a/synapse/storage/relations.py +++ b/synapse/storage/relations.py @@ -13,13 +13,16 @@ # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import attr from synapse.api.errors import SynapseError from synapse.types import JsonDict +if TYPE_CHECKING: + from synapse.storage.databases.main import DataStore + logger = logging.getLogger(__name__) @@ -39,14 +42,14 @@ class PaginationChunk: next_batch: Optional[Any] = None prev_batch: Optional[Any] = None - def to_dict(self) -> Dict[str, Any]: + async def to_dict(self, store: "DataStore") -> Dict[str, Any]: d = {"chunk": self.chunk} if self.next_batch: - d["next_batch"] = self.next_batch.to_string() + d["next_batch"] = await self.next_batch.to_string(store) if self.prev_batch: - d["prev_batch"] = self.prev_batch.to_string() + d["prev_batch"] = await self.prev_batch.to_string(store) return d @@ -75,7 +78,7 @@ class RelationPaginationToken: except ValueError: raise SynapseError(400, "Invalid relation pagination token") - def to_string(self) -> str: + async def to_string(self, store: "DataStore") -> str: return "%d-%d" % (self.topological, self.stream) def as_tuple(self) -> Tuple[Any, ...]: @@ -105,7 +108,7 @@ class AggregationPaginationToken: except ValueError: raise SynapseError(400, "Invalid aggregation pagination token") - def to_string(self) -> str: + async def to_string(self, store: "DataStore") -> str: return "%d-%d" % (self.count, self.stream) def as_tuple(self) -> Tuple[Any, ...]: -- cgit 1.4.1