summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py329
1 files changed, 179 insertions, 150 deletions
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7128dc1742..e46e44ba54 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -16,6 +16,8 @@
 import logging
 from typing import Dict, List, Optional, Tuple
 
+import attr
+
 from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
@@ -28,6 +30,25 @@ from synapse.types import JsonDict
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True)
+class _CalculateChainCover:
+    """Return value for _calculate_chain_cover_txn.
+    """
+
+    # The last room_id/depth/stream processed.
+    room_id = attr.ib(type=str)
+    depth = attr.ib(type=int)
+    stream = attr.ib(type=int)
+
+    # Number of rows processed
+    processed_count = attr.ib(type=int)
+
+    # Map from room_id to last depth/stream processed for each room that we have
+    # processed all events for (i.e. the rooms we can flip the
+    # `has_auth_chain_index` for)
+    finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+
+
 class EventsBackgroundUpdatesStore(SQLBaseStore):
 
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         current_room_id = progress.get("current_room_id", "")
 
-        # Have we finished processing the current room.
-        finished = progress.get("finished", True)
-
         # Where we've processed up to in the room, defaults to the start of the
         # room.
         last_depth = progress.get("last_depth", -1)
         last_stream = progress.get("last_stream", -1)
 
-        # Have we set the `has_auth_chain_index` for the room yet.
-        has_set_room_has_chain_index = progress.get(
-            "has_set_room_has_chain_index", False
+        result = await self.db_pool.runInteraction(
+            "_chain_cover_index",
+            self._calculate_chain_cover_txn,
+            current_room_id,
+            last_depth,
+            last_stream,
+            batch_size,
+            single_room=False,
         )
 
-        if finished:
-            # If we've finished with the previous room (or its our first
-            # iteration) we move on to the next room.
-
-            def _get_next_room(txn: Cursor) -> Optional[str]:
-                sql = """
-                    SELECT room_id FROM rooms
-                    WHERE room_id > ?
-                        AND (
-                            NOT has_auth_chain_index
-                            OR has_auth_chain_index IS NULL
-                        )
-                    ORDER BY room_id
-                    LIMIT 1
-                """
-                txn.execute(sql, (current_room_id,))
-                row = txn.fetchone()
-                if row:
-                    return row[0]
+        finished = result.processed_count == 0
 
-                return None
-
-            current_room_id = await self.db_pool.runInteraction(
-                "_chain_cover_index", _get_next_room
-            )
-            if not current_room_id:
-                await self.db_pool.updates._end_background_update("chain_cover")
-                return 0
-
-            logger.debug("Adding chain cover to %s", current_room_id)
-
-        def _calculate_auth_chain(
-            txn: Cursor, last_depth: int, last_stream: int
-        ) -> Tuple[int, int, int]:
-            # Get the next set of events in the room (that we haven't already
-            # computed chain cover for). We do this in topological order.
-
-            # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
-            # comparison, but that is not supported on older SQLite versions
-            tuple_clause, tuple_args = make_tuple_comparison_clause(
-                self.database_engine,
-                [
-                    ("topological_ordering", last_depth),
-                    ("stream_ordering", last_stream),
-                ],
-            )
+        total_rows_processed = result.processed_count
+        current_room_id = result.room_id
+        last_depth = result.depth
+        last_stream = result.stream
 
-            sql = """
-                SELECT
-                    event_id, state_events.type, state_events.state_key,
-                    topological_ordering, stream_ordering
-                FROM events
-                INNER JOIN state_events USING (event_id)
-                LEFT JOIN event_auth_chains USING (event_id)
-                LEFT JOIN event_auth_chain_to_calculate USING (event_id)
-                WHERE events.room_id = ?
-                    AND event_auth_chains.event_id IS NULL
-                    AND event_auth_chain_to_calculate.event_id IS NULL
-                    AND %(tuple_cmp)s
-                ORDER BY topological_ordering, stream_ordering
-                LIMIT ?
-            """ % {
-                "tuple_cmp": tuple_clause,
-            }
-
-            args = [current_room_id]
-            args.extend(tuple_args)
-            args.append(batch_size)
-
-            txn.execute(sql, args)
-            rows = txn.fetchall()
-
-            # Put the results in the necessary format for
-            # `_add_chain_cover_index`
-            event_to_room_id = {row[0]: current_room_id for row in rows}
-            event_to_types = {row[0]: (row[1], row[2]) for row in rows}
-
-            new_last_depth = rows[-1][3] if rows else last_depth  # type: int
-            new_last_stream = rows[-1][4] if rows else last_stream  # type: int
-
-            count = len(rows)
-
-            # We also need to fetch the auth events for them.
-            auth_events = self.db_pool.simple_select_many_txn(
-                txn,
-                table="event_auth",
-                column="event_id",
-                iterable=event_to_room_id,
-                keyvalues={},
-                retcols=("event_id", "auth_id"),
-            )
-
-            event_to_auth_chain = {}  # type: Dict[str, List[str]]
-            for row in auth_events:
-                event_to_auth_chain.setdefault(row["event_id"], []).append(
-                    row["auth_id"]
-                )
-
-            # Calculate and persist the chain cover index for this set of events.
-            #
-            # Annoyingly we need to gut wrench into the persit event store so that
-            # we can reuse the function to calculate the chain cover for rooms.
-            PersistEventsStore._add_chain_cover_index(
-                txn,
-                self.db_pool,
-                event_to_room_id,
-                event_to_types,
-                event_to_auth_chain,
-            )
-
-            return new_last_depth, new_last_stream, count
-
-        last_depth, last_stream, count = await self.db_pool.runInteraction(
-            "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
-        )
-
-        total_rows_processed = count
-
-        if count < batch_size and not has_set_room_has_chain_index:
+        for room_id, (depth, stream) in result.finished_room_map.items():
             # If we've done all the events in the room we flip the
             # `has_auth_chain_index` in the DB. Note that its possible for
             # further events to be persisted between the above and setting the
@@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             await self.db_pool.simple_update(
                 table="rooms",
-                keyvalues={"room_id": current_room_id},
+                keyvalues={"room_id": room_id},
                 updatevalues={"has_auth_chain_index": True},
                 desc="_chain_cover_index",
             )
-            has_set_room_has_chain_index = True
 
             # Handle any events that might have raced with us flipping the
             # bit above.
-            last_depth, last_stream, count = await self.db_pool.runInteraction(
-                "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+            result = await self.db_pool.runInteraction(
+                "_chain_cover_index",
+                self._calculate_chain_cover_txn,
+                room_id,
+                depth,
+                stream,
+                batch_size=None,
+                single_room=True,
             )
 
-            total_rows_processed += count
+            total_rows_processed += result.processed_count
 
-            # Note that at this point its technically possible that more events
-            # than our `batch_size` have been persisted without their chain
-            # cover, so we need to continue processing this room if the last
-            # count returned was equal to the `batch_size`.
+        if finished:
+            await self.db_pool.updates._end_background_update("chain_cover")
+            return total_rows_processed
 
-        if count < batch_size:
-            # We've finished calculating the index for this room, move on to the
-            # next room.
-            await self.db_pool.updates._background_update_progress(
-                "chain_cover", {"current_room_id": current_room_id, "finished": True},
-            )
-        else:
-            # We still have outstanding events to calculate the index for.
-            await self.db_pool.updates._background_update_progress(
-                "chain_cover",
-                {
-                    "current_room_id": current_room_id,
-                    "last_depth": last_depth,
-                    "last_stream": last_stream,
-                    "has_auth_chain_index": has_set_room_has_chain_index,
-                    "finished": False,
-                },
-            )
+        await self.db_pool.updates._background_update_progress(
+            "chain_cover",
+            {
+                "current_room_id": current_room_id,
+                "last_depth": last_depth,
+                "last_stream": last_stream,
+            },
+        )
 
         return total_rows_processed
+
+    def _calculate_chain_cover_txn(
+        self,
+        txn: Cursor,
+        last_room_id: str,
+        last_depth: int,
+        last_stream: int,
+        batch_size: Optional[int],
+        single_room: bool,
+    ) -> _CalculateChainCover:
+        """Calculate the chain cover for `batch_size` events, ordered by
+        `(room_id, depth, stream)`.
+
+        Args:
+            txn,
+            last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
+                tuple to fetch results after.
+            batch_size: The maximum number of events to process. If None then
+                no limit.
+            single_room: Whether to calculate the index for just the given
+                room.
+        """
+
+        # Get the next set of events in the room (that we haven't already
+        # computed chain cover for). We do this in topological order.
+
+        # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+        # comparison, but that is not supported on older SQLite versions
+        tuple_clause, tuple_args = make_tuple_comparison_clause(
+            self.database_engine,
+            [
+                ("events.room_id", last_room_id),
+                ("topological_ordering", last_depth),
+                ("stream_ordering", last_stream),
+            ],
+        )
+
+        extra_clause = ""
+        if single_room:
+            extra_clause = "AND events.room_id = ?"
+            tuple_args.append(last_room_id)
+
+        sql = """
+            SELECT
+                event_id, state_events.type, state_events.state_key,
+                topological_ordering, stream_ordering,
+                events.room_id
+            FROM events
+            INNER JOIN state_events USING (event_id)
+            LEFT JOIN event_auth_chains USING (event_id)
+            LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+            WHERE event_auth_chains.event_id IS NULL
+                AND event_auth_chain_to_calculate.event_id IS NULL
+                AND %(tuple_cmp)s
+                %(extra)s
+            ORDER BY events.room_id, topological_ordering, stream_ordering
+            %(limit)s
+        """ % {
+            "tuple_cmp": tuple_clause,
+            "limit": "LIMIT ?" if batch_size is not None else "",
+            "extra": extra_clause,
+        }
+
+        if batch_size is not None:
+            tuple_args.append(batch_size)
+
+        txn.execute(sql, tuple_args)
+        rows = txn.fetchall()
+
+        # Put the results in the necessary format for
+        # `_add_chain_cover_index`
+        event_to_room_id = {row[0]: row[5] for row in rows}
+        event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+        # Calculate the new last position we've processed up to.
+        new_last_depth = rows[-1][3] if rows else last_depth  # type: int
+        new_last_stream = rows[-1][4] if rows else last_stream  # type: int
+        new_last_room_id = rows[-1][5] if rows else ""  # type: str
+
+        # Map from room_id to last depth/stream_ordering processed for the room,
+        # excluding the last room (which we're likely still processing). We also
+        # need to include the room passed in if it's not included in the result
+        # set (as we then know we've processed all events in said room).
+        #
+        # This is the set of rooms that we can now safely flip the
+        # `has_auth_chain_index` bit for.
+        finished_rooms = {
+            row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
+        }
+        if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
+            finished_rooms[last_room_id] = (last_depth, last_stream)
+
+        count = len(rows)
+
+        # We also need to fetch the auth events for them.
+        auth_events = self.db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth",
+            column="event_id",
+            iterable=event_to_room_id,
+            keyvalues={},
+            retcols=("event_id", "auth_id"),
+        )
+
+        event_to_auth_chain = {}  # type: Dict[str, List[str]]
+        for row in auth_events:
+            event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+
+        # Calculate and persist the chain cover index for this set of events.
+        #
+        # Annoyingly we need to gut wrench into the persit event store so that
+        # we can reuse the function to calculate the chain cover for rooms.
+        PersistEventsStore._add_chain_cover_index(
+            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+        )
+
+        return _CalculateChainCover(
+            room_id=new_last_room_id,
+            depth=new_last_depth,
+            stream=new_last_stream,
+            processed_count=count,
+            finished_room_map=finished_rooms,
+        )