summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16116.bugfix1
-rw-r--r--synapse/storage/databases/main/event_federation.py184
-rw-r--r--tests/storage/test_event_federation.py241
3 files changed, 355 insertions, 71 deletions
diff --git a/changelog.d/16116.bugfix b/changelog.d/16116.bugfix
new file mode 100644
index 0000000000..f57a26ae39
--- /dev/null
+++ b/changelog.d/16116.bugfix
@@ -0,0 +1 @@
+Fix performance of state resolutions for large, old rooms that did not have the full auth chain persisted.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 534dc32413..fab7008a8f 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -452,33 +452,56 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # sets.
         seen_chains: Set[int] = set()
 
-        sql = """
-            SELECT event_id, chain_id, sequence_number
-            FROM event_auth_chains
-            WHERE %s
-        """
-        for batch in batch_iter(initial_events, 1000):
-            clause, args = make_in_list_sql_clause(
-                txn.database_engine, "event_id", batch
-            )
-            txn.execute(sql % (clause,), args)
+        # Fetch the chain cover index for the initial set of events we're
+        # considering.
+        def fetch_chain_info(events_to_fetch: Collection[str]) -> None:
+            sql = """
+                SELECT event_id, chain_id, sequence_number
+                FROM event_auth_chains
+                WHERE %s
+            """
+            for batch in batch_iter(events_to_fetch, 1000):
+                clause, args = make_in_list_sql_clause(
+                    txn.database_engine, "event_id", batch
+                )
+                txn.execute(sql % (clause,), args)
 
-            for event_id, chain_id, sequence_number in txn:
-                chain_info[event_id] = (chain_id, sequence_number)
-                seen_chains.add(chain_id)
-                chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+                for event_id, chain_id, sequence_number in txn:
+                    chain_info[event_id] = (chain_id, sequence_number)
+                    seen_chains.add(chain_id)
+                    chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+        fetch_chain_info(initial_events)
 
         # Check that we actually have a chain ID for all the events.
         events_missing_chain_info = initial_events.difference(chain_info)
+
+        # The result set to return, i.e. the auth chain difference.
+        result: Set[str] = set()
+
         if events_missing_chain_info:
-            # This can happen due to e.g. downgrade/upgrade of the server. We
-            # raise an exception and fall back to the previous algorithm.
-            logger.info(
-                "Unexpectedly found that events don't have chain IDs in room %s: %s",
+            # For some reason we have events we haven't calculated the chain
+            # index for, so we need to handle those separately. This should only
+            # happen for older rooms where the server doesn't have all the auth
+            # events.
+            result = self._fixup_auth_chain_difference_sets(
+                txn,
                 room_id,
-                events_missing_chain_info,
+                state_sets=state_sets,
+                events_missing_chain_info=events_missing_chain_info,
+                events_that_have_chain_index=chain_info,
             )
-            raise _NoChainCoverIndex(room_id)
+
+            # We now need to refetch any events that we have added to the state
+            # sets.
+            new_events_to_fetch = {
+                event_id
+                for state_set in state_sets
+                for event_id in state_set
+                if event_id not in initial_events
+            }
+
+            fetch_chain_info(new_events_to_fetch)
 
         # Corresponds to `state_sets`, except as a map from chain ID to max
         # sequence number reachable from the state set.
@@ -487,8 +510,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
             chains: Dict[int, int] = {}
             set_to_chain.append(chains)
 
-            for event_id in state_set:
-                chain_id, seq_no = chain_info[event_id]
+            for state_id in state_set:
+                chain_id, seq_no = chain_info[state_id]
 
                 chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
 
@@ -532,7 +555,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # from *any* state set and the minimum sequence number reachable from
         # *all* state sets. Events in that range are in the auth chain
         # difference.
-        result = set()
 
         # Mapping from chain ID to the range of sequence numbers that should be
         # pulled from the database.
@@ -588,6 +610,122 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         return result
 
+    def _fixup_auth_chain_difference_sets(
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        state_sets: List[Set[str]],
+        events_missing_chain_info: Set[str],
+        events_that_have_chain_index: Collection[str],
+    ) -> Set[str]:
+        """Helper for `_get_auth_chain_difference_using_cover_index_txn` to
+        handle the case where we haven't calculated the chain cover index for
+        all events.
+
+        This modifies `state_sets` so that they only include events that have a
+        chain cover index, and returns a set of event IDs that are part of the
+        auth difference.
+        """
+
+        # This works similarly to the handling of unpersisted events in
+        # `synapse.state.v2_get_auth_chain_difference`. We uses the observation
+        # that if you can split the set of events into two classes X and Y,
+        # where no events in Y have events in X in their auth chain, then we can
+        # calculate the auth difference by considering X and Y separately.
+        #
+        # We do this in three steps:
+        #   1. Compute the set of events without chain cover index belonging to
+        #      the auth difference.
+        #   2. Replacing the un-indexed events in the state_sets with their auth
+        #      events, recursively, until the state_sets contain only indexed
+        #      events. We can then calculate the auth difference of those state
+        #      sets using the chain cover index.
+        #   3. Add the results of 1 and 2 together.
+
+        # By construction we know that all events that we haven't persisted the
+        # chain cover index for are contained in
+        # `event_auth_chain_to_calculate`, so we pull out the events from those
+        # rather than doing recursive queries to walk the auth chain.
+        #
+        # We pull out those events with their auth events, which gives us enough
+        # information to construct the auth chain of an event up to auth events
+        # that have the chain cover index.
+        sql = """
+            SELECT tc.event_id, ea.auth_id, eac.chain_id IS NOT NULL
+            FROM event_auth_chain_to_calculate AS tc
+            LEFT JOIN event_auth AS ea USING (event_id)
+            LEFT JOIN event_auth_chains AS eac ON (ea.auth_id = eac.event_id)
+            WHERE tc.room_id = ?
+        """
+        txn.execute(sql, (room_id,))
+        event_to_auth_ids: Dict[str, Set[str]] = {}
+        events_that_have_chain_index = set(events_that_have_chain_index)
+        for event_id, auth_id, auth_id_has_chain in txn:
+            s = event_to_auth_ids.setdefault(event_id, set())
+            if auth_id is not None:
+                s.add(auth_id)
+                if auth_id_has_chain:
+                    events_that_have_chain_index.add(auth_id)
+
+        if events_missing_chain_info - event_to_auth_ids.keys():
+            # Uh oh, we somehow haven't correctly done the chain cover index,
+            # bail and fall back to the old method.
+            logger.info(
+                "Unexpectedly found that events don't have chain IDs in room %s: %s",
+                room_id,
+                events_missing_chain_info - event_to_auth_ids.keys(),
+            )
+            raise _NoChainCoverIndex(room_id)
+
+        # Create a map from event IDs we care about to their partial auth chain.
+        event_id_to_partial_auth_chain: Dict[str, Set[str]] = {}
+        for event_id, auth_ids in event_to_auth_ids.items():
+            if not any(event_id in state_set for state_set in state_sets):
+                continue
+
+            processing = set(auth_ids)
+            to_add = set()
+            while processing:
+                auth_id = processing.pop()
+                to_add.add(auth_id)
+
+                sub_auth_ids = event_to_auth_ids.get(auth_id)
+                if sub_auth_ids is None:
+                    continue
+
+                processing.update(sub_auth_ids - to_add)
+
+            event_id_to_partial_auth_chain[event_id] = to_add
+
+        # Now we do two things:
+        #   1. Update the state sets to only include indexed events; and
+        #   2. Create a new list containing the auth chains of the un-indexed
+        #      events
+        unindexed_state_sets: List[Set[str]] = []
+        for state_set in state_sets:
+            unindexed_state_set = set()
+            for event_id, auth_chain in event_id_to_partial_auth_chain.items():
+                if event_id not in state_set:
+                    continue
+
+                unindexed_state_set.add(event_id)
+
+                state_set.discard(event_id)
+                state_set.difference_update(auth_chain)
+                for auth_id in auth_chain:
+                    if auth_id in events_that_have_chain_index:
+                        state_set.add(auth_id)
+                    else:
+                        unindexed_state_set.add(auth_id)
+
+            unindexed_state_sets.append(unindexed_state_set)
+
+        # Calculate and return the auth difference of the un-indexed events.
+        union = unindexed_state_sets[0].union(*unindexed_state_sets[1:])
+        intersection = unindexed_state_sets[0].intersection(*unindexed_state_sets[1:])
+
+        return union - intersection
+
     def _get_auth_chain_difference_txn(
         self, txn: LoggingTransaction, state_sets: List[Set[str]]
     ) -> Set[str]:
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 9c151a5e62..7a4ecab2d5 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,7 +13,19 @@
 # limitations under the License.
 
 import datetime
-from typing import Dict, List, Tuple, Union, cast
+from typing import (
+    Collection,
+    Dict,
+    FrozenSet,
+    Iterable,
+    List,
+    Mapping,
+    Set,
+    Tuple,
+    TypeVar,
+    Union,
+    cast,
+)
 
 import attr
 from parameterized import parameterized
@@ -38,6 +50,138 @@ from synapse.util import Clock, json_encoder
 import tests.unittest
 import tests.utils
 
+# The silly auth graph we use to test the auth difference algorithm,
+# where the top are the most recent events.
+#
+#   A   B
+#    \ /
+#  D  E
+#  \  |
+#   ` F   C
+#     |  /|
+#     G ´ |
+#     | \ |
+#     H   I
+#     |   |
+#     K   J
+
+AUTH_GRAPH: Dict[str, List[str]] = {
+    "a": ["e"],
+    "b": ["e"],
+    "c": ["g", "i"],
+    "d": ["f"],
+    "e": ["f"],
+    "f": ["g"],
+    "g": ["h", "i"],
+    "h": ["k"],
+    "i": ["j"],
+    "k": [],
+    "j": [],
+}
+
+DEPTH_GRAPH = {
+    "a": 7,
+    "b": 7,
+    "c": 4,
+    "d": 6,
+    "e": 6,
+    "f": 5,
+    "g": 3,
+    "h": 2,
+    "i": 2,
+    "k": 1,
+    "j": 1,
+}
+
+T = TypeVar("T")
+
+
+def get_all_topologically_sorted_orders(
+    nodes: Iterable[T],
+    graph: Mapping[T, Collection[T]],
+) -> List[List[T]]:
+    """Given a set of nodes and a graph, return all possible topological
+    orderings.
+    """
+
+    # This is implemented by Kahn's algorithm, and forking execution each time
+    # we have a choice over which node to consider next.
+
+    degree_map = {node: 0 for node in nodes}
+    reverse_graph: Dict[T, Set[T]] = {}
+
+    for node, edges in graph.items():
+        if node not in degree_map:
+            continue
+
+        for edge in set(edges):
+            if edge in degree_map:
+                degree_map[node] += 1
+
+            reverse_graph.setdefault(edge, set()).add(node)
+        reverse_graph.setdefault(node, set())
+
+    zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+
+    return _get_all_topologically_sorted_orders_inner(
+        reverse_graph, zero_degree, degree_map
+    )
+
+
+def _get_all_topologically_sorted_orders_inner(
+    reverse_graph: Dict[T, Set[T]],
+    zero_degree: List[T],
+    degree_map: Dict[T, int],
+) -> List[List[T]]:
+    new_paths = []
+
+    # Rather than only choosing *one* item from the list of nodes with zero
+    # degree, we "fork" execution and run the algorithm for each node in the
+    # zero degree.
+    for node in zero_degree:
+        new_degree_map = degree_map.copy()
+        new_zero_degree = zero_degree.copy()
+        new_zero_degree.remove(node)
+
+        for edge in reverse_graph.get(node, []):
+            if edge in new_degree_map:
+                new_degree_map[edge] -= 1
+                if new_degree_map[edge] == 0:
+                    new_zero_degree.append(edge)
+
+        paths = _get_all_topologically_sorted_orders_inner(
+            reverse_graph, new_zero_degree, new_degree_map
+        )
+        for path in paths:
+            path.insert(0, node)
+
+        new_paths.extend(paths)
+
+    if not new_paths:
+        return [[]]
+
+    return new_paths
+
+
+def get_all_topologically_consistent_subsets(
+    nodes: Iterable[T],
+    graph: Mapping[T, Collection[T]],
+) -> Set[FrozenSet[T]]:
+    """Get all subsets of the graph where if node N is in the subgraph, then all
+    nodes that can reach that node (i.e. for all X there exists a path X -> N)
+    are in the subgraph.
+    """
+    all_topological_orderings = get_all_topologically_sorted_orders(nodes, graph)
+
+    graph_subsets = set()
+    for ordering in all_topological_orderings:
+        ordering.reverse()
+
+        for idx in range(len(ordering)):
+            graph_subsets.add(frozenset(ordering[:idx]))
+
+    return graph_subsets
+
 
 @attr.s(auto_attribs=True, frozen=True, slots=True)
 class _BackfillSetupInfo:
@@ -172,49 +316,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
     def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
         room_id = "@ROOM:local"
 
-        # The silly auth graph we use to test the auth difference algorithm,
-        # where the top are the most recent events.
-        #
-        #   A   B
-        #    \ /
-        #  D  E
-        #  \  |
-        #   ` F   C
-        #     |  /|
-        #     G ´ |
-        #     | \ |
-        #     H   I
-        #     |   |
-        #     K   J
-
-        auth_graph: Dict[str, List[str]] = {
-            "a": ["e"],
-            "b": ["e"],
-            "c": ["g", "i"],
-            "d": ["f"],
-            "e": ["f"],
-            "f": ["g"],
-            "g": ["h", "i"],
-            "h": ["k"],
-            "i": ["j"],
-            "k": [],
-            "j": [],
-        }
-
-        depth_map = {
-            "a": 7,
-            "b": 7,
-            "c": 4,
-            "d": 6,
-            "e": 6,
-            "f": 5,
-            "g": 3,
-            "h": 2,
-            "i": 2,
-            "k": 1,
-            "j": 1,
-        }
-
         # Mark the room as maybe having a cover index.
 
         def store_room(txn: LoggingTransaction) -> None:
@@ -238,9 +339,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         def insert_event(txn: LoggingTransaction) -> None:
             stream_ordering = 0
 
-            for event_id in auth_graph:
+            for event_id in AUTH_GRAPH:
                 stream_ordering += 1
-                depth = depth_map[event_id]
+                depth = DEPTH_GRAPH[event_id]
 
                 self.store.db_pool.simple_insert_txn(
                     txn,
@@ -260,8 +361,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
             self.persist_events._persist_event_auth_chain_txn(
                 txn,
                 [
-                    cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
-                    for event_id in auth_graph
+                    cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+                    for event_id in AUTH_GRAPH
                 ],
             )
 
@@ -344,7 +445,51 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         room_id = self._setup_auth_chain(use_chain_cover_index)
 
         # Now actually test that various combinations give the right result:
+        self.assert_auth_diff_is_expected(room_id)
+
+    @parameterized.expand(
+        [
+            [graph_subset]
+            for graph_subset in get_all_topologically_consistent_subsets(
+                AUTH_GRAPH, AUTH_GRAPH
+            )
+        ]
+    )
+    def test_auth_difference_partial(self, graph_subset: Collection[str]) -> None:
+        """Test that if we only have a chain cover index on a partial subset of
+        the room we still get the correct auth chain difference.
+
+        We do this by removing the chain cover index for every valid subset of the
+        graph.
+        """
+        room_id = self._setup_auth_chain(True)
+
+        for event_id in graph_subset:
+            # Remove chain cover from that event.
+            self.get_success(
+                self.store.db_pool.simple_delete(
+                    table="event_auth_chains",
+                    keyvalues={"event_id": event_id},
+                    desc="test_auth_difference_partial_remove",
+                )
+            )
+            self.get_success(
+                self.store.db_pool.simple_insert(
+                    table="event_auth_chain_to_calculate",
+                    values={
+                        "event_id": event_id,
+                        "room_id": room_id,
+                        "type": "",
+                        "state_key": "",
+                    },
+                    desc="test_auth_difference_partial_remove",
+                )
+            )
+
+        self.assert_auth_diff_is_expected(room_id)
 
+    def assert_auth_diff_is_expected(self, room_id: str) -> None:
+        """Assert the auth chain difference returns the correct answers."""
         difference = self.get_success(
             self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
         )