diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7718acbf1c..ad79cc5610 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -39,16 +39,13 @@ from synapse.storage.database import (
)
from synapse.storage.databases.main.stream import generate_pagination_where_clause
from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import (
- AggregationPaginationToken,
- PaginationChunk,
- RelationPaginationToken,
-)
-from synapse.types import JsonDict
+from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
from synapse.util.caches.descriptors import cached, cachedList
if TYPE_CHECKING:
from synapse.server import HomeServer
+ from synapse.storage.databases.main import DataStore
logger = logging.getLogger(__name__)
@@ -98,8 +95,8 @@ class RelationsWorkerStore(SQLBaseStore):
aggregation_key: Optional[str] = None,
limit: int = 5,
direction: str = "b",
- from_token: Optional[RelationPaginationToken] = None,
- to_token: Optional[RelationPaginationToken] = None,
+ from_token: Optional[StreamToken] = None,
+ to_token: Optional[StreamToken] = None,
) -> PaginationChunk:
"""Get a list of relations for an event, ordered by topological ordering.
@@ -138,8 +135,10 @@ class RelationsWorkerStore(SQLBaseStore):
pagination_clause = generate_pagination_where_clause(
direction=direction,
column_names=("topological_ordering", "stream_ordering"),
- from_token=attr.astuple(from_token) if from_token else None, # type: ignore[arg-type]
- to_token=attr.astuple(to_token) if to_token else None, # type: ignore[arg-type]
+ from_token=from_token.room_key.as_historical_tuple()
+ if from_token
+ else None,
+ to_token=to_token.room_key.as_historical_tuple() if to_token else None,
engine=self.database_engine,
)
@@ -177,12 +176,27 @@ class RelationsWorkerStore(SQLBaseStore):
last_topo_id = row[1]
last_stream_id = row[2]
- next_batch = None
+ # If there are more events, generate the next pagination key.
+ next_token = None
if len(events) > limit and last_topo_id and last_stream_id:
- next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+ next_key = RoomStreamToken(last_topo_id, last_stream_id)
+ if from_token:
+ next_token = from_token.copy_and_replace("room_key", next_key)
+ else:
+ next_token = StreamToken(
+ room_key=next_key,
+ presence_key=0,
+ typing_key=0,
+ receipt_key=0,
+ account_data_key=0,
+ push_rules_key=0,
+ to_device_key=0,
+ device_list_key=0,
+ groups_key=0,
+ )
return PaginationChunk(
- chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+ chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
)
return await self.db_pool.runInteraction(
@@ -676,13 +690,15 @@ class RelationsWorkerStore(SQLBaseStore):
annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
if annotations.chunk:
- aggregations.annotations = annotations.to_dict()
+ aggregations.annotations = await annotations.to_dict(
+ cast("DataStore", self)
+ )
references = await self.get_relations_for_event(
event_id, room_id, RelationTypes.REFERENCE, direction="f"
)
if references.chunk:
- aggregations.references = references.to_dict()
+ aggregations.references = await references.to_dict(cast("DataStore", self))
# If this event is the start of a thread, include a summary of the replies.
if self._msc3440_enabled:
|