diff --git a/synapse/state.py b/synapse/state.py
index 2ffdb0b01e..1bf4a0df6f 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -531,24 +531,33 @@ class StateResolutionHandler(object):
def resolve_delta_state(self, unchanged_state_groups, changed_groups, conflicted_state, store,
event_map, state_map_factory):
deltas = set(entry for _, _, group_deltas in changed_groups for entry in group_deltas)
+ conflicted_state_keys = set((etype, state_key) for etype, state_key, _ in conflicted_state)
+ conflicted_state_keys.update(deltas)
- to_recalculate = event_auth.filter_dependent_state(deltas, conflicted_state)
+ if not deltas:
+ logger.info("No deltas")
+ defer.returnValue({}, conflicted_state)
+
+ to_recalculate = set(event_auth.filter_dependent_state(deltas, conflicted_state))
+ to_recalculate = set((etype, state_key) for etype, state_key, _ in to_recalculate)
+ to_recalculate.update(deltas)
new_groups = list(unchanged_state_groups)
new_groups.extend(g for _, g, _ in changed_groups)
- types = [(etype, state_key) for etype, state_key, _ in to_recalculate]
- state_sets = yield store._get_state_groups_from_groups(new_groups, types=types)
+ state_sets = yield store._get_state_groups_from_groups(new_groups, types=to_recalculate)
state_sets = {
sg: {
key: cs[key]
- for key in types
+ for key in to_recalculate
+ if key in cs
}
for sg, cs in state_sets.iteritems()
}
logger.info("Recalculating: %s", to_recalculate)
+ logger.info("State Sets: %s", state_sets)
group_names = frozenset(new_groups)
with (yield self.resolve_linearizer.queue(group_names)):
@@ -557,9 +566,10 @@ class StateResolutionHandler(object):
cache = self._state_cache.get(group_names, None)
if cache:
+ logger.info("Using cache")
new_state = {
- (etype, state_key): cache.state[(etype, state_key)]
- for etype, state_key, _ in to_recalculate
+ key: cache.state[key]
+ for key in to_recalculate
}
unconflicted_state, _ = _seperate(
state_sets.values(),
@@ -585,13 +595,23 @@ class StateResolutionHandler(object):
_, cs = _seperate(state_sets.values())
needed_state = _get_auth_event_keys(cs, state_map)
- sg = new_groups[0]
- res = yield store._get_state_for_groups((sg,), types=needed_state)
- state = res[sg]
- for s in state_sets.itervalues():
- s.update(state)
- logger.info("Added state: %s", len(needed_state))
+ if needed_state - conflicted_state_keys:
+ res = yield store._get_state_for_groups((new_groups[0],), types=(needed_state - conflicted_state_keys))
+ state = res[new_groups[0]]
+ for s in state_sets.itervalues():
+ s.update(state)
+
+ logger.info("Added unconflicted state: %s", state)
+
+ needed_conflicted_state = needed_state & conflicted_state_keys
+ if needed_conflicted_state:
+ for sg, s in state_sets.iteritems():
+ res = yield store._get_state_for_groups((sg,), types=needed_conflicted_state)
+ state = res[sg]
+ s.update(state)
+
+ logger.info("Added conflicted state: %s", needed_conflicted_state)
new_state, unconflicted_state, _ = yield resolve_events_with_factory(
state_sets.values(),
@@ -601,6 +621,8 @@ class StateResolutionHandler(object):
conflicted_state = [e for e in conflicted_state if (e[0], e[1]) not in unconflicted_state]
+ logger.info("Returning: %s", new_state)
+
defer.returnValue((new_state, conflicted_state))
|