summary refs log tree commit diff
path: root/synapse/storage
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage')
-rw-r--r--synapse/storage/databases/main/relations.py46
-rw-r--r--synapse/storage/relations.py15
2 files changed, 40 insertions, 21 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 7718acbf1c..ad79cc5610 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -39,16 +39,13 @@ from synapse.storage.database import (
 )
 from synapse.storage.databases.main.stream import generate_pagination_where_clause
 from synapse.storage.engines import PostgresEngine
-from synapse.storage.relations import (
-    AggregationPaginationToken,
-    PaginationChunk,
-    RelationPaginationToken,
-)
-from synapse.types import JsonDict
+from synapse.storage.relations import AggregationPaginationToken, PaginationChunk
+from synapse.types import JsonDict, RoomStreamToken, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
+    from synapse.storage.databases.main import DataStore
 
 logger = logging.getLogger(__name__)
 
@@ -98,8 +95,8 @@ class RelationsWorkerStore(SQLBaseStore):
         aggregation_key: Optional[str] = None,
         limit: int = 5,
         direction: str = "b",
-        from_token: Optional[RelationPaginationToken] = None,
-        to_token: Optional[RelationPaginationToken] = None,
+        from_token: Optional[StreamToken] = None,
+        to_token: Optional[StreamToken] = None,
     ) -> PaginationChunk:
         """Get a list of relations for an event, ordered by topological ordering.
 
@@ -138,8 +135,10 @@ class RelationsWorkerStore(SQLBaseStore):
         pagination_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=attr.astuple(from_token) if from_token else None,  # type: ignore[arg-type]
-            to_token=attr.astuple(to_token) if to_token else None,  # type: ignore[arg-type]
+            from_token=from_token.room_key.as_historical_tuple()
+            if from_token
+            else None,
+            to_token=to_token.room_key.as_historical_tuple() if to_token else None,
             engine=self.database_engine,
         )
 
@@ -177,12 +176,27 @@ class RelationsWorkerStore(SQLBaseStore):
                 last_topo_id = row[1]
                 last_stream_id = row[2]
 
-            next_batch = None
+            # If there are more events, generate the next pagination key.
+            next_token = None
             if len(events) > limit and last_topo_id and last_stream_id:
-                next_batch = RelationPaginationToken(last_topo_id, last_stream_id)
+                next_key = RoomStreamToken(last_topo_id, last_stream_id)
+                if from_token:
+                    next_token = from_token.copy_and_replace("room_key", next_key)
+                else:
+                    next_token = StreamToken(
+                        room_key=next_key,
+                        presence_key=0,
+                        typing_key=0,
+                        receipt_key=0,
+                        account_data_key=0,
+                        push_rules_key=0,
+                        to_device_key=0,
+                        device_list_key=0,
+                        groups_key=0,
+                    )
 
             return PaginationChunk(
-                chunk=list(events[:limit]), next_batch=next_batch, prev_batch=from_token
+                chunk=list(events[:limit]), next_batch=next_token, prev_batch=from_token
             )
 
         return await self.db_pool.runInteraction(
@@ -676,13 +690,15 @@ class RelationsWorkerStore(SQLBaseStore):
 
         annotations = await self.get_aggregation_groups_for_event(event_id, room_id)
         if annotations.chunk:
-            aggregations.annotations = annotations.to_dict()
+            aggregations.annotations = await annotations.to_dict(
+                cast("DataStore", self)
+            )
 
         references = await self.get_relations_for_event(
             event_id, room_id, RelationTypes.REFERENCE, direction="f"
         )
         if references.chunk:
-            aggregations.references = references.to_dict()
+            aggregations.references = await references.to_dict(cast("DataStore", self))
 
         # If this event is the start of a thread, include a summary of the replies.
         if self._msc3440_enabled:
diff --git a/synapse/storage/relations.py b/synapse/storage/relations.py
index b1536c1ca4..36ca2b8273 100644
--- a/synapse/storage/relations.py
+++ b/synapse/storage/relations.py
@@ -13,13 +13,16 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
 
 import attr
 
 from synapse.api.errors import SynapseError
 from synapse.types import JsonDict
 
+if TYPE_CHECKING:
+    from synapse.storage.databases.main import DataStore
+
 logger = logging.getLogger(__name__)
 
 
@@ -39,14 +42,14 @@ class PaginationChunk:
     next_batch: Optional[Any] = None
     prev_batch: Optional[Any] = None
 
-    def to_dict(self) -> Dict[str, Any]:
+    async def to_dict(self, store: "DataStore") -> Dict[str, Any]:
         d = {"chunk": self.chunk}
 
         if self.next_batch:
-            d["next_batch"] = self.next_batch.to_string()
+            d["next_batch"] = await self.next_batch.to_string(store)
 
         if self.prev_batch:
-            d["prev_batch"] = self.prev_batch.to_string()
+            d["prev_batch"] = await self.prev_batch.to_string(store)
 
         return d
 
@@ -75,7 +78,7 @@ class RelationPaginationToken:
         except ValueError:
             raise SynapseError(400, "Invalid relation pagination token")
 
-    def to_string(self) -> str:
+    async def to_string(self, store: "DataStore") -> str:
         return "%d-%d" % (self.topological, self.stream)
 
     def as_tuple(self) -> Tuple[Any, ...]:
@@ -105,7 +108,7 @@ class AggregationPaginationToken:
         except ValueError:
             raise SynapseError(400, "Invalid aggregation pagination token")
 
-    def to_string(self) -> str:
+    async def to_string(self, store: "DataStore") -> str:
         return "%d-%d" % (self.count, self.stream)
 
     def as_tuple(self) -> Tuple[Any, ...]: