summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/13278.bugfix1
-rw-r--r--synapse/state/__init__.py19
-rw-r--r--synapse/storage/controllers/persist_events.py3
-rw-r--r--tests/test_state.py42
4 files changed, 58 insertions, 7 deletions
diff --git a/changelog.d/13278.bugfix b/changelog.d/13278.bugfix
new file mode 100644
index 0000000000..49e9377c79
--- /dev/null
+++ b/changelog.d/13278.bugfix
@@ -0,0 +1 @@
+Fix long-standing bug where in rare instances Synapse could store the incorrect state for a room after a state resolution.
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index dcd272034d..3a65bd0849 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -83,7 +83,7 @@ def _gen_state_id() -> str:
 
 
 class _StateCacheEntry:
-    __slots__ = ["state", "state_group", "prev_group", "delta_ids"]
+    __slots__ = ["_state", "state_group", "prev_group", "delta_ids"]
 
     def __init__(
         self,
@@ -96,7 +96,10 @@ class _StateCacheEntry:
             raise Exception("Either state or state group must be not None")
 
         # A map from (type, state_key) to event_id.
-        self.state = frozendict(state) if state is not None else None
+        #
+        # This can be None if we have a `state_group` (as then we can fetch the
+        # state from the DB.)
+        self._state = frozendict(state) if state is not None else None
 
         # the ID of a state group if one and only one is involved.
         # otherwise, None otherwise?
@@ -114,8 +117,8 @@ class _StateCacheEntry:
         looking up the state group in the DB.
         """
 
-        if self.state is not None:
-            return self.state
+        if self._state is not None:
+            return self._state
 
         assert self.state_group is not None
 
@@ -128,7 +131,7 @@ class _StateCacheEntry:
         # cache eviction purposes. This is why if `self.state` is None it's fine
         # to return 1.
 
-        return len(self.state) if self.state else 1
+        return len(self._state) if self._state else 1
 
 
 class StateHandler:
@@ -743,6 +746,12 @@ def _make_state_cache_entry(
     delta_ids: Optional[StateMap[str]] = None
 
     for old_group, old_state in state_groups_ids.items():
+        if old_state.keys() - new_state.keys():
+            # Currently we don't support deltas that remove keys from the state
+            # map, so we have to ignore this group as a candidate to base the
+            # new group on.
+            continue
+
         n_delta_ids = {k: v for k, v in new_state.items() if old_state.get(k) != v}
         if not delta_ids or len(n_delta_ids) < len(delta_ids):
             prev_group = old_group
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index af65e5913b..cf98b0ab48 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -948,7 +948,8 @@ class EventsPersistenceStorageController:
                 events_context,
             )
 
-        return res.state, None, new_latest_event_ids
+        full_state = await res.get_state(self._state_controller)
+        return full_state, None, new_latest_event_ids
 
     async def _prune_extremities(
         self,
diff --git a/tests/test_state.py b/tests/test_state.py
index 6ca8d8f21d..e2c0013671 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -21,7 +21,7 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.api.room_versions import RoomVersions
 from synapse.events import make_event_from_dict
 from synapse.events.snapshot import EventContext
-from synapse.state import StateHandler, StateResolutionHandler
+from synapse.state import StateHandler, StateResolutionHandler, _make_state_cache_entry
 from synapse.util import Clock
 from synapse.util.macaroons import MacaroonGenerator
 
@@ -760,3 +760,43 @@ class StateTestCase(unittest.TestCase):
 
         result = yield defer.ensureDeferred(self.state.compute_event_context(event))
         return result
+
+    def test_make_state_cache_entry(self):
+        "Test that calculating a prev_group and delta is correct"
+
+        new_state = {
+            ("a", ""): "E",
+            ("b", ""): "E",
+            ("c", ""): "E",
+            ("d", ""): "E",
+        }
+
+        # old_state_1 has fewer differences to new_state than old_state_2, but
+        # the delta involves deleting a key, which isn't allowed in the deltas,
+        # so we should pick old_state_2 as the prev_group.
+
+        # `old_state_1` has two differences: `a` and `e`
+        old_state_1 = {
+            ("a", ""): "F",
+            ("b", ""): "E",
+            ("c", ""): "E",
+            ("d", ""): "E",
+            ("e", ""): "E",
+        }
+
+        # `old_state_2` has three differences: `a`, `c` and `d`
+        old_state_2 = {
+            ("a", ""): "F",
+            ("b", ""): "E",
+            ("c", ""): "F",
+            ("d", ""): "F",
+        }
+
+        entry = _make_state_cache_entry(new_state, {1: old_state_1, 2: old_state_2})
+
+        self.assertEqual(entry.prev_group, 2)
+
+        # There are three changes from `old_state_2` to `new_state`
+        self.assertEqual(
+            entry.delta_ids, {("a", ""): "E", ("c", ""): "E", ("d", ""): "E"}
+        )