summary refs log tree commit diff
path: root/tests/state/test_v2.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/state/test_v2.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index 8d3845c870..a44960203e 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -22,7 +22,7 @@ import attr
 from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.event_auth import auth_types_for_event
-from synapse.events import FrozenEvent
+from synapse.events import make_event_from_dict
 from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
 from synapse.types import EventID
 
@@ -58,6 +58,7 @@ class FakeEvent(object):
         self.type = type
         self.state_key = state_key
         self.content = content
+        self.room_id = ROOM_ID
 
     def to_event(self, auth_events, prev_events):
         """Given the auth_events and prev_events, convert to a Frozen Event
@@ -88,7 +89,7 @@ class FakeEvent(object):
         if self.state_key is not None:
             event_dict["state_key"] = self.state_key
 
-        return FrozenEvent(event_dict)
+        return make_event_from_dict(event_dict)
 
 
 # All graphs start with this set of events
@@ -418,6 +419,7 @@ class StateTestCase(unittest.TestCase):
                 state_before = dict(state_at_event[prev_events[0]])
             else:
                 state_d = resolve_events_with_store(
+                    ROOM_ID,
                     RoomVersions.V2.identifier,
                     [state_at_event[n] for n in prev_events],
                     event_map=event_map,
@@ -565,6 +567,7 @@ class SimpleParamStateTestCase(unittest.TestCase):
         # Test that we correctly handle passing `None` as the event_map
 
         state_d = resolve_events_with_store(
+            ROOM_ID,
             RoomVersions.V2.identifier,
             [self.state_at_bob, self.state_at_charlie],
             event_map=None,
@@ -600,7 +603,7 @@ class TestStateResolutionStore(object):
 
         return {eid: self.event_map[eid] for eid in event_ids if eid in self.event_map}
 
-    def get_auth_chain(self, event_ids):
+    def _get_auth_chain(self, event_ids):
         """Gets the full auth chain for a set of events (including rejected
         events).
 
@@ -614,7 +617,6 @@ class TestStateResolutionStore(object):
         Args:
             event_ids (list): The event IDs of the events to fetch the auth
                 chain for. Must be state events.
-
         Returns:
             Deferred[list[str]]: List of event IDs of the auth chain.
         """
@@ -634,3 +636,9 @@ class TestStateResolutionStore(object):
                 stack.append(aid)
 
         return list(result)
+
+    def get_auth_chain_difference(self, auth_sets):
+        chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
+
+        common = set(chains[0]).intersection(*chains[1:])
+        return set(chains[0]).union(*chains[1:]) - common