summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/databases/main/relations.py38
-rw-r--r--synapse/storage/databases/main/stream.py154
2 files changed, 122 insertions, 70 deletions
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index 84f844b79e..be2242b6ac 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -40,9 +40,13 @@ from synapse.storage.database import (
     LoggingTransaction,
     make_in_list_sql_clause,
 )
-from synapse.storage.databases.main.stream import generate_pagination_where_clause
+from synapse.storage.databases.main.stream import (
+    generate_next_token,
+    generate_pagination_bounds,
+    generate_pagination_where_clause,
+)
 from synapse.storage.engines import PostgresEngine
-from synapse.types import JsonDict, RoomStreamToken, StreamKeyType, StreamToken
+from synapse.types import JsonDict, StreamKeyType, StreamToken
 from synapse.util.caches.descriptors import cached, cachedList
 
 if TYPE_CHECKING:
@@ -207,24 +211,23 @@ class RelationsWorkerStore(SQLBaseStore):
             where_clause.append("type = ?")
             where_args.append(event_type)
 
+        order, from_bound, to_bound = generate_pagination_bounds(
+            direction,
+            from_token.room_key if from_token else None,
+            to_token.room_key if to_token else None,
+        )
+
         pagination_clause = generate_pagination_where_clause(
             direction=direction,
             column_names=("topological_ordering", "stream_ordering"),
-            from_token=from_token.room_key.as_historical_tuple()
-            if from_token
-            else None,
-            to_token=to_token.room_key.as_historical_tuple() if to_token else None,
+            from_token=from_bound,
+            to_token=to_bound,
             engine=self.database_engine,
         )
 
         if pagination_clause:
             where_clause.append(pagination_clause)
 
-        if direction == "b":
-            order = "DESC"
-        else:
-            order = "ASC"
-
         sql = """
             SELECT event_id, relation_type, sender, topological_ordering, stream_ordering
             FROM event_relations
@@ -266,16 +269,9 @@ class RelationsWorkerStore(SQLBaseStore):
                 topo_orderings = topo_orderings[:limit]
                 stream_orderings = stream_orderings[:limit]
 
-                topo = topo_orderings[-1]
-                token = stream_orderings[-1]
-                if direction == "b":
-                    # Tokens are positions between events.
-                    # This token points *after* the last event in the chunk.
-                    # We need it to point to the event before it in the chunk
-                    # when we are going backwards so we subtract one from the
-                    # stream part.
-                    token -= 1
-                next_key = RoomStreamToken(topo, token)
+                next_key = generate_next_token(
+                    direction, topo_orderings[-1], stream_orderings[-1]
+                )
 
                 if from_token:
                     next_token = from_token.copy_and_replace(
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index d28fc65df9..8977bf33e7 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -170,6 +170,104 @@ def generate_pagination_where_clause(
     return " AND ".join(where_clause)
 
 
+def generate_pagination_bounds(
+    direction: str,
+    from_token: Optional[RoomStreamToken],
+    to_token: Optional[RoomStreamToken],
+) -> Tuple[
+    str, Optional[Tuple[Optional[int], int]], Optional[Tuple[Optional[int], int]]
+]:
+    """
+    Generate a start and end point for this page of events.
+
+    Args:
+        direction: Whether pagination is going forwards or backwards. One of "f" or "b".
+        from_token: The token to start pagination at, or None to start at the first value.
+        to_token: The token to end pagination at, or None to not limit the end point.
+
+    Returns:
+        A three tuple of:
+
+            ASC or DESC for sorting of the query.
+
+            The starting position as a tuple of ints representing
+            (topological position, stream position) or None if no from_token was
+            provided. The topological position may be None for live tokens.
+
+            The end position in the same format as the starting position, or None
+            if no to_token was provided.
+    """
+
+    # Tokens really represent positions between elements, but we use
+    # the convention of pointing to the event before the gap. Hence
+    # we have a bit of asymmetry when it comes to equalities.
+    if direction == "b":
+        order = "DESC"
+    else:
+        order = "ASC"
+
+    # The bounds for the stream tokens are complicated by the fact
+    # that we need to handle the instance_map part of the tokens. We do this
+    # by fetching all events between the min stream token and the maximum
+    # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
+    # then filtering the results.
+    from_bound: Optional[Tuple[Optional[int], int]] = None
+    if from_token:
+        if from_token.topological is not None:
+            from_bound = from_token.as_historical_tuple()
+        elif direction == "b":
+            from_bound = (
+                None,
+                from_token.get_max_stream_pos(),
+            )
+        else:
+            from_bound = (
+                None,
+                from_token.stream,
+            )
+
+    to_bound: Optional[Tuple[Optional[int], int]] = None
+    if to_token:
+        if to_token.topological is not None:
+            to_bound = to_token.as_historical_tuple()
+        elif direction == "b":
+            to_bound = (
+                None,
+                to_token.stream,
+            )
+        else:
+            to_bound = (
+                None,
+                to_token.get_max_stream_pos(),
+            )
+
+    return order, from_bound, to_bound
+
+
+def generate_next_token(
+    direction: str, last_topo_ordering: int, last_stream_ordering: int
+) -> RoomStreamToken:
+    """
+    Generate the next room stream token based on the currently returned data.
+
+    Args:
+        direction: Whether pagination is going forwards or backwards. One of "f" or "b".
+        last_topo_ordering: The last topological ordering being returned.
+        last_stream_ordering: The last stream ordering being returned.
+
+    Returns:
+        A new RoomStreamToken to return to the client.
+    """
+    if direction == "b":
+        # Tokens are positions between events.
+        # This token points *after* the last event in the chunk.
+        # We need it to point to the event before it in the chunk
+        # when we are going backwards so we subtract one from the
+        # stream part.
+        last_stream_ordering -= 1
+    return RoomStreamToken(last_topo_ordering, last_stream_ordering)
+
+
 def _make_generic_sql_bound(
     bound: str,
     column_names: Tuple[str, str],
@@ -1300,47 +1398,11 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             `to_token`), or `limit` is zero.
         """
 
-        # Tokens really represent positions between elements, but we use
-        # the convention of pointing to the event before the gap. Hence
-        # we have a bit of asymmetry when it comes to equalities.
         args = [False, room_id]
-        if direction == "b":
-            order = "DESC"
-        else:
-            order = "ASC"
-
-        # The bounds for the stream tokens are complicated by the fact
-        # that we need to handle the instance_map part of the tokens. We do this
-        # by fetching all events between the min stream token and the maximum
-        # stream token (as returned by `RoomStreamToken.get_max_stream_pos`) and
-        # then filtering the results.
-        if from_token.topological is not None:
-            from_bound: Tuple[Optional[int], int] = from_token.as_historical_tuple()
-        elif direction == "b":
-            from_bound = (
-                None,
-                from_token.get_max_stream_pos(),
-            )
-        else:
-            from_bound = (
-                None,
-                from_token.stream,
-            )
 
-        to_bound: Optional[Tuple[Optional[int], int]] = None
-        if to_token:
-            if to_token.topological is not None:
-                to_bound = to_token.as_historical_tuple()
-            elif direction == "b":
-                to_bound = (
-                    None,
-                    to_token.stream,
-                )
-            else:
-                to_bound = (
-                    None,
-                    to_token.get_max_stream_pos(),
-                )
+        order, from_bound, to_bound = generate_pagination_bounds(
+            direction, from_token, to_token
+        )
 
         bounds = generate_pagination_where_clause(
             direction=direction,
@@ -1436,16 +1498,10 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         ][:limit]
 
         if rows:
-            topo = rows[-1].topological_ordering
-            token = rows[-1].stream_ordering
-            if direction == "b":
-                # Tokens are positions between events.
-                # This token points *after* the last event in the chunk.
-                # We need it to point to the event before it in the chunk
-                # when we are going backwards so we subtract one from the
-                # stream part.
-                token -= 1
-            next_token = RoomStreamToken(topo, token)
+            assert rows[-1].topological_ordering is not None
+            next_token = generate_next_token(
+                direction, rows[-1].topological_ordering, rows[-1].stream_ordering
+            )
         else:
             # TODO (erikj): We should work out what to do here instead.
             next_token = to_token if to_token else from_token