summary refs log tree commit diff
path: root/synapse/state/v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/v2.py')
-rw-r--r--synapse/state/v2.py87
1 files changed, 83 insertions, 4 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f57df0d728..ffc504ce77 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -252,9 +252,88 @@ async def _get_auth_chain_difference(
         Set of event IDs
     """
 
-    difference = await state_res_store.get_auth_chain_difference(
-        [set(state_set.values()) for state_set in state_sets]
-    )
+    # The `StateResolutionStore.get_auth_chain_difference` function assumes that
+    # all events passed to it (and their auth chains) have been persisted
+    # previously. This is not the case for any events in the `event_map`, and so
+    # we need to manually handle those events.
+    #
+    # We do this by:
+    #   1. calculating the auth chain difference for the state sets based on the
+    #      events in `event_map` alone
+    #   2. replacing any events in the state_sets that are also in `event_map`
+    #      with their auth events (recursively), and then calling
+    #      `store.get_auth_chain_difference` as normal
+    #   3. adding the results of 1 and 2 together.
+
+    # Map from event ID in `event_map` to their auth event IDs, and their auth
+    # event IDs if they appear in the `event_map`. This is the intersection of
+    # the event's auth chain with the events in the `event_map` *plus* their
+    # auth event IDs.
+    events_to_auth_chain = {}  # type: Dict[str, Set[str]]
+    for event in event_map.values():
+        chain = {event.event_id}
+        events_to_auth_chain[event.event_id] = chain
+
+        to_search = [event]
+        while to_search:
+            for auth_id in to_search.pop().auth_event_ids():
+                chain.add(auth_id)
+                auth_event = event_map.get(auth_id)
+                if auth_event:
+                    to_search.append(auth_event)
+
+    # We now a) calculate the auth chain difference for the unpersisted events
+    # and b) work out the state sets to pass to the store.
+    #
+    # Note: If the `event_map` is empty (which is the common case), we can do a
+    # much simpler calculation.
+    if event_map:
+        # The list of state sets to pass to the store, where each state set is a set
+        # of the event ids making up the state. This is similar to `state_sets`,
+        # except that (a) we only have event ids, not the complete
+        # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+        # unpersisted events and replaced them with the persisted events in
+        # their auth chain.
+        state_sets_ids = []  # type: List[Set[str]]
+
+        # For each state set, the unpersisted event IDs reachable (by their auth
+        # chain) from the events in that set.
+        unpersisted_set_ids = []  # type: List[Set[str]]
+
+        for state_set in state_sets:
+            set_ids = set()  # type: Set[str]
+            state_sets_ids.append(set_ids)
+
+            unpersisted_ids = set()  # type: Set[str]
+            unpersisted_set_ids.append(unpersisted_ids)
+
+            for event_id in state_set.values():
+                event_chain = events_to_auth_chain.get(event_id)
+                if event_chain is not None:
+                    # We have an event in `event_map`. We add all the auth
+                    # events that it references (that aren't also in `event_map`).
+                    set_ids.update(e for e in event_chain if e not in event_map)
+
+                    # We also add the full chain of unpersisted event IDs
+                    # referenced by this state set, so that we can work out the
+                    # auth chain difference of the unpersisted events.
+                    unpersisted_ids.update(e for e in event_chain if e in event_map)
+                else:
+                    set_ids.add(event_id)
+
+        # The auth chain difference of the unpersisted events of the state sets
+        # is calculated by taking the difference between the union and
+        # intersections.
+        union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+        intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+        difference_from_event_map = union - intersection  # type: Collection[str]
+    else:
+        difference_from_event_map = ()
+        state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+
+    difference = await state_res_store.get_auth_chain_difference(state_sets_ids)
+    difference.update(difference_from_event_map)
 
     return difference