summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-01-14 15:18:27 +0000
committerGitHub <noreply@github.com>2021-01-14 15:18:27 +0000
commit7036e24e98fc21855c34876d7024015470721bbe (patch)
tree983414be069d0221ffab4b508cb5733cd28d3ab1
parentSplit OidcProvider out of OidcHandler (#9107) (diff)
downloadsynapse-7036e24e98fc21855c34876d7024015470721bbe.tar.xz
Add background update for add chain cover index (#9029)
-rw-r--r--changelog.d/8868.misc2
-rw-r--r--changelog.d/9029.misc1
-rwxr-xr-xscripts/synapse_port_db2
-rw-r--r--synapse/storage/databases/main/events.py50
-rw-r--r--synapse/storage/databases/main/events_bg_updates.py192
-rw-r--r--synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql17
-rw-r--r--tests/storage/test_event_chain.py114
7 files changed, 360 insertions, 18 deletions
diff --git a/changelog.d/8868.misc b/changelog.d/8868.misc
index 1a11e30944..346741d982 100644
--- a/changelog.d/8868.misc
+++ b/changelog.d/8868.misc
@@ -1 +1 @@
-Improve efficiency of large state resolutions for new rooms.
+Improve efficiency of large state resolutions.
diff --git a/changelog.d/9029.misc b/changelog.d/9029.misc
new file mode 100644
index 0000000000..346741d982
--- /dev/null
+++ b/changelog.d/9029.misc
@@ -0,0 +1 @@
+Improve efficiency of large state resolutions.
diff --git a/scripts/synapse_port_db b/scripts/synapse_port_db
index 22dd169bfb..69bf9110a6 100755
--- a/scripts/synapse_port_db
+++ b/scripts/synapse_port_db
@@ -70,7 +70,7 @@ logger = logging.getLogger("synapse_port_db")
 
 BOOLEAN_COLUMNS = {
     "events": ["processed", "outlier", "contains_url"],
-    "rooms": ["is_public"],
+    "rooms": ["is_public", "has_auth_chain_index"],
     "event_edges": ["is_state"],
     "presence_list": ["accepted"],
     "presence_stream": ["currently_active"],
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 186f064036..e0fbcc58cf 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -466,9 +466,6 @@ class PersistEventsStore:
         if not state_events:
             return
 
-        # Map from event ID to chain ID/sequence number.
-        chain_map = {}  # type: Dict[str, Tuple[int, int]]
-
         # We need to know the type/state_key and auth events of the events we're
         # calculating chain IDs for. We don't rely on having the full Event
         # instances as we'll potentially be pulling more events from the DB and
@@ -479,9 +476,33 @@ class PersistEventsStore:
         event_to_auth_chain = {
             e.event_id: e.auth_event_ids() for e in state_events.values()
         }
+        event_to_room_id = {e.event_id: e.room_id for e in state_events.values()}
+
+        self._add_chain_cover_index(
+            txn, event_to_room_id, event_to_types, event_to_auth_chain
+        )
+
+    def _add_chain_cover_index(
+        self,
+        txn,
+        event_to_room_id: Dict[str, str],
+        event_to_types: Dict[str, Tuple[str, str]],
+        event_to_auth_chain: Dict[str, List[str]],
+    ) -> None:
+        """Calculate the chain cover index for the given events.
+
+        Args:
+            event_to_room_id: Event ID to the room ID of the event
+            event_to_types: Event ID to type and state_key of the event
+            event_to_auth_chain: Event ID to list of auth event IDs of the
+                event (events with no auth events can be excluded).
+        """
+
+        # Map from event ID to chain ID/sequence number.
+        chain_map = {}  # type: Dict[str, Tuple[int, int]]
 
         # Set of event IDs to calculate chain ID/seq numbers for.
-        events_to_calc_chain_id_for = set(state_events)
+        events_to_calc_chain_id_for = set(event_to_room_id)
 
         # We check if there are any events that need to be handled in the rooms
         # we're looking at. These should just be out of band memberships, where
@@ -491,7 +512,7 @@ class PersistEventsStore:
             table="event_auth_chain_to_calculate",
             keyvalues={},
             column="room_id",
-            iterable={e.room_id for e in state_events.values()},
+            iterable=set(event_to_room_id.values()),
             retcols=("event_id", "type", "state_key"),
         )
         for row in rows:
@@ -582,16 +603,17 @@ class PersistEventsStore:
                     # the list of events to calculate chain IDs for next time
                     # around. (Otherwise we will have already added it to the
                     # table).
-                    event = state_events.get(event_id)
-                    if event:
+                    room_id = event_to_room_id.get(event_id)
+                    if room_id:
+                        e_type, state_key = event_to_types[event_id]
                         self.db_pool.simple_insert_txn(
                             txn,
                             table="event_auth_chain_to_calculate",
                             values={
-                                "event_id": event.event_id,
-                                "room_id": event.room_id,
-                                "type": event.type,
-                                "state_key": event.state_key,
+                                "event_id": event_id,
+                                "room_id": room_id,
+                                "type": e_type,
+                                "state_key": state_key,
                             },
                         )
 
@@ -617,7 +639,7 @@ class PersistEventsStore:
             events_to_calc_chain_id_for, event_to_auth_chain
         ):
             existing_chain_id = None
-            for auth_id in event_to_auth_chain[event_id]:
+            for auth_id in event_to_auth_chain.get(event_id, []):
                 if event_to_types.get(event_id) == event_to_types.get(auth_id):
                     existing_chain_id = chain_map[auth_id]
                     break
@@ -730,11 +752,11 @@ class PersistEventsStore:
             # auth events (A, B) to check if B is reachable from A.
             reduction = {
                 a_id
-                for a_id in event_to_auth_chain[event_id]
+                for a_id in event_to_auth_chain.get(event_id, [])
                 if chain_map[a_id][0] != chain_id
             }
             for start_auth_id, end_auth_id in itertools.permutations(
-                event_to_auth_chain[event_id], r=2,
+                event_to_auth_chain.get(event_id, []), r=2,
             ):
                 if chain_links.exists_path_from(
                     chain_map[start_auth_id], chain_map[end_auth_id]
diff --git a/synapse/storage/databases/main/events_bg_updates.py b/synapse/storage/databases/main/events_bg_updates.py
index 7e4b175d08..90a40a92b4 100644
--- a/synapse/storage/databases/main/events_bg_updates.py
+++ b/synapse/storage/databases/main/events_bg_updates.py
@@ -14,13 +14,13 @@
 # limitations under the License.
 
 import logging
-from typing import List, Tuple
+from typing import Dict, List, Optional, Tuple
 
 from synapse.api.constants import EventContentFields
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import make_event_from_dict
 from synapse.storage._base import SQLBaseStore, db_to_json, make_in_list_sql_clause
-from synapse.storage.database import DatabasePool
+from synapse.storage.database import DatabasePool, make_tuple_comparison_clause
 from synapse.storage.types import Cursor
 from synapse.types import JsonDict
 
@@ -108,6 +108,10 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             "rejected_events_metadata", self._rejected_events_metadata,
         )
 
+        self.db_pool.updates.register_background_update_handler(
+            "chain_cover", self._chain_cover_index,
+        )
+
     async def _background_reindex_fields_sender(self, progress, batch_size):
         target_min_stream_id = progress["target_min_stream_id_inclusive"]
         max_stream_id = progress["max_stream_id_exclusive"]
@@ -706,3 +710,187 @@ class EventsBackgroundUpdatesStore(SQLBaseStore):
             )
 
         return len(results)
+
+    async def _chain_cover_index(self, progress: dict, batch_size: int) -> int:
+        """A background updates that iterates over all rooms and generates the
+        chain cover index for them.
+        """
+
+        current_room_id = progress.get("current_room_id", "")
+
+        # Have we finished processing the current room.
+        finished = progress.get("finished", True)
+
+        # Where we've processed up to in the room, defaults to the start of the
+        # room.
+        last_depth = progress.get("last_depth", -1)
+        last_stream = progress.get("last_stream", -1)
+
+        # Have we set the `has_auth_chain_index` for the room yet.
+        has_set_room_has_chain_index = progress.get(
+            "has_set_room_has_chain_index", False
+        )
+
+        if finished:
+            # If we've finished with the previous room (or its our first
+            # iteration) we move on to the next room.
+
+            def _get_next_room(txn: Cursor) -> Optional[str]:
+                sql = """
+                    SELECT room_id FROM rooms
+                    WHERE room_id > ?
+                        AND (
+                            NOT has_auth_chain_index
+                            OR has_auth_chain_index IS NULL
+                        )
+                    ORDER BY room_id
+                    LIMIT 1
+                """
+                txn.execute(sql, (current_room_id,))
+                row = txn.fetchone()
+                if row:
+                    return row[0]
+
+                return None
+
+            current_room_id = await self.db_pool.runInteraction(
+                "_chain_cover_index", _get_next_room
+            )
+            if not current_room_id:
+                await self.db_pool.updates._end_background_update("chain_cover")
+                return 0
+
+            logger.debug("Adding chain cover to %s", current_room_id)
+
+        def _calculate_auth_chain(
+            txn: Cursor, last_depth: int, last_stream: int
+        ) -> Tuple[int, int, int]:
+            # Get the next set of events in the room (that we haven't already
+            # computed chain cover for). We do this in topological order.
+
+            # We want to do a `(topological_ordering, stream_ordering) > (?,?)`
+            # comparison, but that is not supported on older SQLite versions
+            tuple_clause, tuple_args = make_tuple_comparison_clause(
+                self.database_engine,
+                [
+                    ("topological_ordering", last_depth),
+                    ("stream_ordering", last_stream),
+                ],
+            )
+
+            sql = """
+                SELECT
+                    event_id, state_events.type, state_events.state_key,
+                    topological_ordering, stream_ordering
+                FROM events
+                INNER JOIN state_events USING (event_id)
+                LEFT JOIN event_auth_chains USING (event_id)
+                LEFT JOIN event_auth_chain_to_calculate USING (event_id)
+                WHERE events.room_id = ?
+                    AND event_auth_chains.event_id IS NULL
+                    AND event_auth_chain_to_calculate.event_id IS NULL
+                    AND %(tuple_cmp)s
+                ORDER BY topological_ordering, stream_ordering
+                LIMIT ?
+            """ % {
+                "tuple_cmp": tuple_clause,
+            }
+
+            args = [current_room_id]
+            args.extend(tuple_args)
+            args.append(batch_size)
+
+            txn.execute(sql, args)
+            rows = txn.fetchall()
+
+            # Put the results in the necessary format for
+            # `_add_chain_cover_index`
+            event_to_room_id = {row[0]: current_room_id for row in rows}
+            event_to_types = {row[0]: (row[1], row[2]) for row in rows}
+
+            new_last_depth = rows[-1][3] if rows else last_depth  # type: int
+            new_last_stream = rows[-1][4] if rows else last_stream  # type: int
+
+            count = len(rows)
+
+            # We also need to fetch the auth events for them.
+            auth_events = self.db_pool.simple_select_many_txn(
+                txn,
+                table="event_auth",
+                column="event_id",
+                iterable=event_to_room_id,
+                keyvalues={},
+                retcols=("event_id", "auth_id"),
+            )
+
+            event_to_auth_chain = {}  # type: Dict[str, List[str]]
+            for row in auth_events:
+                event_to_auth_chain.setdefault(row["event_id"], []).append(
+                    row["auth_id"]
+                )
+
+            # Calculate and persist the chain cover index for this set of events.
+            #
+            # Annoyingly we need to gut wrench into the persit event store so that
+            # we can reuse the function to calculate the chain cover for rooms.
+            self.hs.get_datastores().persist_events._add_chain_cover_index(
+                txn, event_to_room_id, event_to_types, event_to_auth_chain,
+            )
+
+            return new_last_depth, new_last_stream, count
+
+        last_depth, last_stream, count = await self.db_pool.runInteraction(
+            "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+        )
+
+        total_rows_processed = count
+
+        if count < batch_size and not has_set_room_has_chain_index:
+            # If we've done all the events in the room we flip the
+            # `has_auth_chain_index` in the DB. Note that its possible for
+            # further events to be persisted between the above and setting the
+            # flag without having the chain cover calculated for them. This is
+            # fine as a) the code gracefully handles these cases and b) we'll
+            # calculate them below.
+
+            await self.db_pool.simple_update(
+                table="rooms",
+                keyvalues={"room_id": current_room_id},
+                updatevalues={"has_auth_chain_index": True},
+                desc="_chain_cover_index",
+            )
+            has_set_room_has_chain_index = True
+
+            # Handle any events that might have raced with us flipping the
+            # bit above.
+            last_depth, last_stream, count = await self.db_pool.runInteraction(
+                "_chain_cover_index", _calculate_auth_chain, last_depth, last_stream
+            )
+
+            total_rows_processed += count
+
+            # Note that at this point its technically possible that more events
+            # than our `batch_size` have been persisted without their chain
+            # cover, so we need to continue processing this room if the last
+            # count returned was equal to the `batch_size`.
+
+        if count < batch_size:
+            # We've finished calculating the index for this room, move on to the
+            # next room.
+            await self.db_pool.updates._background_update_progress(
+                "chain_cover", {"current_room_id": current_room_id, "finished": True},
+            )
+        else:
+            # We still have outstanding events to calculate the index for.
+            await self.db_pool.updates._background_update_progress(
+                "chain_cover",
+                {
+                    "current_room_id": current_room_id,
+                    "last_depth": last_depth,
+                    "last_stream": last_stream,
+                    "has_auth_chain_index": has_set_room_has_chain_index,
+                    "finished": False,
+                },
+            )
+
+        return total_rows_processed
diff --git a/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
new file mode 100644
index 0000000000..fe3dca71dd
--- /dev/null
+++ b/synapse/storage/databases/main/schema/delta/59/06chain_cover_index.sql
@@ -0,0 +1,17 @@
+/* Copyright 2020 The Matrix.org Foundation C.I.C
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+INSERT INTO background_updates (ordering, update_name, progress_json, depends_on) VALUES
+  (5906, 'chain_cover', '{}', 'rejected_events_metadata');
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index 83c377824b..ff67a73749 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -20,7 +20,10 @@ from twisted.trial import unittest
 from synapse.api.constants import EventTypes
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
 from synapse.storage.databases.main.events import _LinkMap
+from synapse.types import create_requester
 
 from tests.unittest import HomeserverTestCase
 
@@ -470,3 +473,114 @@ class LinkMapTestCase(unittest.TestCase):
         self.assertCountEqual(link_map.get_links_between(1, 2), [(1, 1), (3, 3)])
 
         self.assertCountEqual(link_map.get_additions(), [(1, 3, 2, 3), (2, 5, 1, 3)])
+
+
+class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
+
+    servlets = [
+        admin.register_servlets,
+        room.register_servlets,
+        login.register_servlets,
+    ]
+
+    def test_background_update(self):
+        """Test that the background update to calculate auth chains for historic
+        rooms works correctly.
+        """
+
+        # Create a room
+        user_id = self.register_user("foo", "pass")
+        token = self.login("foo", "pass")
+        room_id = self.helper.create_room_as(user_id, tok=token)
+        requester = create_requester(user_id)
+
+        store = self.hs.get_datastore()
+
+        # Mark the room as not having a chain cover index
+        self.get_success(
+            store.db_pool.simple_update(
+                table="rooms",
+                keyvalues={"room_id": room_id},
+                updatevalues={"has_auth_chain_index": False},
+                desc="test",
+            )
+        )
+
+        # Create a fork in the DAG with different events.
+        event_handler = self.hs.get_event_creation_handler()
+        latest_event_ids = self.get_success(store.get_prev_events_for_room(room_id))
+        event, context = self.get_success(
+            event_handler.create_event(
+                requester,
+                {
+                    "type": "some_state_type",
+                    "state_key": "",
+                    "content": {},
+                    "room_id": room_id,
+                    "sender": user_id,
+                },
+                prev_event_ids=latest_event_ids,
+            )
+        )
+        self.get_success(
+            event_handler.handle_new_client_event(requester, event, context)
+        )
+        state1 = list(self.get_success(context.get_current_state_ids()).values())
+
+        event, context = self.get_success(
+            event_handler.create_event(
+                requester,
+                {
+                    "type": "some_state_type",
+                    "state_key": "",
+                    "content": {},
+                    "room_id": room_id,
+                    "sender": user_id,
+                },
+                prev_event_ids=latest_event_ids,
+            )
+        )
+        self.get_success(
+            event_handler.handle_new_client_event(requester, event, context)
+        )
+        state2 = list(self.get_success(context.get_current_state_ids()).values())
+
+        # Delete the chain cover info.
+
+        def _delete_tables(txn):
+            txn.execute("DELETE FROM event_auth_chains")
+            txn.execute("DELETE FROM event_auth_chain_links")
+
+        self.get_success(store.db_pool.runInteraction("test", _delete_tables))
+
+        # Insert and run the background update.
+        self.get_success(
+            store.db_pool.simple_insert(
+                "background_updates",
+                {"update_name": "chain_cover", "progress_json": "{}"},
+            )
+        )
+
+        # Ugh, have to reset this flag
+        store.db_pool.updates._all_done = False
+
+        while not self.get_success(
+            store.db_pool.updates.has_completed_background_updates()
+        ):
+            self.get_success(
+                store.db_pool.updates.do_next_background_update(100), by=0.1
+            )
+
+        # Test that the `has_auth_chain_index` has been set
+        self.assertTrue(self.get_success(store.has_auth_chain_index(room_id)))
+
+        # Test that calculating the auth chain difference using the newly
+        # calculated chain cover works.
+        self.get_success(
+            store.db_pool.runInteraction(
+                "test",
+                store._get_auth_chain_difference_using_cover_index_txn,
+                room_id,
+                [state1, state2],
+            )
+        )