summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12087.bugfix1
-rw-r--r--synapse/federation/federation_server.py7
-rw-r--r--synapse/handlers/federation.py61
3 files changed, 25 insertions, 44 deletions
diff --git a/changelog.d/12087.bugfix b/changelog.d/12087.bugfix
new file mode 100644
index 0000000000..6dacdddd0d
--- /dev/null
+++ b/changelog.d/12087.bugfix
@@ -0,0 +1 @@
+Fix a long-standing bug which caused the `/_matrix/federation/v1/state` and `.../state_ids` endpoints to return incorrect or invalid data when called for an event which we have stored as an "outlier".
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 482bbdd867..af2d0f7d79 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -22,7 +22,6 @@ from typing import (
     Callable,
     Collection,
     Dict,
-    Iterable,
     List,
     Optional,
     Tuple,
@@ -577,10 +576,10 @@ class FederationServer(FederationBase):
     async def _on_context_state_request_compute(
         self, room_id: str, event_id: Optional[str]
     ) -> Dict[str, list]:
+        pdus: Collection[EventBase]
         if event_id:
-            pdus: Iterable[EventBase] = await self.handler.get_state_for_pdu(
-                room_id, event_id
-            )
+            event_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id)
+            pdus = await self.store.get_events_as_list(event_ids)
         else:
             pdus = (await self.state.get_current_state(room_id)).values()
 
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index db39aeabde..350ec9c03a 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -950,54 +950,35 @@ class FederationHandler:
 
         return event
 
-    async def get_state_for_pdu(self, room_id: str, event_id: str) -> List[EventBase]:
-        """Returns the state at the event. i.e. not including said event."""
-
-        event = await self.store.get_event(event_id, check_room_id=room_id)
-
-        state_groups = await self.state_store.get_state_groups(room_id, [event_id])
-
-        if state_groups:
-            _, state = list(state_groups.items()).pop()
-            results = {(e.type, e.state_key): e for e in state}
-
-            if event.is_state():
-                # Get previous state
-                if "replaces_state" in event.unsigned:
-                    prev_id = event.unsigned["replaces_state"]
-                    if prev_id != event.event_id:
-                        prev_event = await self.store.get_event(prev_id)
-                        results[(event.type, event.state_key)] = prev_event
-                else:
-                    del results[(event.type, event.state_key)]
-
-            res = list(results.values())
-            return res
-        else:
-            return []
-
     async def get_state_ids_for_pdu(self, room_id: str, event_id: str) -> List[str]:
         """Returns the state at the event. i.e. not including said event."""
         event = await self.store.get_event(event_id, check_room_id=room_id)
+        if event.internal_metadata.outlier:
+            raise NotFoundError("State not known at event %s" % (event_id,))
 
         state_groups = await self.state_store.get_state_groups_ids(room_id, [event_id])
 
-        if state_groups:
-            _, state = list(state_groups.items()).pop()
-            results = state
+        # get_state_groups_ids should return exactly one result
+        assert len(state_groups) == 1
 
-            if event.is_state():
-                # Get previous state
-                if "replaces_state" in event.unsigned:
-                    prev_id = event.unsigned["replaces_state"]
-                    if prev_id != event.event_id:
-                        results[(event.type, event.state_key)] = prev_id
-                else:
-                    results.pop((event.type, event.state_key), None)
+        state_map = next(iter(state_groups.values()))
 
-            return list(results.values())
-        else:
-            return []
+        state_key = event.get_state_key()
+        if state_key is not None:
+            # the event was not rejected (get_event raises a NotFoundError for rejected
+            # events) so the state at the event should include the event itself.
+            assert (
+                state_map.get((event.type, state_key)) == event.event_id
+            ), "State at event did not include event itself"
+
+            # ... but we need the state *before* that event
+            if "replaces_state" in event.unsigned:
+                prev_id = event.unsigned["replaces_state"]
+                state_map[(event.type, state_key)] = prev_id
+            else:
+                del state_map[(event.type, state_key)]
+
+        return list(state_map.values())
 
     async def on_backfill_request(
         self, origin: str, room_id: str, pdu_list: List[str], limit: int