summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9124.misc1
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py329
-rw-r--r--tests/storage/test_event_chain.py217
3 files changed, 366 insertions, 181 deletions
diff --git a/changelog.d/9124.misc b/changelog.d/9124.misc
new file mode 100644
index 0000000000..346741d982
--- /dev/null
+++ b/changelog.d/9124.misc
@@ -0,0 +1 @@
+Improve efficiency of large state resolutions.
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7128dc1742..e46e44ba54 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -16,6 +16,8 @@
 import logging
 from typing import Dict, List, Optional, Tuple
 
+import attr
+
 from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
@@ -28,6 +30,25 @@ from synapse.types import JsonDict
 logger = logging.getLogger(__name__)
 
 
+@attr.s(slots=True, frozen=True)
+class _CalculateChainCover:
+    """Return value for _calculate_chain_cover_txn.
+    """
+
+    # The last room_id/depth/stream processed.
+    room_id = attr.ib(type=str)
+    depth = attr.ib(type=int)
+    stream = attr.ib(type=int)
+
+    # Number of rows processed
+    processed_count = attr.ib(type=int)
+
+    # Map from room_id to last depth/stream processed for each room that we have
+    # processed all events for (i.e. the rooms we can flip the
+    # `has_auth_chain_index` for)
+    finished_room_map = attr.ib(type=Dict[str, Tuple[int, int]])
+
+
 class EventsBackgroundUpdatesStore(SQLBaseStore):
 
     EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts"
@@ -719,138 +740,29 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
         current_room_id = progress.get("current_room_id", "")
 
-        # Have we finished processing the current room.
-        finished = progress.get("finished", True)
-
         # Where we've processed up to in the room, defaults to the start of the
         # room.
         last_depth = progress.get("last_depth", -1)
         last_stream = progress.get("last_stream", -1)
 
-        # Have we set the `has_auth_chain_index` for the room yet.
-        has_set_room_has_chain_index = progress.get(
-            "has_set_room_has_chain_index", False
+        result = await self.db_pool.runInteraction(
+            "_chain_cover_index",
+            self._calculate_chain_cover_txn,
+            current_room_id,
+            last_depth,
+            last_stream,
+            batch_size,
+            single_room=False,
         )
 
-        if finished:
-            # If we've finished with the previous room (or its our first
-            # iteration) we move on to the next room.
-
-            def _get_next_room(txn: Cursor) -> Optional[str]:
-                sql = """
-                    SELECT room_id FROM rooms
-                    WHERE room_id > ?
-                        AND (
-                            NOT has_auth_chain_index
-                            OR has_auth_chain_index IS NULL
-                        )
-                    ORDER BY room_id
-                    LIMIT 1
-                """
-                txn.execute(sql, (current_room_id,))
-                row = txn.fetchone()
-                if row:
-                    return row[0]
+        finished = result.processed_count == 0
 
-                return None
-
-            current_room_id = await self.db_pool.runInteraction(
-                "_chain_cover_index", _get_next_room
-            )
-            if not current_room_id:
-                await self.db_pool.updates._end_background_update("chain_cover")
-                return 0
-
-            logger.debug("Adding chain cover to %s", current_room_id)
-
-        def _calculate_auth_chain(
-            txn: Cursor, last_depth: int, last_stream: int
-        ) -> Tuple[int, int, int]:
-            # Get the next set of events in the room (that we haven't already
-            # computed chain cover for). We do this in topological order.
-
-            # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
-            # comparison, but that is not supported on older SQLite versions
-            tuple_clause, tuple_args = make_tuple_comparison_clause(
-                self.database_engine,
-                [
-                    ("topological_ordering", last_depth),
-                    ("stream_ordering", last_stream),
-                ],
-            )
+        total_rows_processed = result.processed_count
+        current_room_id = result.room_id
+        last_depth = result.depth
+        last_stream = result.stream
 
-            sql = """
-                SELECT
-                    event_id, state_events.type, state_events.state_key,
-                    topological_ordering, stream_ordering
-                FROM events
-                INNER JOIN state_events USING (event_id)
-                LEFT JOIN event_auth_chains USING (event_id)
-                LEFT JOIN event_auth_chain_to_calculate USING (event_id)
-                WHERE events.room_id = ?
-                    AND event_auth_chains.event_id IS NULL
-                    AND event_auth_chain_to_calculate.event_id IS NULL
-                    AND %(tuple_cmp)s
-                ORDER BY topological_ordering, stream_ordering
-                LIMIT ?
-            """ % {
-                "tuple_cmp": tuple_clause,
-            }
-
-            args = [current_room_id]
-            args.extend(tuple_args)
-            args.append(batch_size)
-
-            txn.execute(sql, args)
-            rows = txn.fetchall()
-
-            # Put the results in the necessary format for
-            # `_add_chain_cover_index`
-            event_to_room_id = {row[0]: current_room_id for row in rows}
-            event_to_types = {row[0]: (row[1], row[2]) for row in rows}
-
-            new_last_depth = rows[-1][3] if rows else last_depth  # type: int
-            new_last_stream = rows[-1][4] if rows else last_stream  # type: int
-
-            count = len(rows)
-
-            # We also need to fetch the auth events for them.
-            auth_events = self.db_pool.simple_select_many_txn(
-                txn,
-                table="event_auth",
-                column="event_id",
-                iterable=event_to_room_id,
-                keyvalues={},
-                retcols=("event_id", "auth_id"),
-            )
-
-            event_to_auth_chain = {}  # type: Dict[str, List[str]]
-            for row in auth_events:
-                event_to_auth_chain.setdefault(row["event_id"], []).append(
-                    row["auth_id"]
-                )
-
-            # Calculate and persist the chain cover index for this set of events.
-            #
-            # Annoyingly we need to gut wrench into the persit event store so that
-            # we can reuse the function to calculate the chain cover for rooms.
-            PersistEventsStore._add_chain_cover_index(
-                txn,
-                self.db_pool,
-                event_to_room_id,
-                event_to_types,
-                event_to_auth_chain,
-            )
-
-            return new_last_depth, new_last_stream, count
-
-        last_depth, last_stream, count = await self.db_pool.runInteraction(
-            "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
-        )
-
-        total_rows_processed = count
-
-        if count < batch_size and not has_set_room_has_chain_index:
+        for room_id, (depth, stream) in result.finished_room_map.items():
             # If we've done all the events in the room we flip the
             # `has_auth_chain_index` in the DB. Note that its possible for
             # further events to be persisted between the above and setting the
@@ -860,42 +772,159 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
 
             await self.db_pool.simple_update(
                 table="rooms",
-                keyvalues={"room_id": current_room_id},
+                keyvalues={"room_id": room_id},
                 updatevalues={"has_auth_chain_index": True},
                 desc="_chain_cover_index",
             )
-            has_set_room_has_chain_index = True
 
             # Handle any events that might have raced with us flipping the
             # bit above.
-            last_depth, last_stream, count = await self.db_pool.runInteraction(
-                "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+            result = await self.db_pool.runInteraction(
+                "_chain_cover_index",
+                self._calculate_chain_cover_txn,
+                room_id,
+                depth,
+                stream,
+                batch_size=None,
+                single_room=True,
             )
 
-            total_rows_processed += count
+            total_rows_processed += result.processed_count
 
-            # Note that at this point its technically possible that more events
-            # than our `batch_size` have been persisted without their chain
-            # cover, so we need to continue processing this room if the last
-            # count returned was equal to the `batch_size`.
+        if finished:
+            await self.db_pool.updates._end_background_update("chain_cover")
+            return total_rows_processed
 
-        if count < batch_size:
-            # We've finished calculating the index for this room, move on to the
-            # next room.
-            await self.db_pool.updates._background_update_progress(
-                "chain_cover", {"current_room_id": current_room_id, "finished": True},
-            )
-        else:
-            # We still have outstanding events to calculate the index for.
-            await self.db_pool.updates._background_update_progress(
-                "chain_cover",
-                {
-                    "current_room_id": current_room_id,
-                    "last_depth": last_depth,
-                    "last_stream": last_stream,
-                    "has_auth_chain_index": has_set_room_has_chain_index,
-                    "finished": False,
-                },
-            )
+        await self.db_pool.updates._background_update_progress(
+            "chain_cover",
+            {
+                "current_room_id": current_room_id,
+                "last_depth": last_depth,
+                "last_stream": last_stream,
+            },
+        )
 
         return total_rows_processed
+
+    def _calculate_chain_cover_txn(
+        self,
+        txn: Cursor,
+        last_room_id: str,
+        last_depth: int,
+        last_stream: int,
+        batch_size: Optional[int],
+        single_room: bool,
+    ) -> _CalculateChainCover:
+        """Calculate the chain cover for `batch_size` events, ordered by
+        `(room_id, depth, stream)`.
+
+        Args:
+            txn,
+            last_room_id, last_depth, last_stream: The `(room_id, depth, stream)`
+                tuple to fetch results after.
+            batch_size: The maximum number of events to process. If None then
+                no limit.
+            single_room: Whether to calculate the index for just the given
+                room.
+        """
+
+        # Get the next set of events in the room (that we haven't already
+        # computed chain cover for). We do this in topological order.
+
+        # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+        # comparison, but that is not supported on older SQLite versions
+        tuple_clause, tuple_args = make_tuple_comparison_clause(
+            self.database_engine,
+            [
+                ("events.room_id", last_room_id),
+                ("topological_ordering", last_depth),
+                ("stream_ordering", last_stream),
+            ],
+        )
+
+        extra_clause = ""
+        if single_room:
+            extra_clause = "AND events.room_id = ?"
+            tuple_args.append(last_room_id)
+
+        sql = """
+            SELECT
+                event_id, state_events.type, state_events.state_key,
+                topological_ordering, stream_ordering,
+                events.room_id
+            FROM events
+            INNER JOIN state_events USING (event_id)
+            LEFT JOIN event_auth_chains USING (event_id)
+            LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+            WHERE event_auth_chains.event_id IS NULL
+                AND event_auth_chain_to_calculate.event_id IS NULL
+                AND %(tuple_cmp)s
+                %(extra)s
+            ORDER BY events.room_id, topological_ordering, stream_ordering
+            %(limit)s
+        """ % {
+            "tuple_cmp": tuple_clause,
+            "limit": "LIMIT ?" if batch_size is not None else "",
+            "extra": extra_clause,
+        }
+
+        if batch_size is not None:
+            tuple_args.append(batch_size)
+
+        txn.execute(sql, tuple_args)
+        rows = txn.fetchall()
+
+        # Put the results in the necessary format for
+        # `_add_chain_cover_index`
+        event_to_room_id = {row[0]: row[5] for row in rows}
+        event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+        # Calculate the new last position we've processed up to.
+        new_last_depth = rows[-1][3] if rows else last_depth  # type: int
+        new_last_stream = rows[-1][4] if rows else last_stream  # type: int
+        new_last_room_id = rows[-1][5] if rows else ""  # type: str
+
+        # Map from room_id to last depth/stream_ordering processed for the room,
+        # excluding the last room (which we're likely still processing). We also
+        # need to include the room passed in if it's not included in the result
+        # set (as we then know we've processed all events in said room).
+        #
+        # This is the set of rooms that we can now safely flip the
+        # `has_auth_chain_index` bit for.
+        finished_rooms = {
+            row[5]: (row[3], row[4]) for row in rows if row[5] != new_last_room_id
+        }
+        if last_room_id not in finished_rooms and last_room_id != new_last_room_id:
+            finished_rooms[last_room_id] = (last_depth, last_stream)
+
+        count = len(rows)
+
+        # We also need to fetch the auth events for them.
+        auth_events = self.db_pool.simple_select_many_txn(
+            txn,
+            table="event_auth",
+            column="event_id",
+            iterable=event_to_room_id,
+            keyvalues={},
+            retcols=("event_id", "auth_id"),
+        )
+
+        event_to_auth_chain = {}  # type: Dict[str, List[str]]
+        for row in auth_events:
+            event_to_auth_chain.setdefault(row["event_id"], []).append(row["auth_id"])
+
+        # Calculate and persist the chain cover index for this set of events.
+        #
+        # Annoyingly we need to gut wrench into the persit event store so that
+        # we can reuse the function to calculate the chain cover for rooms.
+        PersistEventsStore._add_chain_cover_index(
+            txn, self.db_pool, event_to_room_id, event_to_types, event_to_auth_chain,
+        )
+
+        return _CalculateChainCover(
+            room_id=new_last_room_id,
+            depth=new_last_depth,
+            stream=new_last_stream,
+            processed_count=count,
+            finished_room_map=finished_rooms,
+        )
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index ff67a73749..0c46ad595b 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Dict, List, Tuple
+from typing import Dict, List, Set, Tuple
 
 from twisted.trial import unittest
 
@@ -483,22 +483,20 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def test_background_update(self):
-        """Test that the background update to calculate auth chains for historic
-        rooms works correctly.
-        """
-
-        # Create a room
-        user_id = self.register_user("foo", "pass")
-        token = self.login("foo", "pass")
-        room_id = self.helper.create_room_as(user_id, tok=token)
-        requester = create_requester(user_id)
+    def prepare(self, reactor, clock, hs):
+        self.store = hs.get_datastore()
+        self.user_id = self.register_user("foo", "pass")
+        self.token = self.login("foo", "pass")
+        self.requester = create_requester(self.user_id)
 
-        store = self.hs.get_datastore()
+    def _generate_room(self) -> Tuple[str, List[Set[str]]]:
+        """Insert a room without a chain cover index.
+        """
+        room_id = self.helper.create_room_as(self.user_id, tok=self.token)
 
         # Mark the room as not having a chain cover index
         self.get_success(
-            store.db_pool.simple_update(
+            self.store.db_pool.simple_update(
                 table="rooms",
                 keyvalues={"room_id": room_id},
                 updatevalues={"has_auth_chain_index": False},
@@ -508,42 +506,44 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
 
         # Create a fork in the DAG with different events.
         event_handler = self.hs.get_event_creation_handler()
-        latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
+        latest_event_ids = self.get_success(
+            self.store.get_prev_events_for_room(room_id)
+        )
         event, context = self.get_success(
             event_handler.create_event(
-                requester,
+                self.requester,
                 {
                     "type": "some_state_type",
                     "state_key": "",
                     "content": {},
                     "room_id": room_id,
-                    "sender": user_id,
+                    "sender": self.user_id,
                 },
                 prev_event_ids=latest_event_ids,
             )
         )
         self.get_success(
-            event_handler.handle_new_client_event(requester, event, context)
+            event_handler.handle_new_client_event(self.requester, event, context)
         )
-        state1 = list(self.get_success(context.get_current_state_ids()).values())
+        state1 = set(self.get_success(context.get_current_state_ids()).values())
 
         event, context = self.get_success(
             event_handler.create_event(
-                requester,
+                self.requester,
                 {
                     "type": "some_state_type",
                     "state_key": "",
                     "content": {},
                     "room_id": room_id,
-                    "sender": user_id,
+                    "sender": self.user_id,
                 },
                 prev_event_ids=latest_event_ids,
             )
         )
         self.get_success(
-            event_handler.handle_new_client_event(requester, event, context)
+            event_handler.handle_new_client_event(self.requester, event, context)
         )
-        state2 = list(self.get_success(context.get_current_state_ids()).values())
+        state2 = set(self.get_success(context.get_current_state_ids()).values())
 
         # Delete the chain cover info.
 
@@ -551,36 +551,191 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
             txn.execute("DELETE FROM event_auth_chains")
             txn.execute("DELETE FROM event_auth_chain_links")
 
-        self.get_success(store.db_pool.runInteraction("test", _delete_tables))
+        self.get_success(self.store.db_pool.runInteraction("test", _delete_tables))
+
+        return room_id, [state1, state2]
+
+    def test_background_update_single_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create a room
+        room_id, states = self._generate_room()
 
         # Insert and run the background update.
         self.get_success(
-            store.db_pool.simple_insert(
+            self.store.db_pool.simple_insert(
                 "background_updates",
                 {"update_name": "chain_cover", "progress_json": "{}"},
             )
         )
 
         # Ugh, have to reset this flag
-        store.db_pool.updates._all_done = False
+        self.store.db_pool.updates._all_done = False
 
         while not self.get_success(
-            store.db_pool.updates.has_completed_background_updates()
+            self.store.db_pool.updates.has_completed_background_updates()
         ):
             self.get_success(
-                store.db_pool.updates.do_next_background_update(100), by=0.1
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
             )
 
         # Test that the `has_auth_chain_index` has been set
-        self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
 
         # Test that calculating the auth chain difference using the newly
         # calculated chain cover works.
         self.get_success(
-            store.db_pool.runInteraction(
+            self.store.db_pool.runInteraction(
                 "test",
-                store._get_auth_chain_difference_using_cover_index_txn,
+                self.store._get_auth_chain_difference_using_cover_index_txn,
                 room_id,
-                [state1, state2],
+                states,
+            )
+        )
+
+    def test_background_update_multiple_rooms(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+        # Create a room
+        room_id1, states1 = self._generate_room()
+        room_id2, states2 = self._generate_room()
+        room_id3, states2 = self._generate_room()
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id3)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test",
+                self.store._get_auth_chain_difference_using_cover_index_txn,
+                room_id1,
+                states1,
             )
         )
+
+    def test_background_update_single_large_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create a room
+        room_id, states = self._generate_room()
+
+        # Add a bunch of state so that it takes multiple iterations of the
+        # background update to process the room.
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        iterations = 0
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            iterations += 1
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Ensure that we did actually take multiple iterations to process the
+        # room.
+        self.assertGreater(iterations, 1)
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            self.store.db_pool.runInteraction(
+                "test",
+                self.store._get_auth_chain_difference_using_cover_index_txn,
+                room_id,
+                states,
+            )
+        )
+
+    def test_background_update_multiple_large_room(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create the rooms
+        room_id1, _ = self._generate_room()
+        room_id2, _ = self._generate_room()
+
+        # Add a bunch of state so that it takes multiple iterations of the
+        # background update to process the room.
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id1, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        for i in range(0, 150):
+            self.helper.send_state(
+                room_id2, event_type="m.test", body={"index": i}, tok=self.token
+            )
+
+        # Insert and run the background update.
+        self.get_success(
+            self.store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        self.store.db_pool.updates._all_done = False
+
+        iterations = 0
+        while not self.get_success(
+            self.store.db_pool.updates.has_completed_background_updates()
+        ):
+            iterations += 1
+            self.get_success(
+                self.store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Ensure that we did actually take multiple iterations to process the
+        # room.
+        self.assertGreater(iterations, 1)
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id1)))
+        self.assertTrue(self.get_success(self.store.has_auth_chain_index(room_id2)))