diff --git a/changelog.d/17606.misc b/changelog.d/17606.misc
new file mode 100644
index 0000000000..47634b1305
--- /dev/null
+++ b/changelog.d/17606.misc
@@ -0,0 +1 @@
+Speed up incremental syncs in sliding sync by adding some more caching.
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 63624f3e8f..246d2acc2f 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -313,6 +313,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
"get_unread_event_push_actions_by_room_for_user", (room_id,)
)
+ self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,))
+
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
self._attempt_to_invalidate_cache("_get_membership_from_event_id", (event_id,))
@@ -404,6 +406,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
)
self._attempt_to_invalidate_cache("get_relations_for_event", (room_id,))
+ self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,))
+
self._attempt_to_invalidate_cache("_get_membership_from_event_id", None)
self._attempt_to_invalidate_cache("get_applicable_edit", None)
self._attempt_to_invalidate_cache("get_thread_id", None)
@@ -476,6 +480,8 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
self._attempt_to_invalidate_cache("get_room_type", (room_id,))
self._attempt_to_invalidate_cache("get_room_encryption", (room_id,))
+ self._attempt_to_invalidate_cache("_get_max_event_pos", (room_id,))
+
# And delete state caches.
self._invalidate_state_caches_all(room_id)
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 4989c960a6..e33a8cfe97 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -50,6 +50,7 @@ from typing import (
Dict,
Iterable,
List,
+ Mapping,
Optional,
Protocol,
Set,
@@ -80,7 +81,7 @@ from synapse.storage.databases.main.events_worker import EventsWorkerStore
from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
from synapse.storage.util.id_generators import MultiWriterIdGenerator
from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.caches.stream_change_cache import StreamChangeCache
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
@@ -1381,8 +1382,52 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
rooms
"""
+ # First we just get the latest positions for the room, as the vast
+ # majority of them will be before the given end token anyway. By doing
+ # this we can cache most rooms.
+ uncapped_results = await self._bulk_get_max_event_pos(room_ids)
+
+ # Check that the stream position for the rooms are from before the
+ # minimum position of the token. If not then we need to fetch more
+ # rows.
+ results: Dict[str, int] = {}
+ recheck_rooms: Set[str] = set()
min_token = end_token.stream
- max_token = end_token.get_max_stream_pos()
+ for room_id, stream in uncapped_results.items():
+ if stream <= min_token:
+ results[room_id] = stream
+ else:
+ recheck_rooms.add(room_id)
+
+ if not recheck_rooms:
+ return results
+
+ # There shouldn't be many rooms that we need to recheck, so we do them
+ # one-by-one.
+ for room_id in recheck_rooms:
+ result = await self.get_last_event_pos_in_room_before_stream_ordering(
+ room_id, end_token
+ )
+ if result is not None:
+ results[room_id] = result[1].stream
+
+ return results
+
+ @cached()
+ async def _get_max_event_pos(self, room_id: str) -> int:
+ raise NotImplementedError()
+
+ @cachedList(cached_method_name="_get_max_event_pos", list_name="room_ids")
+ async def _bulk_get_max_event_pos(
+ self, room_ids: StrCollection
+ ) -> Mapping[str, int]:
+ """Fetch the max position of a persisted event in the room."""
+
+ # We need to be careful not to return positions ahead of the current
+ # positions, so we get the current token now and cap our queries to it.
+ now_token = self.get_room_max_token()
+ max_pos = now_token.get_max_stream_pos()
+
results: Dict[str, int] = {}
# First, we check for the rooms in the stream change cache to see if we
@@ -1390,31 +1435,32 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
missing_room_ids: Set[str] = set()
for room_id in room_ids:
stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id)
- if stream_pos and stream_pos <= min_token:
+ if stream_pos is not None:
results[room_id] = stream_pos
else:
missing_room_ids.add(room_id)
+ if not missing_room_ids:
+ return results
+
# Next, we query the stream position from the DB. At first we fetch all
# positions less than the *max* stream pos in the token, then filter
# them down. We do this as a) this is a cheaper query, and b) the vast
# majority of rooms will have a latest token from before the min stream
# pos.
- def bulk_get_last_event_pos_txn(
- txn: LoggingTransaction, batch_room_ids: StrCollection
+ def bulk_get_max_event_pos_txn(
+ txn: LoggingTransaction, batched_room_ids: StrCollection
) -> Dict[str, int]:
- # This query fetches the latest stream position in the rooms before
- # the given max position.
clause, args = make_in_list_sql_clause(
- self.database_engine, "room_id", batch_room_ids
+ self.database_engine, "room_id", batched_room_ids
)
sql = f"""
SELECT room_id, (
SELECT stream_ordering FROM events AS e
LEFT JOIN rejections USING (event_id)
WHERE e.room_id = r.room_id
- AND stream_ordering <= ?
+ AND e.stream_ordering <= ?
AND NOT outlier
AND rejection_reason IS NULL
ORDER BY stream_ordering DESC
@@ -1423,72 +1469,29 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
FROM rooms AS r
WHERE {clause}
"""
- txn.execute(sql, [max_token] + args)
+ txn.execute(sql, [max_pos] + args)
return {row[0]: row[1] for row in txn}
recheck_rooms: Set[str] = set()
- for batched in batch_iter(missing_room_ids, 1000):
- result = await self.db_pool.runInteraction(
- "bulk_get_last_event_pos_in_room_before_stream_ordering",
- bulk_get_last_event_pos_txn,
- batched,
+ for batched in batch_iter(room_ids, 1000):
+ batch_results = await self.db_pool.runInteraction(
+ "_bulk_get_max_event_pos", bulk_get_max_event_pos_txn, batched
)
-
- # Check that the stream position for the rooms are from before the
- # minimum position of the token. If not then we need to fetch more
- # rows.
- for room_id, stream in result.items():
- if stream <= min_token:
- results[room_id] = stream
+ for room_id, stream_ordering in batch_results.items():
+ if stream_ordering <= now_token.stream:
+ results.update(batch_results)
else:
recheck_rooms.add(room_id)
- if not recheck_rooms:
- return results
-
- # For the remaining rooms we need to fetch all rows between the min and
- # max stream positions in the end token, and filter out the rows that
- # are after the end token.
- #
- # This query should be fast as the range between the min and max should
- # be small.
-
- def bulk_get_last_event_pos_recheck_txn(
- txn: LoggingTransaction, batch_room_ids: StrCollection
- ) -> Dict[str, int]:
- clause, args = make_in_list_sql_clause(
- self.database_engine, "room_id", batch_room_ids
- )
- sql = f"""
- SELECT room_id, instance_name, stream_ordering
- FROM events
- WHERE ? < stream_ordering AND stream_ordering <= ?
- AND NOT outlier
- AND rejection_reason IS NULL
- AND {clause}
- ORDER BY stream_ordering ASC
- """
- txn.execute(sql, [min_token, max_token] + args)
-
- # We take the max stream ordering that is less than the token. Since
- # we ordered by stream ordering we just need to iterate through and
- # take the last matching stream ordering.
- txn_results: Dict[str, int] = {}
- for row in txn:
- room_id = row[0]
- event_pos = PersistedEventPosition(row[1], row[2])
- if not event_pos.persisted_after(end_token):
- txn_results[room_id] = event_pos.stream
-
- return txn_results
-
- for batched in batch_iter(recheck_rooms, 1000):
- recheck_result = await self.db_pool.runInteraction(
- "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck",
- bulk_get_last_event_pos_recheck_txn,
- batched,
+ # We now need to handle rooms where the above query returned a stream
+ # position that was potentially too new. This should happen very rarely
+ # so we just query the rooms one-by-one
+ for room_id in recheck_rooms:
+ result = await self.get_last_event_pos_in_room_before_stream_ordering(
+ room_id, now_token
)
- results.update(recheck_result)
+ if result is not None:
+ results[room_id] = result[1].stream
return results
|