diff --git a/synapse/state.py b/synapse/state.py
index c75499c3e0..90b14e758c 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -337,7 +337,7 @@ class StateHandler(object):
for st in state_groups_ids.values()
]
with Measure(self.clock, "state._resolve_events"):
- new_state, _ = Resolver.resolve_events(
+ new_state, _ = resolve_events(
state_sets, event_type, state_key
)
new_state = {
@@ -392,11 +392,11 @@ class StateHandler(object):
)
with Measure(self.clock, "state._resolve_events"):
if event.is_state():
- return Resolver.resolve_events(
+ return resolve_events(
state_sets, event.type, event.state_key
)
else:
- return Resolver.resolve_events(state_sets)
+ return resolve_events(state_sets)
def _ordered_events(events):
@@ -406,138 +406,136 @@ def _ordered_events(events):
return sorted(events, key=key_func)
-class Resolver(object):
- @staticmethod
- def resolve_events(state_sets, event_type=None, state_key=""):
- """
- Returns
- (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
- (new_state, prev_states). new_state is a map from (type, state_key)
- to event. prev_states is a list of event_ids.
- """
- state = {}
- for st in state_sets:
- for e in st:
- state.setdefault(
- (e.type, e.state_key),
- {}
- )[e.event_id] = e
-
- unconflicted_state = {
- k: v.values()[0] for k, v in state.items()
- if len(v.values()) == 1
- }
+def resolve_events(state_sets, event_type=None, state_key=""):
+ """
+ Returns
+ (dict[(str, str), synapse.events.FrozenEvent], list[str]): a tuple
+ (new_state, prev_states). new_state is a map from (type, state_key)
+ to event. prev_states is a list of event_ids.
+ """
+ state = {}
+ for st in state_sets:
+ for e in st:
+ state.setdefault(
+ (e.type, e.state_key),
+ {}
+ )[e.event_id] = e
+
+ unconflicted_state = {
+ k: v.values()[0] for k, v in state.items()
+ if len(v.values()) == 1
+ }
+
+ conflicted_state = {
+ k: v.values()
+ for k, v in state.items()
+ if len(v.values()) > 1
+ }
+
+ if event_type:
+ prev_states_events = conflicted_state.get(
+ (event_type, state_key), []
+ )
+ prev_states = [s.event_id for s in prev_states_events]
+ else:
+ prev_states = []
+
+ auth_events = {
+ k: e for k, e in unconflicted_state.items()
+ if k[0] in AuthEventTypes
+ }
+
+ try:
+ resolved_state = _resolve_state_events(
+ conflicted_state, auth_events
+ )
+ except:
+ logger.exception("Failed to resolve state")
+ raise
- conflicted_state = {
- k: v.values()
- for k, v in state.items()
- if len(v.values()) > 1
- }
+ new_state = unconflicted_state
+ new_state.update(resolved_state)
- if event_type:
- prev_states_events = conflicted_state.get(
- (event_type, state_key), []
+ return new_state, prev_states
+
+
+def _resolve_state_events(conflicted_state, auth_events):
+ """ This is where we actually decide which of the conflicted state to
+ use.
+
+ We resolve conflicts in the following order:
+ 1. power levels
+ 2. join rules
+ 3. memberships
+ 4. other events.
+ """
+ resolved_state = {}
+ power_key = (EventTypes.PowerLevels, "")
+ if power_key in conflicted_state:
+ events = conflicted_state[power_key]
+ logger.debug("Resolving conflicted power levels %r", events)
+ resolved_state[power_key] = _resolve_auth_events(
+ events, auth_events)
+
+ auth_events.update(resolved_state)
+
+ for key, events in conflicted_state.items():
+ if key[0] == EventTypes.JoinRules:
+ logger.debug("Resolving conflicted join rules %r", events)
+ resolved_state[key] = _resolve_auth_events(
+ events,
+ auth_events
)
- prev_states = [s.event_id for s in prev_states_events]
- else:
- prev_states = []
- auth_events = {
- k: e for k, e in unconflicted_state.items()
- if k[0] in AuthEventTypes
- }
+ auth_events.update(resolved_state)
- try:
- resolved_state = Resolver._resolve_state_events(
- conflicted_state, auth_events
+ for key, events in conflicted_state.items():
+ if key[0] == EventTypes.Member:
+ logger.debug("Resolving conflicted member lists %r", events)
+ resolved_state[key] = _resolve_auth_events(
+ events,
+ auth_events
)
- except:
- logger.exception("Failed to resolve state")
- raise
- new_state = unconflicted_state
- new_state.update(resolved_state)
+ auth_events.update(resolved_state)
- return new_state, prev_states
+ for key, events in conflicted_state.items():
+ if key not in resolved_state:
+ logger.debug("Resolving conflicted state %r:%r", key, events)
+ resolved_state[key] = _resolve_normal_events(
+ events, auth_events
+ )
- @staticmethod
- def _resolve_state_events(conflicted_state, auth_events):
- """ This is where we actually decide which of the conflicted state to
- use.
+ return resolved_state
- We resolve conflicts in the following order:
- 1. power levels
- 2. join rules
- 3. memberships
- 4. other events.
- """
- resolved_state = {}
- power_key = (EventTypes.PowerLevels, "")
- if power_key in conflicted_state:
- events = conflicted_state[power_key]
- logger.debug("Resolving conflicted power levels %r", events)
- resolved_state[power_key] = Resolver._resolve_auth_events(
- events, auth_events)
-
- auth_events.update(resolved_state)
-
- for key, events in conflicted_state.items():
- if key[0] == EventTypes.JoinRules:
- logger.debug("Resolving conflicted join rules %r", events)
- resolved_state[key] = Resolver._resolve_auth_events(
- events,
- auth_events
- )
- auth_events.update(resolved_state)
+def _resolve_auth_events(events, auth_events):
+ reverse = [i for i in reversed(_ordered_events(events))]
- for key, events in conflicted_state.items():
- if key[0] == EventTypes.Member:
- logger.debug("Resolving conflicted member lists %r", events)
- resolved_state[key] = Resolver._resolve_auth_events(
- events,
- auth_events
- )
+ auth_events = dict(auth_events)
+
+ prev_event = reverse[0]
+ for event in reverse[1:]:
+ auth_events[(prev_event.type, prev_event.state_key)] = prev_event
+ try:
+ # The signatures have already been checked at this point
+ event_auth.check(event, auth_events, do_sig_check=False)
+ prev_event = event
+ except AuthError:
+ return prev_event
- auth_events.update(resolved_state)
+ return event
- for key, events in conflicted_state.items():
- if key not in resolved_state:
- logger.debug("Resolving conflicted state %r:%r", key, events)
- resolved_state[key] = Resolver._resolve_normal_events(
- events, auth_events
- )
- return resolved_state
-
- @staticmethod
- def _resolve_auth_events(events, auth_events):
- reverse = [i for i in reversed(_ordered_events(events))]
-
- auth_events = dict(auth_events)
-
- prev_event = reverse[0]
- for event in reverse[1:]:
- auth_events[(prev_event.type, prev_event.state_key)] = prev_event
- try:
- # The signatures have already been checked at this point
- event_auth.check(event, auth_events, do_sig_check=False)
- prev_event = event
- except AuthError:
- return prev_event
-
- return event
-
- @staticmethod
- def _resolve_normal_events(events, auth_events):
- for event in _ordered_events(events):
- try:
- # The signatures have already been checked at this point
- event_auth.check(event, auth_events, do_sig_check=False)
- return event
- except AuthError:
- pass
-
- # Use the last event (the one with the least depth) if they all fail
- # the auth check.
- return event
+def _resolve_normal_events(events, auth_events):
+ for event in _ordered_events(events):
+ try:
+ # The signatures have already been checked at this point
+ event_auth.check(event, auth_events, do_sig_check=False)
+ return event
+ except AuthError:
+ pass
+
+ # Use the last event (the one with the least depth) if they all fail
+ # the auth check.
+ return event
|