diff options
Diffstat (limited to 'tests/state/test_v2.py')
-rw-r--r-- | tests/state/test_v2.py | 193 |
1 files changed, 190 insertions, 3 deletions
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py index ad9bbef9d2..77c72834f2 100644 --- a/tests/state/test_v2.py +++ b/tests/state/test_v2.py @@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership from synapse.api.room_versions import RoomVersions from synapse.event_auth import auth_types_for_event from synapse.events import make_event_from_dict -from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store +from synapse.state.v2 import ( + _get_auth_chain_difference, + lexicographical_topological_sort, + resolve_events_with_store, +) from synapse.types import EventID from tests import unittest @@ -84,7 +88,7 @@ class FakeEvent: event_dict = { "auth_events": [(a, {}) for a in auth_events], "prev_events": [(p, {}) for p in prev_events], - "event_id": self.node_id, + "event_id": self.event_id, "sender": self.sender, "type": self.type, "content": self.content, @@ -377,6 +381,61 @@ class StateTestCase(unittest.TestCase): self.do_check(events, edges, expected_state_ids) + def test_mainline_sort(self): + """Tests that the mainline ordering works correctly. + """ + + events = [ + FakeEvent( + id="T1", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="PA1", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100, BOB: 50}}, + ), + FakeEvent( + id="T2", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="PA2", + sender=ALICE, + type=EventTypes.PowerLevels, + state_key="", + content={ + "users": {ALICE: 100, BOB: 50}, + "events": {EventTypes.PowerLevels: 100}, + }, + ), + FakeEvent( + id="PB", + sender=BOB, + type=EventTypes.PowerLevels, + state_key="", + content={"users": {ALICE: 100, BOB: 50}}, + ), + FakeEvent( + id="T3", sender=BOB, type=EventTypes.Topic, state_key="", content={} + ), + FakeEvent( + id="T4", sender=ALICE, type=EventTypes.Topic, state_key="", content={} + ), + ] + + edges = [ + ["END", "T3", "PA2", "T2", "PA1", "T1", "START"], + ["END", "T4", "PB", "PA1"], + ] + + # We expect T3 to be picked as the other topics are pointing at older + # power levels. Note that without mainline ordering we'd pick T4 due to + # it being sent *after* T3. + expected_state_ids = ["T3", "PA2"] + + self.do_check(events, edges, expected_state_ids) + def do_check(self, events, edges, expected_state_ids): """Take a list of events and edges and calculate the state of the graph at END, and asserts it matches `expected_state_ids` @@ -587,6 +646,134 @@ class SimpleParamStateTestCase(unittest.TestCase): self.assert_dict(self.expected_combined_state, state) +class AuthChainDifferenceTestCase(unittest.TestCase): + """We test that `_get_auth_chain_difference` correctly handles unpersisted + events. + """ + + def test_simple(self): + # Test getting the auth difference for a simple chain with a single + # unpersisted event: + # + # Unpersisted | Persisted + # | + # C -|-> B -> A + + a = FakeEvent( + id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([], []) + + b = FakeEvent( + id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([a.event_id], []) + + c = FakeEvent( + id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([b.event_id], []) + + persisted_events = {a.event_id: a, b.event_id: b} + unpersited_events = {c.event_id: c} + + state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}] + + store = TestStateResolutionStore(persisted_events) + + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) + difference = self.successResultOf(defer.ensureDeferred(diff_d)) + + self.assertEqual(difference, {c.event_id}) + + def test_multiple_unpersisted_chain(self): + # Test getting the auth difference for a simple chain with multiple + # unpersisted events: + # + # Unpersisted | Persisted + # | + # D -> C -|-> B -> A + + a = FakeEvent( + id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([], []) + + b = FakeEvent( + id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([a.event_id], []) + + c = FakeEvent( + id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([b.event_id], []) + + d = FakeEvent( + id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([c.event_id], []) + + persisted_events = {a.event_id: a, b.event_id: b} + unpersited_events = {c.event_id: c, d.event_id: d} + + state_sets = [ + {"a": a.event_id, "b": b.event_id}, + {"c": c.event_id, "d": d.event_id}, + ] + + store = TestStateResolutionStore(persisted_events) + + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) + difference = self.successResultOf(defer.ensureDeferred(diff_d)) + + self.assertEqual(difference, {d.event_id, c.event_id}) + + def test_unpersisted_events_different_sets(self): + # Test getting the auth difference for with multiple unpersisted events + # in different branches: + # + # Unpersisted | Persisted + # | + # D --> C -|-> B -> A + # E ----^ -|---^ + # | + + a = FakeEvent( + id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([], []) + + b = FakeEvent( + id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([a.event_id], []) + + c = FakeEvent( + id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([b.event_id], []) + + d = FakeEvent( + id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([c.event_id], []) + + e = FakeEvent( + id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={}, + ).to_event([c.event_id, b.event_id], []) + + persisted_events = {a.event_id: a, b.event_id: b} + unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e} + + state_sets = [ + {"a": a.event_id, "b": b.event_id, "e": e.event_id}, + {"c": c.event_id, "d": d.event_id}, + ] + + store = TestStateResolutionStore(persisted_events) + + diff_d = _get_auth_chain_difference( + ROOM_ID, state_sets, unpersited_events, store + ) + difference = self.successResultOf(defer.ensureDeferred(diff_d)) + + self.assertEqual(difference, {d.event_id, e.event_id}) + + def pairwise(iterable): "s -> (s0,s1), (s1,s2), (s2, s3), ..." a, b = itertools.tee(iterable) @@ -647,7 +834,7 @@ class TestStateResolutionStore: return list(result) - def get_auth_chain_difference(self, auth_sets): + def get_auth_chain_difference(self, room_id, auth_sets): chains = [frozenset(self._get_auth_chain(a)) for a in auth_sets] common = set(chains[0]).intersection(*chains[1:]) |