diff --git a/changelog.d/16116.bugfix b/changelog.d/16116.bugfix
new file mode 100644
index 0000000000..f57a26ae39
--- /dev/null
+++ b/changelog.d/16116.bugfix
@@ -0,0 +1 @@
+Fix performance of state resolutions for large, old rooms that did not have the full auth chain persisted.
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 534dc32413..fab7008a8f 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -452,33 +452,56 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# sets.
seen_chains: Set[int] = set()
- sql = """
- SELECT event_id, chain_id, sequence_number
- FROM event_auth_chains
- WHERE %s
- """
- for batch in batch_iter(initial_events, 1000):
- clause, args = make_in_list_sql_clause(
- txn.database_engine, "event_id", batch
- )
- txn.execute(sql % (clause,), args)
+ # Fetch the chain cover index for the initial set of events we're
+ # considering.
+ def fetch_chain_info(events_to_fetch: Collection[str]) -> None:
+ sql = """
+ SELECT event_id, chain_id, sequence_number
+ FROM event_auth_chains
+ WHERE %s
+ """
+ for batch in batch_iter(events_to_fetch, 1000):
+ clause, args = make_in_list_sql_clause(
+ txn.database_engine, "event_id", batch
+ )
+ txn.execute(sql % (clause,), args)
- for event_id, chain_id, sequence_number in txn:
- chain_info[event_id] = (chain_id, sequence_number)
- seen_chains.add(chain_id)
- chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+ for event_id, chain_id, sequence_number in txn:
+ chain_info[event_id] = (chain_id, sequence_number)
+ seen_chains.add(chain_id)
+ chain_to_event.setdefault(chain_id, {})[sequence_number] = event_id
+
+ fetch_chain_info(initial_events)
# Check that we actually have a chain ID for all the events.
events_missing_chain_info = initial_events.difference(chain_info)
+
+ # The result set to return, i.e. the auth chain difference.
+ result: Set[str] = set()
+
if events_missing_chain_info:
- # This can happen due to e.g. downgrade/upgrade of the server. We
- # raise an exception and fall back to the previous algorithm.
- logger.info(
- "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ # For some reason we have events we haven't calculated the chain
+ # index for, so we need to handle those separately. This should only
+ # happen for older rooms where the server doesn't have all the auth
+ # events.
+ result = self._fixup_auth_chain_difference_sets(
+ txn,
room_id,
- events_missing_chain_info,
+ state_sets=state_sets,
+ events_missing_chain_info=events_missing_chain_info,
+ events_that_have_chain_index=chain_info,
)
- raise _NoChainCoverIndex(room_id)
+
+ # We now need to refetch any events that we have added to the state
+ # sets.
+ new_events_to_fetch = {
+ event_id
+ for state_set in state_sets
+ for event_id in state_set
+ if event_id not in initial_events
+ }
+
+ fetch_chain_info(new_events_to_fetch)
# Corresponds to `state_sets`, except as a map from chain ID to max
# sequence number reachable from the state set.
@@ -487,8 +510,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
chains: Dict[int, int] = {}
set_to_chain.append(chains)
- for event_id in state_set:
- chain_id, seq_no = chain_info[event_id]
+ for state_id in state_set:
+ chain_id, seq_no = chain_info[state_id]
chains[chain_id] = max(seq_no, chains.get(chain_id, 0))
@@ -532,7 +555,6 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# from *any* state set and the minimum sequence number reachable from
# *all* state sets. Events in that range are in the auth chain
# difference.
- result = set()
# Mapping from chain ID to the range of sequence numbers that should be
# pulled from the database.
@@ -588,6 +610,122 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return result
+ def _fixup_auth_chain_difference_sets(
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_sets: List[Set[str]],
+ events_missing_chain_info: Set[str],
+ events_that_have_chain_index: Collection[str],
+ ) -> Set[str]:
+ """Helper for `_get_auth_chain_difference_using_cover_index_txn` to
+ handle the case where we haven't calculated the chain cover index for
+ all events.
+
+ This modifies `state_sets` so that they only include events that have a
+ chain cover index, and returns a set of event IDs that are part of the
+ auth difference.
+ """
+
+ # This works similarly to the handling of unpersisted events in
+ # `synapse.state.v2_get_auth_chain_difference`. We uses the observation
+ # that if you can split the set of events into two classes X and Y,
+ # where no events in Y have events in X in their auth chain, then we can
+ # calculate the auth difference by considering X and Y separately.
+ #
+ # We do this in three steps:
+ # 1. Compute the set of events without chain cover index belonging to
+ # the auth difference.
+ # 2. Replacing the un-indexed events in the state_sets with their auth
+ # events, recursively, until the state_sets contain only indexed
+ # events. We can then calculate the auth difference of those state
+ # sets using the chain cover index.
+ # 3. Add the results of 1 and 2 together.
+
+ # By construction we know that all events that we haven't persisted the
+ # chain cover index for are contained in
+ # `event_auth_chain_to_calculate`, so we pull out the events from those
+ # rather than doing recursive queries to walk the auth chain.
+ #
+ # We pull out those events with their auth events, which gives us enough
+ # information to construct the auth chain of an event up to auth events
+ # that have the chain cover index.
+ sql = """
+ SELECT tc.event_id, ea.auth_id, eac.chain_id IS NOT NULL
+ FROM event_auth_chain_to_calculate AS tc
+ LEFT JOIN event_auth AS ea USING (event_id)
+ LEFT JOIN event_auth_chains AS eac ON (ea.auth_id = eac.event_id)
+ WHERE tc.room_id = ?
+ """
+ txn.execute(sql, (room_id,))
+ event_to_auth_ids: Dict[str, Set[str]] = {}
+ events_that_have_chain_index = set(events_that_have_chain_index)
+ for event_id, auth_id, auth_id_has_chain in txn:
+ s = event_to_auth_ids.setdefault(event_id, set())
+ if auth_id is not None:
+ s.add(auth_id)
+ if auth_id_has_chain:
+ events_that_have_chain_index.add(auth_id)
+
+ if events_missing_chain_info - event_to_auth_ids.keys():
+ # Uh oh, we somehow haven't correctly done the chain cover index,
+ # bail and fall back to the old method.
+ logger.info(
+ "Unexpectedly found that events don't have chain IDs in room %s: %s",
+ room_id,
+ events_missing_chain_info - event_to_auth_ids.keys(),
+ )
+ raise _NoChainCoverIndex(room_id)
+
+ # Create a map from event IDs we care about to their partial auth chain.
+ event_id_to_partial_auth_chain: Dict[str, Set[str]] = {}
+ for event_id, auth_ids in event_to_auth_ids.items():
+ if not any(event_id in state_set for state_set in state_sets):
+ continue
+
+ processing = set(auth_ids)
+ to_add = set()
+ while processing:
+ auth_id = processing.pop()
+ to_add.add(auth_id)
+
+ sub_auth_ids = event_to_auth_ids.get(auth_id)
+ if sub_auth_ids is None:
+ continue
+
+ processing.update(sub_auth_ids - to_add)
+
+ event_id_to_partial_auth_chain[event_id] = to_add
+
+ # Now we do two things:
+ # 1. Update the state sets to only include indexed events; and
+ # 2. Create a new list containing the auth chains of the un-indexed
+ # events
+ unindexed_state_sets: List[Set[str]] = []
+ for state_set in state_sets:
+ unindexed_state_set = set()
+ for event_id, auth_chain in event_id_to_partial_auth_chain.items():
+ if event_id not in state_set:
+ continue
+
+ unindexed_state_set.add(event_id)
+
+ state_set.discard(event_id)
+ state_set.difference_update(auth_chain)
+ for auth_id in auth_chain:
+ if auth_id in events_that_have_chain_index:
+ state_set.add(auth_id)
+ else:
+ unindexed_state_set.add(auth_id)
+
+ unindexed_state_sets.append(unindexed_state_set)
+
+ # Calculate and return the auth difference of the un-indexed events.
+ union = unindexed_state_sets[0].union(*unindexed_state_sets[1:])
+ intersection = unindexed_state_sets[0].intersection(*unindexed_state_sets[1:])
+
+ return union - intersection
+
def _get_auth_chain_difference_txn(
self, txn: LoggingTransaction, state_sets: List[Set[str]]
) -> Set[str]:
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 9c151a5e62..7a4ecab2d5 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -13,7 +13,19 @@
# limitations under the License.
import datetime
-from typing import Dict, List, Tuple, Union, cast
+from typing import (
+ Collection,
+ Dict,
+ FrozenSet,
+ Iterable,
+ List,
+ Mapping,
+ Set,
+ Tuple,
+ TypeVar,
+ Union,
+ cast,
+)
import attr
from parameterized import parameterized
@@ -38,6 +50,138 @@ from synapse.util import Clock, json_encoder
import tests.unittest
import tests.utils
+# The silly auth graph we use to test the auth difference algorithm,
+# where the top are the most recent events.
+#
+# A B
+# \ /
+# D E
+# \ |
+# ` F C
+# | /|
+# G ´ |
+# | \ |
+# H I
+# | |
+# K J
+
+AUTH_GRAPH: Dict[str, List[str]] = {
+ "a": ["e"],
+ "b": ["e"],
+ "c": ["g", "i"],
+ "d": ["f"],
+ "e": ["f"],
+ "f": ["g"],
+ "g": ["h", "i"],
+ "h": ["k"],
+ "i": ["j"],
+ "k": [],
+ "j": [],
+}
+
+DEPTH_GRAPH = {
+ "a": 7,
+ "b": 7,
+ "c": 4,
+ "d": 6,
+ "e": 6,
+ "f": 5,
+ "g": 3,
+ "h": 2,
+ "i": 2,
+ "k": 1,
+ "j": 1,
+}
+
+T = TypeVar("T")
+
+
+def get_all_topologically_sorted_orders(
+ nodes: Iterable[T],
+ graph: Mapping[T, Collection[T]],
+) -> List[List[T]]:
+ """Given a set of nodes and a graph, return all possible topological
+ orderings.
+ """
+
+ # This is implemented by Kahn's algorithm, and forking execution each time
+ # we have a choice over which node to consider next.
+
+ degree_map = {node: 0 for node in nodes}
+ reverse_graph: Dict[T, Set[T]] = {}
+
+ for node, edges in graph.items():
+ if node not in degree_map:
+ continue
+
+ for edge in set(edges):
+ if edge in degree_map:
+ degree_map[node] += 1
+
+ reverse_graph.setdefault(edge, set()).add(node)
+ reverse_graph.setdefault(node, set())
+
+ zero_degree = [node for node, degree in degree_map.items() if degree == 0]
+
+ return _get_all_topologically_sorted_orders_inner(
+ reverse_graph, zero_degree, degree_map
+ )
+
+
+def _get_all_topologically_sorted_orders_inner(
+ reverse_graph: Dict[T, Set[T]],
+ zero_degree: List[T],
+ degree_map: Dict[T, int],
+) -> List[List[T]]:
+ new_paths = []
+
+ # Rather than only choosing *one* item from the list of nodes with zero
+ # degree, we "fork" execution and run the algorithm for each node in the
+ # zero degree.
+ for node in zero_degree:
+ new_degree_map = degree_map.copy()
+ new_zero_degree = zero_degree.copy()
+ new_zero_degree.remove(node)
+
+ for edge in reverse_graph.get(node, []):
+ if edge in new_degree_map:
+ new_degree_map[edge] -= 1
+ if new_degree_map[edge] == 0:
+ new_zero_degree.append(edge)
+
+ paths = _get_all_topologically_sorted_orders_inner(
+ reverse_graph, new_zero_degree, new_degree_map
+ )
+ for path in paths:
+ path.insert(0, node)
+
+ new_paths.extend(paths)
+
+ if not new_paths:
+ return [[]]
+
+ return new_paths
+
+
+def get_all_topologically_consistent_subsets(
+ nodes: Iterable[T],
+ graph: Mapping[T, Collection[T]],
+) -> Set[FrozenSet[T]]:
+ """Get all subsets of the graph where if node N is in the subgraph, then all
+ nodes that can reach that node (i.e. for all X there exists a path X -> N)
+ are in the subgraph.
+ """
+ all_topological_orderings = get_all_topologically_sorted_orders(nodes, graph)
+
+ graph_subsets = set()
+ for ordering in all_topological_orderings:
+ ordering.reverse()
+
+ for idx in range(len(ordering)):
+ graph_subsets.add(frozenset(ordering[:idx]))
+
+ return graph_subsets
+
@attr.s(auto_attribs=True, frozen=True, slots=True)
class _BackfillSetupInfo:
@@ -172,49 +316,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def _setup_auth_chain(self, use_chain_cover_index: bool) -> str:
room_id = "@ROOM:local"
- # The silly auth graph we use to test the auth difference algorithm,
- # where the top are the most recent events.
- #
- # A B
- # \ /
- # D E
- # \ |
- # ` F C
- # | /|
- # G ´ |
- # | \ |
- # H I
- # | |
- # K J
-
- auth_graph: Dict[str, List[str]] = {
- "a": ["e"],
- "b": ["e"],
- "c": ["g", "i"],
- "d": ["f"],
- "e": ["f"],
- "f": ["g"],
- "g": ["h", "i"],
- "h": ["k"],
- "i": ["j"],
- "k": [],
- "j": [],
- }
-
- depth_map = {
- "a": 7,
- "b": 7,
- "c": 4,
- "d": 6,
- "e": 6,
- "f": 5,
- "g": 3,
- "h": 2,
- "i": 2,
- "k": 1,
- "j": 1,
- }
-
# Mark the room as maybe having a cover index.
def store_room(txn: LoggingTransaction) -> None:
@@ -238,9 +339,9 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
def insert_event(txn: LoggingTransaction) -> None:
stream_ordering = 0
- for event_id in auth_graph:
+ for event_id in AUTH_GRAPH:
stream_ordering += 1
- depth = depth_map[event_id]
+ depth = DEPTH_GRAPH[event_id]
self.store.db_pool.simple_insert_txn(
txn,
@@ -260,8 +361,8 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
self.persist_events._persist_event_auth_chain_txn(
txn,
[
- cast(EventBase, FakeEvent(event_id, room_id, auth_graph[event_id]))
- for event_id in auth_graph
+ cast(EventBase, FakeEvent(event_id, room_id, AUTH_GRAPH[event_id]))
+ for event_id in AUTH_GRAPH
],
)
@@ -344,7 +445,51 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
room_id = self._setup_auth_chain(use_chain_cover_index)
# Now actually test that various combinations give the right result:
+ self.assert_auth_diff_is_expected(room_id)
+
+ @parameterized.expand(
+ [
+ [graph_subset]
+ for graph_subset in get_all_topologically_consistent_subsets(
+ AUTH_GRAPH, AUTH_GRAPH
+ )
+ ]
+ )
+ def test_auth_difference_partial(self, graph_subset: Collection[str]) -> None:
+ """Test that if we only have a chain cover index on a partial subset of
+ the room we still get the correct auth chain difference.
+
+ We do this by removing the chain cover index for every valid subset of the
+ graph.
+ """
+ room_id = self._setup_auth_chain(True)
+
+ for event_id in graph_subset:
+ # Remove chain cover from that event.
+ self.get_success(
+ self.store.db_pool.simple_delete(
+ table="event_auth_chains",
+ keyvalues={"event_id": event_id},
+ desc="test_auth_difference_partial_remove",
+ )
+ )
+ self.get_success(
+ self.store.db_pool.simple_insert(
+ table="event_auth_chain_to_calculate",
+ values={
+ "event_id": event_id,
+ "room_id": room_id,
+ "type": "",
+ "state_key": "",
+ },
+ desc="test_auth_difference_partial_remove",
+ )
+ )
+
+ self.assert_auth_diff_is_expected(room_id)
+ def assert_auth_diff_is_expected(self, room_id: str) -> None:
+ """Assert the auth chain difference returns the correct answers."""
difference = self.get_success(
self.store.get_auth_chain_difference(room_id, [{"a"}, {"b"}])
)
|