diff --git a/tests/test_state.py b/tests/test_state.py
index 429a18cbf7..96fdb8636c 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -29,8 +29,15 @@ from .utils import MockClock
_next_event_id = 1000
-def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
- prev_events=[], **kwargs):
+def create_event(
+ name=None,
+ type=None,
+ state_key=None,
+ depth=2,
+ event_id=None,
+ prev_events=[],
+ **kwargs
+):
global _next_event_id
if not event_id:
@@ -39,9 +46,9 @@ def create_event(name=None, type=None, state_key=None, depth=2, event_id=None,
if not name:
if state_key is not None:
- name = "<%s-%s, %s>" % (type, state_key, event_id,)
+ name = "<%s-%s, %s>" % (type, state_key, event_id)
else:
- name = "<%s, %s>" % (type, event_id,)
+ name = "<%s, %s>" % (type, event_id)
d = {
"event_id": event_id,
@@ -80,8 +87,9 @@ class StateGroupStore(object):
return defer.succeed(groups)
- def store_state_group(self, event_id, room_id, prev_group, delta_ids,
- current_state_ids):
+ def store_state_group(
+ self, event_id, room_id, prev_group, delta_ids, current_state_ids
+ ):
state_group = self._next_group
self._next_group += 1
@@ -91,7 +99,8 @@ class StateGroupStore(object):
def get_events(self, event_ids, **kwargs):
return {
- e_id: self._event_id_to_event[e_id] for e_id in event_ids
+ e_id: self._event_id_to_event[e_id]
+ for e_id in event_ids
if e_id in self._event_id_to_event
}
@@ -129,9 +138,7 @@ class Graph(object):
prev_events = []
events[event_id] = create_event(
- event_id=event_id,
- prev_events=prev_events,
- **fields
+ event_id=event_id, prev_events=prev_events, **fields
)
self._leaves = clobbered
@@ -147,10 +154,15 @@ class Graph(object):
class StateTestCase(unittest.TestCase):
def setUp(self):
self.store = StateGroupStore()
- hs = Mock(spec_set=[
- "get_datastore", "get_auth", "get_state_handler", "get_clock",
- "get_state_resolution_handler",
- ])
+ hs = Mock(
+ spec_set=[
+ "get_datastore",
+ "get_auth",
+ "get_state_handler",
+ "get_clock",
+ "get_state_resolution_handler",
+ ]
+ )
hs.get_datastore.return_value = self.store
hs.get_state_handler.return_value = None
hs.get_clock.return_value = MockClock()
@@ -164,35 +176,13 @@ class StateTestCase(unittest.TestCase):
def test_branch_no_conflict(self):
graph = Graph(
nodes={
- "START": DictObj(
- type=EventTypes.Create,
- state_key="",
- depth=1,
- ),
- "A": DictObj(
- type=EventTypes.Message,
- depth=2,
- ),
- "B": DictObj(
- type=EventTypes.Message,
- depth=3,
- ),
- "C": DictObj(
- type=EventTypes.Name,
- state_key="",
- depth=3,
- ),
- "D": DictObj(
- type=EventTypes.Message,
- depth=4,
- ),
+ "START": DictObj(type=EventTypes.Create, state_key="", depth=1),
+ "A": DictObj(type=EventTypes.Message, depth=2),
+ "B": DictObj(type=EventTypes.Message, depth=3),
+ "C": DictObj(type=EventTypes.Name, state_key="", depth=3),
+ "D": DictObj(type=EventTypes.Message, depth=4),
},
- edges={
- "A": ["START"],
- "B": ["A"],
- "C": ["A"],
- "D": ["B", "C"]
- }
+ edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
)
self.store.register_events(graph.walk())
@@ -224,27 +214,11 @@ class StateTestCase(unittest.TestCase):
membership=Membership.JOIN,
depth=2,
),
- "B": DictObj(
- type=EventTypes.Name,
- state_key="",
- depth=3,
- ),
- "C": DictObj(
- type=EventTypes.Name,
- state_key="",
- depth=4,
- ),
- "D": DictObj(
- type=EventTypes.Message,
- depth=5,
- ),
+ "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
+ "C": DictObj(type=EventTypes.Name, state_key="", depth=4),
+ "D": DictObj(type=EventTypes.Message, depth=5),
},
- edges={
- "A": ["START"],
- "B": ["A"],
- "C": ["A"],
- "D": ["B", "C"]
- }
+ edges={"A": ["START"], "B": ["A"], "C": ["A"], "D": ["B", "C"]},
)
self.store.register_events(graph.walk())
@@ -259,8 +233,7 @@ class StateTestCase(unittest.TestCase):
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertSetEqual(
- {"START", "A", "C"},
- {e_id for e_id in prev_state_ids.values()}
+ {"START", "A", "C"}, {e_id for e_id in prev_state_ids.values()}
)
@defer.inlineCallbacks
@@ -280,11 +253,7 @@ class StateTestCase(unittest.TestCase):
membership=Membership.JOIN,
depth=2,
),
- "B": DictObj(
- type=EventTypes.Name,
- state_key="",
- depth=3,
- ),
+ "B": DictObj(type=EventTypes.Name, state_key="", depth=3),
"C": DictObj(
type=EventTypes.Member,
state_key="@user_id_2:example.com",
@@ -298,18 +267,9 @@ class StateTestCase(unittest.TestCase):
depth=4,
sender="@user_id_2:example.com",
),
- "E": DictObj(
- type=EventTypes.Message,
- depth=5,
- ),
+ "E": DictObj(type=EventTypes.Message, depth=5),
},
- edges={
- "A": ["START"],
- "B": ["A"],
- "C": ["B"],
- "D": ["B"],
- "E": ["C", "D"]
- }
+ edges={"A": ["START"], "B": ["A"], "C": ["B"], "D": ["B"], "E": ["C", "D"]},
)
self.store.register_events(graph.walk())
@@ -324,8 +284,7 @@ class StateTestCase(unittest.TestCase):
prev_state_ids = yield context_store["E"].get_prev_state_ids(self.store)
self.assertSetEqual(
- {"START", "A", "B", "C"},
- {e for e in prev_state_ids.values()}
+ {"START", "A", "B", "C"}, {e for e in prev_state_ids.values()}
)
@defer.inlineCallbacks
@@ -357,30 +316,17 @@ class StateTestCase(unittest.TestCase):
state_key="",
content={
"events": {"m.room.name": 50},
- "users": {userid1: 100,
- userid2: 60},
+ "users": {userid1: 100, userid2: 60},
},
),
- "A5": DictObj(
- type=EventTypes.Name,
- state_key="",
- ),
+ "A5": DictObj(type=EventTypes.Name, state_key=""),
"B": DictObj(
type=EventTypes.PowerLevels,
state_key="",
- content={
- "events": {"m.room.name": 50},
- "users": {userid2: 30},
- },
- ),
- "C": DictObj(
- type=EventTypes.Name,
- state_key="",
- sender=userid2,
- ),
- "D": DictObj(
- type=EventTypes.Message,
+ content={"events": {"m.room.name": 50}, "users": {userid2: 30}},
),
+ "C": DictObj(type=EventTypes.Name, state_key="", sender=userid2),
+ "D": DictObj(type=EventTypes.Message),
}
edges = {
"A2": ["A1"],
@@ -389,7 +335,7 @@ class StateTestCase(unittest.TestCase):
"A5": ["A4"],
"B": ["A5"],
"C": ["A5"],
- "D": ["B", "C"]
+ "D": ["B", "C"],
}
self._add_depths(nodes, edges)
graph = Graph(nodes, edges)
@@ -406,8 +352,7 @@ class StateTestCase(unittest.TestCase):
prev_state_ids = yield context_store["D"].get_prev_state_ids(self.store)
self.assertSetEqual(
- {"A1", "A2", "A3", "A5", "B"},
- {e for e in prev_state_ids.values()}
+ {"A1", "A2", "A3", "A5", "B"}, {e for e in prev_state_ids.values()}
)
def _add_depths(self, nodes, edges):
@@ -432,9 +377,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(
- event, old_state=old_state
- )
+ context = yield self.state.compute_event_context(event, old_state=old_state)
current_state_ids = yield context.get_current_state_ids(self.store)
@@ -454,9 +397,7 @@ class StateTestCase(unittest.TestCase):
create_event(type="test2", state_key=""),
]
- context = yield self.state.compute_event_context(
- event, old_state=old_state
- )
+ context = yield self.state.compute_event_context(event, old_state=old_state)
prev_state_ids = yield context.get_prev_state_ids(self.store)
@@ -468,8 +409,7 @@ class StateTestCase(unittest.TestCase):
def test_trivial_annotate_message(self):
prev_event_id = "prev_event_id"
event = create_event(
- type="test_message", name="event2",
- prev_events=[(prev_event_id, {})],
+ type="test_message", name="event2", prev_events=[(prev_event_id, {})]
)
old_state = [
@@ -479,7 +419,10 @@ class StateTestCase(unittest.TestCase):
]
group_name = self.store.store_state_group(
- prev_event_id, event.room_id, None, None,
+ prev_event_id,
+ event.room_id,
+ None,
+ None,
{(e.type, e.state_key): e.event_id for e in old_state},
)
self.store.register_event_id_state_group(prev_event_id, group_name)
@@ -489,8 +432,7 @@ class StateTestCase(unittest.TestCase):
current_state_ids = yield context.get_current_state_ids(self.store)
self.assertEqual(
- set([e.event_id for e in old_state]),
- set(current_state_ids.values())
+ set([e.event_id for e in old_state]), set(current_state_ids.values())
)
self.assertEqual(group_name, context.state_group)
@@ -499,8 +441,7 @@ class StateTestCase(unittest.TestCase):
def test_trivial_annotate_state(self):
prev_event_id = "prev_event_id"
event = create_event(
- type="state", state_key="", name="event2",
- prev_events=[(prev_event_id, {})],
+ type="state", state_key="", name="event2", prev_events=[(prev_event_id, {})]
)
old_state = [
@@ -510,7 +451,10 @@ class StateTestCase(unittest.TestCase):
]
group_name = self.store.store_state_group(
- prev_event_id, event.room_id, None, None,
+ prev_event_id,
+ event.room_id,
+ None,
+ None,
{(e.type, e.state_key): e.event_id for e in old_state},
)
self.store.register_event_id_state_group(prev_event_id, group_name)
@@ -520,8 +464,7 @@ class StateTestCase(unittest.TestCase):
prev_state_ids = yield context.get_prev_state_ids(self.store)
self.assertEqual(
- set([e.event_id for e in old_state]),
- set(prev_state_ids.values())
+ set([e.event_id for e in old_state]), set(prev_state_ids.values())
)
self.assertIsNotNone(context.state_group)
@@ -531,13 +474,12 @@ class StateTestCase(unittest.TestCase):
prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
- type="test_message", name="event3",
+ type="test_message",
+ name="event3",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
- creation = create_event(
- type=EventTypes.Create, state_key=""
- )
+ creation = create_event(type=EventTypes.Create, state_key="")
old_state_1 = [
creation,
@@ -557,7 +499,7 @@ class StateTestCase(unittest.TestCase):
self.store.register_events(old_state_2)
context = yield self._get_context(
- event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids(self.store)
@@ -571,13 +513,13 @@ class StateTestCase(unittest.TestCase):
prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
- type="test4", state_key="", name="event",
+ type="test4",
+ state_key="",
+ name="event",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
- creation = create_event(
- type=EventTypes.Create, state_key=""
- )
+ creation = create_event(type=EventTypes.Create, state_key="")
old_state_1 = [
creation,
@@ -599,7 +541,7 @@ class StateTestCase(unittest.TestCase):
self.store.get_events = store.get_events
context = yield self._get_context(
- event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids(self.store)
@@ -613,29 +555,25 @@ class StateTestCase(unittest.TestCase):
prev_event_id1 = "event_id1"
prev_event_id2 = "event_id2"
event = create_event(
- type="test4", name="event",
+ type="test4",
+ name="event",
prev_events=[(prev_event_id1, {}), (prev_event_id2, {})],
)
member_event = create_event(
type=EventTypes.Member,
state_key="@user_id:example.com",
- content={
- "membership": Membership.JOIN,
- }
+ content={"membership": Membership.JOIN},
)
power_levels = create_event(
- type=EventTypes.PowerLevels, state_key="",
- content={"users": {
- "@foo:bar": "100",
- "@user_id:example.com": "100",
- }}
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {"@foo:bar": "100", "@user_id:example.com": "100"}},
)
creation = create_event(
- type=EventTypes.Create, state_key="",
- content={"creator": "@foo:bar"}
+ type=EventTypes.Create, state_key="", content={"creator": "@foo:bar"}
)
old_state_1 = [
@@ -658,14 +596,12 @@ class StateTestCase(unittest.TestCase):
self.store.get_events = store.get_events
context = yield self._get_context(
- event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids(self.store)
- self.assertEqual(
- old_state_2[3].event_id, current_state_ids[("test1", "1")]
- )
+ self.assertEqual(old_state_2[3].event_id, current_state_ids[("test1", "1")])
# Reverse the depth to make sure we are actually using the depths
# during state resolution.
@@ -688,25 +624,30 @@ class StateTestCase(unittest.TestCase):
store.register_events(old_state_2)
context = yield self._get_context(
- event, prev_event_id1, old_state_1, prev_event_id2, old_state_2,
+ event, prev_event_id1, old_state_1, prev_event_id2, old_state_2
)
current_state_ids = yield context.get_current_state_ids(self.store)
- self.assertEqual(
- old_state_1[3].event_id, current_state_ids[("test1", "1")]
- )
+ self.assertEqual(old_state_1[3].event_id, current_state_ids[("test1", "1")])
- def _get_context(self, event, prev_event_id_1, old_state_1, prev_event_id_2,
- old_state_2):
+ def _get_context(
+ self, event, prev_event_id_1, old_state_1, prev_event_id_2, old_state_2
+ ):
sg1 = self.store.store_state_group(
- prev_event_id_1, event.room_id, None, None,
+ prev_event_id_1,
+ event.room_id,
+ None,
+ None,
{(e.type, e.state_key): e.event_id for e in old_state_1},
)
self.store.register_event_id_state_group(prev_event_id_1, sg1)
sg2 = self.store.store_state_group(
- prev_event_id_2, event.room_id, None, None,
+ prev_event_id_2,
+ event.room_id,
+ None,
+ None,
{(e.type, e.state_key): e.event_id for e in old_state_2},
)
self.store.register_event_id_state_group(prev_event_id_2, sg2)
|