diff options
Diffstat (limited to 'synapse/state.py')
-rw-r--r-- | synapse/state.py | 157 |
1 files changed, 93 insertions, 64 deletions
diff --git a/synapse/state.py b/synapse/state.py index 36d8210eb5..9db84c9b5c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -16,7 +16,7 @@ from twisted.internet import defer -from synapse.federation.pdu_codec import encode_event_id +from synapse.federation.pdu_codec import encode_event_id, decode_event_id from synapse.util.logutils import log_function from collections import namedtuple @@ -87,9 +87,11 @@ class StateHandler(object): # than the power level of the user # power_level = self._get_power_level_for_event(event) + pdu_id, origin = decode_event_id(event.event_id, self.server_name) + yield self.store.update_current_state( - pdu_id=event.event_id, - origin=self.server_name, + pdu_id=pdu_id, + origin=origin, context=key.context, pdu_type=key.type, state_key=key.state_key @@ -113,6 +115,8 @@ class StateHandler(object): is_new = yield self._handle_new_state(new_pdu) + logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin) + if is_new: yield self.store.update_current_state( pdu_id=new_pdu.pdu_id, @@ -132,7 +136,9 @@ class StateHandler(object): @defer.inlineCallbacks @log_function def _handle_new_state(self, new_pdu): - tree = yield self.store.get_unresolved_state_tree(new_pdu) + tree, missing_branch = yield self.store.get_unresolved_state_tree( + new_pdu + ) new_branch, current_branch = tree logger.debug( @@ -140,6 +146,28 @@ class StateHandler(object): new_branch, current_branch ) + if missing_branch is not None: + # We're missing some PDUs. Fetch them. + # TODO (erikj): Limit this. + missing_prev = tree[missing_branch][-1] + + pdu_id = missing_prev.prev_state_id + origin = missing_prev.prev_state_origin + + is_missing = yield self.store.get_pdu(pdu_id, origin) is None + if not is_missing: + raise Exception("Conflict resolution failed") + + yield self._replication.get_pdu( + destination=missing_prev.origin, + pdu_origin=origin, + pdu_id=pdu_id, + outlier=True + ) + + updated_current = yield self._handle_new_state(new_pdu) + defer.returnValue(updated_current) + if not current_branch: # There is no current state defer.returnValue(True) @@ -148,84 +176,85 @@ class StateHandler(object): n = new_branch[-1] c = current_branch[-1] - if n.pdu_id == c.pdu_id and n.origin == c.origin: - # We have all the PDUs we need, so we can just do the conflict - # resolution. + common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin + + if common_ancestor: + # We found a common ancestor! if len(current_branch) == 1: # This is a direct clobber so we can just... defer.returnValue(True) - conflict_res = [ - self._do_power_level_conflict_res, - self._do_chain_length_conflict_res, - self._do_hash_conflict_res, - ] - - for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) - - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - - raise Exception("Conflict resolution failed.") - else: - # We need to ask for PDUs. - missing_prev = max( - new_branch[-1], current_branch[-1], - key=lambda x: x.depth - ) - - if not hasattr(missing_prev, "prev_state_id"): - # FIXME Hmm - # temporary fallback - for algo in conflict_res: - new_res, curr_res = algo(new_branch, current_branch) + # We didn't find a common ancestor. This is probably fine. + pass - if new_res < curr_res: - defer.returnValue(False) - elif new_res > curr_res: - defer.returnValue(True) - return + result = yield self._do_conflict_res( + new_branch, current_branch, common_ancestor + ) + defer.returnValue(result) - pdu_id = missing_prev.prev_state_id - origin = missing_prev.prev_state_origin + @defer.inlineCallbacks + def _do_conflict_res(self, new_branch, current_branch, common_ancestor): + conflict_res = [ + self._do_power_level_conflict_res, + self._do_chain_length_conflict_res, + self._do_hash_conflict_res, + ] - is_missing = yield self.store.get_pdu(pdu_id, origin) is None + for algo in conflict_res: + new_res, curr_res = yield defer.maybeDeferred( + algo, + new_branch, current_branch, common_ancestor + ) - if not is_missing: - raise Exception("Conflict resolution failed.") + if new_res < curr_res: + defer.returnValue(False) + elif new_res > curr_res: + defer.returnValue(True) - yield self._replication.get_pdu( - destination=missing_prev.origin, - pdu_origin=origin, - pdu_id=pdu_id, - outlier=True - ) + raise Exception("Conflict resolution failed.") - updated_current = yield self._handle_new_state(new_pdu) - defer.returnValue(updated_current) + @defer.inlineCallbacks + def _do_power_level_conflict_res(self, new_branch, current_branch, + common_ancestor): + new_powers_deferreds = [] + for e in new_branch[:-1] if common_ancestor else new_branch: + if hasattr(e, "user_id"): + new_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + current_powers_deferreds = [] + for e in current_branch[:-1] if common_ancestor else current_branch: + if hasattr(e, "user_id"): + current_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + new_powers = yield defer.gatherResults( + new_powers_deferreds, + consumeErrors=True + ) - def _do_power_level_conflict_res(self, new_branch, current_branch): - max_power_new = max( - new_branch[:-1], - key=lambda t: t.power_level - ).power_level + current_powers = yield defer.gatherResults( + current_powers_deferreds, + consumeErrors=True + ) - max_power_current = max( - current_branch[:-1], - key=lambda t: t.power_level - ).power_level + max_power_new = max(new_powers) + max_power_current = max(current_powers) - return (max_power_new, max_power_current) + defer.returnValue( + (max_power_new, max_power_current) + ) - def _do_chain_length_conflict_res(self, new_branch, current_branch): + def _do_chain_length_conflict_res(self, new_branch, current_branch, + common_ancestor): return (len(new_branch), len(current_branch)) - def _do_hash_conflict_res(self, new_branch, current_branch): + def _do_hash_conflict_res(self, new_branch, current_branch, + common_ancestor): new_str = "".join([p.pdu_id + p.origin for p in new_branch]) c_str = "".join([p.pdu_id + p.origin for p in current_branch]) |