diff options
author | Erik Johnston <erik@matrix.org> | 2018-03-28 16:19:26 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2018-03-28 16:19:26 +0100 |
commit | cfc2169b31d0771a808c52eeaf5e09f4222e50fc (patch) | |
tree | 3acefb549c0ad9ec316bfb1ddc45f59abc7884b4 | |
parent | Measure time it takes to calculate state group ID (diff) | |
download | synapse-cfc2169b31d0771a808c52eeaf5e09f4222e50fc.tar.xz |
WIP fast path state
-rw-r--r-- | synapse/event_auth.py | 29 | ||||
-rw-r--r-- | synapse/state.py | 117 |
2 files changed, 134 insertions, 12 deletions
diff --git a/synapse/event_auth.py b/synapse/event_auth.py index cd5627e36a..cf48f692fb 100644 --- a/synapse/event_auth.py +++ b/synapse/event_auth.py @@ -676,3 +676,32 @@ def auth_types_for_event(event): auth_types.append(key) return auth_types + + +def filter_dependent_state(keys, state): + if (EventTypes.Create, "") in keys: + return state + + if (EventTypes.PowerLevels, "") in keys: + return state + + def _filter_state(entry): + etype, state_key, sender = entry + + if (etype, state_key) in keys: + return True + + if (EventTypes.Member, sender) in keys: + return True + + if etype == EventTypes.Member: + if (EventTypes.JoinRules, "") in keys: + return True + + for t, _ in keys: + if t == EventTypes.ThirdPartyInvite: + return True + + return False + + return filter(_filter_state, state) diff --git a/synapse/state.py b/synapse/state.py index 26093c8434..2ffdb0b01e 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -55,9 +55,9 @@ def _gen_state_id(): class _StateCacheEntry(object): - __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids"] + __slots__ = ["state", "state_group", "state_id", "prev_group", "delta_ids", "conflicted_state"] - def __init__(self, state, state_group, prev_group=None, delta_ids=None): + def __init__(self, state, state_group, prev_group=None, delta_ids=None, conflicted_state=None): # dict[(str, str), str] map from (type, state_key) to event_id self.state = frozendict(state) @@ -80,6 +80,8 @@ class _StateCacheEntry(object): else: self.state_id = _gen_state_id() + self.conflicted_state = conflicted_state + def __len__(self): return len(self.state) @@ -375,7 +377,7 @@ class StateHandler(object): } with Measure(self.clock, "state._resolve_events"): - new_state = resolve_events_with_state_map(state_set_ids, state_map) + new_state, _, _ = resolve_events_with_state_map(state_set_ids, state_map) new_state = { key: state_map[ev_id] for key, ev_id in new_state.iteritems() @@ -462,18 +464,17 @@ class StateResolutionHandler(object): for key, e_id in st.iteritems(): state.setdefault(key, set()).add(e_id) - # build a map from state key to the event_ids which set that state, - # including only those where there are state keys in conflict. - conflicted_state = { - k: list(v) - for k, v in state.iteritems() - if len(v) > 1 - } + # Check that there is a conflict between the state groups + conflicted_state = False + for values in state.itervalues(): + if len(values): + conflicted_state = True + break if conflicted_state: logger.info("Resolving conflicted state for %r", room_id) with Measure(self.clock, "state._resolve_events"): - new_state = yield resolve_events_with_factory( + new_state, _, conflicted_state_keys = yield resolve_events_with_factory( state_groups_ids.values(), event_map=event_map, state_map_factory=state_map_factory, @@ -482,6 +483,7 @@ class StateResolutionHandler(object): new_state = { key: e_ids.pop() for key, e_ids in state.iteritems() } + conflicted_state_keys = [] with Measure(self.clock, "state.create_group_ids"): # if the new state matches any of the input state groups, we can @@ -517,6 +519,7 @@ class StateResolutionHandler(object): state_group=state_group, prev_group=prev_group, delta_ids=delta_ids, + conflicted_state=conflicted_state_keys, ) if self._state_cache is not None: @@ -524,6 +527,82 @@ class StateResolutionHandler(object): defer.returnValue(cache) + @defer.inlineCallbacks + def resolve_delta_state(self, unchanged_state_groups, changed_groups, conflicted_state, store, + event_map, state_map_factory): + deltas = set(entry for _, _, group_deltas in changed_groups for entry in group_deltas) + + to_recalculate = event_auth.filter_dependent_state(deltas, conflicted_state) + + new_groups = list(unchanged_state_groups) + new_groups.extend(g for _, g, _ in changed_groups) + + types = [(etype, state_key) for etype, state_key, _ in to_recalculate] + state_sets = yield store._get_state_groups_from_groups(new_groups, types=types) + + state_sets = { + sg: { + key: cs[key] + for key in types + } + for sg, cs in state_sets.iteritems() + } + + logger.info("Recalculating: %s", to_recalculate) + + group_names = frozenset(new_groups) + with (yield self.resolve_linearizer.queue(group_names)): + cache = None + if self._state_cache is not None: + cache = self._state_cache.get(group_names, None) + + if cache: + new_state = { + (etype, state_key): cache.state[(etype, state_key)] + for etype, state_key, _ in to_recalculate + } + unconflicted_state, _ = _seperate( + state_sets.values(), + ) + else: + needed_events = set( + event_id + for state in state_sets.itervalues() + for event_id in state.itervalues() + ) + if event_map is not None: + needed_events -= set(event_map.iterkeys()) + + # logger.info("state_sets: %s", state_sets) + logger.info("Asking for %d conflicted events", len(needed_events)) + + # dict[str, FrozenEvent]: a map from state event id to event. Only includes + # the state events which are in conflict (and those in event_map) + state_map = yield state_map_factory(needed_events) + if event_map is not None: + state_map.update(event_map) + + _, cs = _seperate(state_sets.values()) + + needed_state = _get_auth_event_keys(cs, state_map) + sg = new_groups[0] + res = yield store._get_state_for_groups((sg,), types=needed_state) + state = res[sg] + for s in state_sets.itervalues(): + s.update(state) + + logger.info("Added state: %s", len(needed_state)) + + new_state, unconflicted_state, _ = yield resolve_events_with_factory( + state_sets.values(), + event_map=state_map, + state_map_factory=state_map_factory, + ) + + conflicted_state = [e for e in conflicted_state if (e[0], e[1]) not in unconflicted_state] + + defer.returnValue((new_state, conflicted_state)) + def _ordered_events(events): def key_func(e): @@ -677,6 +756,15 @@ def resolve_events_with_factory(state_sets, event_map, state_map_factory): )) +def _get_auth_event_keys(conflicted_state, state_map): + auth_events = set() + for event_ids in conflicted_state.itervalues(): + for event_id in event_ids: + if event_id in state_map: + auth_events.update(event_auth.auth_types_for_event(state_map[event_id])) + return auth_events + + def _create_auth_events_from_maps(unconflicted_state, conflicted_state, state_map): auth_events = {} for event_ids in conflicted_state.itervalues(): @@ -719,7 +807,12 @@ def _resolve_with_state(unconflicted_state_ids, conflicted_state_ds, auth_event_ for key, event in resolved_state.iteritems(): new_state[key] = event.event_id - return new_state + return new_state, unconflicted_state_ids, [ + (key[0], key[1], event.sender) + for key, evs in conflicted_state.iteritems() + if len(evs) > 1 + for event in evs + ] def _resolve_state_events(conflicted_state, auth_events): |