summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2014-09-08 20:13:27 +0100
committerErik Johnston <erik@matrix.org>2014-09-08 20:13:27 +0100
commit942d8412c49a1d481f0bedd189eb1598629b103c (patch)
treead199a4bd331f2594b98b0094f070a8e7bc68d9c
parentAdded number of users in recent rooms. (diff)
downloadsynapse-942d8412c49a1d481f0bedd189eb1598629b103c.tar.xz
Handle the case where we don't have a common ancestor
-rw-r--r--synapse/state.py27
-rw-r--r--tests/test_state.py24
2 files changed, 42 insertions, 9 deletions
diff --git a/synapse/state.py b/synapse/state.py
index e69282860a..0cc1344d51 100644
--- a/synapse/state.py
+++ b/synapse/state.py
@@ -174,7 +174,9 @@ class StateHandler(object):
         n = new_branch[-1]
         c = current_branch[-1]
 
-        if n.pdu_id == c.pdu_id and n.origin == c.origin:
+        common_ancestor = n.pdu_id == c.pdu_id and n.origin == c.origin
+
+        if common_ancestor:
             # We found a common ancestor!
 
             if len(current_branch) == 1:
@@ -185,10 +187,12 @@ class StateHandler(object):
             # We didn't find a common ancestor. This is probably fine.
             pass
 
-        result = self._do_conflict_res(new_branch, current_branch)
+        result = self._do_conflict_res(
+            new_branch, current_branch, common_ancestor
+        )
         defer.returnValue(result)
 
-    def _do_conflict_res(self, new_branch, current_branch):
+    def _do_conflict_res(self, new_branch, current_branch, common_ancestor):
         conflict_res = [
             self._do_power_level_conflict_res,
             self._do_chain_length_conflict_res,
@@ -196,7 +200,9 @@ class StateHandler(object):
         ]
 
         for algo in conflict_res:
-            new_res, curr_res = algo(new_branch, current_branch)
+            new_res, curr_res = algo(
+                new_branch, current_branch, common_ancestor
+            )
 
             if new_res < curr_res:
                 defer.returnValue(False)
@@ -205,23 +211,26 @@ class StateHandler(object):
 
         raise Exception("Conflict resolution failed.")
 
-    def _do_power_level_conflict_res(self, new_branch, current_branch):
+    def _do_power_level_conflict_res(self, new_branch, current_branch,
+                                     common_ancestor):
         max_power_new = max(
-            new_branch[:-1],
+            new_branch[:-1] if common_ancestor else new_branch,
             key=lambda t: t.power_level
         ).power_level
 
         max_power_current = max(
-            current_branch[:-1],
+            current_branch[:-1] if common_ancestor else current_branch,
             key=lambda t: t.power_level
         ).power_level
 
         return (max_power_new, max_power_current)
 
-    def _do_chain_length_conflict_res(self, new_branch, current_branch):
+    def _do_chain_length_conflict_res(self, new_branch, current_branch,
+                                      common_ancestor):
         return (len(new_branch), len(current_branch))
 
-    def _do_hash_conflict_res(self, new_branch, current_branch):
+    def _do_hash_conflict_res(self, new_branch, current_branch,
+                              common_ancestor):
         new_str = "".join([p.pdu_id + p.origin for p in new_branch])
         c_str = "".join([p.pdu_id + p.origin for p in current_branch])
 
diff --git a/tests/test_state.py b/tests/test_state.py
index 4512475ebd..a9fc3fb85c 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -441,6 +441,30 @@ class StateTestCase(unittest.TestCase):
         self.assertEqual(1, self.persistence.update_current_state.call_count)
 
     @defer.inlineCallbacks
+    def test_no_common_ancestor(self):
+        # We do a direct overwriting of the old state, i.e., the new state
+        # points to the old state.
+
+        old_pdu = new_fake_pdu_entry("A", "test", "mem", "x", None, 5)
+        new_pdu = new_fake_pdu_entry("B", "test", "mem", "x", None, 10)
+
+        self.persistence.get_unresolved_state_tree.return_value = (
+            (ReturnType([new_pdu], [old_pdu]), None)
+        )
+
+        is_new = yield self.state.handle_new_state(new_pdu)
+
+        self.assertTrue(is_new)
+
+        self.persistence.get_unresolved_state_tree.assert_called_once_with(
+            new_pdu
+        )
+
+        self.assertEqual(1, self.persistence.update_current_state.call_count)
+
+        self.assertFalse(self.replication.get_pdu.called)
+
+    @defer.inlineCallbacks
     def test_new_event(self):
         event = Mock()
         event.event_id = "12123123@test"