diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index 2fa529fcd0..c7e3015b5d 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -16,15 +16,23 @@
import logging
from collections import namedtuple
-from typing import Dict, Iterable, List, Optional, Set
-
-from six import iteritems, itervalues
+from typing import (
+ Awaitable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Union,
+ cast,
+ overload,
+)
import attr
from frozendict import frozendict
from prometheus_client import Histogram
-
-from twisted.internet import defer
+from typing_extensions import Literal
from synapse.api.constants import EventTypes
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS, StateResolutionVersions
@@ -32,8 +40,10 @@ from synapse.events import EventBase
from synapse.events.snapshot import EventContext
from synapse.logging.utils import log_function
from synapse.state import v1, v2
-from synapse.storage.data_stores.main.events_worker import EventRedactBehaviour
-from synapse.types import StateMap
+from synapse.storage.databases.main.events_worker import EventRedactBehaviour
+from synapse.storage.roommember import ProfileInfo
+from synapse.types import Collection, MutableStateMap, StateMap
+from synapse.util import Clock
from synapse.util.async_helpers import Linearizer
from synapse.util.caches.expiringcache import ExpiringCache
from synapse.util.metrics import Measure, measure_func
@@ -67,11 +77,17 @@ def _gen_state_id():
return s
-class _StateCacheEntry(object):
+class _StateCacheEntry:
__slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
- def __init__(self, state, state_group, prev_group=None, delta_ids=None):
- # dict[(str, str), str] map from (type, state_key) to event_id
+ def __init__(
+ self,
+ state: StateMap[str],
+ state_group: Optional[int],
+ prev_group: Optional[int] = None,
+ delta_ids: Optional[StateMap[str]] = None,
+ ):
+ # A map from (type, state_key) to event_id.
self.state = frozendict(state)
# the ID of a state group if one and only one is involved.
@@ -97,7 +113,7 @@ class _StateCacheEntry(object):
return len(self.state)
-class StateHandler(object):
+class StateHandler:
"""Fetches bits of state from the stores, and does state resolution
where necessary
"""
@@ -109,114 +125,131 @@ class StateHandler(object):
self.hs = hs
self._state_resolution_handler = hs.get_state_resolution_handler()
- @defer.inlineCallbacks
- def get_current_state(
- self, room_id, event_type=None, state_key="", latest_event_ids=None
- ):
- """ Retrieves the current state for the room. This is done by
+ @overload
+ async def get_current_state(
+ self,
+ room_id: str,
+ event_type: Literal[None] = None,
+ state_key: str = "",
+ latest_event_ids: Optional[List[str]] = None,
+ ) -> StateMap[EventBase]:
+ ...
+
+ @overload
+ async def get_current_state(
+ self,
+ room_id: str,
+ event_type: str,
+ state_key: str = "",
+ latest_event_ids: Optional[List[str]] = None,
+ ) -> Optional[EventBase]:
+ ...
+
+ async def get_current_state(
+ self,
+ room_id: str,
+ event_type: Optional[str] = None,
+ state_key: str = "",
+ latest_event_ids: Optional[List[str]] = None,
+ ) -> Union[Optional[EventBase], StateMap[EventBase]]:
+ """Retrieves the current state for the room. This is done by
calling `get_latest_events_in_room` to get the leading edges of the
event graph and then resolving any of the state conflicts.
This is equivalent to getting the state of an event that were to send
next before receiving any new events.
- If `event_type` is specified, then the method returns only the one
- event (or None) with that `event_type` and `state_key`.
-
Returns:
- map from (type, state_key) to event
+ If `event_type` is specified, then the method returns only the one
+ event (or None) with that `event_type` and `state_key`.
+
+ Otherwise, a map from (type, state_key) to event.
"""
if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+ assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state")
- ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
state = ret.state
if event_type:
event_id = state.get((event_type, state_key))
event = None
if event_id:
- event = yield self.store.get_event(event_id, allow_none=True)
+ event = await self.store.get_event(event_id, allow_none=True)
return event
- state_map = yield self.store.get_events(
+ state_map = await self.store.get_events(
list(state.values()), get_prev_content=False
)
- state = {
- key: state_map[e_id] for key, e_id in iteritems(state) if e_id in state_map
+ return {
+ key: state_map[e_id] for key, e_id in state.items() if e_id in state_map
}
- return state
-
- @defer.inlineCallbacks
- def get_current_state_ids(self, room_id, latest_event_ids=None):
+ async def get_current_state_ids(
+ self, room_id: str, latest_event_ids: Optional[Iterable[str]] = None
+ ) -> StateMap[str]:
"""Get the current state, or the state at a set of events, for a room
Args:
- room_id (str):
-
- latest_event_ids (iterable[str]|None): if given, the forward
- extremities to resolve. If None, we look them up from the
- database (via a cache)
+ room_id:
+ latest_event_ids: if given, the forward extremities to resolve. If
+ None, we look them up from the database (via a cache).
Returns:
- Deferred[dict[(str, str), str)]]: the state dict, mapping from
- (event_type, state_key) -> event_id
+ the state dict, mapping from (event_type, state_key) -> event_id
"""
if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+ assert latest_event_ids is not None
logger.debug("calling resolve_state_groups from get_current_state_ids")
- ret = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
- state = ret.state
+ ret = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ return ret.state
- return state
-
- @defer.inlineCallbacks
- def get_current_users_in_room(self, room_id, latest_event_ids=None):
+ async def get_current_users_in_room(
+ self, room_id: str, latest_event_ids: Optional[List[str]] = None
+ ) -> Dict[str, ProfileInfo]:
"""
Get the users who are currently in a room.
Args:
- room_id (str): The ID of the room.
- latest_event_ids (List[str]|None): Precomputed list of latest
- event IDs. Will be computed if None.
+ room_id: The ID of the room.
+ latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
Returns:
- Deferred[Dict[str,ProfileInfo]]: Dictionary of user IDs to their
- profileinfo.
+ Dictionary of user IDs to their profileinfo.
"""
if not latest_event_ids:
- latest_event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
+ latest_event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+ assert latest_event_ids is not None
+
logger.debug("calling resolve_state_groups from get_current_users_in_room")
- entry = yield self.resolve_state_groups_for_events(room_id, latest_event_ids)
- joined_users = yield self.store.get_joined_users_from_state(room_id, entry)
- return joined_users
+ entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
+ return await self.store.get_joined_users_from_state(room_id, entry)
- @defer.inlineCallbacks
- def get_current_hosts_in_room(self, room_id):
- event_ids = yield self.store.get_latest_event_ids_in_room(room_id)
- return (yield self.get_hosts_in_room_at_events(room_id, event_ids))
+ async def get_current_hosts_in_room(self, room_id: str) -> Set[str]:
+ event_ids = await self.store.get_latest_event_ids_in_room(room_id)
+ return await self.get_hosts_in_room_at_events(room_id, event_ids)
- @defer.inlineCallbacks
- def get_hosts_in_room_at_events(self, room_id, event_ids):
+ async def get_hosts_in_room_at_events(
+ self, room_id: str, event_ids: List[str]
+ ) -> Set[str]:
"""Get the hosts that were in a room at the given event ids
Args:
- room_id (str):
- event_ids (list[str]):
+ room_id:
+ event_ids:
Returns:
- Deferred[list[str]]: the hosts in the room at the given events
+ The hosts in the room at the given events
"""
- entry = yield self.resolve_state_groups_for_events(room_id, event_ids)
- joined_hosts = yield self.store.get_joined_hosts(room_id, entry)
- return joined_hosts
+ entry = await self.resolve_state_groups_for_events(room_id, event_ids)
+ return await self.store.get_joined_hosts(room_id, entry)
- @defer.inlineCallbacks
- def compute_event_context(
+ async def compute_event_context(
self, event: EventBase, old_state: Optional[Iterable[EventBase]] = None
- ):
+ ) -> EventContext:
"""Build an EventContext structure for the event.
This works out what the current state should be for the event, and
@@ -229,7 +262,7 @@ class StateHandler(object):
when receiving an event from federation where we don't have the
prev events for, e.g. when backfilling.
Returns:
- synapse.events.snapshot.EventContext:
+ The event context.
"""
if event.internal_metadata.is_outlier():
@@ -270,7 +303,7 @@ class StateHandler(object):
# if we're given the state before the event, then we use that
state_ids_before_event = {
(s.type, s.state_key): s.event_id for s in old_state
- }
+ } # type: StateMap[str]
state_group_before_event = None
state_group_before_event_prev_group = None
deltas_to_state_group_before_event = None
@@ -279,7 +312,7 @@ class StateHandler(object):
# otherwise, we'll need to resolve the state across the prev_events.
logger.debug("calling resolve_state_groups from compute_event_context")
- entry = yield self.resolve_state_groups_for_events(
+ entry = await self.resolve_state_groups_for_events(
event.room_id, event.prev_event_ids()
)
@@ -296,7 +329,7 @@ class StateHandler(object):
#
if not state_group_before_event:
- state_group_before_event = yield self.state_store.store_state_group(
+ state_group_before_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event_prev_group,
@@ -336,7 +369,7 @@ class StateHandler(object):
state_ids_after_event[key] = event.event_id
delta_ids = {key: event.event_id}
- state_group_after_event = yield self.state_store.store_state_group(
+ state_group_after_event = await self.state_store.store_state_group(
event.event_id,
event.room_id,
prev_group=state_group_before_event,
@@ -354,27 +387,25 @@ class StateHandler(object):
)
@measure_func()
- @defer.inlineCallbacks
- def resolve_state_groups_for_events(self, room_id, event_ids):
+ async def resolve_state_groups_for_events(
+ self, room_id: str, event_ids: Iterable[str]
+ ) -> _StateCacheEntry:
""" Given a list of event_ids this method fetches the state at each
event, resolves conflicts between them and returns them.
Args:
- room_id (str)
- event_ids (list[str])
- explicit_room_version (str|None): If set uses the the given room
- version to choose the resolution algorithm. If None, then
- checks the database for room version.
+ room_id
+ event_ids
Returns:
- Deferred[_StateCacheEntry]: resolved state
+ The resolved state
"""
logger.debug("resolve_state_groups event_ids %s", event_ids)
# map from state group id to the state in that state group (where
# 'state' is a map from state key to event id)
# dict[int, dict[(str, str), str]]
- state_groups_ids = yield self.state_store.get_state_groups_ids(
+ state_groups_ids = await self.state_store.get_state_groups_ids(
room_id, event_ids
)
@@ -383,7 +414,7 @@ class StateHandler(object):
elif len(state_groups_ids) == 1:
name, state_list = list(state_groups_ids.items()).pop()
- prev_group, delta_ids = yield self.state_store.get_state_group_delta(name)
+ prev_group, delta_ids = await self.state_store.get_state_group_delta(name)
return _StateCacheEntry(
state=state_list,
@@ -392,9 +423,9 @@ class StateHandler(object):
delta_ids=delta_ids,
)
- room_version = yield self.store.get_room_version_id(room_id)
+ room_version = await self.store.get_room_version_id(room_id)
- result = yield self._state_resolution_handler.resolve_state_groups(
+ result = await self._state_resolution_handler.resolve_state_groups(
room_id,
room_version,
state_groups_ids,
@@ -403,8 +434,12 @@ class StateHandler(object):
)
return result
- @defer.inlineCallbacks
- def resolve_events(self, room_version, state_sets, event):
+ async def resolve_events(
+ self,
+ room_version: str,
+ state_sets: Collection[Iterable[EventBase]],
+ event: EventBase,
+ ) -> StateMap[EventBase]:
logger.info(
"Resolving state for %s with %d groups", event.room_id, len(state_sets)
)
@@ -415,7 +450,8 @@ class StateHandler(object):
state_map = {ev.event_id: ev for st in state_sets for ev in st}
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events_with_store(
+ new_state = await resolve_events_with_store(
+ self.clock,
event.room_id,
room_version,
state_set_ids,
@@ -423,12 +459,10 @@ class StateHandler(object):
state_res_store=StateResolutionStore(self.store),
)
- new_state = {key: state_map[ev_id] for key, ev_id in iteritems(new_state)}
-
- return new_state
+ return {key: state_map[ev_id] for key, ev_id in new_state.items()}
-class StateResolutionHandler(object):
+class StateResolutionHandler:
"""Responsible for doing state conflict resolution.
Note that the storage layer depends on this handler, so all functions must
@@ -451,10 +485,14 @@ class StateResolutionHandler(object):
reset_expiry_on_get=True,
)
- @defer.inlineCallbacks
@log_function
- def resolve_state_groups(
- self, room_id, room_version, state_groups_ids, event_map, state_res_store
+ async def resolve_state_groups(
+ self,
+ room_id: str,
+ room_version: str,
+ state_groups_ids: Dict[int, StateMap[str]],
+ event_map: Optional[Dict[str, EventBase]],
+ state_res_store: "StateResolutionStore",
):
"""Resolves conflicts between a set of state groups
@@ -462,13 +500,13 @@ class StateResolutionHandler(object):
not be called for a single state group
Args:
- room_id (str): room we are resolving for (used for logging and sanity checks)
- room_version (str): version of the room
- state_groups_ids (dict[int, dict[(str, str), str]]):
- map from state group id to the state in that state group
+ room_id: room we are resolving for (used for logging and sanity checks)
+ room_version: version of the room
+ state_groups_ids:
+ A map from state group id to the state in that state group
(where 'state' is a map from state key to event id)
- event_map(dict[str,FrozenEvent]|None):
+ event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
@@ -476,16 +514,16 @@ class StateResolutionHandler(object):
If None, all events will be fetched via state_res_store.
- state_res_store (StateResolutionStore)
+ state_res_store
Returns:
- Deferred[_StateCacheEntry]: resolved state
+ The resolved state
"""
logger.debug("resolve_state_groups state_groups %s", state_groups_ids.keys())
group_names = frozenset(state_groups_ids.keys())
- with (yield self.resolve_linearizer.queue(group_names)):
+ with (await self.resolve_linearizer.queue(group_names)):
if self._state_cache is not None:
cache = self._state_cache.get(group_names, None)
if cache:
@@ -503,10 +541,10 @@ class StateResolutionHandler(object):
#
# XXX: is this actually worthwhile, or should we just let
# resolve_events_with_store do it?
- new_state = {}
+ new_state = {} # type: MutableStateMap[str]
conflicted_state = False
- for st in itervalues(state_groups_ids):
- for key, e_id in iteritems(st):
+ for st in state_groups_ids.values():
+ for key, e_id in st.items():
if key in new_state:
conflicted_state = True
break
@@ -517,12 +555,20 @@ class StateResolutionHandler(object):
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events_with_store(
- room_id,
- room_version,
- list(itervalues(state_groups_ids)),
- event_map=event_map,
- state_res_store=state_res_store,
+ # resolve_events_with_store returns a StateMap, but we can
+ # treat it as a MutableStateMap as it is above. It isn't
+ # actually mutated anymore (and is frozen in
+ # _make_state_cache_entry below).
+ new_state = cast(
+ MutableStateMap,
+ await resolve_events_with_store(
+ self.clock,
+ room_id,
+ room_version,
+ list(state_groups_ids.values()),
+ event_map=event_map,
+ state_res_store=state_res_store,
+ ),
)
# if the new state matches any of the input state groups, we can
@@ -539,21 +585,22 @@ class StateResolutionHandler(object):
return cache
-def _make_state_cache_entry(new_state, state_groups_ids):
+def _make_state_cache_entry(
+ new_state: StateMap[str], state_groups_ids: Dict[int, StateMap[str]]
+) -> _StateCacheEntry:
"""Given a resolved state, and a set of input state groups, pick one to base
a new state group on (if any), and return an appropriately-constructed
_StateCacheEntry.
Args:
- new_state (dict[(str, str), str]): resolved state map (mapping from
- (type, state_key) to event_id)
+ new_state: resolved state map (mapping from (type, state_key) to event_id)
- state_groups_ids (dict[int, dict[(str, str), str]]):
- map from state group id to the state in that state group
- (where 'state' is a map from state key to event id)
+ state_groups_ids:
+ map from state group id to the state in that state group (where
+ 'state' is a map from state key to event id)
Returns:
- _StateCacheEntry
+ The cache entry.
"""
# if the new state matches any of the input state groups, we can
# use that state group again. Otherwise we will generate a state_id
@@ -561,12 +608,12 @@ def _make_state_cache_entry(new_state, state_groups_ids):
# not get persisted.
# first look for exact matches
- new_state_event_ids = set(itervalues(new_state))
- for sg, state in iteritems(state_groups_ids):
+ new_state_event_ids = set(new_state.values())
+ for sg, state in state_groups_ids.items():
if len(new_state_event_ids) != len(state):
continue
- old_state_event_ids = set(itervalues(state))
+ old_state_event_ids = set(state.values())
if new_state_event_ids == old_state_event_ids:
# got an exact match.
return _StateCacheEntry(state=new_state, state_group=sg)
@@ -579,8 +626,8 @@ def _make_state_cache_entry(new_state, state_groups_ids):
prev_group = None
delta_ids = None
- for old_group, old_state in iteritems(state_groups_ids):
- n_delta_ids = {k: v for k, v in iteritems(new_state) if old_state.get(k) != v}
+ for old_group, old_state in state_groups_ids.items():
+ n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
if not delta_ids or len(n_delta_ids) < len(delta_ids):
prev_group = old_group
delta_ids = n_delta_ids
@@ -591,12 +638,13 @@ def _make_state_cache_entry(new_state, state_groups_ids):
def resolve_events_with_store(
+ clock: Clock,
room_id: str,
room_version: str,
- state_sets: List[StateMap[str]],
+ state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "StateResolutionStore",
-):
+) -> Awaitable[StateMap[str]]:
"""
Args:
room_id: the room we are working in
@@ -617,8 +665,7 @@ def resolve_events_with_store(
state_res_store: a place to fetch events from
Returns:
- Deferred[dict[(str, str), str]]:
- a map from (type, state_key) to event_id.
+ a map from (type, state_key) to event_id.
"""
v = KNOWN_ROOM_VERSIONS[room_version]
if v.state_res == StateResolutionVersions.V1:
@@ -627,12 +674,12 @@ def resolve_events_with_store(
)
else:
return v2.resolve_events_with_store(
- room_id, room_version, state_sets, event_map, state_res_store
+ clock, room_id, room_version, state_sets, event_map, state_res_store
)
@attr.s
-class StateResolutionStore(object):
+class StateResolutionStore:
"""Interface that allows state resolution algorithms to access the database
in well defined way.
@@ -642,15 +689,17 @@ class StateResolutionStore(object):
store = attr.ib()
- def get_events(self, event_ids, allow_rejected=False):
+ def get_events(
+ self, event_ids: Iterable[str], allow_rejected: bool = False
+ ) -> Awaitable[Dict[str, EventBase]]:
"""Get events from the database
Args:
- event_ids (list): The event_ids of the events to fetch
- allow_rejected (bool): If True return rejected events.
+ event_ids: The event_ids of the events to fetch
+ allow_rejected: If True return rejected events.
Returns:
- Deferred[dict[str, FrozenEvent]]: Dict from event_id to event.
+ An awaitable which resolves to a dict from event_id to event.
"""
return self.store.get_events(
@@ -660,7 +709,9 @@ class StateResolutionStore(object):
allow_rejected=allow_rejected,
)
- def get_auth_chain_difference(self, state_sets: List[Set[str]]):
+ def get_auth_chain_difference(
+ self, state_sets: List[Set[str]]
+ ) -> Awaitable[Set[str]]:
"""Given sets of state events figure out the auth chain difference (as
per state res v2 algorithm).
@@ -669,7 +720,7 @@ class StateResolutionStore(object):
chain.
Returns:
- Deferred[Set[str]]: Set of event IDs.
+ An awaitable that resolves to a set of event IDs.
"""
return self.store.get_auth_chain_difference(state_sets)
diff --git a/synapse/state/v1.py b/synapse/state/v1.py
index 9bf98d06f2..a493279cbd 100644
--- a/synapse/state/v1.py
+++ b/synapse/state/v1.py
@@ -15,18 +15,24 @@
import hashlib
import logging
-from typing import Callable, Dict, List, Optional
-
-from six import iteritems, iterkeys, itervalues
-
-from twisted.internet import defer
+from typing import (
+ Awaitable,
+ Callable,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+)
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import RoomVersions
from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
logger = logging.getLogger(__name__)
@@ -34,13 +40,12 @@ logger = logging.getLogger(__name__)
POWER_KEY = (EventTypes.PowerLevels, "")
-@defer.inlineCallbacks
-def resolve_events_with_store(
+async def resolve_events_with_store(
room_id: str,
- state_sets: List[StateMap[str]],
+ state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
- state_map_factory: Callable,
-):
+ state_map_factory: Callable[[Iterable[str]], Awaitable[Dict[str, EventBase]]],
+) -> StateMap[str]:
"""
Args:
room_id: the room we are working in
@@ -58,11 +63,10 @@ def resolve_events_with_store(
state_map_factory: will be called
with a list of event_ids that are needed, and should return with
- a Deferred of dict of event_id to event.
+ an Awaitable that resolves to a dict of event_id to event.
Returns:
- Deferred[dict[(str, str), str]]:
- a map from (type, state_key) to event_id.
+ A map from (type, state_key) to event_id.
"""
if len(state_sets) == 1:
return state_sets[0]
@@ -70,19 +74,19 @@ def resolve_events_with_store(
unconflicted_state, conflicted_state = _seperate(state_sets)
needed_events = {
- event_id for event_ids in itervalues(conflicted_state) for event_id in event_ids
+ event_id for event_ids in conflicted_state.values() for event_id in event_ids
}
needed_event_count = len(needed_events)
if event_map is not None:
- needed_events -= set(iterkeys(event_map))
+ needed_events -= set(event_map.keys())
logger.info(
"Asking for %d/%d conflicted events", len(needed_events), needed_event_count
)
- # dict[str, FrozenEvent]: a map from state event id to event. Only includes
- # the state events which are in conflict (and those in event_map)
- state_map = yield state_map_factory(needed_events)
+ # A map from state event id to event. Only includes the state events which
+ # are in conflict (and those in event_map).
+ state_map = await state_map_factory(needed_events)
if event_map is not None:
state_map.update(event_map)
@@ -96,23 +100,21 @@ def resolve_events_with_store(
# get the ids of the auth events which allow us to authenticate the
# conflicted state, picking only from the unconflicting state.
- #
- # dict[(str, str), str]: a map from state key to event id
auth_events = _create_auth_events_from_maps(
unconflicted_state, conflicted_state, state_map
)
- new_needed_events = set(itervalues(auth_events))
+ new_needed_events = set(auth_events.values())
new_needed_event_count = len(new_needed_events)
new_needed_events -= needed_events
if event_map is not None:
- new_needed_events -= set(iterkeys(event_map))
+ new_needed_events -= set(event_map.keys())
logger.info(
"Asking for %d/%d auth events", len(new_needed_events), new_needed_event_count
)
- state_map_new = yield state_map_factory(new_needed_events)
+ state_map_new = await state_map_factory(new_needed_events)
for event in state_map_new.values():
if event.room_id != room_id:
raise Exception(
@@ -127,32 +129,33 @@ def resolve_events_with_store(
)
-def _seperate(state_sets):
+def _seperate(
+ state_sets: Iterable[StateMap[str]],
+) -> Tuple[MutableStateMap[str], MutableStateMap[Set[str]]]:
"""Takes the state_sets and figures out which keys are conflicted and
which aren't. i.e., which have multiple different event_ids associated
with them in different state sets.
Args:
- state_sets(iterable[dict[(str, str), str]]):
+ state_sets:
List of dicts of (type, state_key) -> event_id, which are the
different state groups to resolve.
Returns:
- (dict[(str, str), str], dict[(str, str), set[str]]):
- A tuple of (unconflicted_state, conflicted_state), where:
+ A tuple of (unconflicted_state, conflicted_state), where:
- unconflicted_state is a dict mapping (type, state_key)->event_id
- for unconflicted state keys.
+ unconflicted_state is a dict mapping (type, state_key)->event_id
+ for unconflicted state keys.
- conflicted_state is a dict mapping (type, state_key) to a set of
- event ids for conflicted state keys.
+ conflicted_state is a dict mapping (type, state_key) to a set of
+ event ids for conflicted state keys.
"""
state_set_iterator = iter(state_sets)
unconflicted_state = dict(next(state_set_iterator))
- conflicted_state = {}
+ conflicted_state = {} # type: MutableStateMap[Set[str]]
for state_set in state_set_iterator:
- for key, value in iteritems(state_set):
+ for key, value in state_set.items():
# Check if there is an unconflicted entry for the state key.
unconflicted_value = unconflicted_state.get(key)
if unconflicted_value is None:
@@ -176,25 +179,42 @@ def _seperate(state_sets):
return unconflicted_state, conflicted_state
-def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
+def _create_auth_events_from_maps(
+ unconflicted_state: StateMap[str],
+ conflicted_state: StateMap[Set[str]],
+ state_map: Dict[str, EventBase],
+) -> StateMap[str]:
+ """
+
+ Args:
+ unconflicted_state: The unconflicted state map.
+ conflicted_state: The conflicted state map.
+ state_map:
+
+ Returns:
+ A map from state key to event id.
+ """
auth_events = {}
- for event_ids in itervalues(conflicted_state):
+ for event_ids in conflicted_state.values():
for event_id in event_ids:
if event_id in state_map:
keys = event_auth.auth_types_for_event(state_map[event_id])
for key in keys:
if key not in auth_events:
- event_id = unconflicted_state.get(key, None)
- if event_id:
- auth_events[key] = event_id
+ auth_event_id = unconflicted_state.get(key, None)
+ if auth_event_id:
+ auth_events[key] = auth_event_id
return auth_events
def _resolve_with_state(
- unconflicted_state_ids, conflicted_state_ids, auth_event_ids, state_map
+ unconflicted_state_ids: MutableStateMap[str],
+ conflicted_state_ids: StateMap[Set[str]],
+ auth_event_ids: StateMap[str],
+ state_map: Dict[str, EventBase],
):
conflicted_state = {}
- for key, event_ids in iteritems(conflicted_state_ids):
+ for key, event_ids in conflicted_state_ids.items():
events = [state_map[ev_id] for ev_id in event_ids if ev_id in state_map]
if len(events) > 1:
conflicted_state[key] = events
@@ -203,7 +223,7 @@ def _resolve_with_state(
auth_events = {
key: state_map[ev_id]
- for key, ev_id in iteritems(auth_event_ids)
+ for key, ev_id in auth_event_ids.items()
if ev_id in state_map
}
@@ -214,13 +234,15 @@ def _resolve_with_state(
raise
new_state = unconflicted_state_ids
- for key, event in iteritems(resolved_state):
+ for key, event in resolved_state.items():
new_state[key] = event.event_id
return new_state
-def _resolve_state_events(conflicted_state, auth_events):
+def _resolve_state_events(
+ conflicted_state: StateMap[List[EventBase]], auth_events: MutableStateMap[EventBase]
+) -> StateMap[EventBase]:
""" This is where we actually decide which of the conflicted state to
use.
@@ -238,21 +260,21 @@ def _resolve_state_events(conflicted_state, auth_events):
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key[0] == EventTypes.JoinRules:
logger.debug("Resolving conflicted join rules %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key[0] == EventTypes.Member:
logger.debug("Resolving conflicted member lists %r", events)
resolved_state[key] = _resolve_auth_events(events, auth_events)
auth_events.update(resolved_state)
- for key, events in iteritems(conflicted_state):
+ for key, events in conflicted_state.items():
if key not in resolved_state:
logger.debug("Resolving conflicted state %r:%r", key, events)
resolved_state[key] = _resolve_normal_events(events, auth_events)
@@ -260,7 +282,9 @@ def _resolve_state_events(conflicted_state, auth_events):
return resolved_state
-def _resolve_auth_events(events, auth_events):
+def _resolve_auth_events(
+ events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
reverse = list(reversed(_ordered_events(events)))
auth_keys = {
@@ -294,7 +318,9 @@ def _resolve_auth_events(events, auth_events):
return event
-def _resolve_normal_events(events, auth_events):
+def _resolve_normal_events(
+ events: List[EventBase], auth_events: StateMap[EventBase]
+) -> EventBase:
for event in _ordered_events(events):
try:
# The signatures have already been checked at this point
@@ -314,7 +340,7 @@ def _resolve_normal_events(events, auth_events):
return event
-def _ordered_events(events):
+def _ordered_events(events: Iterable[EventBase]) -> List[EventBase]:
def key_func(e):
# we have to use utf-8 rather than ascii here because it turns out we allow
# people to send us events with non-ascii event IDs :/
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index 18484e2fa6..edf94e7ad6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,11 +16,21 @@
import heapq
import itertools
import logging
-from typing import Dict, List, Optional
-
-from six import iteritems, itervalues
-
-from twisted.internet import defer
+from typing import (
+ Any,
+ Callable,
+ Dict,
+ Generator,
+ Iterable,
+ List,
+ Optional,
+ Sequence,
+ Set,
+ Tuple,
+ overload,
+)
+
+from typing_extensions import Literal
import synapse.state
from synapse import event_auth
@@ -28,29 +38,34 @@ from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
from synapse.events import EventBase
-from synapse.types import StateMap
+from synapse.types import MutableStateMap, StateMap
+from synapse.util import Clock
logger = logging.getLogger(__name__)
-@defer.inlineCallbacks
-def resolve_events_with_store(
+# We want to await to the reactor occasionally during state res when dealing
+# with large data sets, so that we don't exhaust the reactor. This is done by
+# awaiting to reactor during loops every N iterations.
+_AWAIT_AFTER_ITERATIONS = 100
+
+
+async def resolve_events_with_store(
+ clock: Clock,
room_id: str,
room_version: str,
- state_sets: List[StateMap[str]],
+ state_sets: Sequence[StateMap[str]],
event_map: Optional[Dict[str, EventBase]],
state_res_store: "synapse.state.StateResolutionStore",
-):
+) -> StateMap[str]:
"""Resolves the state using the v2 state resolution algorithm
Args:
+ clock
room_id: the room we are working in
-
room_version: The room version
-
state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
-
event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
@@ -62,8 +77,7 @@ def resolve_events_with_store(
state_res_store:
Returns:
- Deferred[dict[(str, str), str]]:
- a map from (type, state_key) to event_id.
+ A map from (type, state_key) to event_id.
"""
logger.debug("Computing conflicted state")
@@ -83,15 +97,15 @@ def resolve_events_with_store(
# Also fetch all auth events that appear in only some of the state sets'
# auth chains.
- auth_diff = yield _get_auth_chain_difference(state_sets, event_map, state_res_store)
+ auth_diff = await _get_auth_chain_difference(state_sets, event_map, state_res_store)
full_conflicted_set = set(
itertools.chain(
- itertools.chain.from_iterable(itervalues(conflicted_state)), auth_diff
+ itertools.chain.from_iterable(conflicted_state.values()), auth_diff
)
)
- events = yield state_res_store.get_events(
+ events = await state_res_store.get_events(
[eid for eid in full_conflicted_set if eid not in event_map],
allow_rejected=True,
)
@@ -114,14 +128,15 @@ def resolve_events_with_store(
eid for eid in full_conflicted_set if _is_power_event(event_map[eid])
)
- sorted_power_events = yield _reverse_topological_power_sort(
- room_id, power_events, event_map, state_res_store, full_conflicted_set
+ sorted_power_events = await _reverse_topological_power_sort(
+ clock, room_id, power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
- resolved_state = yield _iterative_auth_checks(
+ resolved_state = await _iterative_auth_checks(
+ clock,
room_id,
room_version,
sorted_power_events,
@@ -135,20 +150,22 @@ def resolve_events_with_store(
# OK, so we've now resolved the power events. Now sort the remaining
# events using the mainline of the resolved power level.
+ set_power_events = set(sorted_power_events)
leftover_events = [
- ev_id for ev_id in full_conflicted_set if ev_id not in sorted_power_events
+ ev_id for ev_id in full_conflicted_set if ev_id not in set_power_events
]
logger.debug("sorting %d remaining events", len(leftover_events))
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
- leftover_events = yield _mainline_sort(
- room_id, leftover_events, pl, event_map, state_res_store
+ leftover_events = await _mainline_sort(
+ clock, room_id, leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
- resolved_state = yield _iterative_auth_checks(
+ resolved_state = await _iterative_auth_checks(
+ clock,
room_id,
room_version,
leftover_events,
@@ -167,25 +184,29 @@ def resolve_events_with_store(
return resolved_state
-@defer.inlineCallbacks
-def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
+async def _get_power_level_for_sender(
+ room_id: str,
+ event_id: str,
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+) -> int:
"""Return the power level of the sender of the given event according to
their auth events.
Args:
- room_id (str)
- event_id (str)
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
+ room_id
+ event_id
+ event_map
+ state_res_store
Returns:
- Deferred[int]
+ The power level.
"""
- event = yield _get_event(room_id, event_id, event_map, state_res_store)
+ event = await _get_event(room_id, event_id, event_map, state_res_store)
pl = None
for aid in event.auth_event_ids():
- aev = yield _get_event(
+ aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
@@ -195,7 +216,7 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
if pl is None:
# Couldn't find power level. Check if they're the creator of the room
for aid in event.auth_event_ids():
- aev = yield _get_event(
+ aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
@@ -214,38 +235,43 @@ def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
return int(level)
-@defer.inlineCallbacks
-def _get_auth_chain_difference(state_sets, event_map, state_res_store):
+async def _get_auth_chain_difference(
+ state_sets: Sequence[StateMap[str]],
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+) -> Set[str]:
"""Compare the auth chains of each state set and return the set of events
that only appear in some but not all of the auth chains.
Args:
- state_sets (list)
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
+ state_sets
+ event_map
+ state_res_store
Returns:
- Deferred[set[str]]: Set of event IDs
+ Set of event IDs
"""
- difference = yield state_res_store.get_auth_chain_difference(
+ difference = await state_res_store.get_auth_chain_difference(
[set(state_set.values()) for state_set in state_sets]
)
return difference
-def _seperate(state_sets):
+def _seperate(
+ state_sets: Iterable[StateMap[str]],
+) -> Tuple[StateMap[str], StateMap[Set[str]]]:
"""Return the unconflicted and conflicted state. This is different than in
the original algorithm, as this defines a key to be conflicted if one of
the state sets doesn't have that key.
Args:
- state_sets (list)
+ state_sets
Returns:
- tuple[dict, dict]: A tuple of unconflicted and conflicted state. The
- conflicted state dict is a map from type/state_key to set of event IDs
+ A tuple of unconflicted and conflicted state. The conflicted state dict
+ is a map from type/state_key to set of event IDs
"""
unconflicted_state = {}
conflicted_state = {}
@@ -258,18 +284,20 @@ def _seperate(state_sets):
event_ids.discard(None)
conflicted_state[key] = event_ids
- return unconflicted_state, conflicted_state
+ # mypy doesn't understand that discarding None above means that conflicted
+ # state is StateMap[Set[str]], not StateMap[Set[Optional[Str]]].
+ return unconflicted_state, conflicted_state # type: ignore
-def _is_power_event(event):
+def _is_power_event(event: EventBase) -> bool:
"""Return whether or not the event is a "power event", as defined by the
v2 state resolution algorithm
Args:
- event (FrozenEvent)
+ event
Returns:
- boolean
+ True if the event is a power event.
"""
if (event.type, event.state_key) in (
(EventTypes.PowerLevels, ""),
@@ -285,21 +313,24 @@ def _is_power_event(event):
return False
-@defer.inlineCallbacks
-def _add_event_and_auth_chain_to_graph(
- graph, room_id, event_id, event_map, state_res_store, auth_diff
-):
+async def _add_event_and_auth_chain_to_graph(
+ graph: Dict[str, Set[str]],
+ room_id: str,
+ event_id: str,
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+ auth_diff: Set[str],
+) -> None:
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
Args:
- graph (dict[str, set[str]]): A map from event ID to the events auth
- event IDs
- room_id (str): the room we are working in
- event_id (str): Event to add to the graph
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
- auth_diff (set[str]): Set of event IDs that are in the auth difference.
+ graph: A map from event ID to the events auth event IDs
+ room_id: the room we are working in
+ event_id: Event to add to the graph
+ event_map
+ state_res_store
+ auth_diff: Set of event IDs that are in the auth difference.
"""
state = [event_id]
@@ -307,7 +338,7 @@ def _add_event_and_auth_chain_to_graph(
eid = state.pop()
graph.setdefault(eid, set())
- event = yield _get_event(room_id, eid, event_map, state_res_store)
+ event = await _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids():
if aid in auth_diff:
if aid not in graph:
@@ -316,37 +347,52 @@ def _add_event_and_auth_chain_to_graph(
graph.setdefault(eid, set()).add(aid)
-@defer.inlineCallbacks
-def _reverse_topological_power_sort(
- room_id, event_ids, event_map, state_res_store, auth_diff
-):
+async def _reverse_topological_power_sort(
+ clock: Clock,
+ room_id: str,
+ event_ids: Iterable[str],
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+ auth_diff: Set[str],
+) -> List[str]:
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
- room_id (str): the room we are working in
- event_ids (list[str]): The events to sort
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
- auth_diff (set[str]): Set of event IDs that are in the auth difference.
+ clock
+ room_id: the room we are working in
+ event_ids: The events to sort
+ event_map
+ state_res_store
+ auth_diff: Set of event IDs that are in the auth difference.
Returns:
- Deferred[list[str]]: The sorted list
+ The sorted list
"""
- graph = {}
- for event_id in event_ids:
- yield _add_event_and_auth_chain_to_graph(
+ graph = {} # type: Dict[str, Set[str]]
+ for idx, event_id in enumerate(event_ids, start=1):
+ await _add_event_and_auth_chain_to_graph(
graph, room_id, event_id, event_map, state_res_store, auth_diff
)
+ # We await occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
+
event_to_pl = {}
- for event_id in graph:
- pl = yield _get_power_level_for_sender(
+ for idx, event_id in enumerate(graph, start=1):
+ pl = await _get_power_level_for_sender(
room_id, event_id, event_map, state_res_store
)
event_to_pl[event_id] = pl
+ # We await occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
+
def _get_power_order(event_id):
ev = event_map[event_id]
pl = event_to_pl[event_id]
@@ -360,33 +406,39 @@ def _reverse_topological_power_sort(
return sorted_events
-@defer.inlineCallbacks
-def _iterative_auth_checks(
- room_id, room_version, event_ids, base_state, event_map, state_res_store
-):
+async def _iterative_auth_checks(
+ clock: Clock,
+ room_id: str,
+ room_version: str,
+ event_ids: List[str],
+ base_state: StateMap[str],
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+) -> MutableStateMap[str]:
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Args:
- room_id (str)
- room_version (str)
- event_ids (list[str]): Ordered list of events to apply auth checks to
- base_state (StateMap[str]): The set of state to start with
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
+ clock
+ room_id
+ room_version
+ event_ids: Ordered list of events to apply auth checks to
+ base_state: The set of state to start with
+ event_map
+ state_res_store
Returns:
- Deferred[StateMap[str]]: Returns the final updated state
+ Returns the final updated state
"""
- resolved_state = base_state.copy()
+ resolved_state = dict(base_state)
room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
- for event_id in event_ids:
+ for idx, event_id in enumerate(event_ids, start=1):
event = event_map[event_id]
auth_events = {}
for aid in event.auth_event_ids():
- ev = yield _get_event(
+ ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
@@ -401,7 +453,7 @@ def _iterative_auth_checks(
for key in event_auth.auth_types_for_event(event):
if key in resolved_state:
ev_id = resolved_state[key]
- ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
+ ev = await _get_event(room_id, ev_id, event_map, state_res_store)
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
@@ -419,114 +471,173 @@ def _iterative_auth_checks(
except AuthError:
pass
+ # We await occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
+
return resolved_state
-@defer.inlineCallbacks
-def _mainline_sort(
- room_id, event_ids, resolved_power_event_id, event_map, state_res_store
-):
+async def _mainline_sort(
+ clock: Clock,
+ room_id: str,
+ event_ids: List[str],
+ resolved_power_event_id: Optional[str],
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+) -> List[str]:
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
- room_id (str): room we're working in
- event_ids (list[str]): Events to sort
- resolved_power_event_id (str): The final resolved power level event ID
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
+ clock
+ room_id: room we're working in
+ event_ids: Events to sort
+ resolved_power_event_id: The final resolved power level event ID
+ event_map
+ state_res_store
Returns:
- Deferred[list[str]]: The sorted list
+ The sorted list
"""
+ if not event_ids:
+ # It's possible for there to be no event IDs here to sort, so we can
+ # skip calculating the mainline in that case.
+ return []
+
mainline = []
pl = resolved_power_event_id
+ idx = 0
while pl:
mainline.append(pl)
- pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
+ pl_ev = await _get_event(room_id, pl, event_map, state_res_store)
auth_events = pl_ev.auth_event_ids()
pl = None
for aid in auth_events:
- ev = yield _get_event(
+ ev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid
break
+ # We await occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx != 0 and idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
+
+ idx += 1
+
mainline_map = {ev_id: i + 1 for i, ev_id in enumerate(reversed(mainline))}
event_ids = list(event_ids)
order_map = {}
- for ev_id in event_ids:
- depth = yield _get_mainline_depth_for_event(
+ for idx, ev_id in enumerate(event_ids, start=1):
+ depth = await _get_mainline_depth_for_event(
event_map[ev_id], mainline_map, event_map, state_res_store
)
order_map[ev_id] = (depth, event_map[ev_id].origin_server_ts, ev_id)
+ # We await occasionally when we're working with large data sets to
+ # ensure that we don't block the reactor loop for too long.
+ if idx % _AWAIT_AFTER_ITERATIONS == 0:
+ await clock.sleep(0)
+
event_ids.sort(key=lambda ev_id: order_map[ev_id])
return event_ids
-@defer.inlineCallbacks
-def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_store):
+async def _get_mainline_depth_for_event(
+ event: EventBase,
+ mainline_map: Dict[str, int],
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+) -> int:
"""Get the mainline depths for the given event based on the mainline map
Args:
- event (FrozenEvent)
- mainline_map (dict[str, int]): Map from event_id to mainline depth for
- events in the mainline.
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
+ event
+ mainline_map: Map from event_id to mainline depth for events in the mainline.
+ event_map
+ state_res_store
Returns:
- Deferred[int]
+ The mainline depth
"""
room_id = event.room_id
+ tmp_event = event # type: Optional[EventBase]
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
- while event:
+ while tmp_event:
depth = mainline_map.get(event.event_id)
if depth is not None:
return depth
- auth_events = event.auth_event_ids()
- event = None
+ auth_events = tmp_event.auth_event_ids()
+ tmp_event = None
for aid in auth_events:
- aev = yield _get_event(
+ aev = await _get_event(
room_id, aid, event_map, state_res_store, allow_none=True
)
if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
- event = aev
+ tmp_event = aev
break
# Didn't find a power level auth event, so we just return 0
return 0
-@defer.inlineCallbacks
-def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
+@overload
+async def _get_event(
+ room_id: str,
+ event_id: str,
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+ allow_none: Literal[False] = False,
+) -> EventBase:
+ ...
+
+
+@overload
+async def _get_event(
+ room_id: str,
+ event_id: str,
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+ allow_none: Literal[True],
+) -> Optional[EventBase]:
+ ...
+
+
+async def _get_event(
+ room_id: str,
+ event_id: str,
+ event_map: Dict[str, EventBase],
+ state_res_store: "synapse.state.StateResolutionStore",
+ allow_none: bool = False,
+) -> Optional[EventBase]:
"""Helper function to look up event in event_map, falling back to looking
it up in the store
Args:
- room_id (str)
- event_id (str)
- event_map (dict[str,FrozenEvent])
- state_res_store (StateResolutionStore)
- allow_none (bool): if the event is not found, return None rather than raising
+ room_id
+ event_id
+ event_map
+ state_res_store
+ allow_none: if the event is not found, return None rather than raising
an exception
Returns:
- Deferred[Optional[FrozenEvent]]
+ The event, or none if the event does not exist (and allow_none is True).
"""
if event_id not in event_map:
- events = yield state_res_store.get_events([event_id], allow_rejected=True)
+ events = await state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events)
event = event_map.get(event_id)
@@ -543,7 +654,9 @@ def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
return event
-def lexicographical_topological_sort(graph, key):
+def lexicographical_topological_sort(
+ graph: Dict[str, Set[str]], key: Callable[[str], Any]
+) -> Generator[str, None, None]:
"""Performs a lexicographic reverse topological sort on the graph.
This returns a reverse topological sort (i.e. if node A references B then B
@@ -553,26 +666,26 @@ def lexicographical_topological_sort(graph, key):
NOTE: `graph` is modified during the sort.
Args:
- graph (dict[str, set[str]]): A representation of the graph where each
- node is a key in the dict and its value are the nodes edges.
- key (func): A function that takes a node and returns a value that is
- comparable and used to order nodes
+ graph: A representation of the graph where each node is a key in the
+ dict and its value are the nodes edges.
+ key: A function that takes a node and returns a value that is comparable
+ and used to order nodes
Yields:
- str: The next node in the topological sort
+ The next node in the topological sort
"""
# Note, this is basically Kahn's algorithm except we look at nodes with no
# outgoing edges, c.f.
# https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm
outdegree_map = graph
- reverse_graph = {}
+ reverse_graph = {} # type: Dict[str, Set[str]]
# Lists of nodes with zero out degree. Is actually a tuple of
# `(key(node), node)` so that sorting does the right thing
zero_outdegree = []
- for node, edges in iteritems(graph):
+ for node, edges in graph.items():
if len(edges) == 0:
zero_outdegree.append((key(node), node))
|