summary refs log tree commit diff
path: root/tests/state/test_v2.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/state/test_v2.py')
-rw-r--r--tests/state/test_v2.py36
1 files changed, 22 insertions, 14 deletions
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py

index 9c5311d916..a44960203e 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py
@@ -22,7 +22,7 @@ import attr from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store from synapse.types import EventID @@ -58,6 +58,7 @@ class FakeEvent(object): self.type = type self.state_key = state_key self.content = content + self.room_id = ROOM_ID def to_event(self, auth_events, prev_events): """Given the auth_events and prev_events, convert to a Frozen Event @@ -88,7 +89,7 @@ class FakeEvent(object): if self.state_key is not None: event_dict["state_key"] = self.state_key - return FrozenEvent(event_dict) + return make_event_from_dict(event_dict) # All graphs start with this set of events @@ -181,7 +182,7 @@ class StateTestCase(unittest.TestCase): id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), ] @@ -229,14 +230,14 @@ class StateTestCase(unittest.TestCase): id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50, CHARLIE: 50}}, ), FakeEvent( id="PC", sender=CHARLIE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50, CHARLIE: 0}}, ), ] @@ -256,7 +257,7 @@ class StateTestCase(unittest.TestCase): id="PA1", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -266,14 +267,14 @@ class StateTestCase(unittest.TestCase): id="PA2", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -296,7 +297,7 @@ class StateTestCase(unittest.TestCase): id="PA", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -326,7 +327,7 @@ class StateTestCase(unittest.TestCase): id="PA1", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -336,14 +337,14 @@ class StateTestCase(unittest.TestCase): id="PA2", sender=ALICE, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 0}}, ), FakeEvent( id="PB", sender=BOB, type=EventTypes.PowerLevels, - state_key='', + state_key="", content={"users": {ALICE: 100, BOB: 50}}, ), FakeEvent( @@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], event_map=event_map, @@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], event_map=None, @@ -600,7 +603,7 @@ class TestStateResolutionStore(object): return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} - def get_auth_chain(self, event_ids): + def _get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected events). @@ -614,7 +617,6 @@ class TestStateResolutionStore(object): Args: event_ids (list): The event IDs of the events to fetch the auth chain for. Must be state events. - Returns: Deferred[list[str]]: List of event IDs of the auth chain. """ @@ -634,3 +636,9 @@ class TestStateResolutionStore(object): stack.append(aid) return list(result) + + def get_auth_chain_difference(self, auth_sets): + chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] + + common = set(chains[0]).intersection(*chains[1:]) + return set(chains[0]).union(*chains[1:]) - common