diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..77c72834f2 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
from synapse.api.room_versions import RoomVersions
from synapse.event_auth import auth_types_for_event
from synapse.events import make_event_from_dict
-from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.state.v2 import (
+ _get_auth_chain_difference,
+ lexicographical_topological_sort,
+ resolve_events_with_store,
+)
from synapse.types import EventID
from tests import unittest
@@ -84,7 +88,7 @@ class FakeEvent:
event_dict = {
"auth_events": [(a, {}) for a in auth_events],
"prev_events": [(p, {}) for p in prev_events],
- "event_id": self.node_id,
+ "event_id": self.event_id,
"sender": self.sender,
"type": self.type,
"content": self.content,
@@ -377,6 +381,61 @@ class StateTestCase(unittest.TestCase):
self.do_check(events, edges, expected_state_ids)
+ def test_mainline_sort(self):
+ """Tests that the mainline ordering works correctly.
+ """
+
+ events = [
+ FakeEvent(
+ id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="PA1",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100, BOB: 50}},
+ ),
+ FakeEvent(
+ id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="PA2",
+ sender=ALICE,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={
+ "users": {ALICE: 100, BOB: 50},
+ "events": {EventTypes.PowerLevels: 100},
+ },
+ ),
+ FakeEvent(
+ id="PB",
+ sender=BOB,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ content={"users": {ALICE: 100, BOB: 50}},
+ ),
+ FakeEvent(
+ id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={}
+ ),
+ FakeEvent(
+ id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={}
+ ),
+ ]
+
+ edges = [
+ ["END", "T3", "PA2", "T2", "PA1", "T1", "START"],
+ ["END", "T4", "PB", "PA1"],
+ ]
+
+ # We expect T3 to be picked as the other topics are pointing at older
+ # power levels. Note that without mainline ordering we'd pick T4 due to
+ # it being sent *after* T3.
+ expected_state_ids = ["T3", "PA2"]
+
+ self.do_check(events, edges, expected_state_ids)
+
def do_check(self, events, edges, expected_state_ids):
"""Take a list of events and edges and calculate the state of the
graph at END, and asserts it matches `expected_state_ids`
@@ -587,6 +646,134 @@ class SimpleParamStateTestCase(unittest.TestCase):
self.assert_dict(self.expected_combined_state, state)
+class AuthChainDifferenceTestCase(unittest.TestCase):
+ """We test that `_get_auth_chain_difference` correctly handles unpersisted
+ events.
+ """
+
+ def test_simple(self):
+ # Test getting the auth difference for a simple chain with a single
+ # unpersisted event:
+ #
+ # Unpersisted | Persisted
+ # |
+ # C -|-> B -> A
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c}
+
+ state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {c.event_id})
+
+ def test_multiple_unpersisted_chain(self):
+ # Test getting the auth difference for a simple chain with multiple
+ # unpersisted events:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D -> C -|-> B -> A
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, c.event_id})
+
+ def test_unpersisted_events_different_sets(self):
+ # Test getting the auth difference for with multiple unpersisted events
+ # in different branches:
+ #
+ # Unpersisted | Persisted
+ # |
+ # D --> C -|-> B -> A
+ # E ----^ -|---^
+ # |
+
+ a = FakeEvent(
+ id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([], [])
+
+ b = FakeEvent(
+ id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([a.event_id], [])
+
+ c = FakeEvent(
+ id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([b.event_id], [])
+
+ d = FakeEvent(
+ id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id], [])
+
+ e = FakeEvent(
+ id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+ ).to_event([c.event_id, b.event_id], [])
+
+ persisted_events = {a.event_id: a, b.event_id: b}
+ unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
+
+ state_sets = [
+ {"a": a.event_id, "b": b.event_id, "e": e.event_id},
+ {"c": c.event_id, "d": d.event_id},
+ ]
+
+ store = TestStateResolutionStore(persisted_events)
+
+ diff_d = _get_auth_chain_difference(
+ ROOM_ID, state_sets, unpersited_events, store
+ )
+ difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+ self.assertEqual(difference, {d.event_id, e.event_id})
+
+
def pairwise(iterable):
"s -> (s0,s1), (s1,s2), (s2, s3), ..."
a, b = itertools.tee(iterable)
@@ -647,7 +834,7 @@ class TestStateResolutionStore:
return list(result)
- def get_auth_chain_difference(self, auth_sets):
+ def get_auth_chain_difference(self, room_id, auth_sets):
chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets]
common = set(chains[0]).intersection(*chains[1:])
|