diff options
Diffstat (limited to 'synapse')
-rw-r--r-- | synapse/federation/units.py | 1 | ||||
-rw-r--r-- | synapse/state.py | 46 | ||||
-rw-r--r-- | synapse/storage/_base.py | 8 | ||||
-rw-r--r-- | synapse/storage/pdu.py | 81 |
4 files changed, 54 insertions, 82 deletions
diff --git a/synapse/federation/units.py b/synapse/federation/units.py index 9740431279..622fe66a8f 100644 --- a/synapse/federation/units.py +++ b/synapse/federation/units.py @@ -69,6 +69,7 @@ class Pdu(JsonEncodedObject): "prev_state_id", "prev_state_origin", "required_power_level", + "user_id", ] internal_keys = [ diff --git a/synapse/state.py b/synapse/state.py index 0cc1344d51..9db84c9b5c 100644 --- a/synapse/state.py +++ b/synapse/state.py @@ -115,6 +115,8 @@ class StateHandler(object): is_new = yield self._handle_new_state(new_pdu) + logger.debug("is_new: %s %s %s", is_new, new_pdu.pdu_id, new_pdu.origin) + if is_new: yield self.store.update_current_state( pdu_id=new_pdu.pdu_id, @@ -187,11 +189,12 @@ class StateHandler(object): # We didn't find a common ancestor. This is probably fine. pass - result = self._do_conflict_res( + result = yield self._do_conflict_res( new_branch, current_branch, common_ancestor ) defer.returnValue(result) + @defer.inlineCallbacks def _do_conflict_res(self, new_branch, current_branch, common_ancestor): conflict_res = [ self._do_power_level_conflict_res, @@ -200,7 +203,8 @@ class StateHandler(object): ] for algo in conflict_res: - new_res, curr_res = algo( + new_res, curr_res = yield defer.maybeDeferred( + algo, new_branch, current_branch, common_ancestor ) @@ -211,19 +215,39 @@ class StateHandler(object): raise Exception("Conflict resolution failed.") + @defer.inlineCallbacks def _do_power_level_conflict_res(self, new_branch, current_branch, common_ancestor): - max_power_new = max( - new_branch[:-1] if common_ancestor else new_branch, - key=lambda t: t.power_level - ).power_level + new_powers_deferreds = [] + for e in new_branch[:-1] if common_ancestor else new_branch: + if hasattr(e, "user_id"): + new_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + current_powers_deferreds = [] + for e in current_branch[:-1] if common_ancestor else current_branch: + if hasattr(e, "user_id"): + current_powers_deferreds.append( + self.store.get_power_level(e.context, e.user_id) + ) + + new_powers = yield defer.gatherResults( + new_powers_deferreds, + consumeErrors=True + ) - max_power_current = max( - current_branch[:-1] if common_ancestor else current_branch, - key=lambda t: t.power_level - ).power_level + current_powers = yield defer.gatherResults( + current_powers_deferreds, + consumeErrors=True + ) + + max_power_new = max(new_powers) + max_power_current = max(current_powers) - return (max_power_new, max_power_current) + defer.returnValue( + (max_power_new, max_power_current) + ) def _do_chain_length_conflict_res(self, new_branch, current_branch, common_ancestor): diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py index 8037225079..8deaaf93bd 100644 --- a/synapse/storage/_base.py +++ b/synapse/storage/_base.py @@ -17,6 +17,7 @@ import logging from twisted.internet import defer from synapse.api.errors import StoreError +from synapse.util.logutils import log_function import collections import copy @@ -91,6 +92,7 @@ class SQLBaseStore(object): self._simple_insert_txn, table, values, or_replace=or_replace ) + @log_function def _simple_insert_txn(self, txn, table, values, or_replace=False): sql = "%s INTO %s (%s) VALUES(%s)" % ( ("INSERT OR REPLACE" if or_replace else "INSERT"), @@ -98,6 +100,12 @@ class SQLBaseStore(object): ", ".join(k for k in values), ", ".join("?" for k in values) ) + + logger.debug( + "[SQL] %s Args=%s Func=%s", + sql, values.values(), + ) + txn.execute(sql, values.values()) return txn.lastrowid diff --git a/synapse/storage/pdu.py b/synapse/storage/pdu.py index f780111b3b..3c859fdeac 100644 --- a/synapse/storage/pdu.py +++ b/synapse/storage/pdu.py @@ -17,6 +17,7 @@ from twisted.internet import defer from ._base import SQLBaseStore, Table, JoinHelper +from synapse.federation.units import Pdu from synapse.util.logutils import log_function from collections import namedtuple @@ -625,53 +626,6 @@ class StatePduStore(SQLBaseStore): return result - def get_next_missing_pdu(self, new_pdu): - """When we get a new state pdu we need to check whether we need to do - any conflict resolution, if we do then we need to check if we need - to go back and request some more state pdus that we haven't seen yet. - - Args: - txn - new_pdu - - Returns: - PduIdTuple: A pdu that we are missing, or None if we have all the - pdus required to do the conflict resolution. - """ - return self._db_pool.runInteraction( - self._get_next_missing_pdu, new_pdu - ) - - def _get_next_missing_pdu(self, txn, new_pdu): - logger.debug( - "get_next_missing_pdu %s %s", - new_pdu.pdu_id, new_pdu.origin - ) - - current = self._get_current_interaction( - txn, - new_pdu.context, new_pdu.pdu_type, new_pdu.state_key - ) - - if (not current or not current.prev_state_id - or not current.prev_state_origin): - return None - - # Oh look, it's a straight clobber, so wooooo almost no-op. - if (new_pdu.prev_state_id == current.pdu_id - and new_pdu.prev_state_origin == current.origin): - return None - - enum_branches = self._enumerate_state_branches(txn, new_pdu, current) - for branch, prev_state, state in enum_branches: - if not state: - return PduIdTuple( - prev_state.prev_state_id, - prev_state.prev_state_origin - ) - - return None - def handle_new_state(self, new_pdu): """Actually perform conflict resolution on the new_pdu on the assumption we have all the pdus required to perform it. @@ -755,24 +709,11 @@ class StatePduStore(SQLBaseStore): return is_current - @classmethod @log_function - def _enumerate_state_branches(cls, txn, pdu_a, pdu_b): + def _enumerate_state_branches(self, txn, pdu_a, pdu_b): branch_a = pdu_a branch_b = pdu_b - get_query = ( - "SELECT %(fields)s FROM %(pdus)s as p " - "LEFT JOIN %(state)s as s " - "ON p.pdu_id = s.pdu_id AND p.origin = s.origin " - "WHERE p.pdu_id = ? AND p.origin = ? " - ) % { - "fields": _pdu_state_joiner.get_fields( - PdusTable="p", StatePdusTable="s"), - "pdus": PdusTable.table_name, - "state": StatePdusTable.table_name, - } - while True: if (branch_a.pdu_id == branch_b.pdu_id and branch_a.origin == branch_b.origin): @@ -804,13 +745,12 @@ class StatePduStore(SQLBaseStore): branch_a.prev_state_origin ) - logger.debug("getting branch_a prev %s", pdu_tuple) - txn.execute(get_query, pdu_tuple) - prev_branch = branch_a - res = txn.fetchone() - branch_a = PduEntry(*res) if res else None + logger.debug("getting branch_a prev %s", pdu_tuple) + branch_a = self._get_pdu_tuple(txn, *pdu_tuple) + if branch_a: + branch_a = Pdu.from_pdu_tuple(branch_a) logger.debug("branch_a=%s", branch_a) @@ -823,14 +763,13 @@ class StatePduStore(SQLBaseStore): branch_b.prev_state_id, branch_b.prev_state_origin ) - txn.execute(get_query, pdu_tuple) - - logger.debug("getting branch_b prev %s", pdu_tuple) prev_branch = branch_b - res = txn.fetchone() - branch_b = PduEntry(*res) if res else None + logger.debug("getting branch_b prev %s", pdu_tuple) + branch_b = self._get_pdu_tuple(txn, *pdu_tuple) + if branch_b: + branch_b = Pdu.from_pdu_tuple(branch_b) logger.debug("branch_b=%s", branch_b) |