diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 9bf98d06f2..a493279cbd 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,18 +15,24 @@
import hashlib
import logging
-from typing import Callable, Dict, List, Optional
-
-from six import iteritems, iterkeys, itervalues
-
-from twisted.internet import defer
+from typing import (
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
logger = logging.getLogger(__name__)
@@ -34,13 +40,12 @@ logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
-@defer.inlineCallbacks
-def resolve_events_with_store(
+async def resolve_events_with_store(
room_id: str,
- state_sets: List[StateMap[str]],
+ state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
- state_map_factory: Callable,
-):
+ state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+) -> StateMap[str]:
"""
Args:
room_id: the room we are working in
@@ -58,11 +63,10 @@ def resolve_events_with_store(
state_map_factory: will be called
with a list of event_ids that are needed, and should return with
- a Deferred of dict of event_id to event.
+ an Awaitable that resolves to a dict of event_id to event.
Returns:
- Deferred[dict[(str, str), str]]:
- a map from (type, state_key) to event_id.
+ A map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -70,19 +74,19 @@ def resolve_events_with_store(
unconflicted_state, conflicted_state = _seperate(state_sets)
needed_events = {
- event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
+ event_id for event_ids in conflicted_state.values() for event_id in event_ids
}
needed_event_count = len(needed_events)
if event_map is not None:
- needed_events -= set(iterkeys(event_map))
+ needed_events -= set(event_map.keys())
logger.info(
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count
)
- # dict[str, FrozenEvent]: a map from state event id to event. Only includes
- # the state events which are in conflict (and those in event_map)
- state_map = yield state_map_factory(needed_events)
+ # A map from state event id to event. Only includes the state events which
+ # are in conflict (and those in event_map).
+ state_map = await state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
@@ -96,23 +100,21 @@ def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
- #
- # dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
- new_needed_events = set(itervalues(auth_events))
+ new_needed_events = set(auth_events.values())
new_needed_event_count = len(new_needed_events)
new_needed_events -= needed_events
if event_map is not None:
- new_needed_events -= set(iterkeys(event_map))
+ new_needed_events -= set(event_map.keys())
logger.info(
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
)
- state_map_new = yield state_map_factory(new_needed_events)
+ state_map_new = await state_map_factory(new_needed_events)
for event in state_map_new.values():
if event.room_id != room_id:
raise Exception(
@@ -127,32 +129,33 @@ def resolve_events_with_store(
)
-def _seperate(state_sets):
+def _seperate(
+ state_sets: Iterable[StateMap[str]],
+) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]:
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
- state_sets(iterable[dict[(str, str), str]]):
+ state_sets:
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
- (dict[(str, str), str], dict[(str, str), set[str]]):
- A tuple of (unconflicted_state, conflicted_state), where:
+ A tuple of (unconflicted_state, conflicted_state), where:
- unconflicted_state is a dict mapping (type, state_key)->event_id
- for unconflicted state keys.
+ unconflicted_state is a dict mapping (type, state_key)->event_id
+ for unconflicted state keys.
- conflicted_state is a dict mapping (type, state_key) to a set of
- event ids for conflicted state keys.
+ conflicted_state is a dict mapping (type, state_key) to a set of
+ event ids for conflicted state keys.
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
- conflicted_state = {}
+ conflicted_state = {} # type: MutableStateMap[Set[str]]
for state_set in state_set_iterator:
- for key, value in iteritems(state_set):
+ for key, value in state_set.items():
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
@@ -176,25 +179,42 @@ def _seperate(state_sets):
return unconflicted_state, conflicted_state
-def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
+def _create_auth_events_from_maps(
+ unconflicted_state: StateMap[str],
+ conflicted_state: StateMap[Set[str]],
+ state_map: Dict[str, EventBase],
+) -> StateMap[str]:
+ """
+
+ Args:
+ unconflicted_state: The unconflicted state map.
+ conflicted_state: The conflicted state map.
+ state_map:
+
+ Returns:
+ A map from state key to event id.
+ """
auth_events = {}
- for event_ids in itervalues(conflicted_state):
+ for event_ids in conflicted_state.values():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
- event_id = unconflicted_state.get(key, None)
- if event_id:
- auth_events[key] = event_id
+ auth_event_id = unconflicted_state.get(key, None)
+ if auth_event_id:
+ auth_events[key] = auth_event_id
return auth_events
def _resolve_with_state(
- unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
+ unconflicted_state_ids: MutableStateMap[str],
+ conflicted_state_ids: StateMap[Set[str]],
+ auth_event_ids: StateMap[str],
+ state_map: Dict[str, EventBase],
):
conflicted_state = {}
- for key, event_ids in iteritems(conflicted_state_ids):
+ for key, event_ids in conflicted_state_ids.items():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
@@ -203,7 +223,7 @@ def _resolve_with_state(
auth_events = {
key: state_map[ev_id]
- for key, ev_id in iteritems(auth_event_ids)
+ for key, ev_id in auth_event_ids.items()
if ev_id in state_map
}
@@ -214,13 +234,15 @@ def _resolve_with_state(
raise
new_state = unconflicted_state_ids
- for key, event in iteritems(resolved_state):
+ for key, event in resolved_state.items():
new_state[key] = event.event_id
return new_state
-def _resolve_state_events(conflicted_state, auth_events):
+def _resolve_state_events(
+ conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
+) -> StateMap[EventBase]:
""" This is where we actually decide which of the conflicted state to
use.
@@ -238,21 +260,21 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(events, auth_events)
@@ -260,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
return resolved_state
-def _resolve_auth_events(events, auth_events):
+def _resolve_auth_events(
+ events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
reverse = list(reversed(_ordered_events(events)))
auth_keys = {
@@ -294,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
return event
-def _resolve_normal_events(events, auth_events):
+def _resolve_normal_events(
+ events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
@@ -314,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
return event
-def _ordered_events(events):
+def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
def key_func(e):
# we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/
|