diff --git a/changelog.d/17468.misc b/changelog.d/17468.misc
new file mode 100644
index 0000000000..d908776204
--- /dev/null
+++ b/changelog.d/17468.misc
@@ -0,0 +1 @@
+Speed up sorting of the room list in sliding sync.
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 886d7c7159..554ab59bf3 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -1230,34 +1230,33 @@ class SlidingSyncHandler:
# Assemble a map of room ID to the `stream_ordering` of the last activity that the
# user should see in the room (<= `to_token`)
last_activity_in_room_map: Dict[str, int] = {}
- for room_id, room_for_user in sync_room_map.items():
- # If they are fully-joined to the room, let's find the latest activity
- # at/before the `to_token`.
- if room_for_user.membership == Membership.JOIN:
- last_event_result = (
- await self.store.get_last_event_pos_in_room_before_stream_ordering(
- room_id, to_token.room_key
- )
- )
-
- # If the room has no events at/before the `to_token`, this is probably a
- # mistake in the code that generates the `sync_room_map` since that should
- # only give us rooms that the user had membership in during the token range.
- assert last_event_result is not None
- _, event_pos = last_event_result
-
- last_activity_in_room_map[room_id] = event_pos.stream
- else:
- # Otherwise, if the user has left/been invited/knocked/been banned from
- # a room, they shouldn't see anything past that point.
+ for room_id, room_for_user in sync_room_map.items():
+ if room_for_user.membership != Membership.JOIN:
+ # If the user has left/been invited/knocked/been banned from a
+ # room, they shouldn't see anything past that point.
#
- # FIXME: It's possible that people should see beyond this point in
- # invited/knocked cases if for example the room has
+ # FIXME: It's possible that people should see beyond this point
+ # in invited/knocked cases if for example the room has
# `invite`/`world_readable` history visibility, see
# https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
+ # For fully-joined rooms, we find the latest activity at/before the
+ # `to_token`.
+ joined_room_positions = (
+ await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering(
+ [
+ room_id
+ for room_id, room_for_user in sync_room_map.items()
+ if room_for_user.membership == Membership.JOIN
+ ],
+ to_token.room_key,
+ )
+ )
+
+ last_activity_in_room_map.update(joined_room_positions)
+
return sorted(
sync_room_map.values(),
# Sort by the last activity (stream_ordering) in the room
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 24abab4a23..715846865b 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1313,6 +1313,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
+ if last_change is None:
+ # If the room isn't in the cache we know that the last change was
+ # somewhere before the earliest known position of the cache, so we
+ # can clamp to that.
+ last_change = self._events_stream_cache.get_earliest_known_position() # type: ignore[attr-defined]
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e74e0d2e91..b034361aec 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -78,10 +78,11 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import PersistedEventPosition, RoomStreamToken
+from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
+from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -1293,6 +1294,126 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
get_last_event_pos_in_room_before_stream_ordering_txn,
)
+ async def bulk_get_last_event_pos_in_room_before_stream_ordering(
+ self,
+ room_ids: StrCollection,
+ end_token: RoomStreamToken,
+ ) -> Dict[str, int]:
+ """Bulk fetch the stream position of the latest events in the given
+ rooms
+ """
+
+ min_token = end_token.stream
+ max_token = end_token.get_max_stream_pos()
+ results: Dict[str, int] = {}
+
+ # First, we check for the rooms in the stream change cache to see if we
+ # can just use the latest position from it.
+ missing_room_ids: Set[str] = set()
+ for room_id in room_ids:
+ stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+ if stream_pos and stream_pos <= min_token:
+ results[room_id] = stream_pos
+ else:
+ missing_room_ids.add(room_id)
+
+ # Next, we query the stream position from the DB. At first we fetch all
+ # positions less than the *max* stream pos in the token, then filter
+ # them down. We do this as a) this is a cheaper query, and b) the vast
+ # majority of rooms will have a latest token from before the min stream
+ # pos.
+
+ def bulk_get_last_event_pos_txn(
+ txn: LoggingTransaction, batch_room_ids: StrCollection
+ ) -> Dict[str, int]:
+ # This query fetches the latest stream position in the rooms before
+ # the given max position.
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch_room_ids
+ )
+ sql = f"""
+ SELECT room_id, (
+ SELECT stream_ordering FROM events AS e
+ LEFT JOIN rejections USING (event_id)
+ WHERE e.room_id = r.room_id
+ AND stream_ordering <= ?
+ AND NOT outlier
+ AND rejection_reason IS NULL
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ )
+ FROM rooms AS r
+ WHERE {clause}
+ """
+ txn.execute(sql, [max_token] + args)
+ return {row[0]: row[1] for row in txn}
+
+ recheck_rooms: Set[str] = set()
+ for batched in batch_iter(missing_room_ids, 1000):
+ result = await self.db_pool.runInteraction(
+ "bulk_get_last_event_pos_in_room_before_stream_ordering",
+ bulk_get_last_event_pos_txn,
+ batched,
+ )
+
+ # Check that the stream position for the rooms are from before the
+ # minimum position of the token. If not then we need to fetch more
+ # rows.
+ for room_id, stream in result.items():
+ if stream <= min_token:
+ results[room_id] = stream
+ else:
+ recheck_rooms.add(room_id)
+
+ if not recheck_rooms:
+ return results
+
+ # For the remaining rooms we need to fetch all rows between the min and
+ # max stream positions in the end token, and filter out the rows that
+ # are after the end token.
+ #
+ # This query should be fast as the range between the min and max should
+ # be small.
+
+ def bulk_get_last_event_pos_recheck_txn(
+ txn: LoggingTransaction, batch_room_ids: StrCollection
+ ) -> Dict[str, int]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch_room_ids
+ )
+ sql = f"""
+ SELECT room_id, instance_name, stream_ordering
+ FROM events
+ WHERE ? < stream_ordering AND stream_ordering <= ?
+ AND NOT outlier
+ AND rejection_reason IS NULL
+ AND {clause}
+ ORDER BY stream_ordering ASC
+ """
+ txn.execute(sql, [min_token, max_token] + args)
+
+ # We take the max stream ordering that is less than the token. Since
+ # we ordered by stream ordering we just need to iterate through and
+ # take the last matching stream ordering.
+ txn_results: Dict[str, int] = {}
+ for row in txn:
+ room_id = row[0]
+ event_pos = PersistedEventPosition(row[1], row[2])
+ if not event_pos.persisted_after(end_token):
+ txn_results[room_id] = event_pos.stream
+
+ return txn_results
+
+ for batched in batch_iter(recheck_rooms, 1000):
+ recheck_result = await self.db_pool.runInteraction(
+ "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck",
+ bulk_get_last_event_pos_recheck_txn,
+ batched,
+ )
+ results.update(recheck_result)
+
+ return results
+
async def get_current_room_stream_token_for_room_id(
self, room_id: str
) -> RoomStreamToken:
diff --git a/synapse/util/caches/stream_change_cache.py b/synapse/util/caches/stream_change_cache.py
index 91c335f85b..16fcb00206 100644
--- a/synapse/util/caches/stream_change_cache.py
+++ b/synapse/util/caches/stream_change_cache.py
@@ -327,7 +327,7 @@ class StreamChangeCache:
for entity in r:
self._entity_to_key.pop(entity, None)
- def get_max_pos_of_last_change(self, entity: EntityType) -> int:
+ def get_max_pos_of_last_change(self, entity: EntityType) -> Optional[int]:
"""Returns an upper bound of the stream id of the last change to an
entity.
@@ -335,7 +335,11 @@ class StreamChangeCache:
entity: The entity to check.
Return:
- The stream position of the latest change for the given entity or
- the earliest known stream position if the entitiy is unknown.
+ The stream position of the latest change for the given entity, if
+ known
"""
- return self._entity_to_key.get(entity, self._earliest_known_stream_pos)
+ return self._entity_to_key.get(entity)
+
+ def get_earliest_known_position(self) -> int:
+ """Returns the earliest position in the cache."""
+ return self._earliest_known_stream_pos
diff --git a/tests/util/test_stream_change_cache.py b/tests/util/test_stream_change_cache.py
index 5d38718a50..af1199ef8a 100644
--- a/tests/util/test_stream_change_cache.py
+++ b/tests/util/test_stream_change_cache.py
@@ -249,5 +249,5 @@ class StreamChangeCacheTests(unittest.HomeserverTestCase):
self.assertEqual(cache.get_max_pos_of_last_change("bar@baz.net"), 3)
self.assertEqual(cache.get_max_pos_of_last_change("user@elsewhere.org"), 4)
- # Unknown entities will return the stream start position.
- self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), 1)
+ # Unknown entities will return None
+ self.assertEqual(cache.get_max_pos_of_last_change("not@here.website"), None)
|