summary refs log tree commit diff
path: root/synapse/storage/databases/main/stream.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/stream.py')
-rw-r--r--synapse/storage/databases/main/stream.py123
1 files changed, 122 insertions, 1 deletions
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index e74e0d2e91..b034361aec 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -78,10 +78,11 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
 from synapse.storage.util.id_generators import MultiWriterIdGenerator
-from synapse.types import PersistedEventPosition, RoomStreamToken
+from synapse.types import PersistedEventPosition, RoomStreamToken, StrCollection
 from synapse.util.caches.descriptors import cached
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 from synapse.util.cancellation import cancellable
+from synapse.util.iterutils import batch_iter
 
 if TYPE_CHECKING:
     from synapse.server import HomeServer
@@ -1293,6 +1294,126 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             get_last_event_pos_in_room_before_stream_ordering_txn,
         )
 
+    async def bulk_get_last_event_pos_in_room_before_stream_ordering(
+        self,
+        room_ids: StrCollection,
+        end_token: RoomStreamToken,
+    ) -> Dict[str, int]:
+        """Bulk fetch the stream position of the latest events in the given
+        rooms
+        """
+
+        min_token = end_token.stream
+        max_token = end_token.get_max_stream_pos()
+        results: Dict[str, int] = {}
+
+        # First, we check for the rooms in the stream change cache to see if we
+        # can just use the latest position from it.
+        missing_room_ids: Set[str] = set()
+        for room_id in room_ids:
+            stream_pos = self._events_stream_cache.get_max_pos_of_last_change(room_id)
+            if stream_pos and stream_pos <= min_token:
+                results[room_id] = stream_pos
+            else:
+                missing_room_ids.add(room_id)
+
+        # Next, we query the stream position from the DB. At first we fetch all
+        # positions less than the *max* stream pos in the token, then filter
+        # them down. We do this as a) this is a cheaper query, and b) the vast
+        # majority of rooms will have a latest token from before the min stream
+        # pos.
+
+        def bulk_get_last_event_pos_txn(
+            txn: LoggingTransaction, batch_room_ids: StrCollection
+        ) -> Dict[str, int]:
+            # This query fetches the latest stream position in the rooms before
+            # the given max position.
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", batch_room_ids
+            )
+            sql = f"""
+                SELECT room_id, (
+                    SELECT stream_ordering FROM events AS e
+                    LEFT JOIN rejections USING (event_id)
+                    WHERE e.room_id = r.room_id
+                        AND stream_ordering <= ?
+                        AND NOT outlier
+                        AND rejection_reason IS NULL
+                    ORDER BY stream_ordering DESC
+                    LIMIT 1
+                )
+                FROM rooms AS r
+                WHERE {clause}
+            """
+            txn.execute(sql, [max_token] + args)
+            return {row[0]: row[1] for row in txn}
+
+        recheck_rooms: Set[str] = set()
+        for batched in batch_iter(missing_room_ids, 1000):
+            result = await self.db_pool.runInteraction(
+                "bulk_get_last_event_pos_in_room_before_stream_ordering",
+                bulk_get_last_event_pos_txn,
+                batched,
+            )
+
+            # Check that the stream position for the rooms are from before the
+            # minimum position of the token. If not then we need to fetch more
+            # rows.
+            for room_id, stream in result.items():
+                if stream <= min_token:
+                    results[room_id] = stream
+                else:
+                    recheck_rooms.add(room_id)
+
+        if not recheck_rooms:
+            return results
+
+        # For the remaining rooms we need to fetch all rows between the min and
+        # max stream positions in the end token, and filter out the rows that
+        # are after the end token.
+        #
+        # This query should be fast as the range between the min and max should
+        # be small.
+
+        def bulk_get_last_event_pos_recheck_txn(
+            txn: LoggingTransaction, batch_room_ids: StrCollection
+        ) -> Dict[str, int]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "room_id", batch_room_ids
+            )
+            sql = f"""
+                SELECT room_id, instance_name, stream_ordering
+                FROM events
+                WHERE ? < stream_ordering AND stream_ordering <= ?
+                    AND NOT outlier
+                    AND rejection_reason IS NULL
+                    AND {clause}
+                ORDER BY stream_ordering ASC
+            """
+            txn.execute(sql, [min_token, max_token] + args)
+
+            # We take the max stream ordering that is less than the token. Since
+            # we ordered by stream ordering we just need to iterate through and
+            # take the last matching stream ordering.
+            txn_results: Dict[str, int] = {}
+            for row in txn:
+                room_id = row[0]
+                event_pos = PersistedEventPosition(row[1], row[2])
+                if not event_pos.persisted_after(end_token):
+                    txn_results[room_id] = event_pos.stream
+
+            return txn_results
+
+        for batched in batch_iter(recheck_rooms, 1000):
+            recheck_result = await self.db_pool.runInteraction(
+                "bulk_get_last_event_pos_in_room_before_stream_ordering_recheck",
+                bulk_get_last_event_pos_recheck_txn,
+                batched,
+            )
+            results.update(recheck_result)
+
+        return results
+
     async def get_current_room_stream_token_for_room_id(
         self, room_id: str
     ) -> RoomStreamToken: