diff options
author | Erik Johnston <erik@matrix.org> | 2022-11-17 16:13:23 +0000 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2022-11-17 16:13:23 +0000 |
commit | 0e99b0bbd02c85e7b8f99ea7528527c4ac1e1cd6 (patch) | |
tree | 498d4e1ce8517551efa2128b52c6194eb88ea0e8 | |
parent | Reintroduce #14376, with bugfix for monoliths (#14468) (diff) | |
download | synapse-0e99b0bbd02c85e7b8f99ea7528527c4ac1e1cd6.tar.xz |
Implement closure of conflicted state events
-rw-r--r-- | synapse/state/v2.py | 46 | ||||
-rw-r--r-- | synapse/storage/databases/main/event_federation.py | 49 |
2 files changed, 86 insertions, 9 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py index 1b9d7d8457..e15cfbd4b3 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -61,7 +61,11 @@ class StateResolutionStore(Protocol): ... def get_auth_chain_difference( - self, room_id: str, state_sets: List[Set[str]] + self, + room_id: str, + state_sets: List[Set[str]], + conflicted_state_ids: Set[str], + conflicted_boundary: Set[str], ) -> Awaitable[Set[str]]: ... @@ -122,10 +126,12 @@ async def resolve_events_with_store( logger.debug("%d conflicted state entries", len(conflicted_state)) logger.debug("Calculating auth chain difference") + conflicted_state_ids = set(itertools.chain.from_iterable(conflicted_state.values())) + # Also fetch all auth events that appear in only some of the state sets' # auth chains. auth_diff = await _get_auth_chain_difference( - room_id, state_sets, event_map, state_res_store + room_id, state_sets, event_map, conflicted_state_ids, state_res_store ) full_conflicted_set = set( @@ -272,6 +278,7 @@ async def _get_auth_chain_difference( room_id: str, state_sets: Sequence[Mapping[Any, str]], unpersisted_events: Dict[str, EventBase], + conflicted_state_ids: Set[str], state_res_store: StateResolutionStore, ) -> Set[str]: """Compare the auth chains of each state set and return the set of events @@ -367,15 +374,46 @@ async def _get_auth_chain_difference( intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:]) auth_difference_unpersisted_part: Collection[str] = union - intersection + + persisted_conflicted_state_ids = { + event_id for event_id in conflicted_state_ids if event_id not in union + } + + boundary = state_sets_ids[0].union(*state_sets_ids[1:]) + conflicted_boundary = set() + + for event_id in persisted_conflicted_state_ids: + auth_chain = events_to_auth_chain.get(event_id) + if not auth_chain: + continue + + conflicted_boundary != auth_chain & boundary + else: auth_difference_unpersisted_part = () + conflicted_boundary = set() + persisted_conflicted_state_ids = conflicted_state_ids state_sets_ids = [set(state_set.values()) for state_set in state_sets] - difference = await state_res_store.get_auth_chain_difference( - room_id, state_sets_ids + difference, conflicted_boundary = await state_res_store.get_auth_chain_difference( + room_id, + state_sets_ids, + persisted_conflicted_state_ids, + conflicted_boundary, ) difference.update(auth_difference_unpersisted_part) + unpersisted_conflicted_state_ids = ( + conflicted_state_ids - persisted_conflicted_state_ids + ) + for boundary_event_id in conflicted_boundary: + for conflicted_id in unpersisted_conflicted_state_ids: + auth_chain = events_to_auth_chain[conflicted_id] + if boundary_event_id not in auth_chain: + continue + + # TODO: Include all paths from conflicted_id -> boundary_id in difference. + return difference diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index 309a4ba664..a8443dc0e2 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -377,7 +377,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas return results async def get_auth_chain_difference( - self, room_id: str, state_sets: List[Set[str]] + self, + room_id: str, + state_sets: List[Set[str]], + conflicted_state_ids: Set[str], + conflicted_boundary: Set[str], ) -> Set[str]: """Given sets of state events figure out the auth chain difference (as per state res v2 algorithm). @@ -400,12 +404,17 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas self._get_auth_chain_difference_using_cover_index_txn, room_id, state_sets, + conflicted_state_ids, + conflicted_boundary, ) except _NoChainCoverIndex: # For whatever reason we don't actually have a chain cover index # for the events in question, so we fall back to the old method. pass + if conflicted_boundary: + raise NotImplementedError() + return await self.db_pool.runInteraction( "get_auth_chain_difference", self._get_auth_chain_difference_txn, @@ -413,8 +422,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) def _get_auth_chain_difference_using_cover_index_txn( - self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]] - ) -> Set[str]: + self, + txn: LoggingTransaction, + room_id: str, + state_sets: List[Set[str]], + conflicted_state_ids: Set[str], + conflicted_boundary: Set[str], + ) -> Tuple[Set[str], Set[str]]: """Calculates the auth chain difference using the chain index. See docs/auth_chain_difference_algorithm.md for details @@ -521,10 +535,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas # pulled from the database. chain_to_gap: Dict[int, Tuple[int, int]] = {} + conflicted_state_chain_ids: Dict[str, List[int]] = {} + for event_id in conflicted_state_ids: + chain_id, seq_no = chain_info[event_id] + conflicted_state_chain_ids.setdefault(chain_id, []).append(seq_no) + + # Filter down the conflicted boundary to only include events that can + # reach conflicted state. + conflicted_boundary_reaches_conflicted = set() + for event_id in conflicted_boundary: + chain_id, seq_no = chain_info[event_id] + min_seq_nos = conflicted_state_chain_ids.get(chain_id) + if min_seq_nos is not None and seq_no >= min(min_seq_nos): + conflicted_boundary_reaches_conflicted.add(event_id) + + for event_id in conflicted_boundary_reaches_conflicted: + chain_id, seq_no = chain_info[event_id] + conflicted_state_chain_ids[chain_id].append(seq_no) + for chain_id in seen_chains: min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain) max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain) + # Now do the closure by increasing the bounds of the range to the + # min and max of those in the conflicted state IDs + seq_nos = conflicted_state_chain_ids.get(chain_id) + for seq_no in seq_nos: + min_seq_no = min(seq_no, min_seq_no) + max_seq_no = min(seq_no, max_seq_no) + if min_seq_no < max_seq_no: # We have a non empty gap, try and fill it from the events that # we have, otherwise add them to the list of gaps to pull out @@ -539,7 +578,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas if not chain_to_gap: # If there are no gaps to fetch, we're done! - return result + return result, conflicted_boundary_reaches_conflicted if isinstance(self.database_engine, PostgresEngine): # We can use `execute_values` to efficiently fetch the gaps when @@ -569,7 +608,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas txn.execute(sql, (chain_id, min_no, max_no)) result.update(r for r, in txn) - return result + return result, conflicted_boundary_reaches_conflicted def _get_auth_chain_difference_txn( self, txn: LoggingTransaction, state_sets: List[Set[str]] |