diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 1b9d7d8457..e15cfbd4b3 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -61,7 +61,11 @@ class StateResolutionStore(Protocol):
...
def get_auth_chain_difference(
- self, room_id: str, state_sets: List[Set[str]]
+ self,
+ room_id: str,
+ state_sets: List[Set[str]],
+ conflicted_state_ids: Set[str],
+ conflicted_boundary: Set[str],
) -> Awaitable[Set[str]]:
...
@@ -122,10 +126,12 @@ async def resolve_events_with_store(
logger.debug("%d conflicted state entries", len(conflicted_state))
logger.debug("Calculating auth chain difference")
+ conflicted_state_ids = set(itertools.chain.from_iterable(conflicted_state.values()))
+
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
auth_diff = await _get_auth_chain_difference(
- room_id, state_sets, event_map, state_res_store
+ room_id, state_sets, event_map, conflicted_state_ids, state_res_store
)
full_conflicted_set = set(
@@ -272,6 +278,7 @@ async def _get_auth_chain_difference(
room_id: str,
state_sets: Sequence[Mapping[Any, str]],
unpersisted_events: Dict[str, EventBase],
+ conflicted_state_ids: Set[str],
state_res_store: StateResolutionStore,
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
@@ -367,15 +374,46 @@ async def _get_auth_chain_difference(
intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
auth_difference_unpersisted_part: Collection[str] = union - intersection
+
+ persisted_conflicted_state_ids = {
+ event_id for event_id in conflicted_state_ids if event_id not in union
+ }
+
+ boundary = state_sets_ids[0].union(*state_sets_ids[1:])
+ conflicted_boundary = set()
+
+ for event_id in persisted_conflicted_state_ids:
+ auth_chain = events_to_auth_chain.get(event_id)
+ if not auth_chain:
+ continue
+
+ conflicted_boundary != auth_chain & boundary
+
else:
auth_difference_unpersisted_part = ()
+ conflicted_boundary = set()
+ persisted_conflicted_state_ids = conflicted_state_ids
state_sets_ids = [set(state_set.values()) for state_set in state_sets]
- difference = await state_res_store.get_auth_chain_difference(
- room_id, state_sets_ids
+ difference, conflicted_boundary = await state_res_store.get_auth_chain_difference(
+ room_id,
+ state_sets_ids,
+ persisted_conflicted_state_ids,
+ conflicted_boundary,
)
difference.update(auth_difference_unpersisted_part)
+ unpersisted_conflicted_state_ids = (
+ conflicted_state_ids - persisted_conflicted_state_ids
+ )
+ for boundary_event_id in conflicted_boundary:
+ for conflicted_id in unpersisted_conflicted_state_ids:
+ auth_chain = events_to_auth_chain[conflicted_id]
+ if boundary_event_id not in auth_chain:
+ continue
+
+ # TODO: Include all paths from conflicted_id -> boundary_id in difference.
+
return difference
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 309a4ba664..a8443dc0e2 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -377,7 +377,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
return results
async def get_auth_chain_difference(
- self, room_id: str, state_sets: List[Set[str]]
+ self,
+ room_id: str,
+ state_sets: List[Set[str]],
+ conflicted_state_ids: Set[str],
+ conflicted_boundary: Set[str],
) -> Set[str]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -400,12 +404,17 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
self._get_auth_chain_difference_using_cover_index_txn,
room_id,
state_sets,
+ conflicted_state_ids,
+ conflicted_boundary,
)
except _NoChainCoverIndex:
# For whatever reason we don't actually have a chain cover index
# for the events in question, so we fall back to the old method.
pass
+ if conflicted_boundary:
+ raise NotImplementedError()
+
return await self.db_pool.runInteraction(
"get_auth_chain_difference",
self._get_auth_chain_difference_txn,
@@ -413,8 +422,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
)
def _get_auth_chain_difference_using_cover_index_txn(
- self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
- ) -> Set[str]:
+ self,
+ txn: LoggingTransaction,
+ room_id: str,
+ state_sets: List[Set[str]],
+ conflicted_state_ids: Set[str],
+ conflicted_boundary: Set[str],
+ ) -> Tuple[Set[str], Set[str]]:
"""Calculates the auth chain difference using the chain index.
See docs/auth_chain_difference_algorithm.md for details
@@ -521,10 +535,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
# pulled from the database.
chain_to_gap: Dict[int, Tuple[int, int]] = {}
+ conflicted_state_chain_ids: Dict[str, List[int]] = {}
+ for event_id in conflicted_state_ids:
+ chain_id, seq_no = chain_info[event_id]
+ conflicted_state_chain_ids.setdefault(chain_id, []).append(seq_no)
+
+ # Filter down the conflicted boundary to only include events that can
+ # reach conflicted state.
+ conflicted_boundary_reaches_conflicted = set()
+ for event_id in conflicted_boundary:
+ chain_id, seq_no = chain_info[event_id]
+ min_seq_nos = conflicted_state_chain_ids.get(chain_id)
+ if min_seq_nos is not None and seq_no >= min(min_seq_nos):
+ conflicted_boundary_reaches_conflicted.add(event_id)
+
+ for event_id in conflicted_boundary_reaches_conflicted:
+ chain_id, seq_no = chain_info[event_id]
+ conflicted_state_chain_ids[chain_id].append(seq_no)
+
for chain_id in seen_chains:
min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
+ # Now do the closure by increasing the bounds of the range to the
+ # min and max of those in the conflicted state IDs
+ seq_nos = conflicted_state_chain_ids.get(chain_id)
+ for seq_no in seq_nos:
+ min_seq_no = min(seq_no, min_seq_no)
+ max_seq_no = min(seq_no, max_seq_no)
+
if min_seq_no < max_seq_no:
# We have a non empty gap, try and fill it from the events that
# we have, otherwise add them to the list of gaps to pull out
@@ -539,7 +578,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
if not chain_to_gap:
# If there are no gaps to fetch, we're done!
- return result
+ return result, conflicted_boundary_reaches_conflicted
if isinstance(self.database_engine, PostgresEngine):
# We can use `execute_values` to efficiently fetch the gaps when
@@ -569,7 +608,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
txn.execute(sql, (chain_id, min_no, max_no))
result.update(r for r, in txn)
- return result
+ return result, conflicted_boundary_reaches_conflicted
def _get_auth_chain_difference_txn(
self, txn: LoggingTransaction, state_sets: List[Set[str]]
|