diff --git a/tests/test_state.py b/tests/test_state.py
index 610ec9fb46..66f22f6813 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -20,7 +20,8 @@ from twisted.internet import defer
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
from synapse.api.room_versions import RoomVersions
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
+from synapse.events.snapshot import EventContext
from synapse.state import StateHandler, StateResolutionHandler
from tests import unittest
@@ -65,7 +66,7 @@ def create_event(
d.update(kwargs)
- event = FrozenEvent(d)
+ event = make_event_from_dict(d)
return event
@@ -118,7 +119,7 @@ class StateGroupStore(object):
def register_event_id_state_group(self, event_id, state_group):
self._event_to_state_group[event_id] = state_group
- def get_room_version(self, room_id):
+ def get_room_version_id(self, room_id):
return RoomVersions.V1.identifier
@@ -158,10 +159,12 @@ class Graph(object):
class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = StateGroupStore()
+ storage = Mock(main=self.store, state=self.store)
hs = Mock(
spec_set=[
"config",
"get_datastore",
+ "get_storage",
"get_auth",
"get_state_handler",
"get_clock",
@@ -174,6 +177,7 @@ class StateTestCase(unittest.TestCase):
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
+ hs.get_storage.return_value = storage
self.state = StateHandler(hs)
self.event_id = 0
@@ -195,16 +199,22 @@ class StateTestCase(unittest.TestCase):
self.store.register_events(graph.walk())
- context_store = {}
+ context_store = {} # type: dict[str, EventContext]
for event in graph.walk():
context = yield self.state.compute_event_context(event)
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ ctx_c = context_store["C"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids()
self.assertEqual(2, len(prev_state_ids))
+ self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
+
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
graph = Graph(
@@ -238,11 +248,16 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ # C ends up winning the resolution between B and C
- self.assertSetEqual(
- {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
- )
+ ctx_c = context_store["C"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids()
+ self.assertSetEqual({"START", "A", "C"}, set(prev_state_ids.values()))
+
+ self.assertEqual(ctx_c.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
@defer.inlineCallbacks
def test_branch_have_banned_conflict(self):
@@ -289,11 +304,16 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+ # C ends up winning the resolution between C and D because bans win over other
+ # changes
- self.assertSetEqual(
- {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
- )
+ ctx_c = context_store["C"]
+ ctx_e = context_store["E"]
+
+ prev_state_ids = yield ctx_e.get_prev_state_ids()
+ self.assertSetEqual({"START", "A", "B", "C"}, set(prev_state_ids.values()))
+ self.assertEqual(ctx_c.state_group, ctx_e.state_group_before_event)
+ self.assertEqual(ctx_e.state_group_before_event, ctx_e.state_group)
@defer.inlineCallbacks
def test_branch_have_perms_conflict(self):
@@ -357,11 +377,17 @@ class StateTestCase(unittest.TestCase):
self.store.register_event_context(event, context)
context_store[event.event_id] = context
- prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ # B ends up winning the resolution between B and C because power levels
+ # win over other changes.
- self.assertSetEqual(
- {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
- )
+ ctx_b = context_store["B"]
+ ctx_d = context_store["D"]
+
+ prev_state_ids = yield ctx_d.get_prev_state_ids()
+ self.assertSetEqual({"A1", "A2", "A3", "A5", "B"}, set(prev_state_ids.values()))
+
+ self.assertEqual(ctx_b.state_group, ctx_d.state_group_before_event)
+ self.assertEqual(ctx_d.state_group_before_event, ctx_d.state_group)
def _add_depths(self, nodes, edges):
def _get_depth(ev):
@@ -387,13 +413,16 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
+ self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- self.assertEqual(
- set(e.event_id for e in old_state), set(current_state_ids.values())
+ current_state_ids = yield context.get_current_state_ids()
+ self.assertCountEqual(
+ (e.event_id for e in old_state), current_state_ids.values()
)
- self.assertIsNotNone(context.state_group)
+ self.assertIsNotNone(context.state_group_before_event)
+ self.assertEqual(context.state_group_before_event, context.state_group)
@defer.inlineCallbacks
def test_annotate_with_old_state(self):
@@ -407,12 +436,19 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event, old_state=old_state)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
+ self.assertCountEqual((e.event_id for e in old_state), prev_state_ids.values())
- self.assertEqual(
- set(e.event_id for e in old_state), set(prev_state_ids.values())
+ current_state_ids = yield context.get_current_state_ids()
+ self.assertCountEqual(
+ (e.event_id for e in old_state + [event]), current_state_ids.values()
)
+ self.assertIsNotNone(context.state_group_before_event)
+ self.assertNotEqual(context.state_group_before_event, context.state_group)
+ self.assertEqual(context.state_group_before_event, context.prev_group)
+ self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
+
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
prev_event_id = "prev_event_id"
@@ -437,10 +473,10 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(
- set([e.event_id for e in old_state]), set(current_state_ids.values())
+ {e.event_id for e in old_state}, set(current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@@ -469,11 +505,9 @@ class StateTestCase(unittest.TestCase):
context = yield self.state.compute_event_context(event)
- prev_state_ids = yield context.get_prev_state_ids(self.store)
+ prev_state_ids = yield context.get_prev_state_ids()
- self.assertEqual(
- set([e.event_id for e in old_state]), set(prev_state_ids.values())
- )
+ self.assertEqual({e.event_id for e in old_state}, set(prev_state_ids.values()))
self.assertIsNotNone(context.state_group)
@@ -510,7 +544,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(len(current_state_ids), 6)
@@ -552,7 +586,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(len(current_state_ids), 6)
@@ -607,7 +641,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
@@ -635,7 +669,7 @@ class StateTestCase(unittest.TestCase):
event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
- current_state_ids = yield context.get_current_state_ids(self.store)
+ current_state_ids = yield context.get_current_state_ids()
self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
|