summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/17468.misc1
-rw-r--r--synapse/handlers/sliding_sync.py43
-rw-r--r--synapse/storage/databases/main/event_federation.py5
-rw-r--r--synapse/storage/databases/main/stream.py123
-rw-r--r--synapse/util/caches/stream_change_cache.py12
-rw-r--r--tests/util/test_stream_change_cache.py4
6 files changed, 159 insertions, 29 deletions
diff --git a/changelog.d/17468.misc b/changelog.d/17468.misc
new file mode 100644
index 0000000000..d908776204
--- /dev/null
+++ b/changelog.d/17468.misc
@@ -0,0 +1 @@
+Speed up sorting of the room list in sliding sync.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 886d7c7159..554ab59bf3 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -1230,34 +1230,33 @@ class SlidingSyncHandler:
         # Assemble a map of room ID to the `stream_ordering` of the last activity that the
         # user should see in the room (<= `to_token`)
         last_activity_in_room_map: Dict[str, int] = {}
-        for room_id, room_for_user in sync_room_map.items():
-            # If they are fully-joined to the room, let's find the latest activity
-            # at/before the `to_token`.
-            if room_for_user.membership == Membership.JOIN:
-                last_event_result = (
-                    await self.store.get_last_event_pos_in_room_before_stream_ordering(
-                        room_id, to_token.room_key
-                    )
-                )
-
-                # If the room has no events at/before the `to_token`, this is probably a
-                # mistake in the code that generates the `sync_room_map` since that should
-                # only give us rooms that the user had membership in during the token range.
-                assert last_event_result is not None
 
-                _, event_pos = last_event_result
-
-                last_activity_in_room_map[room_id] = event_pos.stream
-            else:
-                # Otherwise, if the user has left/been invited/knocked/been banned from
-                # a room, they shouldn't see anything past that point.
+        for room_id, room_for_user in sync_room_map.items():
+            if room_for_user.membership != Membership.JOIN:
+                # If the user has left/been invited/knocked/been banned from a
+                # room, they shouldn't see anything past that point.
                 #
-                # FIXME: It's possible that people should see beyond this point in
-                # invited/knocked cases if for example the room has
+                # FIXME: It's possible that people should see beyond this point
+                # in invited/knocked cases if for example the room has
                 # `invite`/`world_readable` history visibility, see
                 # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
                 last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
 
+        # For fully-joined rooms, we find the latest activity at/before the
+        # `to_token`.
+        joined_room_positions = (
+            await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering(
+                [
+                    room_id
+                    for room_id, room_for_user in sync_room_map.items()
+                    if room_for_user.membership == Membership.JOIN
+                ],
+                to_token.room_key,
+            )
+        )
+
+        last_activity_in_room_map.update(joined_room_positions)
+
         return sorted(
             sync_room_map.values(),
             # Sort by the last activity (stream_ordering) in the room
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 24abab4a23..715846865b 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1313,6 +1313,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # We want to make the cache more effective, so we clamp to the last
         # change before the given ordering.
         last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id)  # type: ignore[attr-defined]
+        if last_change is None:
+            # If the room isn't in the cache we know that the last change was
+            # somewhere before the earliest known position of the cache, so we
+            # can clamp to that.
+            last_change = self._events_stream_cache.get_earliest_known_position()  # type: ignore[attr-defined]
 
         # We don't always have a full stream_to_exterm_id table, e.g. after
         # the upgrade that introduced it, so we make sure we never ask for a
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e74e0d2e91..b034361aec 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -78,10 +78,11 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import PersistedEventPosition, RoomStreamToken
+from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.cancellation import cancellable
+from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -1293,6 +1294,126 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             get_last_event_pos_in_room_before_stream_ordering_txn,
         )
 
+    async def bulk_get_last_event_pos_in_room_before_stream_ordering(
+        self,
+        room_ids: StrCollection,
+        end_token: RoomStreamToken,
+    ) -> Dict[str, int]:
+        """Bulk fetch the stream position of the latest events in the given
+        rooms
+        """
+
+        min_token = end_token.stream
+        max_token = end_token.get_max_stream_pos()
+        results: Dict[str, int] = {}
+
+        # First, we check for the rooms in the stream change cache to see if we
+        # can just use the latest position from it.
+        missing_room_ids: Set[str] = set()
+        for room_id in room_ids:
+            stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+            if stream_pos and stream_pos <= min_token:
+                results[room_id] = stream_pos
+            else:
+                missing_room_ids.add(room_id)
+
+        # Next, we query the stream position from the DB. At first we fetch all
+        # positions less than the *max* stream pos in the token, then filter
+        # them down. We do this as a) this is a cheaper query, and b) the vast
+        # majority of rooms will have a latest token from before the min stream
+        # pos.
+
+        def bulk_get_last_event_pos_txn(
+            txn: LoggingTransaction, batch_room_ids: StrCollection
+        ) -> Dict[str, int]:
+            # This query fetches the latest stream position in the rooms before
+            # the given max position.
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", batch_room_ids
+            )
+            sql = f"""
+                SELECT room_id, (
+                    SELECT stream_ordering FROM events AS e
+                    LEFT JOIN rejections USING (event_id)
+                    WHERE e.room_id = r.room_id
+                        AND stream_ordering <= ?
+                        AND NOT outlier
+                        AND rejection_reason IS NULL
+                    ORDER BY stream_ordering DESC
+                    LIMIT 1
+                )
+                FROM rooms AS r
+                WHERE {clause}
+            """
+            txn.execute(sql, [max_token] + args)
+            return {row[0]: row[1] for row in txn}
+
+        recheck_rooms: Set[str] = set()
+        for batched in batch_iter(missing_room_ids, 1000):
+            result = await self.db_pool.runInteraction(
+                "bulk_get_last_event_pos_in_room_before_stream_ordering",
+                bulk_get_last_event_pos_txn,
+                batched,
+            )
+
+            # Check that the stream position for the rooms are from before the
+            # minimum position of the token. If not then we need to fetch more
+            # rows.
+            for room_id, stream in result.items():
+                if stream <= min_token:
+                    results[room_id] = stream
+                else:
+                    recheck_rooms.add(room_id)
+
+        if not recheck_rooms:
+            return results
+
+        # For the remaining rooms we need to fetch all rows between the min and
+        # max stream positions in the end token, and filter out the rows that
+        # are after the end token.
+        #
+        # This query should be fast as the range between the min and max should
+        # be small.
+
+        def bulk_get_last_event_pos_recheck_txn(
+            txn: LoggingTransaction, batch_room_ids: StrCollection
+        ) -> Dict[str, int]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", batch_room_ids
+            )
+            sql = f"""
+                SELECT room_id, instance_name, stream_ordering
+                FROM events
+                WHERE ? < stream_ordering AND stream_ordering <= ?
+                    AND NOT outlier
+                    AND rejection_reason IS NULL
+                    AND {clause}
+                ORDER BY stream_ordering ASC
+            """
+            txn.execute(sql, [min_token, max_token] + args)
+
+            # We take the max stream ordering that is less than the token. Since
+            # we ordered by stream ordering we just need to iterate through and
+            # take the last matching stream ordering.
+            txn_results: Dict[str, int] = {}
+            for row in txn:
+                room_id = row[0]
+                event_pos = PersistedEventPosition(row[1], row[2])
+                if not event_pos.persisted_after(end_token):
+                    txn_results[room_id] = event_pos.stream
+
+            return txn_results
+
+        for batched in batch_iter(recheck_rooms, 1000):
+            recheck_result = await self.db_pool.runInteraction(
+                "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck",
+                bulk_get_last_event_pos_recheck_txn,
+                batched,
+            )
+            results.update(recheck_result)
+
+        return results
+
     async def get_current_room_stream_token_for_room_id(
         self, room_id: str
     ) -> RoomStreamToken:
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 91c335f85b..16fcb00206 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -327,7 +327,7 @@ class StreamChangeCache:
             for entity in r:
                 self._entity_to_key.pop(entity, None)
 
-    def get_max_pos_of_last_change(self, entity: EntityType) -> int:
+    def get_max_pos_of_last_change(self, entity: EntityType) -> Optional[int]:
         """Returns an upper bound of the stream id of the last change to an
         entity.
 
@@ -335,7 +335,11 @@ class StreamChangeCache:
             entity: The entity to check.
 
         Return:
-            The stream position of the latest change for the given entity or
-            the earliest known stream position if the entitiy is unknown.
+            The stream position of the latest change for the given entity, if
+            known
         """
-        return self._entity_to_key.get(entity, self._earliest_known_stream_pos)
+        return self._entity_to_key.get(entity)
+
+    def get_earliest_known_position(self) -> int:
+        """Returns the earliest position in the cache."""
+        return self._earliest_known_stream_pos
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 5d38718a50..af1199ef8a 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -249,5 +249,5 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
         self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3)
         self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4)
 
-        # Unknown entities will return the stream start position.
-        self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1)
+        # Unknown entities will return None
+        self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), None)