diff --git a/synapse/event_auth.py b/synapse/event_auth.py
index cd5627e36a..cf48f692fb 100644
--- a/synapse/event_auth.py
+++ b/synapse/event_auth.py
@@ -676,3 +676,32 @@ def auth_types_for_event(event):
auth_types.append(key)
return auth_types
+
+
+def filter_dependent_state(keys, state):
+ if (EventTypes.Create, "") in keys:
+ return state
+
+ if (EventTypes.PowerLevels, "") in keys:
+ return state
+
+ def _filter_state(entry):
+ etype, state_key, sender = entry
+
+ if (etype, state_key) in keys:
+ return True
+
+ if (EventTypes.Member, sender) in keys:
+ return True
+
+ if etype == EventTypes.Member:
+ if (EventTypes.JoinRules, "") in keys:
+ return True
+
+ for t, _ in keys:
+ if t == EventTypes.ThirdPartyInvite:
+ return True
+
+ return False
+
+ return filter(_filter_state, state)
diff --git a/synapse/state.py b/synapse/state.py
index 26093c8434..2ffdb0b01e 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -55,9 +55,9 @@ def _gen_state_id():
class _StateCacheEntry(object):
- __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"]
+ __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids", "conflicted_state"]
- def __init__(self, state, state_group, prev_group=None, delta_ids=None):
+ def __init__(self, state, state_group, prev_group=None, delta_ids=None, conflicted_state=None):
# dict[(str, str), str] map from (type, state_key) to event_id
self.state = frozendict(state)
@@ -80,6 +80,8 @@ class _StateCacheEntry(object):
else:
self.state_id = _gen_state_id()
+ self.conflicted_state = conflicted_state
+
def __len__(self):
return len(self.state)
@@ -375,7 +377,7 @@ class StateHandler(object):
}
with Measure(self.clock, "state._resolve_events"):
- new_state = resolve_events_with_state_map(state_set_ids, state_map)
+ new_state, _, _ = resolve_events_with_state_map(state_set_ids, state_map)
new_state = {
key: state_map[ev_id] for key, ev_id in new_state.iteritems()
@@ -462,18 +464,17 @@ class StateResolutionHandler(object):
for key, e_id in st.iteritems():
state.setdefault(key, set()).add(e_id)
- # build a map from state key to the event_ids which set that state,
- # including only those where there are state keys in conflict.
- conflicted_state = {
- k: list(v)
- for k, v in state.iteritems()
- if len(v) > 1
- }
+ # Check that there is a conflict between the state groups
+ conflicted_state = False
+ for values in state.itervalues():
+ if len(values):
+ conflicted_state = True
+ break
if conflicted_state:
logger.info("Resolving conflicted state for %r", room_id)
with Measure(self.clock, "state._resolve_events"):
- new_state = yield resolve_events_with_factory(
+ new_state, _, conflicted_state_keys = yield resolve_events_with_factory(
state_groups_ids.values(),
event_map=event_map,
state_map_factory=state_map_factory,
@@ -482,6 +483,7 @@ class StateResolutionHandler(object):
new_state = {
key: e_ids.pop() for key, e_ids in state.iteritems()
}
+ conflicted_state_keys = []
with Measure(self.clock, "state.create_group_ids"):
# if the new state matches any of the input state groups, we can
@@ -517,6 +519,7 @@ class StateResolutionHandler(object):
state_group=state_group,
prev_group=prev_group,
delta_ids=delta_ids,
+ conflicted_state=conflicted_state_keys,
)
if self._state_cache is not None:
@@ -524,6 +527,82 @@ class StateResolutionHandler(object):
defer.returnValue(cache)
+ @defer.inlineCallbacks
+ def resolve_delta_state(self, unchanged_state_groups, changed_groups, conflicted_state, store,
+ event_map, state_map_factory):
+ deltas = set(entry for _, _, group_deltas in changed_groups for entry in group_deltas)
+
+ to_recalculate = event_auth.filter_dependent_state(deltas, conflicted_state)
+
+ new_groups = list(unchanged_state_groups)
+ new_groups.extend(g for _, g, _ in changed_groups)
+
+ types = [(etype, state_key) for etype, state_key, _ in to_recalculate]
+ state_sets = yield store._get_state_groups_from_groups(new_groups, types=types)
+
+ state_sets = {
+ sg: {
+ key: cs[key]
+ for key in types
+ }
+ for sg, cs in state_sets.iteritems()
+ }
+
+ logger.info("Recalculating: %s", to_recalculate)
+
+ group_names = frozenset(new_groups)
+ with (yield self.resolve_linearizer.queue(group_names)):
+ cache = None
+ if self._state_cache is not None:
+ cache = self._state_cache.get(group_names, None)
+
+ if cache:
+ new_state = {
+ (etype, state_key): cache.state[(etype, state_key)]
+ for etype, state_key, _ in to_recalculate
+ }
+ unconflicted_state, _ = _seperate(
+ state_sets.values(),
+ )
+ else:
+ needed_events = set(
+ event_id
+ for state in state_sets.itervalues()
+ for event_id in state.itervalues()
+ )
+ if event_map is not None:
+ needed_events -= set(event_map.iterkeys())
+
+ # logger.info("state_sets: %s", state_sets)
+ logger.info("Asking for %d conflicted events", len(needed_events))
+
+ # dict[str, FrozenEvent]: a map from state event id to event. Only includes
+ # the state events which are in conflict (and those in event_map)
+ state_map = yield state_map_factory(needed_events)
+ if event_map is not None:
+ state_map.update(event_map)
+
+ _, cs = _seperate(state_sets.values())
+
+ needed_state = _get_auth_event_keys(cs, state_map)
+ sg = new_groups[0]
+ res = yield store._get_state_for_groups((sg,), types=needed_state)
+ state = res[sg]
+ for s in state_sets.itervalues():
+ s.update(state)
+
+ logger.info("Added state: %s", len(needed_state))
+
+ new_state, unconflicted_state, _ = yield resolve_events_with_factory(
+ state_sets.values(),
+ event_map=state_map,
+ state_map_factory=state_map_factory,
+ )
+
+ conflicted_state = [e for e in conflicted_state if (e[0], e[1]) not in unconflicted_state]
+
+ defer.returnValue((new_state, conflicted_state))
+
def _ordered_events(events):
def key_func(e):
@@ -677,6 +756,15 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory):
))
+def _get_auth_event_keys(conflicted_state, state_map):
+ auth_events = set()
+ for event_ids in conflicted_state.itervalues():
+ for event_id in event_ids:
+ if event_id in state_map:
+ auth_events.update(event_auth.auth_types_for_event(state_map[event_id]))
+ return auth_events
+
+
def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map):
auth_events = {}
for event_ids in conflicted_state.itervalues():
@@ -719,7 +807,12 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_
for key, event in resolved_state.iteritems():
new_state[key] = event.event_id
- return new_state
+ return new_state, unconflicted_state_ids, [
+ (key[0], key[1], event.sender)
+ for key, evs in conflicted_state.iteritems()
+ if len(evs) > 1
+ for event in evs
+ ]
def _resolve_state_events(conflicted_state, auth_events):
|