summary refs log tree commit diff
path: root/tests/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_state.py')
-rw-r--r--tests/test_state.py138
1 files changed, 75 insertions, 63 deletions
diff --git a/tests/test_state.py b/tests/test_state.py
index 1a11bbcee0..6454f994e3 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -67,9 +67,11 @@ class StateGroupStore(object):
         self._event_to_state_group = {}
         self._group_to_state = {}
 
+        self._event_id_to_event = {}
+
         self._next_group = 1
 
-    def get_state_groups(self, room_id, event_ids):
+    def get_state_groups_ids(self, room_id, event_ids):
         groups = {}
         for event_id in event_ids:
             group = self._event_to_state_group.get(event_id)
@@ -79,22 +81,23 @@ class StateGroupStore(object):
         return defer.succeed(groups)
 
     def store_state_groups(self, event, context):
-        if context.current_state is None:
+        if context.current_state_ids is None:
             return
 
-        state_events = context.current_state
-
-        if event.is_state():
-            state_events[(event.type, event.state_key)] = event
+        state_events = dict(context.current_state_ids)
 
-        state_group = context.state_group
-        if not state_group:
-            state_group = self._next_group
-            self._next_group += 1
+        self._group_to_state[context.state_group] = state_events
+        self._event_to_state_group[event.event_id] = context.state_group
 
-            self._group_to_state[state_group] = state_events.values()
+    def get_events(self, event_ids, **kwargs):
+        return {
+            e_id: self._event_id_to_event[e_id] for e_id in event_ids
+            if e_id in self._event_id_to_event
+        }
 
-        self._event_to_state_group[event.event_id] = state_group
+    def register_events(self, events):
+        for e in events:
+            self._event_id_to_event[e.event_id] = e
 
 
 class DictObj(dict):
@@ -136,8 +139,10 @@ class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.store = Mock(
             spec_set=[
-                "get_state_groups",
+                "get_state_groups_ids",
                 "add_event_hashes",
+                "get_events",
+                "get_next_state_group",
             ]
         )
         hs = Mock(spec_set=[
@@ -148,6 +153,8 @@ class StateTestCase(unittest.TestCase):
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
 
+        self.store.get_next_state_group.side_effect = Mock
+
         self.state = StateHandler(hs)
         self.event_id = 0
 
@@ -187,7 +194,7 @@ class StateTestCase(unittest.TestCase):
         )
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
 
         context_store = {}
 
@@ -196,7 +203,7 @@ class StateTestCase(unittest.TestCase):
             store.store_state_groups(event, context)
             context_store[event.event_id] = context
 
-        self.assertEqual(2, len(context_store["D"].current_state))
+        self.assertEqual(2, len(context_store["D"].prev_state_ids))
 
     @defer.inlineCallbacks
     def test_branch_basic_conflict(self):
@@ -239,7 +246,9 @@ class StateTestCase(unittest.TestCase):
         )
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.get_events = store.get_events
+        store.register_events(graph.walk())
 
         context_store = {}
 
@@ -250,7 +259,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"START", "A", "C"},
-            {e.event_id for e in context_store["D"].current_state.values()}
+            {e_id for e_id in context_store["D"].prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -303,7 +312,9 @@ class StateTestCase(unittest.TestCase):
         )
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.get_events = store.get_events
+        store.register_events(graph.walk())
 
         context_store = {}
 
@@ -314,7 +325,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"START", "A", "B", "C"},
-            {e.event_id for e in context_store["E"].current_state.values()}
+            {e for e in context_store["E"].prev_state_ids.values()}
         )
 
     @defer.inlineCallbacks
@@ -384,7 +395,9 @@ class StateTestCase(unittest.TestCase):
         graph = Graph(nodes, edges)
 
         store = StateGroupStore()
-        self.store.get_state_groups.side_effect = store.get_state_groups
+        self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+        self.store.get_events = store.get_events
+        store.register_events(graph.walk())
 
         context_store = {}
 
@@ -395,7 +408,7 @@ class StateTestCase(unittest.TestCase):
 
         self.assertSetEqual(
             {"A1", "A2", "A3", "A5", "B"},
-            {e.event_id for e in context_store["D"].current_state.values()}
+            {e for e in context_store["D"].prev_state_ids.values()}
         )
 
     def _add_depths(self, nodes, edges):
@@ -424,16 +437,11 @@ class StateTestCase(unittest.TestCase):
             event, old_state=old_state
         )
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
-            set(old_state), set(context.current_state.values())
+            set(e.event_id for e in old_state), set(context.current_state_ids.values())
         )
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_annotate_with_old_state(self):
@@ -449,18 +457,10 @@ class StateTestCase(unittest.TestCase):
             event, old_state=old_state
         )
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
-            set(old_state),
-            set(context.current_state.values())
+            set(e.event_id for e in old_state), set(context.prev_state_ids.values())
         )
 
-        self.assertIsNone(context.state_group)
-
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
         event = create_event(type="test_message", name="event")
@@ -473,20 +473,15 @@ class StateTestCase(unittest.TestCase):
 
         group_name = "group_name_1"
 
-        self.store.get_state_groups.return_value = {
-            group_name: old_state,
+        self.store.get_state_groups_ids.return_value = {
+            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
         }
 
         context = yield self.state.compute_event_context(event)
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set([e.event_id for e in context.current_state.values()])
+            set(context.current_state_ids.values())
         )
 
         self.assertEqual(group_name, context.state_group)
@@ -503,23 +498,18 @@ class StateTestCase(unittest.TestCase):
 
         group_name = "group_name_1"
 
-        self.store.get_state_groups.return_value = {
-            group_name: old_state,
+        self.store.get_state_groups_ids.return_value = {
+            group_name: {(e.type, e.state_key): e.event_id for e in old_state},
         }
 
         context = yield self.state.compute_event_context(event)
 
-        for k, v in context.current_state.items():
-            type, state_key = k
-            self.assertEqual(type, v.type)
-            self.assertEqual(state_key, v.state_key)
-
         self.assertEqual(
             set([e.event_id for e in old_state]),
-            set([e.event_id for e in context.current_state.values()])
+            set(context.prev_state_ids.values())
         )
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_resolve_message_conflict(self):
@@ -543,11 +533,16 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
+        store = StateGroupStore()
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+        self.store.get_events = store.get_events
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 6)
+        self.assertEqual(len(context.current_state_ids), 6)
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_resolve_state_conflict(self):
@@ -571,11 +566,16 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test4", state_key=""),
         ]
 
+        store = StateGroupStore()
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+        self.store.get_events = store.get_events
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(len(context.current_state), 6)
+        self.assertEqual(len(context.current_state_ids), 6)
 
-        self.assertIsNone(context.state_group)
+        self.assertIsNotNone(context.state_group)
 
     @defer.inlineCallbacks
     def test_standard_depth_conflict(self):
@@ -606,9 +606,16 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test1", state_key="1", depth=2),
         ]
 
+        store = StateGroupStore()
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+        self.store.get_events = store.get_events
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
+        self.assertEqual(
+            old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+        )
 
         # Reverse the depth to make sure we are actually using the depths
         # during state resolution.
@@ -625,17 +632,22 @@ class StateTestCase(unittest.TestCase):
             create_event(type="test1", state_key="1", depth=1),
         ]
 
+        store.register_events(old_state_1)
+        store.register_events(old_state_2)
+
         context = yield self._get_context(event, old_state_1, old_state_2)
 
-        self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
+        self.assertEqual(
+            old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+        )
 
     def _get_context(self, event, old_state_1, old_state_2):
         group_name_1 = "group_name_1"
         group_name_2 = "group_name_2"
 
-        self.store.get_state_groups.return_value = {
-            group_name_1: old_state_1,
-            group_name_2: old_state_2,
+        self.store.get_state_groups_ids.return_value = {
+            group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
+            group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
         }
 
         return self.state.compute_event_context(event)