summary refs log tree commit diff
path: root/synapse/state/v2.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/state/v2.py69
1 files changed, 37 insertions, 32 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index cf3045f82e..af03851c71 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -271,40 +271,41 @@ async def _get_power_level_for_sender(
 async def _get_auth_chain_difference(
     room_id: str,
     state_sets: Sequence[Mapping[Any, str]],
-    event_map: Dict[str, EventBase],
+    unpersisted_events: Dict[str, EventBase],
     state_res_store: StateResolutionStore,
 ) -> Set[str]:
     """Compare the auth chains of each state set and return the set of events
-    that only appear in some but not all of the auth chains.
+    that only appear in some, but not all of the auth chains.
 
     Args:
-        state_sets
-        event_map
-        state_res_store
+        state_sets: The input state sets we are trying to resolve across.
+        unpersisted_events: A map from event ID to EventBase containing all unpersisted
+            events involved in this resolution.
+        state_res_store:
 
     Returns:
-        Set of event IDs
+        The auth difference of the given state sets, as a set of event IDs.
     """
 
     # The `StateResolutionStore.get_auth_chain_difference` function assumes that
     # all events passed to it (and their auth chains) have been persisted
-    # previously. This is not the case for any events in the `event_map`, and so
-    # we need to manually handle those events.
+    # previously. We need to manually handle any other events that are yet to be
+    # persisted.
     #
-    # We do this by:
-    #   1. calculating the auth chain difference for the state sets based on the
-    #      events in `event_map` alone
-    #   2. replacing any events in the state_sets that are also in `event_map`
-    #      with their auth events (recursively), and then calling
-    #      `store.get_auth_chain_difference` as normal
-    #   3. adding the results of 1 and 2 together.
-
-    # Map from event ID in `event_map` to their auth event IDs, and their auth
-    # event IDs if they appear in the `event_map`. This is the intersection of
-    # the event's auth chain with the events in the `event_map` *plus* their
+    # We do this in three steps:
+    #   1. Compute the set of unpersisted events belonging to the auth difference.
+    #   2. Replacing any unpersisted events in the state_sets with their auth events,
+    #      recursively, until the state_sets contain only persisted events.
+    #      Then we call `store.get_auth_chain_difference` as normal, which computes
+    #      the set of persisted events belonging to the auth difference.
+    #   3. Adding the results of 1 and 2 together.
+
+    # Map from event ID in `unpersisted_events` to their auth event IDs, and their auth
+    # event IDs if they appear in the `unpersisted_events`. This is the intersection of
+    # the event's auth chain with the events in `unpersisted_events` *plus* their
     # auth event IDs.
     events_to_auth_chain: Dict[str, Set[str]] = {}
-    for event in event_map.values():
+    for event in unpersisted_events.values():
         chain = {event.event_id}
         events_to_auth_chain[event.event_id] = chain
 
@@ -312,16 +313,16 @@ async def _get_auth_chain_difference(
         while to_search:
             for auth_id in to_search.pop().auth_event_ids():
                 chain.add(auth_id)
-                auth_event = event_map.get(auth_id)
+                auth_event = unpersisted_events.get(auth_id)
                 if auth_event:
                     to_search.append(auth_event)
 
-    # We now a) calculate the auth chain difference for the unpersisted events
-    # and b) work out the state sets to pass to the store.
+    # We now 1) calculate the auth chain difference for the unpersisted events
+    # and 2) work out the state sets to pass to the store.
     #
-    # Note: If the `event_map` is empty (which is the common case), we can do a
+    # Note: If there are no `unpersisted_events` (which is the common case), we can do a
     # much simpler calculation.
-    if event_map:
+    if unpersisted_events:
         # The list of state sets to pass to the store, where each state set is a set
         # of the event ids making up the state. This is similar to `state_sets`,
         # except that (a) we only have event ids, not the complete
@@ -344,14 +345,18 @@ async def _get_auth_chain_difference(
             for event_id in state_set.values():
                 event_chain = events_to_auth_chain.get(event_id)
                 if event_chain is not None:
-                    # We have an event in `event_map`. We add all the auth
-                    # events that it references (that aren't also in `event_map`).
-                    set_ids.update(e for e in event_chain if e not in event_map)
+                    # We have an unpersisted event. We add all the auth
+                    # events that it references which are also unpersisted.
+                    set_ids.update(
+                        e for e in event_chain if e not in unpersisted_events
+                    )
 
                     # We also add the full chain of unpersisted event IDs
                     # referenced by this state set, so that we can work out the
                     # auth chain difference of the unpersisted events.
-                    unpersisted_ids.update(e for e in event_chain if e in event_map)
+                    unpersisted_ids.update(
+                        e for e in event_chain if e in unpersisted_events
+                    )
                 else:
                     set_ids.add(event_id)
 
@@ -361,15 +366,15 @@ async def _get_auth_chain_difference(
         union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
         intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
 
-        difference_from_event_map: Collection[str] = union - intersection
+        auth_difference_unpersisted_part: Collection[str] = union - intersection
     else:
-        difference_from_event_map = ()
+        auth_difference_unpersisted_part = ()
         state_sets_ids = [set(state_set.values()) for state_set in state_sets]
 
     difference = await state_res_store.get_auth_chain_difference(
         room_id, state_sets_ids
     )
-    difference.update(difference_from_event_map)
+    difference.update(auth_difference_unpersisted_part)
 
     return difference