From 97053c94060ea31d3b9d41a129221ad4b2a76865 Mon Sep 17 00:00:00 2001 From: David Robertson Date: Thu, 9 Jun 2022 09:48:04 +0100 Subject: Type annotations for `test_v2` (#12985) --- synapse/state/v2.py | 57 +++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 42 insertions(+), 15 deletions(-) (limited to 'synapse') diff --git a/synapse/state/v2.py b/synapse/state/v2.py index c618df2fde..0e609114ef 100644 --- a/synapse/state/v2.py +++ b/synapse/state/v2.py @@ -17,12 +17,14 @@ import itertools import logging from typing import ( Any, + Awaitable, Callable, Collection, Dict, Generator, Iterable, List, + Mapping, Optional, Sequence, Set, @@ -30,33 +32,58 @@ from typing import ( overload, ) -from typing_extensions import Literal +from typing_extensions import Literal, Protocol -import synapse.state from synapse import event_auth from synapse.api.constants import EventTypes from synapse.api.errors import AuthError from synapse.api.room_versions import RoomVersion from synapse.events import EventBase from synapse.types import MutableStateMap, StateMap -from synapse.util import Clock logger = logging.getLogger(__name__) +class Clock(Protocol): + # This is usually synapse.util.Clock, but it's replaced with a FakeClock in tests. + # We only ever sleep(0) though, so that other async functions can make forward + # progress without waiting for stateres to complete. + def sleep(self, duration_ms: float) -> Awaitable[None]: + ... + + +class StateResolutionStore(Protocol): + # This is usually synapse.state.StateResolutionStore, but it's replaced with a + # TestStateResolutionStore in tests. + def get_events( + self, event_ids: Collection[str], allow_rejected: bool = False + ) -> Awaitable[Dict[str, EventBase]]: + ... + + def get_auth_chain_difference( + self, room_id: str, state_sets: List[Set[str]] + ) -> Awaitable[Set[str]]: + ... + + # We want to await to the reactor occasionally during state res when dealing # with large data sets, so that we don't exhaust the reactor. This is done by # awaiting to reactor during loops every N iterations. _AWAIT_AFTER_ITERATIONS = 100 +__all__ = [ + "resolve_events_with_store", +] + + async def resolve_events_with_store( clock: Clock, room_id: str, room_version: RoomVersion, state_sets: Sequence[StateMap[str]], event_map: Optional[Dict[str, EventBase]], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, ) -> StateMap[str]: """Resolves the state using the v2 state resolution algorithm @@ -194,7 +221,7 @@ async def _get_power_level_for_sender( room_id: str, event_id: str, event_map: Dict[str, EventBase], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, ) -> int: """Return the power level of the sender of the given event according to their auth events. @@ -243,9 +270,9 @@ async def _get_power_level_for_sender( async def _get_auth_chain_difference( room_id: str, - state_sets: Sequence[StateMap[str]], + state_sets: Sequence[Mapping[Any, str]], event_map: Dict[str, EventBase], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, ) -> Set[str]: """Compare the auth chains of each state set and return the set of events that only appear in some but not all of the auth chains. @@ -406,7 +433,7 @@ async def _add_event_and_auth_chain_to_graph( room_id: str, event_id: str, event_map: Dict[str, EventBase], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, auth_diff: Set[str], ) -> None: """Helper function for _reverse_topological_power_sort that add the event @@ -440,7 +467,7 @@ async def _reverse_topological_power_sort( room_id: str, event_ids: Iterable[str], event_map: Dict[str, EventBase], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, auth_diff: Set[str], ) -> List[str]: """Returns a list of the event_ids sorted by reverse topological ordering, @@ -501,7 +528,7 @@ async def _iterative_auth_checks( event_ids: List[str], base_state: StateMap[str], event_map: Dict[str, EventBase], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, ) -> MutableStateMap[str]: """Sequentially apply auth checks to each event in given list, updating the state as it goes along. @@ -570,7 +597,7 @@ async def _mainline_sort( event_ids: List[str], resolved_power_event_id: Optional[str], event_map: Dict[str, EventBase], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, ) -> List[str]: """Returns a sorted list of event_ids sorted by mainline ordering based on the given event resolved_power_event_id @@ -639,7 +666,7 @@ async def _get_mainline_depth_for_event( event: EventBase, mainline_map: Dict[str, int], event_map: Dict[str, EventBase], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, ) -> int: """Get the mainline depths for the given event based on the mainline map @@ -683,7 +710,7 @@ async def _get_event( room_id: str, event_id: str, event_map: Dict[str, EventBase], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, allow_none: Literal[False] = False, ) -> EventBase: ... @@ -694,7 +721,7 @@ async def _get_event( room_id: str, event_id: str, event_map: Dict[str, EventBase], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, allow_none: Literal[True], ) -> Optional[EventBase]: ... @@ -704,7 +731,7 @@ async def _get_event( room_id: str, event_id: str, event_map: Dict[str, EventBase], - state_res_store: "synapse.state.StateResolutionStore", + state_res_store: StateResolutionStore, allow_none: bool = False, ) -> Optional[EventBase]: """Helper function to look up event in event_map, falling back to looking -- cgit 1.4.1