summary refs log tree commit diff
diff options
context:
space:
mode:
authorEric Eastwood <eric.eastwood@beta.gouv.fr>2024-08-08 09:24:26 -0500
committerEric Eastwood <eric.eastwood@beta.gouv.fr>2024-08-08 09:24:26 -0500
commit47a8077d8e4f14b941a4a06eaaffd955787e9a47 (patch)
tree6587930d4266d830ebe106d332e2c94fc7dced46
parentFix test description (diff)
parentReplace deprecated `HTTPAdapter.get_connection` method with `get_connection_... (diff)
downloadsynapse-47a8077d8e4f14b941a4a06eaaffd955787e9a47.tar.xz
Merge branch 'develop' into madlittlemods/sliding-sync-must-await-full-state
Conflicts:
	synapse/handlers/sliding_sync.py
-rw-r--r--changelog.d/17510.bugfix1
-rw-r--r--changelog.d/17536.misc1
-rw-r--r--docs/development/room-dag-concepts.md6
-rwxr-xr-xscripts-dev/federation_client.py25
-rw-r--r--synapse/handlers/admin.py10
-rw-r--r--synapse/handlers/pagination.py32
-rw-r--r--synapse/handlers/room.py2
-rw-r--r--synapse/handlers/sliding_sync.py65
-rw-r--r--synapse/handlers/sync.py69
-rw-r--r--synapse/storage/databases/main/stream.py275
-rw-r--r--tests/storage/test_stream.py2
11 files changed, 352 insertions, 136 deletions
diff --git a/changelog.d/17510.bugfix b/changelog.d/17510.bugfix
new file mode 100644
index 0000000000..3170c284bd
--- /dev/null
+++ b/changelog.d/17510.bugfix
@@ -0,0 +1 @@
+Fix timeline ordering (using `stream_ordering` instead of topological ordering) in experimental [MSC3575](https://github.com/matrix-org/matrix-spec-proposals/pull/3575) Sliding Sync `/sync` endpoint.
diff --git a/changelog.d/17536.misc b/changelog.d/17536.misc
new file mode 100644
index 0000000000..116ef0c36d
--- /dev/null
+++ b/changelog.d/17536.misc
@@ -0,0 +1 @@
+Replace override of deprecated method `HTTPAdapter.get_connection` with `get_connection_with_tls_context`.
\ No newline at end of file
diff --git a/docs/development/room-dag-concepts.md b/docs/development/room-dag-concepts.md
index 76709487f8..35b667831c 100644
--- a/docs/development/room-dag-concepts.md
+++ b/docs/development/room-dag-concepts.md
@@ -21,8 +21,10 @@ incrementing integer, but backfilled events start with `stream_ordering=-1` and
 
 ---
 
- - `/sync` returns things in the order they arrive at the server (`stream_ordering`).
- - `/messages` (and `/backfill` in the federation API) return them in the order determined by the event graph `(topological_ordering, stream_ordering)`.
+ - Incremental `/sync?since=xxx` returns things in the order they arrive at the server
+   (`stream_ordering`).
+ - Initial `/sync`, `/messages` (and `/backfill` in the federation API) return them in
+   the order determined by the event graph `(topological_ordering, stream_ordering)`.
 
 The general idea is that, if you're following a room in real-time (i.e.
 `/sync`), you probably want to see the messages as they arrive at your server,
diff --git a/scripts-dev/federation_client.py b/scripts-dev/federation_client.py
index 4c758e5424..fb879ef555 100755
--- a/scripts-dev/federation_client.py
+++ b/scripts-dev/federation_client.py
@@ -43,7 +43,7 @@ import argparse
 import base64
 import json
 import sys
-from typing import Any, Dict, Optional, Tuple
+from typing import Any, Dict, Mapping, Optional, Tuple, Union
 from urllib import parse as urlparse
 
 import requests
@@ -75,7 +75,7 @@ def encode_canonical_json(value: object) -> bytes:
         value,
         # Encode code-points outside of ASCII as UTF-8 rather than \u escapes
         ensure_ascii=False,
-        # Remove unecessary white space.
+        # Remove unnecessary white space.
         separators=(",", ":"),
         # Sort the keys of dictionaries.
         sort_keys=True,
@@ -298,12 +298,23 @@ class MatrixConnectionAdapter(HTTPAdapter):
 
         return super().send(request, *args, **kwargs)
 
-    def get_connection(
-        self, url: str, proxies: Optional[Dict[str, str]] = None
+    def get_connection_with_tls_context(
+        self,
+        request: PreparedRequest,
+        verify: Optional[Union[bool, str]],
+        proxies: Optional[Mapping[str, str]] = None,
+        cert: Optional[Union[Tuple[str, str], str]] = None,
     ) -> HTTPConnectionPool:
-        # overrides the get_connection() method in the base class
-        parsed = urlparse.urlsplit(url)
-        (host, port, ssl_server_name) = self._lookup(parsed.netloc)
+        # overrides the get_connection_with_tls_context() method in the base class
+        parsed = urlparse.urlsplit(request.url)
+
+        # Extract the server name from the request URL, and ensure it's a str.
+        hostname = parsed.netloc
+        if isinstance(hostname, bytes):
+            hostname = hostname.decode("utf-8")
+        assert isinstance(hostname, str)
+
+        (host, port, ssl_server_name) = self._lookup(hostname)
         print(
             f"Connecting to {host}:{port} with SNI {ssl_server_name}", file=sys.stderr
         )
diff --git a/synapse/handlers/admin.py b/synapse/handlers/admin.py
index ec35784c5f..b44e862493 100644
--- a/synapse/handlers/admin.py
+++ b/synapse/handlers/admin.py
@@ -197,8 +197,14 @@ class AdminHandler:
             # events that we have and then filtering, this isn't the most
             # efficient method perhaps but it does guarantee we get everything.
             while True:
-                events, _ = await self._store.paginate_room_events(
-                    room_id, from_key, to_key, limit=100, direction=Direction.FORWARDS
+                events, _ = (
+                    await self._store.paginate_room_events_by_topological_ordering(
+                        room_id=room_id,
+                        from_key=from_key,
+                        to_key=to_key,
+                        limit=100,
+                        direction=Direction.FORWARDS,
+                    )
                 )
                 if not events:
                     break
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index 872c85fbad..6fd7afa280 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -507,13 +507,15 @@ class PaginationHandler:
 
         # Initially fetch the events from the database. With any luck, we can return
         # these without blocking on backfill (handled below).
-        events, next_key = await self.store.paginate_room_events(
-            room_id=room_id,
-            from_key=from_token.room_key,
-            to_key=to_room_key,
-            direction=pagin_config.direction,
-            limit=pagin_config.limit,
-            event_filter=event_filter,
+        events, next_key = (
+            await self.store.paginate_room_events_by_topological_ordering(
+                room_id=room_id,
+                from_key=from_token.room_key,
+                to_key=to_room_key,
+                direction=pagin_config.direction,
+                limit=pagin_config.limit,
+                event_filter=event_filter,
+            )
         )
 
         if pagin_config.direction == Direction.BACKWARDS:
@@ -582,13 +584,15 @@ class PaginationHandler:
                 # If we did backfill something, refetch the events from the database to
                 # catch anything new that might have been added since we last fetched.
                 if did_backfill:
-                    events, next_key = await self.store.paginate_room_events(
-                        room_id=room_id,
-                        from_key=from_token.room_key,
-                        to_key=to_room_key,
-                        direction=pagin_config.direction,
-                        limit=pagin_config.limit,
-                        event_filter=event_filter,
+                    events, next_key = (
+                        await self.store.paginate_room_events_by_topological_ordering(
+                            room_id=room_id,
+                            from_key=from_token.room_key,
+                            to_key=to_room_key,
+                            direction=pagin_config.direction,
+                            limit=pagin_config.limit,
+                            event_filter=event_filter,
+                        )
                     )
             else:
                 # Otherwise, we can backfill in the background for eventual
diff --git a/synapse/handlers/room.py b/synapse/handlers/room.py
index 262d9f4044..2c6e672ede 100644
--- a/synapse/handlers/room.py
+++ b/synapse/handlers/room.py
@@ -1750,7 +1750,7 @@ class RoomEventSource(EventSource[RoomStreamToken, EventBase]):
                 from_key=from_key,
                 to_key=to_key,
                 limit=limit or 10,
-                order="ASC",
+                direction=Direction.FORWARDS,
             )
 
             events = list(room_events)
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index be08f3851c..77ec09bc85 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -65,7 +65,10 @@ from synapse.storage.databases.main.state import (
     ROOM_UNKNOWN_SENTINEL,
     Sentinel as StateSentinel,
 )
-from synapse.storage.databases.main.stream import CurrentStateDeltaMembership
+from synapse.storage.databases.main.stream import (
+    CurrentStateDeltaMembership,
+    PaginateFunction,
+)
 from synapse.storage.roommember import MemberSummary
 from synapse.types import (
     DeviceListUpdates,
@@ -658,6 +661,23 @@ class SlidingSyncHandler:
                         filtered_sync_room_map, to_token
                     )
 
+                    if not lazy_loading:
+                        # Exclude partially-stated rooms unless the `required_state`
+                        # only has `["m.room.member", "$LAZY"]` for membership
+                        # (lazy-loading room members).
+                        filtered_sync_room_map = {
+                            room_id: room
+                            for room_id, room in filtered_sync_room_map.items()
+                            if not partial_state_room_map.get(room_id)
+                        }
+
+                    all_rooms.update(filtered_sync_room_map)
+
+                    # Sort the list
+                    sorted_room_info = await self.sort_rooms(
+                        filtered_sync_room_map, to_token
+                    )
+
                     ops: List[SlidingSyncResult.SlidingWindowList.Operation] = []
                     if list_config.ranges:
                         for range in list_config.ranges:
@@ -733,6 +753,8 @@ class SlidingSyncHandler:
                     if not room_membership_for_user_at_to_token:
                         continue
 
+                    all_rooms.add(room_id)
+
                     room_membership_for_user_map[room_id] = (
                         room_membership_for_user_at_to_token
                     )
@@ -869,9 +891,9 @@ class SlidingSyncHandler:
                 # positions we can use that.
                 missing_event_map_by_room = (
                     await self.store.get_room_events_stream_for_rooms(
-                        missing_rooms,
-                        from_key=from_token.stream_token.room_key,
-                        to_key=to_token.room_key,
+                        room_ids=missing_rooms,
+                        from_key=to_token.room_key,
+                        to_key=from_token.stream_token.room_key,
                         limit=1,
                     )
                 )
@@ -1971,10 +1993,13 @@ class SlidingSyncHandler:
         # We should return historical messages (before token range) in the
         # following cases because we want clients to be able to show a basic
         # screen of information:
+        #
         #  - Initial sync (because no `from_token` to limit us anyway)
         #  - When users `newly_joined`
         #  - For an incremental sync where we haven't sent it down this
         #    connection before
+        #
+        # Relevant spec issue: https://github.com/matrix-org/matrix-spec/issues/1917
         from_bound = None
         initial = True
         if from_token and not room_membership_for_user_at_to_token.newly_joined:
@@ -2035,7 +2060,36 @@ class SlidingSyncHandler:
                     room_membership_for_user_at_to_token.event_pos.to_room_stream_token()
                 )
 
-            timeline_events, new_room_key = await self.store.paginate_room_events(
+            # For initial `/sync` (and other historical scenarios mentioned above), we
+            # want to view a historical section of the timeline; to fetch events by
+            # `topological_ordering` (best representation of the room DAG as others were
+            # seeing it at the time). This also aligns with the order that `/messages`
+            # returns events in.
+            #
+            # For incremental `/sync`, we want to get all updates for rooms since
+            # the last `/sync` (regardless if those updates arrived late or happened
+            # a while ago in the past); to fetch events by `stream_ordering` (in the
+            # order they were received by the server).
+            #
+            # Relevant spec issue: https://github.com/matrix-org/matrix-spec/issues/1917
+            #
+            # FIXME: Using workaround for mypy,
+            # https://github.com/python/mypy/issues/10740#issuecomment-1997047277 and
+            # https://github.com/python/mypy/issues/17479
+            paginate_room_events_by_topological_ordering: PaginateFunction = (
+                self.store.paginate_room_events_by_topological_ordering
+            )
+            paginate_room_events_by_stream_ordering: PaginateFunction = (
+                self.store.paginate_room_events_by_stream_ordering
+            )
+            pagination_method: PaginateFunction = (
+                # Use `topographical_ordering` for historical events
+                paginate_room_events_by_topological_ordering
+                if from_bound is None
+                # Use `stream_ordering` for updates
+                else paginate_room_events_by_stream_ordering
+            )
+            timeline_events, new_room_key = await pagination_method(
                 room_id=room_id,
                 # The bounds are reversed so we can paginate backwards
                 # (from newer to older events) starting at to_bound.
@@ -2046,7 +2100,6 @@ class SlidingSyncHandler:
                 # We add one so we can determine if there are enough events to saturate
                 # the limit or not (see `limited`)
                 limit=room_sync_config.timeline_limit + 1,
-                event_filter=None,
             )
 
             # We want to return the events in ascending order (the last event is the
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index ede014180c..6af2eeb75f 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -43,6 +43,7 @@ from prometheus_client import Counter
 
 from synapse.api.constants import (
     AccountDataTypes,
+    Direction,
     EventContentFields,
     EventTypes,
     JoinRules,
@@ -64,6 +65,7 @@ from synapse.logging.opentracing import (
 )
 from synapse.storage.databases.main.event_push_actions import RoomNotifCounts
 from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
+from synapse.storage.databases.main.stream import PaginateFunction
 from synapse.storage.roommember import MemberSummary
 from synapse.types import (
     DeviceListUpdates,
@@ -879,22 +881,49 @@ class SyncHandler:
                 since_key = since_token.room_key
 
             while limited and len(recents) < timeline_limit and max_repeat:
-                # If we have a since_key then we are trying to get any events
-                # that have happened since `since_key` up to `end_key`, so we
-                # can just use `get_room_events_stream_for_room`.
-                # Otherwise, we want to return the last N events in the room
-                # in topological ordering.
-                if since_key:
-                    events, end_key = await self.store.get_room_events_stream_for_room(
-                        room_id,
-                        limit=load_limit + 1,
-                        from_key=since_key,
-                        to_key=end_key,
-                    )
-                else:
-                    events, end_key = await self.store.get_recent_events_for_room(
-                        room_id, limit=load_limit + 1, end_token=end_key
-                    )
+                # For initial `/sync`, we want to view a historical section of the
+                # timeline; to fetch events by `topological_ordering` (best
+                # representation of the room DAG as others were seeing it at the time).
+                # This also aligns with the order that `/messages` returns events in.
+                #
+                # For incremental `/sync`, we want to get all updates for rooms since
+                # the last `/sync` (regardless if those updates arrived late or happened
+                # a while ago in the past); to fetch events by `stream_ordering` (in the
+                # order they were received by the server).
+                #
+                # Relevant spec issue: https://github.com/matrix-org/matrix-spec/issues/1917
+                #
+                # FIXME: Using workaround for mypy,
+                # https://github.com/python/mypy/issues/10740#issuecomment-1997047277 and
+                # https://github.com/python/mypy/issues/17479
+                paginate_room_events_by_topological_ordering: PaginateFunction = (
+                    self.store.paginate_room_events_by_topological_ordering
+                )
+                paginate_room_events_by_stream_ordering: PaginateFunction = (
+                    self.store.paginate_room_events_by_stream_ordering
+                )
+                pagination_method: PaginateFunction = (
+                    # Use `topographical_ordering` for historical events
+                    paginate_room_events_by_topological_ordering
+                    if since_key is None
+                    # Use `stream_ordering` for updates
+                    else paginate_room_events_by_stream_ordering
+                )
+                events, end_key = await pagination_method(
+                    room_id=room_id,
+                    # The bounds are reversed so we can paginate backwards
+                    # (from newer to older events) starting at to_bound.
+                    # This ensures we fill the `limit` with the newest events first,
+                    from_key=end_key,
+                    to_key=since_key,
+                    direction=Direction.BACKWARDS,
+                    # We add one so we can determine if there are enough events to saturate
+                    # the limit or not (see `limited`)
+                    limit=load_limit + 1,
+                )
+                # We want to return the events in ascending order (the last event is the
+                # most recent).
+                events.reverse()
 
                 log_kv({"loaded_recents": len(events)})
 
@@ -2641,9 +2670,10 @@ class SyncHandler:
         # a "gap" in the timeline, as described by the spec for /sync.
         room_to_events = await self.store.get_room_events_stream_for_rooms(
             room_ids=sync_result_builder.joined_room_ids,
-            from_key=since_token.room_key,
-            to_key=now_token.room_key,
+            from_key=now_token.room_key,
+            to_key=since_token.room_key,
             limit=timeline_limit + 1,
+            direction=Direction.BACKWARDS,
         )
 
         # We loop through all room ids, even if there are no new events, in case
@@ -2654,6 +2684,9 @@ class SyncHandler:
             newly_joined = room_id in newly_joined_rooms
             if room_entry:
                 events, start_key = room_entry
+                # We want to return the events in ascending order (the last event is the
+                # most recent).
+                events.reverse()
 
                 prev_batch_token = now_token.copy_and_replace(
                     StreamKeyType.ROOM, start_key
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 95775e3804..4989c960a6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -51,6 +51,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Protocol,
     Set,
     Tuple,
     cast,
@@ -59,7 +60,7 @@ from typing import (
 
 import attr
 from immutabledict import immutabledict
-from typing_extensions import Literal
+from typing_extensions import Literal, assert_never
 
 from twisted.internet import defer
 
@@ -97,6 +98,18 @@ _STREAM_TOKEN = "stream"
 _TOPOLOGICAL_TOKEN = "topological"
 
 
+class PaginateFunction(Protocol):
+    async def __call__(
+        self,
+        *,
+        room_id: str,
+        from_key: RoomStreamToken,
+        to_key: Optional[RoomStreamToken] = None,
+        direction: Direction = Direction.BACKWARDS,
+        limit: int = 0,
+    ) -> Tuple[List[EventBase], RoomStreamToken]: ...
+
+
 # Used as return values for pagination APIs
 @attr.s(slots=True, frozen=True, auto_attribs=True)
 class _EventDictReturn:
@@ -280,7 +293,7 @@ def generate_pagination_bounds(
 
 
 def generate_next_token(
-    direction: Direction, last_topo_ordering: int, last_stream_ordering: int
+    direction: Direction, last_topo_ordering: Optional[int], last_stream_ordering: int
 ) -> RoomStreamToken:
     """
     Generate the next room stream token based on the currently returned data.
@@ -447,7 +460,6 @@ def _filter_results_by_stream(
     The `instance_name` arg is optional to handle historic rows, and is
     interpreted as if it was "master".
     """
-
     if instance_name is None:
         instance_name = "master"
 
@@ -660,33 +672,43 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     async def get_room_events_stream_for_rooms(
         self,
+        *,
         room_ids: Collection[str],
         from_key: RoomStreamToken,
-        to_key: RoomStreamToken,
+        to_key: Optional[RoomStreamToken] = None,
+        direction: Direction = Direction.BACKWARDS,
         limit: int = 0,
-        order: str = "DESC",
     ) -> Dict[str, Tuple[List[EventBase], RoomStreamToken]]:
         """Get new room events in stream ordering since `from_key`.
 
         Args:
             room_ids
-            from_key: Token from which no events are returned before
-            to_key: Token from which no events are returned after. (This
-                is typically the current stream token)
+            from_key: The token to stream from (starting point and heading in the given
+                direction)
+            to_key: The token representing the end stream position (end point)
             limit: Maximum number of events to return
-            order: Either "DESC" or "ASC". Determines which events are
-                returned when the result is limited. If "DESC" then the most
-                recent `limit` events are returned, otherwise returns the
-                oldest `limit` events.
+            direction: Indicates whether we are paginating forwards or backwards
+                from `from_key`.
 
         Returns:
             A map from room id to a tuple containing:
                 - list of recent events in the room
                 - stream ordering key for the start of the chunk of events returned.
+
+            When Direction.FORWARDS: from_key < x <= to_key, (ascending order)
+            When Direction.BACKWARDS: from_key >= x > to_key, (descending order)
         """
-        room_ids = self._events_stream_cache.get_entities_changed(
-            room_ids, from_key.stream
-        )
+        if direction == Direction.FORWARDS:
+            room_ids = self._events_stream_cache.get_entities_changed(
+                room_ids, from_key.stream
+            )
+        elif direction == Direction.BACKWARDS:
+            if to_key is not None:
+                room_ids = self._events_stream_cache.get_entities_changed(
+                    room_ids, to_key.stream
+                )
+        else:
+            assert_never(direction)
 
         if not room_ids:
             return {}
@@ -698,12 +720,12 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 defer.gatherResults(
                     [
                         run_in_background(
-                            self.get_room_events_stream_for_room,
-                            room_id,
-                            from_key,
-                            to_key,
-                            limit,
-                            order=order,
+                            self.paginate_room_events_by_stream_ordering,
+                            room_id=room_id,
+                            from_key=from_key,
+                            to_key=to_key,
+                            direction=direction,
+                            limit=limit,
                         )
                         for room_id in rm_ids
                     ],
@@ -727,69 +749,122 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             if self._events_stream_cache.has_entity_changed(room_id, from_id)
         }
 
-    async def get_room_events_stream_for_room(
+    async def paginate_room_events_by_stream_ordering(
         self,
+        *,
         room_id: str,
         from_key: RoomStreamToken,
-        to_key: RoomStreamToken,
+        to_key: Optional[RoomStreamToken] = None,
+        direction: Direction = Direction.BACKWARDS,
         limit: int = 0,
-        order: str = "DESC",
     ) -> Tuple[List[EventBase], RoomStreamToken]:
-        """Get new room events in stream ordering since `from_key`.
+        """
+        Paginate events by `stream_ordering` in the room from the `from_key` in the
+        given `direction` to the `to_key` or `limit`.
 
         Args:
             room_id
-            from_key: Token from which no events are returned before
-            to_key: Token from which no events are returned after. (This
-                is typically the current stream token)
+            from_key: The token to stream from (starting point and heading in the given
+                direction)
+            to_key: The token representing the end stream position (end point)
+            direction: Indicates whether we are paginating forwards or backwards
+                from `from_key`.
             limit: Maximum number of events to return
-            order: Either "DESC" or "ASC". Determines which events are
-                returned when the result is limited. If "DESC" then the most
-                recent `limit` events are returned, otherwise returns the
-                oldest `limit` events.
 
         Returns:
-            The list of events (in ascending stream order) and the token from the start
-            of the chunk of events returned.
+            The results as a list of events and a token that points to the end
+            of the result set. If no events are returned then the end of the
+            stream has been reached (i.e. there are no events between `from_key`
+            and `to_key`).
+
+            When Direction.FORWARDS: from_key < x <= to_key, (ascending order)
+            When Direction.BACKWARDS: from_key >= x > to_key, (descending order)
         """
-        if from_key == to_key:
-            return [], from_key
 
-        has_changed = self._events_stream_cache.has_entity_changed(
-            room_id, from_key.stream
-        )
+        # FIXME: When going forwards, we should enforce that the `to_key` is not `None`
+        # because we always need an upper bound when querying the events stream (as
+        # otherwise we'll potentially pick up events that are not fully persisted).
+
+        # We should only be working with `stream_ordering` tokens here
+        assert from_key is None or from_key.topological is None
+        assert to_key is None or to_key.topological is None
+
+        # We can bail early if we're looking forwards, and our `to_key` is already
+        # before our `from_key`.
+        if (
+            direction == Direction.FORWARDS
+            and to_key is not None
+            and to_key.is_before_or_eq(from_key)
+        ):
+            # Token selection matches what we do below if there are no rows
+            return [], to_key if to_key else from_key
+        # Or vice-versa, if we're looking backwards and our `from_key` is already before
+        # our `to_key`.
+        elif (
+            direction == Direction.BACKWARDS
+            and to_key is not None
+            and from_key.is_before_or_eq(to_key)
+        ):
+            # Token selection matches what we do below if there are no rows
+            return [], to_key if to_key else from_key
+
+        # We can do a quick sanity check to see if any events have been sent in the room
+        # since the earlier token.
+        has_changed = True
+        if direction == Direction.FORWARDS:
+            has_changed = self._events_stream_cache.has_entity_changed(
+                room_id, from_key.stream
+            )
+        elif direction == Direction.BACKWARDS:
+            if to_key is not None:
+                has_changed = self._events_stream_cache.has_entity_changed(
+                    room_id, to_key.stream
+                )
+        else:
+            assert_never(direction)
 
         if not has_changed:
-            return [], from_key
+            # Token selection matches what we do below if there are no rows
+            return [], to_key if to_key else from_key
 
-        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
-            # To handle tokens with a non-empty instance_map we fetch more
-            # results than necessary and then filter down
-            min_from_id = from_key.stream
-            max_to_id = to_key.get_max_stream_pos()
+        order, from_bound, to_bound = generate_pagination_bounds(
+            direction, from_key, to_key
+        )
 
-            sql = """
-                SELECT event_id, instance_name, topological_ordering, stream_ordering
+        bounds = generate_pagination_where_clause(
+            direction=direction,
+            # The empty string will shortcut downstream code to only use the
+            # `stream_ordering` column
+            column_names=("", "stream_ordering"),
+            from_token=from_bound,
+            to_token=to_bound,
+            engine=self.database_engine,
+        )
+
+        def f(txn: LoggingTransaction) -> List[_EventDictReturn]:
+            sql = f"""
+                SELECT event_id, instance_name, stream_ordering
                 FROM events
                 WHERE
                     room_id = ?
                     AND not outlier
-                    AND stream_ordering > ? AND stream_ordering <= ?
-                ORDER BY stream_ordering %s LIMIT ?
-            """ % (
-                order,
-            )
-            txn.execute(sql, (room_id, min_from_id, max_to_id, 2 * limit))
+                    AND {bounds}
+                ORDER BY stream_ordering {order} LIMIT ?
+            """
+            txn.execute(sql, (room_id, 2 * limit))
 
             rows = [
                 _EventDictReturn(event_id, None, stream_ordering)
-                for event_id, instance_name, topological_ordering, stream_ordering in txn
-                if _filter_results(
-                    from_key,
-                    to_key,
-                    instance_name,
-                    topological_ordering,
-                    stream_ordering,
+                for event_id, instance_name, stream_ordering in txn
+                if _filter_results_by_stream(
+                    lower_token=(
+                        to_key if direction == Direction.BACKWARDS else from_key
+                    ),
+                    upper_token=(
+                        from_key if direction == Direction.BACKWARDS else to_key
+                    ),
+                    instance_name=instance_name,
+                    stream_ordering=stream_ordering,
                 )
             ][:limit]
             return rows
@@ -800,17 +875,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             [r.event_id for r in rows], get_prev_content=True
         )
 
-        if order.lower() == "desc":
-            ret.reverse()
-
         if rows:
-            key = RoomStreamToken(stream=min(r.stream_ordering for r in rows))
+            next_key = generate_next_token(
+                direction=direction,
+                last_topo_ordering=None,
+                last_stream_ordering=rows[-1].stream_ordering,
+            )
         else:
-            # Assume we didn't get anything because there was nothing to
-            # get.
-            key = from_key
+            # TODO (erikj): We should work out what to do here instead. (same as
+            # `_paginate_room_events_by_topological_ordering_txn(...)`)
+            next_key = to_key if to_key else from_key
 
-        return ret, key
+        return ret, next_key
 
     @trace
     async def get_current_state_delta_membership_changes_for_user(
@@ -1118,7 +1194,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         rows, token = await self.db_pool.runInteraction(
             "get_recent_event_ids_for_room",
-            self._paginate_room_events_txn,
+            self._paginate_room_events_by_topological_ordering_txn,
             room_id,
             from_token=end_token,
             limit=limit,
@@ -1624,7 +1700,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             topological=topological_ordering, stream=stream_ordering
         )
 
-        rows, start_token = self._paginate_room_events_txn(
+        rows, start_token = self._paginate_room_events_by_topological_ordering_txn(
             txn,
             room_id,
             before_token,
@@ -1634,7 +1710,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
         )
         events_before = [r.event_id for r in rows]
 
-        rows, end_token = self._paginate_room_events_txn(
+        rows, end_token = self._paginate_room_events_by_topological_ordering_txn(
             txn,
             room_id,
             after_token,
@@ -1797,14 +1873,14 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     def has_room_changed_since(self, room_id: str, stream_id: int) -> bool:
         return self._events_stream_cache.has_entity_changed(room_id, stream_id)
 
-    def _paginate_room_events_txn(
+    def _paginate_room_events_by_topological_ordering_txn(
         self,
         txn: LoggingTransaction,
         room_id: str,
         from_token: RoomStreamToken,
         to_token: Optional[RoomStreamToken] = None,
         direction: Direction = Direction.BACKWARDS,
-        limit: int = -1,
+        limit: int = 0,
         event_filter: Optional[Filter] = None,
     ) -> Tuple[List[_EventDictReturn], RoomStreamToken]:
         """Returns list of events before or after a given token.
@@ -1826,6 +1902,24 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             been reached (i.e. there are no events between `from_token` and
             `to_token`), or `limit` is zero.
         """
+        # We can bail early if we're looking forwards, and our `to_key` is already
+        # before our `from_token`.
+        if (
+            direction == Direction.FORWARDS
+            and to_token is not None
+            and to_token.is_before_or_eq(from_token)
+        ):
+            # Token selection matches what we do below if there are no rows
+            return [], to_token if to_token else from_token
+        # Or vice-versa, if we're looking backwards and our `from_token` is already before
+        # our `to_token`.
+        elif (
+            direction == Direction.BACKWARDS
+            and to_token is not None
+            and from_token.is_before_or_eq(to_token)
+        ):
+            # Token selection matches what we do below if there are no rows
+            return [], to_token if to_token else from_token
 
         args: List[Any] = [room_id]
 
@@ -1910,7 +2004,6 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             "bounds": bounds,
             "order": order,
         }
-
         txn.execute(sql, args)
 
         # Filter the result set.
@@ -1943,27 +2036,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     @trace
     @tag_args
-    async def paginate_room_events(
+    async def paginate_room_events_by_topological_ordering(
         self,
+        *,
         room_id: str,
         from_key: RoomStreamToken,
         to_key: Optional[RoomStreamToken] = None,
         direction: Direction = Direction.BACKWARDS,
-        limit: int = -1,
+        limit: int = 0,
         event_filter: Optional[Filter] = None,
     ) -> Tuple[List[EventBase], RoomStreamToken]:
-        """Returns list of events before or after a given token.
-
-        When Direction.FORWARDS: from_key < x <= to_key
-        When Direction.BACKWARDS: from_key >= x > to_key
+        """
+        Paginate events by `topological_ordering` (tie-break with `stream_ordering`) in
+        the room from the `from_key` in the given `direction` to the `to_key` or
+        `limit`.
 
         Args:
             room_id
-            from_key: The token used to stream from
-            to_key: A token which if given limits the results to only those before
+            from_key: The token to stream from (starting point and heading in the given
+                direction)
+            to_key: The token representing the end stream position (end point)
             direction: Indicates whether we are paginating forwards or backwards
                 from `from_key`.
-            limit: The maximum number of events to return.
+            limit: Maximum number of events to return
             event_filter: If provided filters the events to those that match the filter.
 
         Returns:
@@ -1971,8 +2066,18 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             of the result set. If no events are returned then the end of the
             stream has been reached (i.e. there are no events between `from_key`
             and `to_key`).
+
+            When Direction.FORWARDS: from_key < x <= to_key, (ascending order)
+            When Direction.BACKWARDS: from_key >= x > to_key, (descending order)
         """
 
+        # FIXME: When going forwards, we should enforce that the `to_key` is not `None`
+        # because we always need an upper bound when querying the events stream (as
+        # otherwise we'll potentially pick up events that are not fully persisted).
+
+        # We have these checks outside of the transaction function (txn) to save getting
+        # a DB connection and switching threads if we don't need to.
+        #
         # We can bail early if we're looking forwards, and our `to_key` is already
         # before our `from_key`.
         if (
@@ -1995,8 +2100,8 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             return [], to_key if to_key else from_key
 
         rows, token = await self.db_pool.runInteraction(
-            "paginate_room_events",
-            self._paginate_room_events_txn,
+            "paginate_room_events_by_topological_ordering",
+            self._paginate_room_events_by_topological_ordering_txn,
             room_id,
             from_key,
             to_key,
diff --git a/tests/storage/test_stream.py b/tests/storage/test_stream.py
index 9dea1af8ea..7b7590da76 100644
--- a/tests/storage/test_stream.py
+++ b/tests/storage/test_stream.py
@@ -148,7 +148,7 @@ class PaginationTestCase(HomeserverTestCase):
         """Make a request to /messages with a filter, returns the chunk of events."""
 
         events, next_key = self.get_success(
-            self.hs.get_datastores().main.paginate_room_events(
+            self.hs.get_datastores().main.paginate_room_events_by_topological_ordering(
                 room_id=self.room_id,
                 from_key=self.from_token.room_key,
                 to_key=None,