summary refs log tree commit diff
path: root/synapse/storage/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/stream.py')
-rw-r--r--synapse/storage/stream.py194
1 files changed, 133 insertions, 61 deletions
diff --git a/synapse/storage/stream.py b/synapse/storage/stream.py
index d105b6b17d..529ad4ea79 100644
--- a/synapse/storage/stream.py
+++ b/synapse/storage/stream.py
@@ -64,59 +64,135 @@ _EventDictReturn = namedtuple(
 )
 
 
-def lower_bound(token, engine, inclusive=False):
-    inclusive = "=" if inclusive else ""
-    if token.topological is None:
-        return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering")
-    else:
-        if isinstance(engine, PostgresEngine):
-            # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
-            # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
-            # use the later form when running against postgres.
-            return "((%d,%d) <%s (%s,%s))" % (
-                token.topological,
-                token.stream,
-                inclusive,
-                "topological_ordering",
-                "stream_ordering",
+def generate_pagination_where_clause(
+    direction, column_names, from_token, to_token, engine,
+):
+    """Creates an SQL expression to bound the columns by the pagination
+    tokens.
+
+    For example creates an SQL expression like:
+
+        (6, 7) >= (topological_ordering, stream_ordering)
+        AND (5, 3) < (topological_ordering, stream_ordering)
+
+    would be generated for dir=b, from_token=(6, 7) and to_token=(5, 3).
+
+    Note that tokens are considered to be after the row they are in, e.g. if
+    a row A has a token T, then we consider A to be before T. This convention
+    is important when figuring out inequalities for the generated SQL, and
+    produces the following result:
+        - If paginating forwards then we exclude any rows matching the from
+          token, but include those that match the to token.
+        - If paginating backwards then we include any rows matching the from
+          token, but include those that match the to token.
+
+    Args:
+        direction (str): Whether we're paginating backwards("b") or
+            forwards ("f").
+        column_names (tuple[str, str]): The column names to bound. Must *not*
+            be user defined as these get inserted directly into the SQL
+            statement without escapes.
+        from_token (tuple[int, int]|None): The start point for the pagination.
+            This is an exclusive minimum bound if direction is "f", and an
+            inclusive maximum bound if direction is "b".
+        to_token (tuple[int, int]|None): The endpoint point for the pagination.
+            This is an inclusive maximum bound if direction is "f", and an
+            exclusive minimum bound if direction is "b".
+        engine: The database engine to generate the clauses for
+
+    Returns:
+        str: The sql expression
+    """
+    assert direction in ("b", "f")
+
+    where_clause = []
+    if from_token:
+        where_clause.append(
+            _make_generic_sql_bound(
+                bound=">=" if direction == "b" else "<",
+                column_names=column_names,
+                values=from_token,
+                engine=engine,
             )
-        return "(%d < %s OR (%d = %s AND %d <%s %s))" % (
-            token.topological,
-            "topological_ordering",
-            token.topological,
-            "topological_ordering",
-            token.stream,
-            inclusive,
-            "stream_ordering",
-        )
-
-
-def upper_bound(token, engine, inclusive=True):
-    inclusive = "=" if inclusive else ""
-    if token.topological is None:
-        return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering")
-    else:
-        if isinstance(engine, PostgresEngine):
-            # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well
-            # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we
-            # use the later form when running against postgres.
-            return "((%d,%d) >%s (%s,%s))" % (
-                token.topological,
-                token.stream,
-                inclusive,
-                "topological_ordering",
-                "stream_ordering",
+        )
+
+    if to_token:
+        where_clause.append(
+            _make_generic_sql_bound(
+                bound="<" if direction == "b" else ">=",
+                column_names=column_names,
+                values=to_token,
+                engine=engine,
             )
-        return "(%d > %s OR (%d = %s AND %d >%s %s))" % (
-            token.topological,
-            "topological_ordering",
-            token.topological,
-            "topological_ordering",
-            token.stream,
-            inclusive,
-            "stream_ordering",
         )
 
+    return " AND ".join(where_clause)
+
+
+def _make_generic_sql_bound(bound, column_names, values, engine):
+    """Create an SQL expression that bounds the given column names by the
+    values, e.g. create the equivalent of `(1, 2) < (col1, col2)`.
+
+    Only works with two columns.
+
+    Older versions of SQLite don't support that syntax so we have to expand it
+    out manually.
+
+    Args:
+        bound (str): The comparison operator to use. One of ">", "<", ">=",
+            "<=", where the values are on the left and columns on the right.
+        names (tuple[str, str]): The column names. Must *not* be user defined
+            as these get inserted directly into the SQL statement without
+            escapes.
+        values (tuple[int|None, int]): The values to bound the columns by. If
+            the first value is None then only creates a bound on the second
+            column.
+        engine: The database engine to generate the SQL for
+
+    Returns:
+        str
+    """
+
+    assert(bound in (">", "<", ">=", "<="))
+
+    name1, name2 = column_names
+    val1, val2 = values
+
+    if val1 is None:
+        val2 = int(val2)
+        return "(%d %s %s)" % (val2, bound, name2)
+
+    val1 = int(val1)
+    val2 = int(val2)
+
+    if isinstance(engine, PostgresEngine):
+        # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well
+        # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we
+        # use the later form when running against postgres.
+        return "((%d,%d) %s (%s,%s))" % (
+            val1, val2,
+            bound,
+            name1, name2,
+        )
+
+    # We want to generate queries of e.g. the form:
+    #
+    #   (val1 < name1 OR (val1 = name1 AND val2 <= name2))
+    #
+    # which is equivalent to (val1, val2) < (name1, name2)
+
+    return """(
+        {val1:d} {strict_bound} {name1}
+        OR ({val1:d} = {name1} AND {val2:d} {bound} {name2})
+    )""".format(
+        name1=name1,
+        val1=val1,
+        name2=name2,
+        val2=val2,
+        strict_bound=bound[0],  # The first bound must always be strict equality here
+        bound=bound,
+    )
+
 
 def filter_to_clause(event_filter):
     # NB: This may create SQL clauses that don't optimise well (and we don't
@@ -762,20 +838,16 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         args = [False, room_id]
         if direction == 'b':
             order = "DESC"
-            bounds = upper_bound(from_token, self.database_engine)
-            if to_token:
-                bounds = "%s AND %s" % (
-                    bounds,
-                    lower_bound(to_token, self.database_engine),
-                )
         else:
             order = "ASC"
-            bounds = lower_bound(from_token, self.database_engine)
-            if to_token:
-                bounds = "%s AND %s" % (
-                    bounds,
-                    upper_bound(to_token, self.database_engine),
-                )
+
+        bounds = generate_pagination_where_clause(
+            direction=direction,
+            column_names=("topological_ordering", "stream_ordering"),
+            from_token=from_token,
+            to_token=to_token,
+            engine=self.database_engine,
+        )
 
         filter_clause, filter_args = filter_to_clause(event_filter)