diff --git a/tests/test_state.py b/tests/test_state.py
index 5845358754..55f37c521f 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -204,8 +204,8 @@ class StateTestCase(unittest.TestCase):
nodes={
"START": DictObj(
type=EventTypes.Create,
- state_key="creator",
- content={"membership": "@user_id:example.com"},
+ state_key="",
+ content={"creator": "@user_id:example.com"},
depth=1,
),
"A": DictObj(
@@ -259,8 +259,8 @@ class StateTestCase(unittest.TestCase):
nodes={
"START": DictObj(
type=EventTypes.Create,
- state_key="creator",
- content={"membership": "@user_id:example.com"},
+ state_key="",
+ content={"creator": "@user_id:example.com"},
depth=1,
),
"A": DictObj(
@@ -432,13 +432,19 @@ class StateTestCase(unittest.TestCase):
def test_resolve_message_conflict(self):
event = create_event(type="test_message", name="event")
+ creation = create_event(
+ type=EventTypes.Create, state_key=""
+ )
+
old_state_1 = [
+ creation,
create_event(type="test1", state_key="1"),
create_event(type="test1", state_key="2"),
create_event(type="test2", state_key=""),
]
old_state_2 = [
+ creation,
create_event(type="test1", state_key="1"),
create_event(type="test3", state_key="2"),
create_event(type="test4", state_key=""),
@@ -446,7 +452,7 @@ class StateTestCase(unittest.TestCase):
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(len(context.current_state), 5)
+ self.assertEqual(len(context.current_state), 6)
self.assertIsNone(context.state_group)
@@ -454,13 +460,19 @@ class StateTestCase(unittest.TestCase):
def test_resolve_state_conflict(self):
event = create_event(type="test4", state_key="", name="event")
+ creation = create_event(
+ type=EventTypes.Create, state_key=""
+ )
+
old_state_1 = [
+ creation,
create_event(type="test1", state_key="1"),
create_event(type="test1", state_key="2"),
create_event(type="test2", state_key=""),
]
old_state_2 = [
+ creation,
create_event(type="test1", state_key="1"),
create_event(type="test3", state_key="2"),
create_event(type="test4", state_key=""),
@@ -468,7 +480,7 @@ class StateTestCase(unittest.TestCase):
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(len(context.current_state), 5)
+ self.assertEqual(len(context.current_state), 6)
self.assertIsNone(context.state_group)
@@ -484,36 +496,45 @@ class StateTestCase(unittest.TestCase):
}
)
+ creation = create_event(
+ type=EventTypes.Create, state_key="",
+ content={"creator": "@foo:bar"}
+ )
+
old_state_1 = [
+ creation,
member_event,
create_event(type="test1", state_key="1", depth=1),
]
old_state_2 = [
+ creation,
member_event,
create_event(type="test1", state_key="1", depth=2),
]
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(old_state_2[1], context.current_state[("test1", "1")])
+ self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
# Reverse the depth to make sure we are actually using the depths
# during state resolution.
old_state_1 = [
+ creation,
member_event,
create_event(type="test1", state_key="1", depth=2),
]
old_state_2 = [
+ creation,
member_event,
create_event(type="test1", state_key="1", depth=1),
]
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(old_state_1[1], context.current_state[("test1", "1")])
+ self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1"
|