diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index e224af8dd8..408d375439 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -124,7 +124,7 @@ async def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
auth_diff = await _get_auth_chain_difference(
- room_id, state_sets, event_map, state_res_store
+ room_id, state_sets, event_map, state_res_store, clock
)
with Measure(clock, "rei_state_res:rews2_b"): # TODO temporary (rei)
@@ -284,6 +284,7 @@ async def _get_auth_chain_difference(
state_sets: Sequence[StateMap[str]],
unpersisted_events: Dict[str, EventBase],
state_res_store: StateResolutionStore,
+ clock: Clock,
) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
that only appear in some, but not all of the auth chains.
@@ -315,77 +316,82 @@ async def _get_auth_chain_difference(
# event IDs if they appear in the `unpersisted_events`. This is the intersection of
# the event's auth chain with the events in `unpersisted_events` *plus* their
# auth event IDs.
- events_to_auth_chain: Dict[str, Set[str]] = {}
- for event in unpersisted_events.values():
- chain = {event.event_id}
- events_to_auth_chain[event.event_id] = chain
-
- to_search = [event]
- while to_search:
- for auth_id in to_search.pop().auth_event_ids():
- chain.add(auth_id)
- auth_event = unpersisted_events.get(auth_id)
- if auth_event:
- to_search.append(auth_event)
+ with Measure(clock, "rei_state_res:rews2_a1"): # TODO temporary (rei)
+ events_to_auth_chain: Dict[str, Set[str]] = {}
+ for event in unpersisted_events.values():
+ chain = {event.event_id}
+ events_to_auth_chain[event.event_id] = chain
+
+ to_search = [event]
+ while to_search:
+ for auth_id in to_search.pop().auth_event_ids():
+ chain.add(auth_id)
+ auth_event = unpersisted_events.get(auth_id)
+ if auth_event:
+ to_search.append(auth_event)
# We now 1) calculate the auth chain difference for the unpersisted events
# and 2) work out the state sets to pass to the store.
#
# Note: If there are no `unpersisted_events` (which is the common case), we can do a
# much simpler calculation.
- if unpersisted_events:
- # The list of state sets to pass to the store, where each state set is a set
- # of the event ids making up the state. This is similar to `state_sets`,
- # except that (a) we only have event ids, not the complete
- # ((type, state_key)->event_id) mappings; and (b) we have stripped out
- # unpersisted events and replaced them with the persisted events in
- # their auth chain.
- state_sets_ids: List[Set[str]] = []
-
- # For each state set, the unpersisted event IDs reachable (by their auth
- # chain) from the events in that set.
- unpersisted_set_ids: List[Set[str]] = []
-
- for state_set in state_sets:
- set_ids: Set[str] = set()
- state_sets_ids.append(set_ids)
-
- unpersisted_ids: Set[str] = set()
- unpersisted_set_ids.append(unpersisted_ids)
-
- for event_id in state_set.values():
- event_chain = events_to_auth_chain.get(event_id)
- if event_chain is not None:
- # We have an unpersisted event. We add all the auth
- # events that it references which are also unpersisted.
- set_ids.update(
- e for e in event_chain if e not in unpersisted_events
- )
-
- # We also add the full chain of unpersisted event IDs
- # referenced by this state set, so that we can work out the
- # auth chain difference of the unpersisted events.
- unpersisted_ids.update(
- e for e in event_chain if e in unpersisted_events
- )
- else:
- set_ids.add(event_id)
-
- # The auth chain difference of the unpersisted events of the state sets
- # is calculated by taking the difference between the union and
- # intersections.
- union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
- intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+ with Measure(clock, "rei_state_res:rews2_a2"): # TODO temporary (rei)
+ if unpersisted_events:
+ # The list of state sets to pass to the store, where each state set is a set
+ # of the event ids making up the state. This is similar to `state_sets`,
+ # except that (a) we only have event ids, not the complete
+ # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+ # unpersisted events and replaced them with the persisted events in
+ # their auth chain.
+ state_sets_ids: List[Set[str]] = []
+
+ # For each state set, the unpersisted event IDs reachable (by their auth
+ # chain) from the events in that set.
+ unpersisted_set_ids: List[Set[str]] = []
+
+ for state_set in state_sets:
+ set_ids: Set[str] = set()
+ state_sets_ids.append(set_ids)
+
+ unpersisted_ids: Set[str] = set()
+ unpersisted_set_ids.append(unpersisted_ids)
+
+ for event_id in state_set.values():
+ event_chain = events_to_auth_chain.get(event_id)
+ if event_chain is not None:
+ # We have an unpersisted event. We add all the auth
+ # events that it references which are also unpersisted.
+ set_ids.update(
+ e for e in event_chain if e not in unpersisted_events
+ )
+
+ # We also add the full chain of unpersisted event IDs
+ # referenced by this state set, so that we can work out the
+ # auth chain difference of the unpersisted events.
+ unpersisted_ids.update(
+ e for e in event_chain if e in unpersisted_events
+ )
+ else:
+ set_ids.add(event_id)
+
+ # The auth chain difference of the unpersisted events of the state sets
+ # is calculated by taking the difference between the union and
+ # intersections.
+ union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+ intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+ auth_difference_unpersisted_part: StrCollection = union - intersection
+ else:
+ auth_difference_unpersisted_part = ()
+ state_sets_ids = [set(state_set.values()) for state_set in state_sets]
- auth_difference_unpersisted_part: StrCollection = union - intersection
- else:
- auth_difference_unpersisted_part = ()
- state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+ with Measure(clock, "rei_state_res:rews2_a3"): # TODO temporary (rei)
+ difference = await state_res_store.get_auth_chain_difference(
+ room_id, state_sets_ids
+ )
- difference = await state_res_store.get_auth_chain_difference(
- room_id, state_sets_ids
- )
- difference.update(auth_difference_unpersisted_part)
+ with Measure(clock, "rei_state_res:rews2_a4"): # TODO temporary (rei)
+ difference.update(auth_difference_unpersisted_part)
return difference
|