summary refs log tree commit diff
path: root/synapse/state/v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/v2.py')
-rw-r--r--synapse/state/v2.py255
1 files changed, 167 insertions, 88 deletions
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 6634955cdc..0e9ffbd6e6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,7 +16,21 @@
 import heapq
 import itertools
 import logging
-from typing import Dict, List, Optional
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Generator,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+    overload,
+)
+
+from typing_extensions import Literal
 
 import synapse.state
 from synapse import event_auth
@@ -40,10 +54,10 @@ async def resolve_events_with_store(
     clock: Clock,
     room_id: str,
     room_version: str,
-    state_sets: List[StateMap[str]],
+    state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
     state_res_store: "synapse.state.StateResolutionStore",
-):
+) -> StateMap[str]:
     """Resolves the state using the v2 state resolution algorithm
 
     Args:
@@ -63,8 +77,7 @@ async def resolve_events_with_store(
         state_res_store:
 
     Returns:
-        Deferred[dict[(str, str), str]]:
-            a map from (type, state_key) to event_id.
+        A map from (type, state_key) to event_id.
     """
 
     logger.debug("Computing conflicted state")
@@ -171,18 +184,23 @@ async def resolve_events_with_store(
     return resolved_state
 
 
-async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
+async def _get_power_level_for_sender(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> int:
     """Return the power level of the sender of the given event according to
     their auth events.
 
     Args:
-        room_id (str)
-        event_id (str)
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        room_id
+        event_id
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[int]
+        The power level.
     """
     event = await _get_event(room_id, event_id, event_map, state_res_store)
 
@@ -217,17 +235,21 @@ async def _get_power_level_for_sender(room_id, event_id, event_map, state_res_st
         return int(level)
 
 
-async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
+async def _get_auth_chain_difference(
+    state_sets: Sequence[StateMap[str]],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> Set[str]:
     """Compare the auth chains of each state set and return the set of events
     that only appear in some but not all of the auth chains.
 
     Args:
-        state_sets (list)
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        state_sets
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[set[str]]: Set of event IDs
+        Set of event IDs
     """
 
     difference = await state_res_store.get_auth_chain_difference(
@@ -237,17 +259,19 @@ async def _get_auth_chain_difference(state_sets, event_map, state_res_store):
     return difference
 
 
-def _seperate(state_sets):
+def _seperate(
+    state_sets: Iterable[StateMap[str]],
+) -> Tuple[StateMap[str], StateMap[Set[str]]]:
     """Return the unconflicted and conflicted state. This is different than in
     the original algorithm, as this defines a key to be conflicted if one of
     the state sets doesn't have that key.
 
     Args:
-        state_sets (list)
+        state_sets
 
     Returns:
-        tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
-        conflicted state dict is a map from type/state_key to set of event IDs
+        A tuple of unconflicted and conflicted state. The conflicted state dict
+        is a map from type/state_key to set of event IDs
     """
     unconflicted_state = {}
     conflicted_state = {}
@@ -260,18 +284,20 @@ def _seperate(state_sets):
             event_ids.discard(None)
             conflicted_state[key] = event_ids
 
-    return unconflicted_state, conflicted_state
+    # mypy doesn't understand that discarding None above means that conflicted
+    # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
+    return unconflicted_state, conflicted_state  # type: ignore
 
 
-def _is_power_event(event):
+def _is_power_event(event: EventBase) -> bool:
     """Return whether or not the event is a "power event", as defined by the
     v2 state resolution algorithm
 
     Args:
-        event (FrozenEvent)
+        event
 
     Returns:
-        boolean
+        True if the event is a power event.
     """
     if (event.type, event.state_key) in (
         (EventTypes.PowerLevels, ""),
@@ -288,19 +314,23 @@ def _is_power_event(event):
 
 
 async def _add_event_and_auth_chain_to_graph(
-    graph, room_id, event_id, event_map, state_res_store, auth_diff
-):
+    graph: Dict[str, Set[str]],
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    auth_diff: Set[str],
+) -> None:
     """Helper function for _reverse_topological_power_sort that add the event
     and its auth chain (that is in the auth diff) to the graph
 
     Args:
-        graph (dict[str, set[str]]): A map from event ID to the events auth
-            event IDs
-        room_id (str): the room we are working in
-        event_id (str): Event to add to the graph
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
-        auth_diff (set[str]): Set of event IDs that are in the auth difference.
+        graph: A map from event ID to the events auth event IDs
+        room_id: the room we are working in
+        event_id: Event to add to the graph
+        event_map
+        state_res_store
+        auth_diff: Set of event IDs that are in the auth difference.
     """
 
     state = [event_id]
@@ -318,24 +348,29 @@ async def _add_event_and_auth_chain_to_graph(
 
 
 async def _reverse_topological_power_sort(
-    clock, room_id, event_ids, event_map, state_res_store, auth_diff
-):
+    clock: Clock,
+    room_id: str,
+    event_ids: Iterable[str],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    auth_diff: Set[str],
+) -> List[str]:
     """Returns a list of the event_ids sorted by reverse topological ordering,
     and then by power level and origin_server_ts
 
     Args:
-        clock (Clock)
-        room_id (str): the room we are working in
-        event_ids (list[str]): The events to sort
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
-        auth_diff (set[str]): Set of event IDs that are in the auth difference.
+        clock
+        room_id: the room we are working in
+        event_ids: The events to sort
+        event_map
+        state_res_store
+        auth_diff: Set of event IDs that are in the auth difference.
 
     Returns:
-        Deferred[list[str]]: The sorted list
+        The sorted list
     """
 
-    graph = {}
+    graph = {}  # type: Dict[str, Set[str]]
     for idx, event_id in enumerate(event_ids, start=1):
         await _add_event_and_auth_chain_to_graph(
             graph, room_id, event_id, event_map, state_res_store, auth_diff
@@ -372,22 +407,28 @@ async def _reverse_topological_power_sort(
 
 
 async def _iterative_auth_checks(
-    clock, room_id, room_version, event_ids, base_state, event_map, state_res_store
-):
+    clock: Clock,
+    room_id: str,
+    room_version: str,
+    event_ids: List[str],
+    base_state: StateMap[str],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> StateMap[str]:
     """Sequentially apply auth checks to each event in given list, updating the
     state as it goes along.
 
     Args:
-        clock (Clock)
-        room_id (str)
-        room_version (str)
-        event_ids (list[str]): Ordered list of events to apply auth checks to
-        base_state (StateMap[str]): The set of state to start with
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        clock
+        room_id
+        room_version
+        event_ids: Ordered list of events to apply auth checks to
+        base_state: The set of state to start with
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[StateMap[str]]: Returns the final updated state
+        Returns the final updated state
     """
     resolved_state = base_state.copy()
     room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
@@ -439,21 +480,26 @@ async def _iterative_auth_checks(
 
 
 async def _mainline_sort(
-    clock, room_id, event_ids, resolved_power_event_id, event_map, state_res_store
-):
+    clock: Clock,
+    room_id: str,
+    event_ids: List[str],
+    resolved_power_event_id: Optional[str],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> List[str]:
     """Returns a sorted list of event_ids sorted by mainline ordering based on
     the given event resolved_power_event_id
 
     Args:
-        clock (Clock)
-        room_id (str): room we're working in
-        event_ids (list[str]): Events to sort
-        resolved_power_event_id (str): The final resolved power level event ID
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        clock
+        room_id: room we're working in
+        event_ids: Events to sort
+        resolved_power_event_id: The final resolved power level event ID
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[list[str]]: The sorted list
+        The sorted list
     """
     if not event_ids:
         # It's possible for there to be no event IDs here to sort, so we can
@@ -505,59 +551,90 @@ async def _mainline_sort(
 
 
 async def _get_mainline_depth_for_event(
-    event, mainline_map, event_map, state_res_store
-):
+    event: EventBase,
+    mainline_map: Dict[str, int],
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+) -> int:
     """Get the mainline depths for the given event based on the mainline map
 
     Args:
-        event (FrozenEvent)
-        mainline_map (dict[str, int]): Map from event_id to mainline depth for
-            events in the mainline.
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
+        event
+        mainline_map: Map from event_id to mainline depth for events in the mainline.
+        event_map
+        state_res_store
 
     Returns:
-        Deferred[int]
+        The mainline depth
     """
 
     room_id = event.room_id
+    tmp_event = event  # type: Optional[EventBase]
 
     # We do an iterative search, replacing `event with the power level in its
     # auth events (if any)
-    while event:
+    while tmp_event:
         depth = mainline_map.get(event.event_id)
         if depth is not None:
             return depth
 
-        auth_events = event.auth_event_ids()
-        event = None
+        auth_events = tmp_event.auth_event_ids()
+        tmp_event = None
 
         for aid in auth_events:
             aev = await _get_event(
                 room_id, aid, event_map, state_res_store, allow_none=True
             )
             if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
-                event = aev
+                tmp_event = aev
                 break
 
     # Didn't find a power level auth event, so we just return 0
     return 0
 
 
-async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
+@overload
+async def _get_event(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    allow_none: Literal[False] = False,
+) -> EventBase:
+    ...
+
+
+@overload
+async def _get_event(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    allow_none: Literal[True],
+) -> Optional[EventBase]:
+    ...
+
+
+async def _get_event(
+    room_id: str,
+    event_id: str,
+    event_map: Dict[str, EventBase],
+    state_res_store: "synapse.state.StateResolutionStore",
+    allow_none: bool = False,
+) -> Optional[EventBase]:
     """Helper function to look up event in event_map, falling back to looking
     it up in the store
 
     Args:
-        room_id (str)
-        event_id (str)
-        event_map (dict[str,FrozenEvent])
-        state_res_store (StateResolutionStore)
-        allow_none (bool): if the event is not found, return None rather than raising
+        room_id
+        event_id
+        event_map
+        state_res_store
+        allow_none: if the event is not found, return None rather than raising
             an exception
 
     Returns:
-        Deferred[Optional[FrozenEvent]]
+        The event, or none if the event does not exist (and allow_none is True).
     """
     if event_id not in event_map:
         events = await state_res_store.get_events([event_id], allow_rejected=True)
@@ -577,7 +654,9 @@ async def _get_event(room_id, event_id, event_map, state_res_store, allow_none=F
     return event
 
 
-def lexicographical_topological_sort(graph, key):
+def lexicographical_topological_sort(
+    graph: Dict[str, Set[str]], key: Callable[[str], Any]
+) -> Generator[str, None, None]:
     """Performs a lexicographic reverse topological sort on the graph.
 
     This returns a reverse topological sort (i.e. if node A references B then B
@@ -587,20 +666,20 @@ def lexicographical_topological_sort(graph, key):
     NOTE: `graph` is modified during the sort.
 
     Args:
-        graph (dict[str, set[str]]): A representation of the graph where each
-            node is a key in the dict and its value are the nodes edges.
-        key (func): A function that takes a node and returns a value that is
-            comparable and used to order nodes
+        graph: A representation of the graph where each node is a key in the
+            dict and its value are the nodes edges.
+        key: A function that takes a node and returns a value that is comparable
+            and used to order nodes
 
     Yields:
-        str: The next node in the topological sort
+        The next node in the topological sort
     """
 
     # Note, this is basically Kahn's algorithm except we look at nodes with no
     # outgoing edges, c.f.
     # https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
     outdegree_map = graph
-    reverse_graph = {}
+    reverse_graph = {}  # type: Dict[str, Set[str]]
 
     # Lists of nodes with zero out degree. Is actually a tuple of
     # `(key(node), node)` so that sorting does the right thing