summary refs log tree commit diff
path: root/synapse/state/v1.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/state/v1.py')
-rw-r--r--synapse/state/v1.py89
1 files changed, 60 insertions, 29 deletions
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index ab5e24841d..a493279cbd 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,14 +15,24 @@
 
 import hashlib
 import logging
-from typing import Awaitable, Callable, Dict, List, Optional
+from typing import (
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Sequence,
+    Set,
+    Tuple,
+)
 
 from synapse import event_auth
 from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import RoomVersions
 from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
 
 logger = logging.getLogger(__name__)
 
@@ -32,10 +42,10 @@ POWER_KEY = (EventTypes.PowerLevels, "")
 
 async def resolve_events_with_store(
     room_id: str,
-    state_sets: List[StateMap[str]],
+    state_sets: Sequence[StateMap[str]],
     event_map: Optional[Dict[str, EventBase]],
-    state_map_factory: Callable[[List[str]], Awaitable],
-):
+    state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+) -> StateMap[str]:
     """
     Args:
         room_id: the room we are working in
@@ -56,8 +66,7 @@ async def resolve_events_with_store(
             an Awaitable that resolves to a dict of event_id to event.
 
     Returns:
-        Deferred[dict[(str, str), str]]:
-            a map from (type, state_key) to event_id.
+        A map from (type, state_key) to event_id.
     """
     if len(state_sets) == 1:
         return state_sets[0]
@@ -75,8 +84,8 @@ async def resolve_events_with_store(
         "Asking for %d/%d conflicted events", len(needed_events), needed_event_count
     )
 
-    # dict[str, FrozenEvent]: a map from state event id to event. Only includes
-    # the state events which are in conflict (and those in event_map)
+    # A map from state event id to event. Only includes the state events which
+    # are in conflict (and those in event_map).
     state_map = await state_map_factory(needed_events)
     if event_map is not None:
         state_map.update(event_map)
@@ -91,8 +100,6 @@ async def resolve_events_with_store(
 
     # get the ids of the auth events which allow us to authenticate the
     # conflicted state, picking only from the unconflicting state.
-    #
-    # dict[(str, str), str]: a map from state key to event id
     auth_events = _create_auth_events_from_maps(
         unconflicted_state, conflicted_state, state_map
     )
@@ -122,29 +129,30 @@ async def resolve_events_with_store(
     )
 
 
-def _seperate(state_sets):
+def _seperate(
+    state_sets: Iterable[StateMap[str]],
+) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]:
     """Takes the state_sets and figures out which keys are conflicted and
     which aren't. i.e., which have multiple different event_ids associated
     with them in different state sets.
 
     Args:
-        state_sets(iterable[dict[(str, str), str]]):
+        state_sets:
             List of dicts of (type, state_key) -> event_id, which are the
             different state groups to resolve.
 
     Returns:
-        (dict[(str, str), str], dict[(str, str), set[str]]):
-            A tuple of (unconflicted_state, conflicted_state), where:
+        A tuple of (unconflicted_state, conflicted_state), where:
 
-            unconflicted_state is a dict mapping (type, state_key)->event_id
-            for unconflicted state keys.
+        unconflicted_state is a dict mapping (type, state_key)->event_id
+        for unconflicted state keys.
 
-            conflicted_state is a dict mapping (type, state_key) to a set of
-            event ids for conflicted state keys.
+        conflicted_state is a dict mapping (type, state_key) to a set of
+        event ids for conflicted state keys.
     """
     state_set_iterator = iter(state_sets)
     unconflicted_state = dict(next(state_set_iterator))
-    conflicted_state = {}
+    conflicted_state = {}  # type: MutableStateMap[Set[str]]
 
     for state_set in state_set_iterator:
         for key, value in state_set.items():
@@ -171,7 +179,21 @@ def _seperate(state_sets):
     return unconflicted_state, conflicted_state
 
 
-def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
+def _create_auth_events_from_maps(
+    unconflicted_state: StateMap[str],
+    conflicted_state: StateMap[Set[str]],
+    state_map: Dict[str, EventBase],
+) -> StateMap[str]:
+    """
+
+    Args:
+        unconflicted_state: The unconflicted state map.
+        conflicted_state: The conflicted state map.
+        state_map:
+
+    Returns:
+        A map from state key to event id.
+    """
     auth_events = {}
     for event_ids in conflicted_state.values():
         for event_id in event_ids:
@@ -179,14 +201,17 @@ def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_ma
                 keys = event_auth.auth_types_for_event(state_map[event_id])
                 for key in keys:
                     if key not in auth_events:
-                        event_id = unconflicted_state.get(key, None)
-                        if event_id:
-                            auth_events[key] = event_id
+                        auth_event_id = unconflicted_state.get(key, None)
+                        if auth_event_id:
+                            auth_events[key] = auth_event_id
     return auth_events
 
 
 def _resolve_with_state(
-    unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
+    unconflicted_state_ids: MutableStateMap[str],
+    conflicted_state_ids: StateMap[Set[str]],
+    auth_event_ids: StateMap[str],
+    state_map: Dict[str, EventBase],
 ):
     conflicted_state = {}
     for key, event_ids in conflicted_state_ids.items():
@@ -215,7 +240,9 @@ def _resolve_with_state(
     return new_state
 
 
-def _resolve_state_events(conflicted_state, auth_events):
+def _resolve_state_events(
+    conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
+) -> StateMap[EventBase]:
     """ This is where we actually decide which of the conflicted state to
     use.
 
@@ -255,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
     return resolved_state
 
 
-def _resolve_auth_events(events, auth_events):
+def _resolve_auth_events(
+    events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
     reverse = list(reversed(_ordered_events(events)))
 
     auth_keys = {
@@ -289,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
     return event
 
 
-def _resolve_normal_events(events, auth_events):
+def _resolve_normal_events(
+    events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
     for event in _ordered_events(events):
         try:
             # The signatures have already been checked at this point
@@ -309,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
     return event
 
 
-def _ordered_events(events):
+def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
     def key_func(e):
         # we have to use utf-8 rather than ascii here because it turns out we allow
         # people to send us events with non-ascii event IDs :/