summary refs log tree commit diff
path: root/tests/test_state.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_state.py')
-rw-r--r--tests/test_state.py102
1 files changed, 68 insertions, 34 deletions
diff --git a/tests/test_state.py b/tests/test_state.py
index 610ec9fb46..66f22f6813 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -20,7 +20,8 @@ from twisted.internet import defer
 from synapse.api.auth import Auth
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
+from synapse.events.snapshot import EventContext
 from synapse.state import StateHandler, StateResolutionHandler
 
 from tests import unittest
@@ -65,7 +66,7 @@ def create_event(
 
     d.update(kwargs)
 
-    event = FrozenEvent(d)
+    event = make_event_from_dict(d)
 
     return event
 
@@ -118,7 +119,7 @@ class StateGroupStore(object):
     def register_event_id_state_group(self, event_id, state_group):
         self._event_to_state_group[event_id] = state_group
 
-    def get_room_version(self, room_id):
+    def get_room_version_id(self, room_id):
         return RoomVersions.V1.identifier
 
 
@@ -158,10 +159,12 @@ class Graph(object):
 class StateTestCase(unittest.TestCase):
     def setUp(self):
         self.store = StateGroupStore()
+        storage = Mock(main=self.store, state=self.store)
         hs = Mock(
             spec_set=[
                 "config",
                 "get_datastore",
+                "get_storage",
                 "get_auth",
                 "get_state_handler",
                 "get_clock",
@@ -174,6 +177,7 @@ class StateTestCase(unittest.TestCase):
         hs.get_clock.return_value = MockClock()
         hs.get_auth.return_value = Auth(hs)
         hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
+        hs.get_storage.return_value = storage
 
         self.state = StateHandler(hs)
         self.event_id = 0
@@ -195,16 +199,22 @@ class StateTestCase(unittest.TestCase):
 
         self.store.register_events(graph.walk())
 
-        context_store = {}
+        context_store = {}  # type: dict[str, EventContext]
 
         for event in graph.walk():
             context = yield self.state.compute_event_context(event)
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        ctx_c = context_store["C"]
+        ctx_d = context_store["D"]
+
+        prev_state_ids = yield ctx_d.get_prev_state_ids()
         self.assertEqual(2, len(prev_state_ids))
 
+        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
     @defer.inlineCallbacks
     def test_branch_basic_conflict(self):
         graph = Graph(
@@ -238,11 +248,16 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        # C ends up winning the resolution between B and C
 
-        self.assertSetEqual(
-            {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
-        )
+        ctx_c = context_store["C"]
+        ctx_d = context_store["D"]
+
+        prev_state_ids = yield ctx_d.get_prev_state_ids()
+        self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
+
+        self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
 
     @defer.inlineCallbacks
     def test_branch_have_banned_conflict(self):
@@ -289,11 +304,16 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+        # C ends up winning the resolution between C and D because bans win over other
+        # changes
 
-        self.assertSetEqual(
-            {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
-        )
+        ctx_c = context_store["C"]
+        ctx_e = context_store["E"]
+
+        prev_state_ids = yield ctx_e.get_prev_state_ids()
+        self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
+        self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
+        self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
 
     @defer.inlineCallbacks
     def test_branch_have_perms_conflict(self):
@@ -357,11 +377,17 @@ class StateTestCase(unittest.TestCase):
             self.store.register_event_context(event, context)
             context_store[event.event_id] = context
 
-        prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+        # B ends up winning the resolution between B and C because power levels
+        # win over other changes.
 
-        self.assertSetEqual(
-            {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
-        )
+        ctx_b = context_store["B"]
+        ctx_d = context_store["D"]
+
+        prev_state_ids = yield ctx_d.get_prev_state_ids()
+        self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
+
+        self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
+        self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
 
     def _add_depths(self, nodes, edges):
         def _get_depth(ev):
@@ -387,13 +413,16 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event, old_state=old_state)
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
+        self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        self.assertEqual(
-            set(e.event_id for e in old_state), set(current_state_ids.values())
+        current_state_ids = yield context.get_current_state_ids()
+        self.assertCountEqual(
+            (e.event_id for e in old_state), current_state_ids.values()
         )
 
-        self.assertIsNotNone(context.state_group)
+        self.assertIsNotNone(context.state_group_before_event)
+        self.assertEqual(context.state_group_before_event, context.state_group)
 
     @defer.inlineCallbacks
     def test_annotate_with_old_state(self):
@@ -407,12 +436,19 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event, old_state=old_state)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
+        self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
 
-        self.assertEqual(
-            set(e.event_id for e in old_state), set(prev_state_ids.values())
+        current_state_ids = yield context.get_current_state_ids()
+        self.assertCountEqual(
+            (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
 
+        self.assertIsNotNone(context.state_group_before_event)
+        self.assertNotEqual(context.state_group_before_event, context.state_group)
+        self.assertEqual(context.state_group_before_event, context.prev_group)
+        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
+
     @defer.inlineCallbacks
     def test_trivial_annotate_message(self):
         prev_event_id = "prev_event_id"
@@ -437,10 +473,10 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event)
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(
-            set([e.event_id for e in old_state]), set(current_state_ids.values())
+            {e.event_id for e in old_state}, set(current_state_ids.values())
         )
 
         self.assertEqual(group_name, context.state_group)
@@ -469,11 +505,9 @@ class StateTestCase(unittest.TestCase):
 
         context = yield self.state.compute_event_context(event)
 
-        prev_state_ids = yield context.get_prev_state_ids(self.store)
+        prev_state_ids = yield context.get_prev_state_ids()
 
-        self.assertEqual(
-            set([e.event_id for e in old_state]), set(prev_state_ids.values())
-        )
+        self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
 
         self.assertIsNotNone(context.state_group)
 
@@ -510,7 +544,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -552,7 +586,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(len(current_state_ids), 6)
 
@@ -607,7 +641,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
 
@@ -635,7 +669,7 @@ class StateTestCase(unittest.TestCase):
             event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
         )
 
-        current_state_ids = yield context.get_current_state_ids(self.store)
+        current_state_ids = yield context.get_current_state_ids()
 
         self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])