summary refs log tree commit diff
path: root/tests/state
diff options
context:
space:
mode:
Diffstat (limited to 'tests/state')
-rw-r--r--tests/state/test_v2.py100
1 files changed, 98 insertions, 2 deletions
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index efd85ebe6c..d67f59b2c7 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -544,8 +544,7 @@ class StateTestCase(unittest.TestCase):
                     state_res_store=TestStateResolutionStore(event_map),
                 )
 
-                self.assertTrue(state_d.called)
-                state_before = state_d.result
+                state_before = self.successResultOf(state_d)
 
             state_after = dict(state_before)
             if fake_event.state_key is not None:
@@ -599,6 +598,103 @@ class LexicographicalTestCase(unittest.TestCase):
         self.assertEqual(["o", "l", "n", "m", "p"], res)
 
 
+class SimpleParamStateTestCase(unittest.TestCase):
+    def setUp(self):
+        # We build up a simple DAG.
+
+        event_map = {}
+
+        create_event = FakeEvent(
+            id="CREATE",
+            sender=ALICE,
+            type=EventTypes.Create,
+            state_key="",
+            content={"creator": ALICE},
+        ).to_event([], [])
+        event_map[create_event.event_id] = create_event
+
+        alice_member = FakeEvent(
+            id="IMA",
+            sender=ALICE,
+            type=EventTypes.Member,
+            state_key=ALICE,
+            content=MEMBERSHIP_CONTENT_JOIN,
+        ).to_event([create_event.event_id], [create_event.event_id])
+        event_map[alice_member.event_id] = alice_member
+
+        join_rules = FakeEvent(
+            id="IJR",
+            sender=ALICE,
+            type=EventTypes.JoinRules,
+            state_key="",
+            content={"join_rule": JoinRules.PUBLIC},
+        ).to_event(
+            auth_events=[create_event.event_id, alice_member.event_id],
+            prev_events=[alice_member.event_id],
+        )
+        event_map[join_rules.event_id] = join_rules
+
+        # Bob and Charlie join at the same time, so there is a fork
+        bob_member = FakeEvent(
+            id="IMB",
+            sender=BOB,
+            type=EventTypes.Member,
+            state_key=BOB,
+            content=MEMBERSHIP_CONTENT_JOIN,
+        ).to_event(
+            auth_events=[create_event.event_id, join_rules.event_id],
+            prev_events=[join_rules.event_id],
+        )
+        event_map[bob_member.event_id] = bob_member
+
+        charlie_member = FakeEvent(
+            id="IMC",
+            sender=CHARLIE,
+            type=EventTypes.Member,
+            state_key=CHARLIE,
+            content=MEMBERSHIP_CONTENT_JOIN,
+        ).to_event(
+            auth_events=[create_event.event_id, join_rules.event_id],
+            prev_events=[join_rules.event_id],
+        )
+        event_map[charlie_member.event_id] = charlie_member
+
+        self.event_map = event_map
+        self.create_event = create_event
+        self.alice_member = alice_member
+        self.join_rules = join_rules
+        self.bob_member = bob_member
+        self.charlie_member = charlie_member
+
+        self.state_at_bob = {
+            (e.type, e.state_key): e.event_id
+            for e in [create_event, alice_member, join_rules, bob_member]
+        }
+
+        self.state_at_charlie = {
+            (e.type, e.state_key): e.event_id
+            for e in [create_event, alice_member, join_rules, charlie_member]
+        }
+
+        self.expected_combined_state = {
+            (e.type, e.state_key): e.event_id
+            for e in [create_event, alice_member, join_rules, bob_member, charlie_member]
+        }
+
+    def test_event_map_none(self):
+        # Test that we correctly handle passing `None` as the event_map
+
+        state_d = resolve_events_with_store(
+            [self.state_at_bob, self.state_at_charlie],
+            event_map=None,
+            state_res_store=TestStateResolutionStore(self.event_map),
+        )
+
+        state = self.successResultOf(state_d)
+
+        self.assert_dict(self.expected_combined_state, state)
+
+
 def pairwise(iterable):
     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
     a, b = itertools.tee(iterable)