diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index b327c86f40..18484e2fa6 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -16,29 +16,42 @@
import heapq
import itertools
import logging
+from typing import Dict, List, Optional
from six import iteritems, itervalues
from twisted.internet import defer
+import synapse.state
from synapse import event_auth
from synapse.api.constants import EventTypes
from synapse.api.errors import AuthError
+from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
+from synapse.events import EventBase
+from synapse.types import StateMap
logger = logging.getLogger(__name__)
@defer.inlineCallbacks
-def resolve_events_with_store(room_version, state_sets, event_map, state_res_store):
+def resolve_events_with_store(
+ room_id: str,
+ room_version: str,
+ state_sets: List[StateMap[str]],
+ event_map: Optional[Dict[str, EventBase]],
+ state_res_store: "synapse.state.StateResolutionStore",
+):
"""Resolves the state using the v2 state resolution algorithm
Args:
- room_version (str): The room version
+ room_id: the room we are working in
+
+ room_version: The room version
- state_sets(list): List of dicts of (type, state_key) -> event_id,
+ state_sets: List of dicts of (type, state_key) -> event_id,
which are the different state groups to resolve.
- event_map(dict[str,FrozenEvent]|None):
+ event_map:
a dict from event_id to event, for any events that we happen to
have in flight (eg, those currently being persisted). This will be
used as a starting point fof finding the state we need; any missing
@@ -46,9 +59,9 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
If None, all events will be fetched via state_res_store.
- state_res_store (StateResolutionStore)
+ state_res_store:
- Returns
+ Returns:
Deferred[dict[(str, str), str]]:
a map from (type, state_key) to event_id.
"""
@@ -84,7 +97,15 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
)
event_map.update(events)
- full_conflicted_set = set(eid for eid in full_conflicted_set if eid in event_map)
+ # everything in the event map should be in the right room
+ for event in event_map.values():
+ if event.room_id != room_id:
+ raise Exception(
+ "Attempting to state-resolve for room %s with event %s which is in %s"
+ % (room_id, event.event_id, event.room_id,)
+ )
+
+ full_conflicted_set = {eid for eid in full_conflicted_set if eid in event_map}
logger.debug("%d full_conflicted_set entries", len(full_conflicted_set))
@@ -94,13 +115,14 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
)
sorted_power_events = yield _reverse_topological_power_sort(
- power_events, event_map, state_res_store, full_conflicted_set
+ room_id, power_events, event_map, state_res_store, full_conflicted_set
)
logger.debug("sorted %d power events", len(sorted_power_events))
# Now sequentially auth each one
resolved_state = yield _iterative_auth_checks(
+ room_id,
room_version,
sorted_power_events,
unconflicted_state,
@@ -121,13 +143,18 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
pl = resolved_state.get((EventTypes.PowerLevels, ""), None)
leftover_events = yield _mainline_sort(
- leftover_events, pl, event_map, state_res_store
+ room_id, leftover_events, pl, event_map, state_res_store
)
logger.debug("resolving remaining events")
resolved_state = yield _iterative_auth_checks(
- room_version, leftover_events, resolved_state, event_map, state_res_store
+ room_id,
+ room_version,
+ leftover_events,
+ resolved_state,
+ event_map,
+ state_res_store,
)
logger.debug("resolved")
@@ -141,11 +168,12 @@ def resolve_events_with_store(room_version, state_sets, event_map, state_res_sto
@defer.inlineCallbacks
-def _get_power_level_for_sender(event_id, event_map, state_res_store):
+def _get_power_level_for_sender(room_id, event_id, event_map, state_res_store):
"""Return the power level of the sender of the given event according to
their auth events.
Args:
+ room_id (str)
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
@@ -153,20 +181,24 @@ def _get_power_level_for_sender(event_id, event_map, state_res_store):
Returns:
Deferred[int]
"""
- event = yield _get_event(event_id, event_map, state_res_store)
+ event = yield _get_event(room_id, event_id, event_map, state_res_store)
pl = None
for aid in event.auth_event_ids():
- aev = yield _get_event(aid, event_map, state_res_store)
- if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
+ aev = yield _get_event(
+ room_id, aid, event_map, state_res_store, allow_none=True
+ )
+ if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
pl = aev
break
if pl is None:
# Couldn't find power level. Check if they're the creator of the room
for aid in event.auth_event_ids():
- aev = yield _get_event(aid, event_map, state_res_store)
- if (aev.type, aev.state_key) == (EventTypes.Create, ""):
+ aev = yield _get_event(
+ room_id, aid, event_map, state_res_store, allow_none=True
+ )
+ if aev and (aev.type, aev.state_key) == (EventTypes.Create, ""):
if aev.content.get("creator") == event.sender:
return 100
break
@@ -195,36 +227,12 @@ def _get_auth_chain_difference(state_sets, event_map, state_res_store):
Returns:
Deferred[set[str]]: Set of event IDs
"""
- common = set(itervalues(state_sets[0])).intersection(
- *(itervalues(s) for s in state_sets[1:])
- )
-
- auth_sets = []
- for state_set in state_sets:
- auth_ids = set(
- eid
- for key, eid in iteritems(state_set)
- if (
- key[0] in (EventTypes.Member, EventTypes.ThirdPartyInvite)
- or key
- in (
- (EventTypes.PowerLevels, ""),
- (EventTypes.Create, ""),
- (EventTypes.JoinRules, ""),
- )
- )
- and eid not in common
- )
- auth_chain = yield state_res_store.get_auth_chain(auth_ids)
- auth_ids.update(auth_chain)
-
- auth_sets.append(auth_ids)
-
- intersection = set(auth_sets[0]).intersection(*auth_sets[1:])
- union = set().union(*auth_sets)
+ difference = yield state_res_store.get_auth_chain_difference(
+ [set(state_set.values()) for state_set in state_sets]
+ )
- return union - intersection
+ return difference
def _seperate(state_sets):
@@ -243,7 +251,7 @@ def _seperate(state_sets):
conflicted_state = {}
for key in set(itertools.chain.from_iterable(state_sets)):
- event_ids = set(state_set.get(key) for state_set in state_sets)
+ event_ids = {state_set.get(key) for state_set in state_sets}
if len(event_ids) == 1:
unconflicted_state[key] = event_ids.pop()
else:
@@ -279,7 +287,7 @@ def _is_power_event(event):
@defer.inlineCallbacks
def _add_event_and_auth_chain_to_graph(
- graph, event_id, event_map, state_res_store, auth_diff
+ graph, room_id, event_id, event_map, state_res_store, auth_diff
):
"""Helper function for _reverse_topological_power_sort that add the event
and its auth chain (that is in the auth diff) to the graph
@@ -287,6 +295,7 @@ def _add_event_and_auth_chain_to_graph(
Args:
graph (dict[str, set[str]]): A map from event ID to the events auth
event IDs
+ room_id (str): the room we are working in
event_id (str): Event to add to the graph
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
@@ -298,7 +307,7 @@ def _add_event_and_auth_chain_to_graph(
eid = state.pop()
graph.setdefault(eid, set())
- event = yield _get_event(eid, event_map, state_res_store)
+ event = yield _get_event(room_id, eid, event_map, state_res_store)
for aid in event.auth_event_ids():
if aid in auth_diff:
if aid not in graph:
@@ -308,11 +317,14 @@ def _add_event_and_auth_chain_to_graph(
@defer.inlineCallbacks
-def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_diff):
+def _reverse_topological_power_sort(
+ room_id, event_ids, event_map, state_res_store, auth_diff
+):
"""Returns a list of the event_ids sorted by reverse topological ordering,
and then by power level and origin_server_ts
Args:
+ room_id (str): the room we are working in
event_ids (list[str]): The events to sort
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
@@ -325,12 +337,14 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
graph = {}
for event_id in event_ids:
yield _add_event_and_auth_chain_to_graph(
- graph, event_id, event_map, state_res_store, auth_diff
+ graph, room_id, event_id, event_map, state_res_store, auth_diff
)
event_to_pl = {}
for event_id in graph:
- pl = yield _get_power_level_for_sender(event_id, event_map, state_res_store)
+ pl = yield _get_power_level_for_sender(
+ room_id, event_id, event_map, state_res_store
+ )
event_to_pl[event_id] = pl
def _get_power_order(event_id):
@@ -348,44 +362,53 @@ def _reverse_topological_power_sort(event_ids, event_map, state_res_store, auth_
@defer.inlineCallbacks
def _iterative_auth_checks(
- room_version, event_ids, base_state, event_map, state_res_store
+ room_id, room_version, event_ids, base_state, event_map, state_res_store
):
"""Sequentially apply auth checks to each event in given list, updating the
state as it goes along.
Args:
+ room_id (str)
room_version (str)
event_ids (list[str]): Ordered list of events to apply auth checks to
- base_state (dict[tuple[str, str], str]): The set of state to start with
+ base_state (StateMap[str]): The set of state to start with
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
Returns:
- Deferred[dict[tuple[str, str], str]]: Returns the final updated state
+ Deferred[StateMap[str]]: Returns the final updated state
"""
resolved_state = base_state.copy()
+ room_version_obj = KNOWN_ROOM_VERSIONS[room_version]
for event_id in event_ids:
event = event_map[event_id]
auth_events = {}
for aid in event.auth_event_ids():
- ev = yield _get_event(aid, event_map, state_res_store)
+ ev = yield _get_event(
+ room_id, aid, event_map, state_res_store, allow_none=True
+ )
- if ev.rejected_reason is None:
- auth_events[(ev.type, ev.state_key)] = ev
+ if not ev:
+ logger.warning(
+ "auth_event id %s for event %s is missing", aid, event_id
+ )
+ else:
+ if ev.rejected_reason is None:
+ auth_events[(ev.type, ev.state_key)] = ev
for key in event_auth.auth_types_for_event(event):
if key in resolved_state:
ev_id = resolved_state[key]
- ev = yield _get_event(ev_id, event_map, state_res_store)
+ ev = yield _get_event(room_id, ev_id, event_map, state_res_store)
if ev.rejected_reason is None:
auth_events[key] = event_map[ev_id]
try:
event_auth.check(
- room_version,
+ room_version_obj,
event,
auth_events,
do_sig_check=False,
@@ -400,11 +423,14 @@ def _iterative_auth_checks(
@defer.inlineCallbacks
-def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_store):
+def _mainline_sort(
+ room_id, event_ids, resolved_power_event_id, event_map, state_res_store
+):
"""Returns a sorted list of event_ids sorted by mainline ordering based on
the given event resolved_power_event_id
Args:
+ room_id (str): room we're working in
event_ids (list[str]): Events to sort
resolved_power_event_id (str): The final resolved power level event ID
event_map (dict[str,FrozenEvent])
@@ -417,12 +443,14 @@ def _mainline_sort(event_ids, resolved_power_event_id, event_map, state_res_stor
pl = resolved_power_event_id
while pl:
mainline.append(pl)
- pl_ev = yield _get_event(pl, event_map, state_res_store)
+ pl_ev = yield _get_event(room_id, pl, event_map, state_res_store)
auth_events = pl_ev.auth_event_ids()
pl = None
for aid in auth_events:
- ev = yield _get_event(aid, event_map, state_res_store)
- if (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
+ ev = yield _get_event(
+ room_id, aid, event_map, state_res_store, allow_none=True
+ )
+ if ev and (ev.type, ev.state_key) == (EventTypes.PowerLevels, ""):
pl = aid
break
@@ -457,6 +485,8 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
Deferred[int]
"""
+ room_id = event.room_id
+
# We do an iterative search, replacing `event with the power level in its
# auth events (if any)
while event:
@@ -468,8 +498,10 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
event = None
for aid in auth_events:
- aev = yield _get_event(aid, event_map, state_res_store)
- if (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
+ aev = yield _get_event(
+ room_id, aid, event_map, state_res_store, allow_none=True
+ )
+ if aev and (aev.type, aev.state_key) == (EventTypes.PowerLevels, ""):
event = aev
break
@@ -478,22 +510,37 @@ def _get_mainline_depth_for_event(event, mainline_map, event_map, state_res_stor
@defer.inlineCallbacks
-def _get_event(event_id, event_map, state_res_store):
+def _get_event(room_id, event_id, event_map, state_res_store, allow_none=False):
"""Helper function to look up event in event_map, falling back to looking
it up in the store
Args:
+ room_id (str)
event_id (str)
event_map (dict[str,FrozenEvent])
state_res_store (StateResolutionStore)
+ allow_none (bool): if the event is not found, return None rather than raising
+ an exception
Returns:
- Deferred[FrozenEvent]
+ Deferred[Optional[FrozenEvent]]
"""
if event_id not in event_map:
events = yield state_res_store.get_events([event_id], allow_rejected=True)
event_map.update(events)
- return event_map[event_id]
+ event = event_map.get(event_id)
+
+ if event is None:
+ if allow_none:
+ return None
+ raise Exception("Unknown event %s" % (event_id,))
+
+ if event.room_id != room_id:
+ raise Exception(
+ "In state res for room %s, event %s is in %s"
+ % (room_id, event_id, event.room_id)
+ )
+ return event
def lexicographical_topological_sort(graph, key):
|