summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py275
1 files changed, 190 insertions, 85 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 95775e3804..4989c960a6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -51,6 +51,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Protocol,
     Set,
     Tuple,
     cast,
@@ -59,7 +60,7 @@ from typing import (
 
 import attr
 from immutabledict import immutabledict
-from typing_extensions import Literal
+from typing_extensions import Literal, assert_never
 
 from twisted.internet import defer
 
@@ -97,6 +98,18 @@ _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
 
+class PaginateFunction(Protocol):
+    async def __call__(
+        self,
+        *,
+        room_id: str,
+        from_key: RoomStreamToken,
+        to_key: Optional[RoomStreamToken] = None,
+        direction: Direction = Direction.BACKWARDS,
+        limit: int = 0,
+    ) -> Tuple[List[EventBase], RoomStreamToken]: ...
+
+
 # Used as return values for pagination APIs
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _EventDictReturn:
@@ -280,7 +293,7 @@ def generate_pagination_bounds(
 
 
 def generate_next_token(
-    direction: Direction, last_topo_ordering: int, last_stream_ordering: int
+    direction: Direction, last_topo_ordering: Optional[int], last_stream_ordering: int
 ) -> RoomStreamToken:
     """
     Generate the next room stream token based on the currently returned data.
@@ -447,7 +460,6 @@ def _filter_results_by_stream(
     The `instance_name` arg is optional to handle historic rows, and is
     interpreted as if it was "master".
     """
-
     if instance_name is None:
         instance_name = "master"
 
@@ -660,33 +672,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     async def get_room_events_stream_for_rooms(
         self,
+        *,
         room_ids: Collection[str],
         from_key: RoomStreamToken,
-        to_key: RoomStreamToken,
+        to_key: Optional[RoomStreamToken] = None,
+        direction: Direction = Direction.BACKWARDS,
         limit: int = 0,
-        order: str = "DESC",
     ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
             room_ids
-            from_key: Token from which no events are returned before
-            to_key: Token from which no events are returned after. (This
-                is typically the current stream token)
+            from_key: The token to stream from (starting point and heading in the given
+                direction)
+            to_key: The token representing the end stream position (end point)
             limit: Maximum number of events to return
-            order: Either "DESC" or "ASC". Determines which events are
-                returned when the result is limited. If "DESC" then the most
-                recent `limit` events are returned, otherwise returns the
-                oldest `limit` events.
+            direction: Indicates whether we are paginating forwards or backwards
+                from `from_key`.
 
         Returns:
             A map from room id to a tuple containing:
                 - list of recent events in the room
                 - stream ordering key for the start of the chunk of events returned.
+
+            When Direction.FORWARDS: from_key < x <= to_key, (ascending order)
+            When Direction.BACKWARDS: from_key >= x > to_key, (descending order)
         """
-        room_ids = self._events_stream_cache.get_entities_changed(
-            room_ids, from_key.stream
-        )
+        if direction == Direction.FORWARDS:
+            room_ids = self._events_stream_cache.get_entities_changed(
+                room_ids, from_key.stream
+            )
+        elif direction == Direction.BACKWARDS:
+            if to_key is not None:
+                room_ids = self._events_stream_cache.get_entities_changed(
+                    room_ids, to_key.stream
+                )
+        else:
+            assert_never(direction)
 
         if not room_ids:
             return {}
@@ -698,12 +720,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 defer.gatherResults(
                     [
                         run_in_background(
-                            self.get_room_events_stream_for_room,
-                            room_id,
-                            from_key,
-                            to_key,
-                            limit,
-                            order=order,
+                            self.paginate_room_events_by_stream_ordering,
+                            room_id=room_id,
+                            from_key=from_key,
+                            to_key=to_key,
+                            direction=direction,
+                            limit=limit,
                         )
                         for room_id in rm_ids
                     ],
@@ -727,69 +749,122 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             if self._events_stream_cache.has_entity_changed(room_id, from_id)
         }
 
-    async def get_room_events_stream_for_room(
+    async def paginate_room_events_by_stream_ordering(
         self,
+        *,
         room_id: str,
         from_key: RoomStreamToken,
-        to_key: RoomStreamToken,
+        to_key: Optional[RoomStreamToken] = None,
+        direction: Direction = Direction.BACKWARDS,
         limit: int = 0,
-        order: str = "DESC",
     ) -> Tuple[List[EventBase], RoomStreamToken]:
-        """Get new room events in stream ordering since `from_key`.
+        """
+        Paginate events by `stream_ordering` in the room from the `from_key` in the
+        given `direction` to the `to_key` or `limit`.
 
         Args:
             room_id
-            from_key: Token from which no events are returned before
-            to_key: Token from which no events are returned after. (This
-                is typically the current stream token)
+            from_key: The token to stream from (starting point and heading in the given
+                direction)
+            to_key: The token representing the end stream position (end point)
+            direction: Indicates whether we are paginating forwards or backwards
+                from `from_key`.
             limit: Maximum number of events to return
-            order: Either "DESC" or "ASC". Determines which events are
-                returned when the result is limited. If "DESC" then the most
-                recent `limit` events are returned, otherwise returns the
-                oldest `limit` events.
 
         Returns:
-            The list of events (in ascending stream order) and the token from the start
-            of the chunk of events returned.
+            The results as a list of events and a token that points to the end
+            of the result set. If no events are returned then the end of the
+            stream has been reached (i.e. there are no events between `from_key`
+            and `to_key`).
+
+            When Direction.FORWARDS: from_key < x <= to_key, (ascending order)
+            When Direction.BACKWARDS: from_key >= x > to_key, (descending order)
         """
-        if from_key == to_key:
-            return [], from_key
 
-        has_changed = self._events_stream_cache.has_entity_changed(
-            room_id, from_key.stream
-        )
+        # FIXME: When going forwards, we should enforce that the `to_key` is not `None`
+        # because we always need an upper bound when querying the events stream (as
+        # otherwise we'll potentially pick up events that are not fully persisted).
+
+        # We should only be working with `stream_ordering` tokens here
+        assert from_key is None or from_key.topological is None
+        assert to_key is None or to_key.topological is None
+
+        # We can bail early if we're looking forwards, and our `to_key` is already
+        # before our `from_key`.
+        if (
+            direction == Direction.FORWARDS
+            and to_key is not None
+            and to_key.is_before_or_eq(from_key)
+        ):
+            # Token selection matches what we do below if there are no rows
+            return [], to_key if to_key else from_key
+        # Or vice-versa, if we're looking backwards and our `from_key` is already before
+        # our `to_key`.
+        elif (
+            direction == Direction.BACKWARDS
+            and to_key is not None
+            and from_key.is_before_or_eq(to_key)
+        ):
+            # Token selection matches what we do below if there are no rows
+            return [], to_key if to_key else from_key
+
+        # We can do a quick sanity check to see if any events have been sent in the room
+        # since the earlier token.
+        has_changed = True
+        if direction == Direction.FORWARDS:
+            has_changed = self._events_stream_cache.has_entity_changed(
+                room_id, from_key.stream
+            )
+        elif direction == Direction.BACKWARDS:
+            if to_key is not None:
+                has_changed = self._events_stream_cache.has_entity_changed(
+                    room_id, to_key.stream
+                )
+        else:
+            assert_never(direction)
 
         if not has_changed:
-            return [], from_key
+            # Token selection matches what we do below if there are no rows
+            return [], to_key if to_key else from_key
 
-        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
-            # To handle tokens with a non-empty instance_map we fetch more
-            # results than necessary and then filter down
-            min_from_id = from_key.stream
-            max_to_id = to_key.get_max_stream_pos()
+        order, from_bound, to_bound = generate_pagination_bounds(
+            direction, from_key, to_key
+        )
 
-            sql = """
-                SELECT event_id, instance_name, topological_ordering, stream_ordering
+        bounds = generate_pagination_where_clause(
+            direction=direction,
+            # The empty string will shortcut downstream code to only use the
+            # `stream_ordering` column
+            column_names=("", "stream_ordering"),
+            from_token=from_bound,
+            to_token=to_bound,
+            engine=self.database_engine,
+        )
+
+        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
+            sql = f"""
+                SELECT event_id, instance_name, stream_ordering
                 FROM events
                 WHERE
                     room_id = ?
                     AND not outlier
-                    AND stream_ordering > ? AND stream_ordering <= ?
-                ORDER BY stream_ordering %s LIMIT ?
-            """ % (
-                order,
-            )
-            txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+                    AND {bounds}
+                ORDER BY stream_ordering {order} LIMIT ?
+            """
+            txn.execute(sql, (room_id, 2 * limit))
 
             rows = [
                 _EventDictReturn(event_id, None, stream_ordering)
-                for event_id, instance_name, topological_ordering, stream_ordering in txn
-                if _filter_results(
-                    from_key,
-                    to_key,
-                    instance_name,
-                    topological_ordering,
-                    stream_ordering,
+                for event_id, instance_name, stream_ordering in txn
+                if _filter_results_by_stream(
+                    lower_token=(
+                        to_key if direction == Direction.BACKWARDS else from_key
+                    ),
+                    upper_token=(
+                        from_key if direction == Direction.BACKWARDS else to_key
+                    ),
+                    instance_name=instance_name,
+                    stream_ordering=stream_ordering,
                 )
             ][:limit]
             return rows
@@ -800,17 +875,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             [r.event_id for r in rows], get_prev_content=True
         )
 
-        if order.lower() == "desc":
-            ret.reverse()
-
         if rows:
-            key = RoomStreamToken(stream=min(r.stream_ordering for r in rows))
+            next_key = generate_next_token(
+                direction=direction,
+                last_topo_ordering=None,
+                last_stream_ordering=rows[-1].stream_ordering,
+            )
         else:
-            # Assume we didn't get anything because there was nothing to
-            # get.
-            key = from_key
+            # TODO (erikj): We should work out what to do here instead. (same as
+            # `_paginate_room_events_by_topological_ordering_txn(...)`)
+            next_key = to_key if to_key else from_key
 
-        return ret, key
+        return ret, next_key
 
     @trace
     async def get_current_state_delta_membership_changes_for_user(
@@ -1118,7 +1194,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         rows, token = await self.db_pool.runInteraction(
             "get_recent_event_ids_for_room",
-            self._paginate_room_events_txn,
+            self._paginate_room_events_by_topological_ordering_txn,
             room_id,
             from_token=end_token,
             limit=limit,
@@ -1624,7 +1700,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             topological=topological_ordering, stream=stream_ordering
         )
 
-        rows, start_token = self._paginate_room_events_txn(
+        rows, start_token = self._paginate_room_events_by_topological_ordering_txn(
             txn,
             room_id,
             before_token,
@@ -1634,7 +1710,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
         events_before = [r.event_id for r in rows]
 
-        rows, end_token = self._paginate_room_events_txn(
+        rows, end_token = self._paginate_room_events_by_topological_ordering_txn(
             txn,
             room_id,
             after_token,
@@ -1797,14 +1873,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
         return self._events_stream_cache.has_entity_changed(room_id, stream_id)
 
-    def _paginate_room_events_txn(
+    def _paginate_room_events_by_topological_ordering_txn(
         self,
         txn: LoggingTransaction,
         room_id: str,
         from_token: RoomStreamToken,
         to_token: Optional[RoomStreamToken] = None,
         direction: Direction = Direction.BACKWARDS,
-        limit: int = -1,
+        limit: int = 0,
         event_filter: Optional[Filter] = None,
     ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
         """Returns list of events before or after a given token.
@@ -1826,6 +1902,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             been reached (i.e. there are no events between `from_token` and
             `to_token`), or `limit` is zero.
         """
+        # We can bail early if we're looking forwards, and our `to_key` is already
+        # before our `from_token`.
+        if (
+            direction == Direction.FORWARDS
+            and to_token is not None
+            and to_token.is_before_or_eq(from_token)
+        ):
+            # Token selection matches what we do below if there are no rows
+            return [], to_token if to_token else from_token
+        # Or vice-versa, if we're looking backwards and our `from_token` is already before
+        # our `to_token`.
+        elif (
+            direction == Direction.BACKWARDS
+            and to_token is not None
+            and from_token.is_before_or_eq(to_token)
+        ):
+            # Token selection matches what we do below if there are no rows
+            return [], to_token if to_token else from_token
 
         args: List[Any] = [room_id]
 
@@ -1910,7 +2004,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             "bounds": bounds,
             "order": order,
         }
-
         txn.execute(sql, args)
 
         # Filter the result set.
@@ -1943,27 +2036,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     @trace
     @tag_args
-    async def paginate_room_events(
+    async def paginate_room_events_by_topological_ordering(
         self,
+        *,
         room_id: str,
         from_key: RoomStreamToken,
         to_key: Optional[RoomStreamToken] = None,
         direction: Direction = Direction.BACKWARDS,
-        limit: int = -1,
+        limit: int = 0,
         event_filter: Optional[Filter] = None,
     ) -> Tuple[List[EventBase], RoomStreamToken]:
-        """Returns list of events before or after a given token.
-
-        When Direction.FORWARDS: from_key < x <= to_key
-        When Direction.BACKWARDS: from_key >= x > to_key
+        """
+        Paginate events by `topological_ordering` (tie-break with `stream_ordering`) in
+        the room from the `from_key` in the given `direction` to the `to_key` or
+        `limit`.
 
         Args:
             room_id
-            from_key: The token used to stream from
-            to_key: A token which if given limits the results to only those before
+            from_key: The token to stream from (starting point and heading in the given
+                direction)
+            to_key: The token representing the end stream position (end point)
             direction: Indicates whether we are paginating forwards or backwards
                 from `from_key`.
-            limit: The maximum number of events to return.
+            limit: Maximum number of events to return
             event_filter: If provided filters the events to those that match the filter.
 
         Returns:
@@ -1971,8 +2066,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             of the result set. If no events are returned then the end of the
             stream has been reached (i.e. there are no events between `from_key`
             and `to_key`).
+
+            When Direction.FORWARDS: from_key < x <= to_key, (ascending order)
+            When Direction.BACKWARDS: from_key >= x > to_key, (descending order)
         """
 
+        # FIXME: When going forwards, we should enforce that the `to_key` is not `None`
+        # because we always need an upper bound when querying the events stream (as
+        # otherwise we'll potentially pick up events that are not fully persisted).
+
+        # We have these checks outside of the transaction function (txn) to save getting
+        # a DB connection and switching threads if we don't need to.
+        #
         # We can bail early if we're looking forwards, and our `to_key` is already
         # before our `from_key`.
         if (
@@ -1995,8 +2100,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             return [], to_key if to_key else from_key
 
         rows, token = await self.db_pool.runInteraction(
-            "paginate_room_events",
-            self._paginate_room_events_txn,
+            "paginate_room_events_by_topological_ordering",
+            self._paginate_room_events_by_topological_ordering_txn,
             room_id,
             from_key,
             to_key,