summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-11-17 16:13:23 +0000
committerErik Johnston <erik@matrix.org>2022-11-17 16:13:23 +0000
commit0e99b0bbd02c85e7b8f99ea7528527c4ac1e1cd6 (patch)
tree498d4e1ce8517551efa2128b52c6194eb88ea0e8
parentReintroduce #14376, with bugfix for monoliths (#14468) (diff)
downloadsynapse-0e99b0bbd02c85e7b8f99ea7528527c4ac1e1cd6.tar.xz
Implement closure of conflicted state events
-rw-r--r--synapse/state/v2.py46
-rw-r--r--synapse/storage/databases/main/event_federation.py49
2 files changed, 86 insertions, 9 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 1b9d7d8457..e15cfbd4b3 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -61,7 +61,11 @@ class StateResolutionStore(Protocol):
         ...
 
     def get_auth_chain_difference(
-        self, room_id: str, state_sets: List[Set[str]]
+        self,
+        room_id: str,
+        state_sets: List[Set[str]],
+        conflicted_state_ids: Set[str],
+        conflicted_boundary: Set[str],
     ) -> Awaitable[Set[str]]:
         ...
 
@@ -122,10 +126,12 @@ async def resolve_events_with_store(
     logger.debug("%d conflicted state entries", len(conflicted_state))
     logger.debug("Calculating auth chain difference")
 
+    conflicted_state_ids = set(itertools.chain.from_iterable(conflicted_state.values()))
+
     # Also fetch all auth events that appear in only some of the state sets'
     # auth chains.
     auth_diff = await _get_auth_chain_difference(
-        room_id, state_sets, event_map, state_res_store
+        room_id, state_sets, event_map, conflicted_state_ids, state_res_store
     )
 
     full_conflicted_set = set(
@@ -272,6 +278,7 @@ async def _get_auth_chain_difference(
     room_id: str,
     state_sets: Sequence[Mapping[Any, str]],
     unpersisted_events: Dict[str, EventBase],
+    conflicted_state_ids: Set[str],
     state_res_store: StateResolutionStore,
 ) -> Set[str]:
     """Compare the auth chains of each state set and return the set of events
@@ -367,15 +374,46 @@ async def _get_auth_chain_difference(
         intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
 
         auth_difference_unpersisted_part: Collection[str] = union - intersection
+
+        persisted_conflicted_state_ids = {
+            event_id for event_id in conflicted_state_ids if event_id not in union
+        }
+
+        boundary = state_sets_ids[0].union(*state_sets_ids[1:])
+        conflicted_boundary = set()
+
+        for event_id in persisted_conflicted_state_ids:
+            auth_chain = events_to_auth_chain.get(event_id)
+            if not auth_chain:
+                continue
+
+            conflicted_boundary != auth_chain & boundary
+
     else:
         auth_difference_unpersisted_part = ()
+        conflicted_boundary = set()
+        persisted_conflicted_state_ids = conflicted_state_ids
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]
 
-    difference = await state_res_store.get_auth_chain_difference(
-        room_id, state_sets_ids
+    difference, conflicted_boundary = await state_res_store.get_auth_chain_difference(
+        room_id,
+        state_sets_ids,
+        persisted_conflicted_state_ids,
+        conflicted_boundary,
     )
     difference.update(auth_difference_unpersisted_part)
 
+    unpersisted_conflicted_state_ids = (
+        conflicted_state_ids - persisted_conflicted_state_ids
+    )
+    for boundary_event_id in conflicted_boundary:
+        for conflicted_id in unpersisted_conflicted_state_ids:
+            auth_chain = events_to_auth_chain[conflicted_id]
+            if boundary_event_id not in auth_chain:
+                continue
+
+            # TODO: Include all paths from conflicted_id -> boundary_id in difference.
+
     return difference
 
 
diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py
index 309a4ba664..a8443dc0e2 100644
--- a/synapse/storage/databases/main/event_federation.py
+++ b/synapse/storage/databases/main/event_federation.py
@@ -377,7 +377,11 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         return results
 
     async def get_auth_chain_difference(
-        self, room_id: str, state_sets: List[Set[str]]
+        self,
+        room_id: str,
+        state_sets: List[Set[str]],
+        conflicted_state_ids: Set[str],
+        conflicted_boundary: Set[str],
     ) -> Set[str]:
         """Given sets of state events figure out the auth chain difference (as
         per state res v2 algorithm).
@@ -400,12 +404,17 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                     self._get_auth_chain_difference_using_cover_index_txn,
                     room_id,
                     state_sets,
+                    conflicted_state_ids,
+                    conflicted_boundary,
                 )
             except _NoChainCoverIndex:
                 # For whatever reason we don't actually have a chain cover index
                 # for the events in question, so we fall back to the old method.
                 pass
 
+        if conflicted_boundary:
+            raise NotImplementedError()
+
         return await self.db_pool.runInteraction(
             "get_auth_chain_difference",
             self._get_auth_chain_difference_txn,
@@ -413,8 +422,13 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         )
 
     def _get_auth_chain_difference_using_cover_index_txn(
-        self, txn: LoggingTransaction, room_id: str, state_sets: List[Set[str]]
-    ) -> Set[str]:
+        self,
+        txn: LoggingTransaction,
+        room_id: str,
+        state_sets: List[Set[str]],
+        conflicted_state_ids: Set[str],
+        conflicted_boundary: Set[str],
+    ) -> Tuple[Set[str], Set[str]]:
         """Calculates the auth chain difference using the chain index.
 
         See docs/auth_chain_difference_algorithm.md for details
@@ -521,10 +535,35 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
         # pulled from the database.
         chain_to_gap: Dict[int, Tuple[int, int]] = {}
 
+        conflicted_state_chain_ids: Dict[str, List[int]] = {}
+        for event_id in conflicted_state_ids:
+            chain_id, seq_no = chain_info[event_id]
+            conflicted_state_chain_ids.setdefault(chain_id, []).append(seq_no)
+
+        # Filter down the conflicted boundary to only include events that can
+        # reach conflicted state.
+        conflicted_boundary_reaches_conflicted = set()
+        for event_id in conflicted_boundary:
+            chain_id, seq_no = chain_info[event_id]
+            min_seq_nos = conflicted_state_chain_ids.get(chain_id)
+            if min_seq_nos is not None and seq_no >= min(min_seq_nos):
+                conflicted_boundary_reaches_conflicted.add(event_id)
+
+        for event_id in conflicted_boundary_reaches_conflicted:
+            chain_id, seq_no = chain_info[event_id]
+            conflicted_state_chain_ids[chain_id].append(seq_no)
+
         for chain_id in seen_chains:
             min_seq_no = min(chains.get(chain_id, 0) for chains in set_to_chain)
             max_seq_no = max(chains.get(chain_id, 0) for chains in set_to_chain)
 
+            # Now do the closure by increasing the bounds of the range to the
+            # min and max of those in the conflicted state IDs
+            seq_nos = conflicted_state_chain_ids.get(chain_id)
+            for seq_no in seq_nos:
+                min_seq_no = min(seq_no, min_seq_no)
+                max_seq_no = min(seq_no, max_seq_no)
+
             if min_seq_no < max_seq_no:
                 # We have a non empty gap, try and fill it from the events that
                 # we have, otherwise add them to the list of gaps to pull out
@@ -539,7 +578,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
 
         if not chain_to_gap:
             # If there are no gaps to fetch, we're done!
-            return result
+            return result, conflicted_boundary_reaches_conflicted
 
         if isinstance(self.database_engine, PostgresEngine):
             # We can use `execute_values` to efficiently fetch the gaps when
@@ -569,7 +608,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas
                 txn.execute(sql, (chain_id, min_no, max_no))
                 result.update(r for r, in txn)
 
-        return result
+        return result, conflicted_boundary_reaches_conflicted
 
     def _get_auth_chain_difference_txn(
         self, txn: LoggingTransaction, state_sets: List[Set[str]]