summary refs log tree commit diff
path: root/tests/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_state.py')
-rw-r--r--tests/test_state.py132
1 files changed, 123 insertions, 9 deletions
diff --git a/tests/test_state.py b/tests/test_state.py
index 5845358754..e4e995b756 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -35,7 +35,7 @@ def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
 
     if not event_id:
         _next_event_id += 1
-        event_id = str(_next_event_id)
+        event_id = "$%s:test" % (_next_event_id,)
 
     if not name:
         if state_key is not None:
@@ -204,8 +204,8 @@ class StateTestCase(unittest.TestCase):
             nodes={
                 "START": DictObj(
                     type=EventTypes.Create,
-                    state_key="creator",
-                    content={"membership": "@user_id:example.com"},
+                    state_key="",
+                    content={"creator": "@user_id:example.com"},
                     depth=1,
                 ),
                 "A": DictObj(
@@ -259,8 +259,8 @@ class StateTestCase(unittest.TestCase):
             nodes={
                 "START": DictObj(
                     type=EventTypes.Create,
-                    state_key="creator",
-                    content={"membership": "@user_id:example.com"},
+                    state_key="",
+                    content={"creator": "@user_id:example.com"},
                     depth=1,
                 ),
                 "A": DictObj(
@@ -318,6 +318,99 @@ class StateTestCase(unittest.TestCase):
         )
 
     @defer.inlineCallbacks
+    def test_branch_have_perms_conflict(self):
+        userid1 = "@user_id:example.com"
+        userid2 = "@user_id2:example.com"
+
+        nodes = {
+            "A1": DictObj(
+                type=EventTypes.Create,
+                state_key="",
+                content={"creator": userid1},
+                depth=1,
+            ),
+            "A2": DictObj(
+                type=EventTypes.Member,
+                state_key=userid1,
+                content={"membership": Membership.JOIN},
+                membership=Membership.JOIN,
+            ),
+            "A3": DictObj(
+                type=EventTypes.Member,
+                state_key=userid2,
+                content={"membership": Membership.JOIN},
+                membership=Membership.JOIN,
+            ),
+            "A4": DictObj(
+                type=EventTypes.PowerLevels,
+                state_key="",
+                content={
+                    "events": {"m.room.name": 50},
+                    "users": {userid1: 100,
+                              userid2: 60},
+                },
+            ),
+            "A5": DictObj(
+                type=EventTypes.Name,
+                state_key="",
+            ),
+            "B": DictObj(
+                type=EventTypes.PowerLevels,
+                state_key="",
+                content={
+                    "events": {"m.room.name": 50},
+                    "users": {userid2: 30},
+                },
+            ),
+            "C": DictObj(
+                type=EventTypes.Name,
+                state_key="",
+                sender=userid2,
+            ),
+            "D": DictObj(
+                type=EventTypes.Message,
+            ),
+        }
+        edges = {
+            "A2": ["A1"],
+            "A3": ["A2"],
+            "A4": ["A3"],
+            "A5": ["A4"],
+            "B": ["A5"],
+            "C": ["A5"],
+            "D": ["B", "C"]
+        }
+        self._add_depths(nodes, edges)
+        graph = Graph(nodes, edges)
+
+        store = StateGroupStore()
+        self.store.get_state_groups.side_effect = store.get_state_groups
+
+        context_store = {}
+
+        for event in graph.walk():
+            context = yield self.state.compute_event_context(event)
+            store.store_state_groups(event, context)
+            context_store[event.event_id] = context
+
+        self.assertSetEqual(
+            {"A1", "A2", "A3", "A5", "B"},
+            {e.event_id for e in context_store["D"].current_state.values()}
+        )
+
+    def _add_depths(self, nodes, edges):
+        def _get_depth(ev):
+            node = nodes[ev]
+            if 'depth' not in node:
+                prevs = edges[ev]
+                depth = max(_get_depth(prev) for prev in prevs) + 1
+                node['depth'] = depth
+            return node['depth']
+
+        for n in nodes:
+            _get_depth(n)
+
+    @defer.inlineCallbacks
     def test_annotate_with_old_message(self):
         event = create_event(type="test_message", name="event")
 
@@ -432,13 +525,19 @@ class StateTestCase(unittest.TestCase):
     def test_resolve_message_conflict(self):
         event = create_event(type="test_message", name="event")
 
+        creation = create_event(
+            type=EventTypes.Create, state_key=""
+        )
+
         old_state_1 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test1", state_key="2"),
             create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test3", state_key="2"),
             create_event(type="test4", state_key=""),
@@ -446,7 +545,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 5)
+        self.assertEqual(len(context.current_state), 6)
 
         self.assertIsNone(context.state_group)
 
@@ -454,13 +553,19 @@ class StateTestCase(unittest.TestCase):
     def test_resolve_state_conflict(self):
         event = create_event(type="test4", state_key="", name="event")
 
+        creation = create_event(
+            type=EventTypes.Create, state_key=""
+        )
+
         old_state_1 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test1", state_key="2"),
             create_event(type="test2", state_key=""),
         ]
 
         old_state_2 = [
+            creation,
             create_event(type="test1", state_key="1"),
             create_event(type="test3", state_key="2"),
             create_event(type="test4", state_key=""),
@@ -468,7 +573,7 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 5)
+        self.assertEqual(len(context.current_state), 6)
 
         self.assertIsNone(context.state_group)
 
@@ -484,36 +589,45 @@ class StateTestCase(unittest.TestCase):
             }
         )
 
+        creation = create_event(
+            type=EventTypes.Create, state_key="",
+            content={"creator": "@foo:bar"}
+        )
+
         old_state_1 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
 
         old_state_2 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
+        self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
 
         # Reverse the depth to make sure we are actually using the depths
         # during state resolution.
 
         old_state_1 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=2),
         ]
 
         old_state_2 = [
+            creation,
             member_event,
             create_event(type="test1", state_key="1", depth=1),
         ]
 
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
+        self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
 
     def _get_context(self, event, old_state_1, old_state_2):
         group_name_1 = "group_name_1"