summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/handlers/sliding_sync.py31
-rw-r--r--synapse/storage/databases/main/stream.py92
2 files changed, 106 insertions, 17 deletions
diff --git a/synapse/handlers/sliding_sync.py b/synapse/handlers/sliding_sync.py
index 423f0329d6..06d7d43f0b 100644
--- a/synapse/handlers/sliding_sync.py
+++ b/synapse/handlers/sliding_sync.py
@@ -1217,25 +1217,11 @@ class SlidingSyncHandler:
         # Assemble a map of room ID to the `stream_ordering` of the last activity that the
         # user should see in the room (<= `to_token`)
         last_activity_in_room_map: Dict[str, int] = {}
+
         for room_id, room_for_user in sync_room_map.items():
             # If they are fully-joined to the room, let's find the latest activity
             # at/before the `to_token`.
-            if room_for_user.membership == Membership.JOIN:
-                last_event_result = (
-                    await self.store.get_last_event_pos_in_room_before_stream_ordering(
-                        room_id, to_token.room_key
-                    )
-                )
-
-                # If the room has no events at/before the `to_token`, this is probably a
-                # mistake in the code that generates the `sync_room_map` since that should
-                # only give us rooms that the user had membership in during the token range.
-                assert last_event_result is not None
-
-                _, event_pos = last_event_result
-
-                last_activity_in_room_map[room_id] = event_pos.stream
-            else:
+            if room_for_user.membership != Membership.JOIN:
                 # Otherwise, if the user has left/been invited/knocked/been banned from
                 # a room, they shouldn't see anything past that point.
                 #
@@ -1245,6 +1231,19 @@ class SlidingSyncHandler:
                 # https://github.com/matrix-org/matrix-spec-proposals/pull/3575#discussion_r1653045932
                 last_activity_in_room_map[room_id] = room_for_user.event_pos.stream
 
+        joined_room_positions = (
+            await self.store.bulk_get_last_event_pos_in_room_before_stream_ordering(
+                [
+                    room_id
+                    for room_id, room_for_user in sync_room_map.items()
+                    if room_for_user.membership == Membership.JOIN
+                ],
+                to_token.room_key,
+            )
+        )
+
+        last_activity_in_room_map.update(joined_room_positions)
+
         return sorted(
             sync_room_map.values(),
             # Sort by the last activity (stream_ordering) in the room
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e74e0d2e91..b119acda29 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -78,10 +78,11 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import PersistedEventPosition, RoomStreamToken
+from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.cancellation import cancellable
+from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -1293,6 +1294,95 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             get_last_event_pos_in_room_before_stream_ordering_txn,
         )
 
+    async def bulk_get_last_event_pos_in_room_before_stream_ordering(
+        self,
+        room_ids: StrCollection,
+        end_token: RoomStreamToken,
+    ) -> Dict[str, int]:
+        """Bulk fetch the latest event position in the given rooms"""
+
+        min_token = end_token.stream
+        max_token = end_token.get_max_stream_pos()
+        results: Dict[str, int] = {}
+
+        missing_room_ids: Set[str] = set()
+        for room_id in room_ids:
+            stream_pos = self._events_stream_cache._entity_to_key.get(room_id)
+            if stream_pos and stream_pos < max_token:
+                results[room_id] = stream_pos
+            else:
+                missing_room_ids.add(room_id)
+
+        def bulk_get_last_event_pos_txn(
+            txn: LoggingTransaction, batch_room_ids: StrCollection
+        ) -> Dict[str, int]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", batch_room_ids
+            )
+            sql = f"""
+                SELECT room_id, (
+                    SELECT stream_ordering FROM events AS e
+                    WHERE e.room_id = r.room_id
+                        AND stream_ordering <= ?
+                        AND NOT outlier
+                    ORDER BY stream_ordering DESC
+                    LIMIT 1
+                )
+                FROM rooms AS r
+                WHERE {clause}
+            """
+            txn.execute(sql, [max_token] + args)
+            return {row[0]: row[1] for row in txn}
+
+        recheck_rooms: Set[str] = set()
+        for batched in batch_iter(missing_room_ids, 1000):
+            result = await self.db_pool.runInteraction(
+                "bulk_get_last_event_pos_in_room_before_stream_ordering",
+                bulk_get_last_event_pos_txn,
+                batched,
+            )
+
+            for room_id, stream in result.items():
+                if min_token < stream:
+                    recheck_rooms.add(room_id)
+                else:
+                    results[room_id] = stream
+
+        if not recheck_rooms:
+            return results
+
+        def bulk_get_last_event_pos_recheck_txn(
+            txn: LoggingTransaction, batch_room_ids: StrCollection
+        ) -> Dict[str, int]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", batch_room_ids
+            )
+            sql = f"""
+                SELECT room_id, instance_name, stream_ordering
+                WHERE ? < stream_ordering AND stream_ordering <= ?
+                    AND {clause}
+                ORDER BY stream_ordering ASC
+            """
+            txn.execute(sql, [min_token, max_token] + args)
+            results: Dict[str, int] = {}
+            for row in txn:
+                room_id = row[0]
+                event_pos = PersistedEventPosition(row[1], row[2])
+                if not event_pos.persisted_after(end_token):
+                    results[room_id] = event_pos.stream
+
+            return results
+
+        for batched in batch_iter(recheck_rooms, 1000):
+            recheck_result = await self.db_pool.runInteraction(
+                "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck",
+                bulk_get_last_event_pos_recheck_txn,
+                batched,
+            )
+            results.update(recheck_result)
+
+        return results
+
     async def get_current_room_stream_token_for_room_id(
         self, room_id: str
     ) -> RoomStreamToken: