summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--tests/test_state.py51
1 files changed, 46 insertions, 5 deletions
diff --git a/tests/test_state.py b/tests/test_state.py
index df9362c985..de2d35145a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -67,6 +67,8 @@ class StateGroupStore(object):
         self._event_to_state_group = {}
         self._group_to_state = {}
 
+        self._event_id_to_event = {}
+
         self._next_group = 1
 
     def get_state_groups_ids(self, room_id, event_ids):
@@ -96,6 +98,16 @@ class StateGroupStore(object):
 
         self._event_to_state_group[event.event_id] = state_group
 
+    def get_events(self, event_ids, **kwargs):
+        return {
+            e_id: self._event_id_to_event[e_id] for e_id in event_ids
+            if e_id in self._event_id_to_event
+        }
+
+    def register_events(self, events):
+        for e in events:
+            self._event_id_to_event[e.event_id] = e
+
 
 class DictObj(dict):
     def __init__(self, **kwargs):
@@ -138,6 +150,7 @@ class StateTestCase(unittest.TestCase):
             spec_set=[
                 "get_state_groups_ids",
                 "add_event_hashes",
+                "get_events",
             ]
         )
         hs = Mock(spec_set=[
@@ -240,6 +253,8 @@ class StateTestCase(unittest.TestCase):
 
         store = StateGroupStore()
         self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.get_events = store.get_events
+        store.register_events(graph.walk())
 
         context_store = {}
 
@@ -250,7 +265,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"START", "A", "C"},
-            {e.event_id for e in context_store["D"].current_state.values()}
+            {e_id for e_id in context_store["D"].current_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -304,6 +319,8 @@ class StateTestCase(unittest.TestCase):
 
         store = StateGroupStore()
         self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.get_events = store.get_events
+        store.register_events(graph.walk())
 
         context_store = {}
 
@@ -314,7 +331,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"START", "A", "B", "C"},
-            {e.event_id for e in context_store["E"].current_state.values()}
+            {e for e in context_store["E"].current_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -385,6 +402,8 @@ class StateTestCase(unittest.TestCase):
 
         store = StateGroupStore()
         self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.get_events = store.get_events
+        store.register_events(graph.walk())
 
         context_store = {}
 
@@ -395,7 +414,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"A1", "A2", "A3", "A5", "B"},
-            {e.event_id for e in context_store["D"].current_state.values()}
+            {e for e in context_store["D"].current_state_ids.values()}
         )
 
     def _add_depths(self, nodes, edges):
@@ -522,6 +541,11 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
+        store = StateGroupStore()
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+        self.store.get_events = store.get_events
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
         self.assertEqual(len(context.current_state_ids), 6)
@@ -550,6 +574,11 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
+        store = StateGroupStore()
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+        self.store.get_events = store.get_events
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
         self.assertEqual(len(context.current_state_ids), 6)
@@ -585,9 +614,16 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test1", state_key="1", depth=2),
         ]
 
+        store = StateGroupStore()
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+        self.store.get_events = store.get_events
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_2[2].event.id, context.current_state_ids[("test1", "1")])
+        self.assertEqual(
+            old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+        )
 
         # Reverse the depth to make sure we are actually using the depths
         # during state resolution.
@@ -604,9 +640,14 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test1", state_key="1", depth=1),
         ]
 
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_1[2].event_id, context.current_state_ids[("test1", "1")])
+        self.assertEqual(
+            old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+        )
 
     def _get_context(self, event, old_state_1, old_state_2):
         group_name_1 = "group_name_1"