diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 8d3845c870..a44960203e 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -22,7 +22,7 @@ import attr
from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
from synapse.types import EventID
@@ -58,6 +58,7 @@ class FakeEvent(object):
self.type = type
self.state_key = state_key
self.content = content
+ self.room_id = ROOM_ID
def to_event(self, auth_events, prev_events):
"""Given the auth_events and prev_events, convert to a Frozen Event
@@ -88,7 +89,7 @@ class FakeEvent(object):
if self.state_key is not None:
event_dict["state_key"] = self.state_key
- return FrozenEvent(event_dict)
+ return make_event_from_dict(event_dict)
# All graphs start with this set of events
@@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase):
state_before = dict(state_at_event[prev_events[0]])
else:
state_d = resolve_events_with_store(
+ ROOM_ID,
RoomVersions.V2.identifier,
[state_at_event[n] for n in prev_events],
event_map=event_map,
@@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
# Test that we correctly handle passing `None` as the event_map
state_d = resolve_events_with_store(
+ ROOM_ID,
RoomVersions.V2.identifier,
[self.state_at_bob, self.state_at_charlie],
event_map=None,
@@ -600,7 +603,7 @@ class TestStateResolutionStore(object):
return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
- def get_auth_chain(self, event_ids):
+ def _get_auth_chain(self, event_ids):
"""Gets the full auth chain for a set of events (including rejected
events).
@@ -614,7 +617,6 @@ class TestStateResolutionStore(object):
Args:
event_ids (list): The event IDs of the events to fetch the auth
chain for. Must be state events.
-
Returns:
Deferred[list[str]]: List of event IDs of the auth chain.
"""
@@ -634,3 +636,9 @@ class TestStateResolutionStore(object):
stack.append(aid)
return list(result)
+
+ def get_auth_chain_difference(self, auth_sets):
+ chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
+
+ common = set(chains[0]).intersection(*chains[1:])
+ return set(chains[0]).union(*chains[1:]) - common
|