diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f57df0d728..f85124bf81 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
from synapse.util import Clock
logger = logging.getLogger(__name__)
@@ -97,7 +97,9 @@ async def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
- auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
+ auth_diff = await _get_auth_chain_difference(
+ room_id, state_sets, event_map, state_res_store
+ )
full_conflicted_set = set(
itertools.chain(
@@ -236,6 +238,7 @@ async def _get_power_level_for_sender(
async def _get_auth_chain_difference(
+ room_id: str,
state_sets: Sequence[StateMap[str]],
event_map: Dict[str, EventBase],
state_res_store: "synapse.state.StateResolutionStore",
@@ -252,9 +255,90 @@ async def _get_auth_chain_difference(
Set of event IDs
"""
+ # The `StateResolutionStore.get_auth_chain_difference` function assumes that
+ # all events passed to it (and their auth chains) have been persisted
+ # previously. This is not the case for any events in the `event_map`, and so
+ # we need to manually handle those events.
+ #
+ # We do this by:
+ # 1. calculating the auth chain difference for the state sets based on the
+ # events in `event_map` alone
+ # 2. replacing any events in the state_sets that are also in `event_map`
+ # with their auth events (recursively), and then calling
+ # `store.get_auth_chain_difference` as normal
+ # 3. adding the results of 1 and 2 together.
+
+ # Map from event ID in `event_map` to their auth event IDs, and their auth
+ # event IDs if they appear in the `event_map`. This is the intersection of
+ # the event's auth chain with the events in the `event_map` *plus* their
+ # auth event IDs.
+ events_to_auth_chain = {} # type: Dict[str, Set[str]]
+ for event in event_map.values():
+ chain = {event.event_id}
+ events_to_auth_chain[event.event_id] = chain
+
+ to_search = [event]
+ while to_search:
+ for auth_id in to_search.pop().auth_event_ids():
+ chain.add(auth_id)
+ auth_event = event_map.get(auth_id)
+ if auth_event:
+ to_search.append(auth_event)
+
+ # We now a) calculate the auth chain difference for the unpersisted events
+ # and b) work out the state sets to pass to the store.
+ #
+ # Note: If the `event_map` is empty (which is the common case), we can do a
+ # much simpler calculation.
+ if event_map:
+ # The list of state sets to pass to the store, where each state set is a set
+ # of the event ids making up the state. This is similar to `state_sets`,
+ # except that (a) we only have event ids, not the complete
+ # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+ # unpersisted events and replaced them with the persisted events in
+ # their auth chain.
+ state_sets_ids = [] # type: List[Set[str]]
+
+ # For each state set, the unpersisted event IDs reachable (by their auth
+ # chain) from the events in that set.
+ unpersisted_set_ids = [] # type: List[Set[str]]
+
+ for state_set in state_sets:
+ set_ids = set() # type: Set[str]
+ state_sets_ids.append(set_ids)
+
+ unpersisted_ids = set() # type: Set[str]
+ unpersisted_set_ids.append(unpersisted_ids)
+
+ for event_id in state_set.values():
+ event_chain = events_to_auth_chain.get(event_id)
+ if event_chain is not None:
+ # We have an event in `event_map`. We add all the auth
+ # events that it references (that aren't also in `event_map`).
+ set_ids.update(e for e in event_chain if e not in event_map)
+
+ # We also add the full chain of unpersisted event IDs
+ # referenced by this state set, so that we can work out the
+ # auth chain difference of the unpersisted events.
+ unpersisted_ids.update(e for e in event_chain if e in event_map)
+ else:
+ set_ids.add(event_id)
+
+ # The auth chain difference of the unpersisted events of the state sets
+ # is calculated by taking the difference between the union and
+ # intersections.
+ union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+ intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+ difference_from_event_map = union - intersection # type: Collection[str]
+ else:
+ difference_from_event_map = ()
+ state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+
difference = await state_res_store.get_auth_chain_difference(
- [set(state_set.values()) for state_set in state_sets]
+ room_id, state_sets_ids
)
+ difference.update(difference_from_event_map)
return difference
|