summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/state.py27
1 files changed, 18 insertions, 9 deletions
diff --git a/synapse/state.py b/synapse/state.py
index e69282860a..0cc1344d51 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -174,7 +174,9 @@ class StateHandler(object):
         n = new_branch[-1]
         c = current_branch[-1]
 
-        if n.pdu_id == c.pdu_id and n.origin == c.origin:
+        common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin
+
+        if common_ancestor:
             # We found a common ancestor!
 
             if len(current_branch) == 1:
@@ -185,10 +187,12 @@ class StateHandler(object):
             # We didn't find a common ancestor. This is probably fine.
             pass
 
-        result = self._do_conflict_res(new_branch, current_branch)
+        result = self._do_conflict_res(
+            new_branch, current_branch, common_ancestor
+        )
         defer.returnValue(result)
 
-    def _do_conflict_res(self, new_branch, current_branch):
+    def _do_conflict_res(self, new_branch, current_branch, common_ancestor):
         conflict_res = [
             self._do_power_level_conflict_res,
             self._do_chain_length_conflict_res,
@@ -196,7 +200,9 @@ class StateHandler(object):
         ]
 
         for algo in conflict_res:
-            new_res, curr_res = algo(new_branch, current_branch)
+            new_res, curr_res = algo(
+                new_branch, current_branch, common_ancestor
+            )
 
             if new_res < curr_res:
                 defer.returnValue(False)
@@ -205,23 +211,26 @@ class StateHandler(object):
 
         raise Exception("Conflict resolution failed.")
 
-    def _do_power_level_conflict_res(self, new_branch, current_branch):
+    def _do_power_level_conflict_res(self, new_branch, current_branch,
+                                     common_ancestor):
         max_power_new = max(
-            new_branch[:-1],
+            new_branch[:-1] if common_ancestor else new_branch,
             key=lambda t: t.power_level
         ).power_level
 
         max_power_current = max(
-            current_branch[:-1],
+            current_branch[:-1] if common_ancestor else current_branch,
             key=lambda t: t.power_level
         ).power_level
 
         return (max_power_new, max_power_current)
 
-    def _do_chain_length_conflict_res(self, new_branch, current_branch):
+    def _do_chain_length_conflict_res(self, new_branch, current_branch,
+                                      common_ancestor):
         return (len(new_branch), len(current_branch))
 
-    def _do_hash_conflict_res(self, new_branch, current_branch):
+    def _do_hash_conflict_res(self, new_branch, current_branch,
+                              common_ancestor):
         new_str = "".join([p.pdu_id + p.origin for p in new_branch])
         c_str = "".join([p.pdu_id + p.origin for p in current_branch])