diff --git a/tests/test_state.py b/tests/test_state.py
index 197e35f140..98ad9e54cd 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -38,7 +38,6 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_annotate_with_old_message(self):
event = self.create_event(type="test_message", name="event")
- context = Mock()
old_state = [
self.create_event(type="test1", state_key="1"),
@@ -46,8 +45,8 @@ class StateTestCase(unittest.TestCase):
self.create_event(type="test2", state_key=""),
]
- yield self.state.annotate_context_with_state(
- event, context, old_state=old_state
+ context = yield self.state.compute_event_context(
+ event, old_state=old_state
)
for k, v in context.current_state.items():
@@ -64,7 +63,6 @@ class StateTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
event = self.create_event(type="state", state_key="", name="event")
- context = Mock()
old_state = [
self.create_event(type="test1", state_key="1"),
@@ -72,8 +70,8 @@ class StateTestCase(unittest.TestCase):
self.create_event(type="test2", state_key=""),
]
- yield self.state.annotate_context_with_state(
- event, context, old_state=old_state
+ context = yield self.state.compute_event_context(
+ event, old_state=old_state
)
for k, v in context.current_state.items():
@@ -92,7 +90,6 @@ class StateTestCase(unittest.TestCase):
def test_trivial_annotate_message(self):
event = self.create_event(type="test_message", name="event")
event.prev_events = []
- context = Mock()
old_state = [
self.create_event(type="test1", state_key="1"),
@@ -106,7 +103,7 @@ class StateTestCase(unittest.TestCase):
group_name: old_state,
}
- yield self.state.annotate_context_with_state(event, context)
+ context = yield self.state.compute_event_context(event)
for k, v in context.current_state.items():
type, state_key = k
@@ -124,7 +121,6 @@ class StateTestCase(unittest.TestCase):
def test_trivial_annotate_state(self):
event = self.create_event(type="state", state_key="", name="event")
event.prev_events = []
- context = Mock()
old_state = [
self.create_event(type="test1", state_key="1"),
@@ -138,7 +134,7 @@ class StateTestCase(unittest.TestCase):
group_name: old_state,
}
- yield self.state.annotate_context_with_state(event, context)
+ context = yield self.state.compute_event_context(event)
for k, v in context.current_state.items():
type, state_key = k
@@ -156,7 +152,6 @@ class StateTestCase(unittest.TestCase):
def test_resolve_message_conflict(self):
event = self.create_event(type="test_message", name="event")
event.prev_events = []
- context = Mock()
old_state_1 = [
self.create_event(type="test1", state_key="1"),
@@ -178,7 +173,7 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2,
}
- yield self.state.annotate_context_with_state(event, context)
+ context = yield self.state.compute_event_context(event)
self.assertEqual(len(context.current_state), 5)
@@ -188,7 +183,6 @@ class StateTestCase(unittest.TestCase):
def test_resolve_state_conflict(self):
event = self.create_event(type="test4", state_key="", name="event")
event.prev_events = []
- context = Mock()
old_state_1 = [
self.create_event(type="test1", state_key="1"),
@@ -210,7 +204,7 @@ class StateTestCase(unittest.TestCase):
group_name_2: old_state_2,
}
- yield self.state.annotate_context_with_state(event, context)
+ context = yield self.state.compute_event_context(event)
self.assertEqual(len(context.current_state), 5)
|