diff --git a/tests/test_state.py b/tests/test_state.py
index 1a11bbcee0..de2d35145a 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -67,9 +67,11 @@ class StateGroupStore(object):
self._event_to_state_group = {}
self._group_to_state = {}
+ self._event_id_to_event = {}
+
self._next_group = 1
- def get_state_groups(self, room_id, event_ids):
+ def get_state_groups_ids(self, room_id, event_ids):
groups = {}
for event_id in event_ids:
group = self._event_to_state_group.get(event_id)
@@ -79,23 +81,33 @@ class StateGroupStore(object):
return defer.succeed(groups)
def store_state_groups(self, event, context):
- if context.current_state is None:
+ if context.current_state_ids is None:
return
- state_events = context.current_state
+ state_events = dict(context.current_state_ids)
if event.is_state():
- state_events[(event.type, event.state_key)] = event
+ state_events[(event.type, event.state_key)] = event.event_id
state_group = context.state_group
if not state_group:
state_group = self._next_group
self._next_group += 1
- self._group_to_state[state_group] = state_events.values()
+ self._group_to_state[state_group] = state_events
self._event_to_state_group[event.event_id] = state_group
+ def get_events(self, event_ids, **kwargs):
+ return {
+ e_id: self._event_id_to_event[e_id] for e_id in event_ids
+ if e_id in self._event_id_to_event
+ }
+
+ def register_events(self, events):
+ for e in events:
+ self._event_id_to_event[e.event_id] = e
+
class DictObj(dict):
def __init__(self, **kwargs):
@@ -136,8 +148,9 @@ class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = Mock(
spec_set=[
- "get_state_groups",
+ "get_state_groups_ids",
"add_event_hashes",
+ "get_events",
]
)
hs = Mock(spec_set=[
@@ -187,7 +200,7 @@ class StateTestCase(unittest.TestCase):
)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
context_store = {}
@@ -196,7 +209,7 @@ class StateTestCase(unittest.TestCase):
store.store_state_groups(event, context)
context_store[event.event_id] = context
- self.assertEqual(2, len(context_store["D"].current_state))
+ self.assertEqual(2, len(context_store["D"].current_state_ids))
@defer.inlineCallbacks
def test_branch_basic_conflict(self):
@@ -239,7 +252,9 @@ class StateTestCase(unittest.TestCase):
)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.get_events = store.get_events
+ store.register_events(graph.walk())
context_store = {}
@@ -250,7 +265,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"START", "A", "C"},
- {e.event_id for e in context_store["D"].current_state.values()}
+ {e_id for e_id in context_store["D"].current_state_ids.values()}
)
@defer.inlineCallbacks
@@ -303,7 +318,9 @@ class StateTestCase(unittest.TestCase):
)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.get_events = store.get_events
+ store.register_events(graph.walk())
context_store = {}
@@ -314,7 +331,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"START", "A", "B", "C"},
- {e.event_id for e in context_store["E"].current_state.values()}
+ {e for e in context_store["E"].current_state_ids.values()}
)
@defer.inlineCallbacks
@@ -384,7 +401,9 @@ class StateTestCase(unittest.TestCase):
graph = Graph(nodes, edges)
store = StateGroupStore()
- self.store.get_state_groups.side_effect = store.get_state_groups
+ self.store.get_state_groups_ids.side_effect = store.get_state_groups_ids
+ self.store.get_events = store.get_events
+ store.register_events(graph.walk())
context_store = {}
@@ -395,7 +414,7 @@ class StateTestCase(unittest.TestCase):
self.assertSetEqual(
{"A1", "A2", "A3", "A5", "B"},
- {e.event_id for e in context_store["D"].current_state.values()}
+ {e for e in context_store["D"].current_state_ids.values()}
)
def _add_depths(self, nodes, edges):
@@ -424,13 +443,8 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
- set(old_state), set(context.current_state.values())
+ set(e.event_id for e in old_state), set(context.current_state_ids.values())
)
self.assertIsNone(context.state_group)
@@ -449,14 +463,8 @@ class StateTestCase(unittest.TestCase):
event, old_state=old_state
)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
- set(old_state),
- set(context.current_state.values())
+ set(e.event_id for e in old_state), set(context.current_state_ids.values())
)
self.assertIsNone(context.state_group)
@@ -473,20 +481,15 @@ class StateTestCase(unittest.TestCase):
group_name = "group_name_1"
- self.store.get_state_groups.return_value = {
- group_name: old_state,
+ self.store.get_state_groups_ids.return_value = {
+ group_name: {(e.type, e.state_key): e.event_id for e in old_state},
}
context = yield self.state.compute_event_context(event)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
set([e.event_id for e in old_state]),
- set([e.event_id for e in context.current_state.values()])
+ set(context.current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@@ -503,20 +506,15 @@ class StateTestCase(unittest.TestCase):
group_name = "group_name_1"
- self.store.get_state_groups.return_value = {
- group_name: old_state,
+ self.store.get_state_groups_ids.return_value = {
+ group_name: {(e.type, e.state_key): e.event_id for e in old_state},
}
context = yield self.state.compute_event_context(event)
- for k, v in context.current_state.items():
- type, state_key = k
- self.assertEqual(type, v.type)
- self.assertEqual(state_key, v.state_key)
-
self.assertEqual(
set([e.event_id for e in old_state]),
- set([e.event_id for e in context.current_state.values()])
+ set(context.current_state_ids.values())
)
self.assertIsNone(context.state_group)
@@ -543,9 +541,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""),
]
+ store = StateGroupStore()
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+ self.store.get_events = store.get_events
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(len(context.current_state), 6)
+ self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group)
@@ -571,9 +574,14 @@ class StateTestCase(unittest.TestCase):
create_event(type="test4", state_key=""),
]
+ store = StateGroupStore()
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+ self.store.get_events = store.get_events
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(len(context.current_state), 6)
+ self.assertEqual(len(context.current_state_ids), 6)
self.assertIsNone(context.state_group)
@@ -606,9 +614,16 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=2),
]
+ store = StateGroupStore()
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+ self.store.get_events = store.get_events
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(old_state_2[2], context.current_state[("test1", "1")])
+ self.assertEqual(
+ old_state_2[2].event_id, context.current_state_ids[("test1", "1")]
+ )
# Reverse the depth to make sure we are actually using the depths
# during state resolution.
@@ -625,17 +640,22 @@ class StateTestCase(unittest.TestCase):
create_event(type="test1", state_key="1", depth=1),
]
+ store.register_events(old_state_1)
+ store.register_events(old_state_2)
+
context = yield self._get_context(event, old_state_1, old_state_2)
- self.assertEqual(old_state_1[2], context.current_state[("test1", "1")])
+ self.assertEqual(
+ old_state_1[2].event_id, context.current_state_ids[("test1", "1")]
+ )
def _get_context(self, event, old_state_1, old_state_2):
group_name_1 = "group_name_1"
group_name_2 = "group_name_2"
- self.store.get_state_groups.return_value = {
- group_name_1: old_state_1,
- group_name_2: old_state_2,
+ self.store.get_state_groups_ids.return_value = {
+ group_name_1: {(e.type, e.state_key): e.event_id for e in old_state_1},
+ group_name_2: {(e.type, e.state_key): e.event_id for e in old_state_2},
}
return self.state.compute_event_context(event)
|