diff options
Diffstat (limited to '')
-rw-r--r-- | tests/state/test_v2.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index 8d3845c870..a44960203e 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -22,7 +22,7 @@ import attr from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event -from synapse.events import FrozenEvent +from synapse.events import make_event_from_dict from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store from synapse.types import EventID @@ -58,6 +58,7 @@ class FakeEvent(object): self.type = type self.state_key = state_key self.content = content + self.room_id = ROOM_ID def to_event(self, auth_events, prev_events): """Given the auth_events and prev_events, convert to a Frozen Event @@ -88,7 +89,7 @@ class FakeEvent(object): if self.state_key is not None: event_dict["state_key"] = self.state_key - return FrozenEvent(event_dict) + return make_event_from_dict(event_dict) # All graphs start with this set of events @@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase): state_before = dict(state_at_event[prev_events[0]]) else: state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [state_at_event[n] for n in prev_events], event_map=event_map, @@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase): # Test that we correctly handle passing `None` as the event_map state_d = resolve_events_with_store( + ROOM_ID, RoomVersions.V2.identifier, [self.state_at_bob, self.state_at_charlie], event_map=None, @@ -600,7 +603,7 @@ class TestStateResolutionStore(object): return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map} - def get_auth_chain(self, event_ids): + def _get_auth_chain(self, event_ids): """Gets the full auth chain for a set of events (including rejected events). @@ -614,7 +617,6 @@ class TestStateResolutionStore(object): Args: event_ids (list): The event IDs of the events to fetch the auth chain for. Must be state events. - Returns: Deferred[list[str]]: List of event IDs of the auth chain. """ @@ -634,3 +636,9 @@ class TestStateResolutionStore(object): stack.append(aid) return list(result) + + def get_auth_chain_difference(self, auth_sets): + chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] + + common = set(chains[0]).intersection(*chains[1:]) + return set(chains[0]).union(*chains[1:]) - common |