diff options
Diffstat (limited to 'synapse/state/v2.py')
-rw-r--r-- | synapse/state/v2.py | 107 |
1 files changed, 49 insertions, 58 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py index bf6caa0946..6634955cdc 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -18,8 +18,6 @@ import itertools import logging from typing import Dict, List, Optional -from twisted.internet import defer - import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes @@ -32,14 +30,13 @@ from synapse.util import Clock logger = logging.getLogger(__name__) -# We want to yield to the reactor occasionally during state res when dealing +# We want to await to the reactor occasionally during state res when dealing # with large data sets, so that we don't exhaust the reactor. This is done by -# yielding to reactor during loops every N iterations. -_YIELD_AFTER_ITERATIONS = 100 +# awaiting to reactor during loops every N iterations. +_AWAIT_AFTER_ITERATIONS = 100 -@defer.inlineCallbacks -def resolve_events_with_store( +async def resolve_events_with_store( clock: Clock, room_id: str, room_version: str, @@ -87,7 +84,7 @@ def resolve_events_with_store( # Also fetch all auth events that appear in only some of the state sets' # auth chains. - auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store) + auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store) full_conflicted_set = set( itertools.chain( @@ -95,7 +92,7 @@ def resolve_events_with_store( ) ) - events = yield state_res_store.get_events( + events = await state_res_store.get_events( [eid for eid in full_conflicted_set if eid not in event_map], allow_rejected=True, ) @@ -118,14 +115,14 @@ def resolve_events_with_store( eid for eid in full_conflicted_set if _is_power_event(event_map[eid]) ) - sorted_power_events = yield _reverse_topological_power_sort( + sorted_power_events = await _reverse_topological_power_sort( clock, room_id, power_events, event_map, state_res_store, full_conflicted_set ) logger.debug("sorted %d power events", len(sorted_power_events)) # Now sequentially auth each one - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( clock, room_id, room_version, @@ -148,13 +145,13 @@ def resolve_events_with_store( logger.debug("sorting %d remaining events", len(leftover_events)) pl = resolved_state.get((EventTypes.PowerLevels, ""), None) - leftover_events = yield _mainline_sort( + leftover_events = await _mainline_sort( clock, room_id, leftover_events, pl, event_map, state_res_store ) logger.debug("resolving remaining events") - resolved_state = yield _iterative_auth_checks( + resolved_state = await _iterative_auth_checks( clock, room_id, room_version, @@ -174,8 +171,7 @@ def resolve_events_with_store( return resolved_state -@defer.inlineCallbacks -def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): +async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): """Return the power level of the sender of the given event according to their auth events. @@ -188,11 +184,11 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): Returns: Deferred[int] """ - event = yield _get_event(room_id, event_id, event_map, state_res_store) + event = await _get_event(room_id, event_id, event_map, state_res_store) pl = None for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -202,7 +198,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): if pl is None: # Couldn't find power level. Check if they're the creator of the room for aid in event.auth_event_ids(): - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""): @@ -221,8 +217,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store): return int(level) -@defer.inlineCallbacks -def _get_auth_chain_difference(state_sets, event_map, state_res_store): +async def _get_auth_chain_difference(state_sets, event_map, state_res_store): """Compare the auth chains of each state set and return the set of events that only appear in some but not all of the auth chains. @@ -235,7 +230,7 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store): Deferred[set[str]]: Set of event IDs """ - difference = yield state_res_store.get_auth_chain_difference( + difference = await state_res_store.get_auth_chain_difference( [set(state_set.values()) for state_set in state_sets] ) @@ -292,8 +287,7 @@ def _is_power_event(event): return False -@defer.inlineCallbacks -def _add_event_and_auth_chain_to_graph( +async def _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ): """Helper function for _reverse_topological_power_sort that add the event @@ -314,7 +308,7 @@ def _add_event_and_auth_chain_to_graph( eid = state.pop() graph.setdefault(eid, set()) - event = yield _get_event(room_id, eid, event_map, state_res_store) + event = await _get_event(room_id, eid, event_map, state_res_store) for aid in event.auth_event_ids(): if aid in auth_diff: if aid not in graph: @@ -323,8 +317,7 @@ def _add_event_and_auth_chain_to_graph( graph.setdefault(eid, set()).add(aid) -@defer.inlineCallbacks -def _reverse_topological_power_sort( +async def _reverse_topological_power_sort( clock, room_id, event_ids, event_map, state_res_store, auth_diff ): """Returns a list of the event_ids sorted by reverse topological ordering, @@ -344,26 +337,26 @@ def _reverse_topological_power_sort( graph = {} for idx, event_id in enumerate(event_ids, start=1): - yield _add_event_and_auth_chain_to_graph( + await _add_event_and_auth_chain_to_graph( graph, room_id, event_id, event_map, state_res_store, auth_diff ) - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) event_to_pl = {} for idx, event_id in enumerate(graph, start=1): - pl = yield _get_power_level_for_sender( + pl = await _get_power_level_for_sender( room_id, event_id, event_map, state_res_store ) event_to_pl[event_id] = pl - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) def _get_power_order(event_id): ev = event_map[event_id] @@ -378,8 +371,7 @@ def _reverse_topological_power_sort( return sorted_events -@defer.inlineCallbacks -def _iterative_auth_checks( +async def _iterative_auth_checks( clock, room_id, room_version, event_ids, base_state, event_map, state_res_store ): """Sequentially apply auth checks to each event in given list, updating the @@ -405,7 +397,7 @@ def _iterative_auth_checks( auth_events = {} for aid in event.auth_event_ids(): - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) @@ -420,7 +412,7 @@ def _iterative_auth_checks( for key in event_auth.auth_types_for_event(event): if key in resolved_state: ev_id = resolved_state[key] - ev = yield _get_event(room_id, ev_id, event_map, state_res_store) + ev = await _get_event(room_id, ev_id, event_map, state_res_store) if ev.rejected_reason is None: auth_events[key] = event_map[ev_id] @@ -438,16 +430,15 @@ def _iterative_auth_checks( except AuthError: pass - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) return resolved_state -@defer.inlineCallbacks -def _mainline_sort( +async def _mainline_sort( clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store ): """Returns a sorted list of event_ids sorted by mainline ordering based on @@ -474,21 +465,21 @@ def _mainline_sort( idx = 0 while pl: mainline.append(pl) - pl_ev = yield _get_event(room_id, pl, event_map, state_res_store) + pl_ev = await _get_event(room_id, pl, event_map, state_res_store) auth_events = pl_ev.auth_event_ids() pl = None for aid in auth_events: - ev = yield _get_event( + ev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""): pl = aid break - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx != 0 and idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) idx += 1 @@ -498,23 +489,24 @@ def _mainline_sort( order_map = {} for idx, ev_id in enumerate(event_ids, start=1): - depth = yield _get_mainline_depth_for_event( + depth = await _get_mainline_depth_for_event( event_map[ev_id], mainline_map, event_map, state_res_store ) order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id) - # We yield occasionally when we're working with large data sets to + # We await occasionally when we're working with large data sets to # ensure that we don't block the reactor loop for too long. - if idx % _YIELD_AFTER_ITERATIONS == 0: - yield clock.sleep(0) + if idx % _AWAIT_AFTER_ITERATIONS == 0: + await clock.sleep(0) event_ids.sort(key=lambda ev_id: order_map[ev_id]) return event_ids -@defer.inlineCallbacks -def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store): +async def _get_mainline_depth_for_event( + event, mainline_map, event_map, state_res_store +): """Get the mainline depths for the given event based on the mainline map Args: @@ -541,7 +533,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor event = None for aid in auth_events: - aev = yield _get_event( + aev = await _get_event( room_id, aid, event_map, state_res_store, allow_none=True ) if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""): @@ -552,8 +544,7 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor return 0 -@defer.inlineCallbacks -def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): +async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): """Helper function to look up event in event_map, falling back to looking it up in the store @@ -569,7 +560,7 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False): Deferred[Optional[FrozenEvent]] """ if event_id not in event_map: - events = yield state_res_store.get_events([event_id], allow_rejected=True) + events = await state_res_store.get_events([event_id], allow_rejected=True) event_map.update(events) event = event_map.get(event_id) |