summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8827.bugfix1
-rw-r--r--synapse/state/v2.py87
-rw-r--r--tests/state/test_v2.py128
-rw-r--r--tests/storage/test_event_federation.py5
4 files changed, 216 insertions, 5 deletions
diff --git a/changelog.d/8827.bugfix b/changelog.d/8827.bugfix
new file mode 100644
index 0000000000..18195680d3
--- /dev/null
+++ b/changelog.d/8827.bugfix
@@ -0,0 +1 @@
+Fix bug where we might not correctly calculate the current state for rooms with multiple extremities.
diff --git a/synapse/state/v2.py b/synapse/state/v2.py
index f57df0d728..ffc504ce77 100644
--- a/synapse/state/v2.py
+++ b/synapse/state/v2.py
@@ -38,7 +38,7 @@ from synapse.api.constants import EventTypes
 from synapse.api.errors import AuthError
 from synapse.api.room_versions import KNOWN_ROOM_VERSIONS
 from synapse.events import EventBase
-from synapse.types import MutableStateMap, StateMap
+from synapse.types import Collection, MutableStateMap, StateMap
 from synapse.util import Clock
 
 logger = logging.getLogger(__name__)
@@ -252,9 +252,88 @@ async def _get_auth_chain_difference(
         Set of event IDs
     """
 
-    difference = await state_res_store.get_auth_chain_difference(
-        [set(state_set.values()) for state_set in state_sets]
-    )
+    # The `StateResolutionStore.get_auth_chain_difference` function assumes that
+    # all events passed to it (and their auth chains) have been persisted
+    # previously. This is not the case for any events in the `event_map`, and so
+    # we need to manually handle those events.
+    #
+    # We do this by:
+    #   1. calculating the auth chain difference for the state sets based on the
+    #      events in `event_map` alone
+    #   2. replacing any events in the state_sets that are also in `event_map`
+    #      with their auth events (recursively), and then calling
+    #      `store.get_auth_chain_difference` as normal
+    #   3. adding the results of 1 and 2 together.
+
+    # Map from event ID in `event_map` to their auth event IDs, and their auth
+    # event IDs if they appear in the `event_map`. This is the intersection of
+    # the event's auth chain with the events in the `event_map` *plus* their
+    # auth event IDs.
+    events_to_auth_chain = {}  # type: Dict[str, Set[str]]
+    for event in event_map.values():
+        chain = {event.event_id}
+        events_to_auth_chain[event.event_id] = chain
+
+        to_search = [event]
+        while to_search:
+            for auth_id in to_search.pop().auth_event_ids():
+                chain.add(auth_id)
+                auth_event = event_map.get(auth_id)
+                if auth_event:
+                    to_search.append(auth_event)
+
+    # We now a) calculate the auth chain difference for the unpersisted events
+    # and b) work out the state sets to pass to the store.
+    #
+    # Note: If the `event_map` is empty (which is the common case), we can do a
+    # much simpler calculation.
+    if event_map:
+        # The list of state sets to pass to the store, where each state set is a set
+        # of the event ids making up the state. This is similar to `state_sets`,
+        # except that (a) we only have event ids, not the complete
+        # ((type, state_key)->event_id) mappings; and (b) we have stripped out
+        # unpersisted events and replaced them with the persisted events in
+        # their auth chain.
+        state_sets_ids = []  # type: List[Set[str]]
+
+        # For each state set, the unpersisted event IDs reachable (by their auth
+        # chain) from the events in that set.
+        unpersisted_set_ids = []  # type: List[Set[str]]
+
+        for state_set in state_sets:
+            set_ids = set()  # type: Set[str]
+            state_sets_ids.append(set_ids)
+
+            unpersisted_ids = set()  # type: Set[str]
+            unpersisted_set_ids.append(unpersisted_ids)
+
+            for event_id in state_set.values():
+                event_chain = events_to_auth_chain.get(event_id)
+                if event_chain is not None:
+                    # We have an event in `event_map`. We add all the auth
+                    # events that it references (that aren't also in `event_map`).
+                    set_ids.update(e for e in event_chain if e not in event_map)
+
+                    # We also add the full chain of unpersisted event IDs
+                    # referenced by this state set, so that we can work out the
+                    # auth chain difference of the unpersisted events.
+                    unpersisted_ids.update(e for e in event_chain if e in event_map)
+                else:
+                    set_ids.add(event_id)
+
+        # The auth chain difference of the unpersisted events of the state sets
+        # is calculated by taking the difference between the union and
+        # intersections.
+        union = unpersisted_set_ids[0].union(*unpersisted_set_ids[1:])
+        intersection = unpersisted_set_ids[0].intersection(*unpersisted_set_ids[1:])
+
+        difference_from_event_map = union - intersection  # type: Collection[str]
+    else:
+        difference_from_event_map = ()
+        state_sets_ids = [set(state_set.values()) for state_set in state_sets]
+
+    difference = await state_res_store.get_auth_chain_difference(state_sets_ids)
+    difference.update(difference_from_event_map)
 
     return difference
 
diff --git a/tests/state/test_v2.py b/tests/state/test_v2.py
index ad9bbef9d2..f5c6db900d 100644
--- a/tests/state/test_v2.py
+++ b/tests/state/test_v2.py
@@ -24,7 +24,11 @@ from synapse.api.constants import EventTypes, JoinRules, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.event_auth import auth_types_for_event
 from synapse.events import make_event_from_dict
-from synapse.state.v2 import lexicographical_topological_sort, resolve_events_with_store
+from synapse.state.v2 import (
+    _get_auth_chain_difference,
+    lexicographical_topological_sort,
+    resolve_events_with_store,
+)
 from synapse.types import EventID
 
 from tests import unittest
@@ -587,6 +591,128 @@ class SimpleParamStateTestCase(unittest.TestCase):
         self.assert_dict(self.expected_combined_state, state)
 
 
+class AuthChainDifferenceTestCase(unittest.TestCase):
+    """We test that `_get_auth_chain_difference` correctly handles unpersisted
+    events.
+    """
+
+    def test_simple(self):
+        # Test getting the auth difference for a simple chain with a single
+        # unpersisted event:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #           C -|-> B -> A
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c}
+
+        state_sets = [{"a": a.event_id, "b": b.event_id}, {"c": c.event_id}]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {c.event_id})
+
+    def test_multiple_unpersisted_chain(self):
+        # Test getting the auth difference for a simple chain with multiple
+        # unpersisted events:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #      D -> C -|-> B -> A
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        d = FakeEvent(
+            id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c, d.event_id: d}
+
+        state_sets = [
+            {"a": a.event_id, "b": b.event_id},
+            {"c": c.event_id, "d": d.event_id},
+        ]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {d.event_id, c.event_id})
+
+    def test_unpersisted_events_different_sets(self):
+        # Test getting the auth difference for with multiple unpersisted events
+        # in different branches:
+        #
+        #  Unpersisted | Persisted
+        #              |
+        #     D --> C -|-> B -> A
+        #     E ----^ -|---^
+        #              |
+
+        a = FakeEvent(
+            id="A", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([], [])
+
+        b = FakeEvent(
+            id="B", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([a.event_id], [])
+
+        c = FakeEvent(
+            id="C", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([b.event_id], [])
+
+        d = FakeEvent(
+            id="D", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id], [])
+
+        e = FakeEvent(
+            id="E", sender=ALICE, type=EventTypes.Member, state_key="", content={},
+        ).to_event([c.event_id, b.event_id], [])
+
+        persisted_events = {a.event_id: a, b.event_id: b}
+        unpersited_events = {c.event_id: c, d.event_id: d, e.event_id: e}
+
+        state_sets = [
+            {"a": a.event_id, "b": b.event_id, "e": e.event_id},
+            {"c": c.event_id, "d": d.event_id},
+        ]
+
+        store = TestStateResolutionStore(persisted_events)
+
+        diff_d = _get_auth_chain_difference(state_sets, unpersited_events, store)
+        difference = self.successResultOf(defer.ensureDeferred(diff_d))
+
+        self.assertEqual(difference, {d.event_id, e.event_id})
+
+
 def pairwise(iterable):
     "s -> (s0,s1), (s1,s2), (s2, s3), ..."
     a, b = itertools.tee(iterable)
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index d4c3b867e3..71c21d8c75 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -217,6 +217,11 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
         self.assertSetEqual(difference, {"a", "b", "c"})
 
         difference = self.get_success(
+            self.store.get_auth_chain_difference([{"a", "c"}, {"b", "c"}])
+        )
+        self.assertSetEqual(difference, {"a", "b"})
+
+        difference = self.get_success(
             self.store.get_auth_chain_difference([{"a"}, {"b"}, {"d"}])
         )
         self.assertSetEqual(difference, {"a", "b", "d", "e"})