diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 24abab4a23..715846865b 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -1313,6 +1313,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# We want to make the cache more effective, so we clamp to the last
# change before the given ordering.
last_change = self._events_stream_cache.get_max_pos_of_last_change(room_id) # type: ignore[attr-defined]
+ if last_change is None:
+ # If the room isn't in the cache we know that the last change was
+ # somewhere before the earliest known position of the cache, so we
+ # can clamp to that.
+ last_change = self._events_stream_cache.get_earliest_known_position() # type: ignore[attr-defined]
# We don't always have a full stream_to_exterm_id table, e.g. after
# the upgrade that introduced it, so we make sure we never ask for a
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e74e0d2e91..b034361aec 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -78,10 +78,11 @@ from synapse.storage.database import (
from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import PersistedEventPosition, RoomStreamToken
+from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection
from synapse.util.caches.descriptors import cached
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
+from synapse.util.iterutils import batch_iter
if TYPE_CHECKING:
from synapse.server import HomeServer
@@ -1293,6 +1294,126 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
get_last_event_pos_in_room_before_stream_ordering_txn,
)
+ async def bulk_get_last_event_pos_in_room_before_stream_ordering(
+ self,
+ room_ids: StrCollection,
+ end_token: RoomStreamToken,
+ ) -> Dict[str, int]:
+ """Bulk fetch the stream position of the latest events in the given
+ rooms
+ """
+
+ min_token = end_token.stream
+ max_token = end_token.get_max_stream_pos()
+ results: Dict[str, int] = {}
+
+ # First, we check for the rooms in the stream change cache to see if we
+ # can just use the latest position from it.
+ missing_room_ids: Set[str] = set()
+ for room_id in room_ids:
+ stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+ if stream_pos and stream_pos <= min_token:
+ results[room_id] = stream_pos
+ else:
+ missing_room_ids.add(room_id)
+
+ # Next, we query the stream position from the DB. At first we fetch all
+ # positions less than the *max* stream pos in the token, then filter
+ # them down. We do this as a) this is a cheaper query, and b) the vast
+ # majority of rooms will have a latest token from before the min stream
+ # pos.
+
+ def bulk_get_last_event_pos_txn(
+ txn: LoggingTransaction, batch_room_ids: StrCollection
+ ) -> Dict[str, int]:
+ # This query fetches the latest stream position in the rooms before
+ # the given max position.
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch_room_ids
+ )
+ sql = f"""
+ SELECT room_id, (
+ SELECT stream_ordering FROM events AS e
+ LEFT JOIN rejections USING (event_id)
+ WHERE e.room_id = r.room_id
+ AND stream_ordering <= ?
+ AND NOT outlier
+ AND rejection_reason IS NULL
+ ORDER BY stream_ordering DESC
+ LIMIT 1
+ )
+ FROM rooms AS r
+ WHERE {clause}
+ """
+ txn.execute(sql, [max_token] + args)
+ return {row[0]: row[1] for row in txn}
+
+ recheck_rooms: Set[str] = set()
+ for batched in batch_iter(missing_room_ids, 1000):
+ result = await self.db_pool.runInteraction(
+ "bulk_get_last_event_pos_in_room_before_stream_ordering",
+ bulk_get_last_event_pos_txn,
+ batched,
+ )
+
+ # Check that the stream position for the rooms are from before the
+ # minimum position of the token. If not then we need to fetch more
+ # rows.
+ for room_id, stream in result.items():
+ if stream <= min_token:
+ results[room_id] = stream
+ else:
+ recheck_rooms.add(room_id)
+
+ if not recheck_rooms:
+ return results
+
+ # For the remaining rooms we need to fetch all rows between the min and
+ # max stream positions in the end token, and filter out the rows that
+ # are after the end token.
+ #
+ # This query should be fast as the range between the min and max should
+ # be small.
+
+ def bulk_get_last_event_pos_recheck_txn(
+ txn: LoggingTransaction, batch_room_ids: StrCollection
+ ) -> Dict[str, int]:
+ clause, args = make_in_list_sql_clause(
+ self.database_engine, "room_id", batch_room_ids
+ )
+ sql = f"""
+ SELECT room_id, instance_name, stream_ordering
+ FROM events
+ WHERE ? < stream_ordering AND stream_ordering <= ?
+ AND NOT outlier
+ AND rejection_reason IS NULL
+ AND {clause}
+ ORDER BY stream_ordering ASC
+ """
+ txn.execute(sql, [min_token, max_token] + args)
+
+ # We take the max stream ordering that is less than the token. Since
+ # we ordered by stream ordering we just need to iterate through and
+ # take the last matching stream ordering.
+ txn_results: Dict[str, int] = {}
+ for row in txn:
+ room_id = row[0]
+ event_pos = PersistedEventPosition(row[1], row[2])
+ if not event_pos.persisted_after(end_token):
+ txn_results[room_id] = event_pos.stream
+
+ return txn_results
+
+ for batched in batch_iter(recheck_rooms, 1000):
+ recheck_result = await self.db_pool.runInteraction(
+ "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck",
+ bulk_get_last_event_pos_recheck_txn,
+ batched,
+ )
+ results.update(recheck_result)
+
+ return results
+
async def get_current_room_stream_token_for_room_id(
self, room_id: str
) -> RoomStreamToken:
|