diff --git a/tests/test_state.py b/tests/test_state.py
index feb84f3d48..429a18cbf7 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -13,18 +13,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from tests import unittest
+from mock import Mock
+
from twisted.internet import defer
-from synapse.events import FrozenEvent
from synapse.api.auth import Auth
from synapse.api.constants import EventTypes, Membership
-from synapse.state import StateHandler
-
-from .utils import MockClock
+from synapse.events import FrozenEvent
+from synapse.state import StateHandler, StateResolutionHandler
-from mock import Mock
+from tests import unittest
+from .utils import MockClock
_next_event_id = 1000
@@ -80,14 +80,14 @@ class StateGroupStore(object):
return defer.succeed(groups)
- def store_state_groups(self, event, context):
- if context.current_state_ids is None:
- return
+ def store_state_group(self, event_id, room_id, prev_group, delta_ids,
+ current_state_ids):
+ state_group = self._next_group
+ self._next_group += 1
- state_events = dict(context.current_state_ids)
+ self._group_to_state[state_group] = dict(current_state_ids)
- self._group_to_state[context.state_group] = state_events
- self._event_to_state_group[event.event_id] = context.state_group
+ return state_group
def get_events(self, event_ids, **kwargs):
return {
@@ -95,10 +95,19 @@ class StateGroupStore(object):
if e_id in self._event_id_to_event
}
+ def get_state_group_delta(self, name):
+ return (None, None)
+
def register_events(self, events):
for e in events:
self._event_id_to_event[e.event_id] = e
+ def register_event_context(self, event, context):
+ self._event_to_state_group[event.event_id] = context.state_group
+
+ def register_event_id_state_group(self, event_id, state_group):
+ self._event_to_state_group[event_id] = state_group
+
class DictObj(dict):
def __init__(self, **kwargs):
@@ -137,25 +146,16 @@ class Graph(object):
class StateTestCase(unittest.TestCase):
def setUp(self):
- self.store = Mock(
- spec_set=[
- "get_state_groups_ids",
- "add_event_hashes",
- "get_events",
- "get_next_state_group",
- "get_state_group_delta",
- ]
- )
+ self.store = StateGroupStore()
hs = Mock(spec_set=[
"get_datastore", "get_auth", "get_state_handler", "get_clock",
+ "get_state_resolution_handler",
])
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()
hs.get_auth.return_value = Auth(hs)
-
- self.store.get_next_state_group.side_effect = Mock
- self.store.get_state_group_delta.return_value = (None, None)
+ hs.get_state_resolution_handler = lambda: StateResolutionHandler(hs)
self.state = StateHandler(hs)
self.event_id = 0
@@ -195,17 +195,17 @@ class StateTestCase(unittest.TestCase):
}
)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
- self.assertEqual(2, len(context_store["D"].prev_state_ids))
+ prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+ self.assertEqual(2, len(prev_state_ids))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
@@ -247,21 +247,20 @@ class StateTestCase(unittest.TestCase):
}
)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
- self.store.get_events = store.get_events
- store.register_events(graph.walk())
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
+ prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+
self.assertSetEqual(
{"START", "A", "C"},
- {e_id for e_id in context_store["D"].prev_state_ids.values()}
+ {e_id for e_id in prev_state_ids.values()}
)
@defer.inlineCallbacks
@@ -313,21 +312,20 @@ class StateTestCase(unittest.TestCase):
}
)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
- self.store.get_events = store.get_events
- store.register_events(graph.walk())
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
+ prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
+
self.assertSetEqual(
{"START", "A", "B", "C"},
- {e for e in context_store["E"].prev_state_ids.values()}
+ {e for e in prev_state_ids.values()}
)
@defer.inlineCallbacks
@@ -396,21 +394,20 @@ class StateTestCase(unittest.TestCase):
self._add_depths(nodes, edges)
graph = Graph(nodes, edges)
- store = StateGroupStore()
- self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
- self.store.get_events = store.get_events
- store.register_events(graph.walk())
+ self.store.register_events(graph.walk())
context_store = {}
for event in graph.walk():
context = yield self.state.compute_event_context(event)
- store.store_state_groups(event, context)
+ self.store.register_event_context(event, context)
context_store[event.event_id] = context
+ prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
+
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"},
- {e for e in context_store["D"].prev_state_ids.values()}
+ {e for e in prev_state_ids.values()}
)
def _add_depths(self, nodes, edges):
@@ -439,8 +436,10 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+
self.assertEqual(
- set(e.event_id for e in old_state), set(context.current_state_ids.values())
+ set(e.event_id for e in old_state), set(current_state_ids.values())
)
self.assertIsNotNone(context.state_group)
@@ -459,13 +458,19 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
self.assertEqual(
- set(e.event_id for e in old_state), set(context.prev_state_ids.values())
+ set(e.event_id for e in old_state), set(prev_state_ids.values())
)
@defer.inlineCallbacks
def test_trivial_annotate_message(self):
- event = create_event(type="test_message", name="event")
+ prev_event_id = "prev_event_id"
+ event = create_event(
+ type="test_message", name="event2",
+ prev_events=[(prev_event_id, {})],
+ )
old_state = [
create_event(type="test1", state_key="1"),
@@ -473,24 +478,30 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = "group_name_1"
-
- self.store.get_state_groups_ids.return_value = {
- group_name: {(e.type, e.state_key): e.event_id for e in old_state},
- }
+ group_name = self.store.store_state_group(
+ prev_event_id, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
+ self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event)
+ current_state_ids = yield context.get_current_state_ids(self.store)
+
self.assertEqual(
set([e.event_id for e in old_state]),
- set(context.current_state_ids.values())
+ set(current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@defer.inlineCallbacks
def test_trivial_annotate_state(self):
- event = create_event(type="state", state_key="", name="event")
+ prev_event_id = "prev_event_id"
+ event = create_event(
+ type="state", state_key="", name="event2",
+ prev_events=[(prev_event_id, {})],
+ )
old_state = [
create_event(type="test1", state_key="1"),
@@ -498,24 +509,31 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- group_name = "group_name_1"
-
- self.store.get_state_groups_ids.return_value = {
- group_name: {(e.type, e.state_key): e.event_id for e in old_state},
- }
+ group_name = self.store.store_state_group(
+ prev_event_id, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state},
+ )
+ self.store.register_event_id_state_group(prev_event_id, group_name)
context = yield self.state.compute_event_context(event)
+ prev_state_ids = yield context.get_prev_state_ids(self.store)
+
self.assertEqual(
set([e.event_id for e in old_state]),
- set(context.prev_state_ids.values())
+ set(prev_state_ids.values())
)
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_resolve_message_conflict(self):
- event = create_event(type="test_message", name="event")
+ prev_event_id1 = "event_id1"
+ prev_event_id2 = "event_id2"
+ event = create_event(
+ type="test_message", name="event3",
+ prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+ )
creation = create_event(
type=EventTypes.Create, state_key=""
@@ -535,20 +553,27 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""),
]
- store = StateGroupStore()
- store.register_events(old_state_1)
- store.register_events(old_state_2)
- self.store.get_events = store.get_events
+ self.store.register_events(old_state_1)
+ self.store.register_events(old_state_2)
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
+
+ current_state_ids = yield context.get_current_state_ids(self.store)
- self.assertEqual(len(context.current_state_ids), 6)
+ self.assertEqual(len(current_state_ids), 6)
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_resolve_state_conflict(self):
- event = create_event(type="test4", state_key="", name="event")
+ prev_event_id1 = "event_id1"
+ prev_event_id2 = "event_id2"
+ event = create_event(
+ type="test4", state_key="", name="event",
+ prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+ )
creation = create_event(
type=EventTypes.Create, state_key=""
@@ -573,15 +598,24 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2)
self.store.get_events = store.get_events
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
+
+ current_state_ids = yield context.get_current_state_ids(self.store)
- self.assertEqual(len(context.current_state_ids), 6)
+ self.assertEqual(len(current_state_ids), 6)
self.assertIsNotNone(context.state_group)
@defer.inlineCallbacks
def test_standard_depth_conflict(self):
- event = create_event(type="test4", name="event")
+ prev_event_id1 = "event_id1"
+ prev_event_id2 = "event_id2"
+ event = create_event(
+ type="test4", name="event",
+ prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
+ )
member_event = create_event(
type=EventTypes.Member,
@@ -591,6 +625,14 @@ class StateTestCase(unittest.TestCase):
}
)
+ power_levels = create_event(
+ type=EventTypes.PowerLevels, state_key="",
+ content={"users": {
+ "@foo:bar": "100",
+ "@user_id:example.com": "100",
+ }}
+ )
+
creation = create_event(
type=EventTypes.Create, state_key="",
content={"creator": "@foo:bar"}
@@ -598,12 +640,14 @@ class StateTestCase(unittest.TestCase):
old_state_1 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=1),
]
old_state_2 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=2),
]
@@ -613,10 +657,14 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2)
self.store.get_events = store.get_events
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
+
+ current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
- old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+ old_state_2[3].event_id, current_state_ids[("test1", "1")]
)
# Reverse the depth to make sure we are actually using the depths
@@ -624,12 +672,14 @@ class StateTestCase(unittest.TestCase):
old_state_1 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=2),
]
old_state_2 = [
creation,
+ power_levels,
member_event,
create_event(type="test1", state_key="1", depth=1),
]
@@ -637,19 +687,28 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_1)
store.register_events(old_state_2)
- context = yield self._get_context(event, old_state_1, old_state_2)
+ context = yield self._get_context(
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ )
+
+ current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
- old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+ old_state_1[3].event_id, current_state_ids[("test1", "1")]
)
- def _get_context(self, event, old_state_1, old_state_2):
- group_name_1 = "group_name_1"
- group_name_2 = "group_name_2"
+ def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
+ old_state_2):
+ sg1 = self.store.store_state_group(
+ prev_event_id_1, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state_1},
+ )
+ self.store.register_event_id_state_group(prev_event_id_1, sg1)
- self.store.get_state_groups_ids.return_value = {
- group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
- group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
- }
+ sg2 = self.store.store_state_group(
+ prev_event_id_2, event.room_id, None, None,
+ {(e.type, e.state_key): e.event_id for e in old_state_2},
+ )
+ self.store.register_event_id_state_group(prev_event_id_2, sg2)
return self.state.compute_event_context(event)
|