summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/13320.misc1
-rw-r--r--synapse/federation/federation_client.py123
-rw-r--r--synapse/handlers/federation_event.py22
-rw-r--r--synapse/storage/databases/main/events.py23
-rw-r--r--tests/federation/test_federation_client.py125
5 files changed, 233 insertions, 61 deletions
diff --git a/changelog.d/13320.misc b/changelog.d/13320.misc
new file mode 100644
index 0000000000..d33cf3a25a
--- /dev/null
+++ b/changelog.d/13320.misc
@@ -0,0 +1 @@
+Fix `FederationClient.get_pdu()` returning events from the cache as `outliers` instead of original events we saw over federation.
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index 7c450ecad0..842f5327c2 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -53,7 +53,7 @@ from synapse.api.room_versions import (
     RoomVersion,
     RoomVersions,
 )
-from synapse.events import EventBase, builder
+from synapse.events import EventBase, builder, make_event_from_dict
 from synapse.federation.federation_base import (
     FederationBase,
     InvalidEventSignatureError,
@@ -299,7 +299,8 @@ class FederationClient(FederationBase):
                 moving to the next destination. None indicates no timeout.
 
         Returns:
-            The requested PDU, or None if we were unable to find it.
+            A copy of the requested PDU that is safe to modify, or None if we
+            were unable to find it.
 
         Raises:
             SynapseError, NotRetryingDestination, FederationDeniedError
@@ -309,7 +310,7 @@ class FederationClient(FederationBase):
         )
 
         logger.debug(
-            "retrieved event id %s from %s: %r",
+            "get_pdu_from_destination_raw: retrieved event id %s from %s: %r",
             event_id,
             destination,
             transaction_data,
@@ -358,54 +359,92 @@ class FederationClient(FederationBase):
             The requested PDU, or None if we were unable to find it.
         """
 
-        # TODO: Rate limit the number of times we try and get the same event.
+        logger.debug(
+            "get_pdu: event_id=%s from destinations=%s", event_id, destinations
+        )
 
-        ev = self._get_pdu_cache.get(event_id)
-        if ev:
-            return ev
+        # TODO: Rate limit the number of times we try and get the same event.
 
-        pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+        # We might need the same event multiple times in quick succession (before
+        # it gets persisted to the database), so we cache the results of the lookup.
+        # Note that this is separate to the regular get_event cache which caches
+        # events once they have been persisted.
+        event = self._get_pdu_cache.get(event_id)
+
+        # If we don't see the event in the cache, go try to fetch it from the
+        # provided remote federated destinations
+        if not event:
+            pdu_attempts = self.pdu_destination_tried.setdefault(event_id, {})
+
+            for destination in destinations:
+                now = self._clock.time_msec()
+                last_attempt = pdu_attempts.get(destination, 0)
+                if last_attempt + PDU_RETRY_TIME_MS > now:
+                    logger.debug(
+                        "get_pdu: skipping destination=%s because we tried it recently last_attempt=%s and we only check every %s (now=%s)",
+                        destination,
+                        last_attempt,
+                        PDU_RETRY_TIME_MS,
+                        now,
+                    )
+                    continue
+
+                try:
+                    event = await self.get_pdu_from_destination_raw(
+                        destination=destination,
+                        event_id=event_id,
+                        room_version=room_version,
+                        timeout=timeout,
+                    )
 
-        signed_pdu = None
-        for destination in destinations:
-            now = self._clock.time_msec()
-            last_attempt = pdu_attempts.get(destination, 0)
-            if last_attempt + PDU_RETRY_TIME_MS > now:
-                continue
+                    pdu_attempts[destination] = now
 
-            try:
-                signed_pdu = await self.get_pdu_from_destination_raw(
-                    destination=destination,
-                    event_id=event_id,
-                    room_version=room_version,
-                    timeout=timeout,
-                )
+                    if event:
+                        # Prime the cache
+                        self._get_pdu_cache[event.event_id] = event
 
-                pdu_attempts[destination] = now
+                        # FIXME: We should add a `break` here to avoid calling every
+                        # destination after we already found a PDU (will follow-up
+                        # in a separate PR)
 
-            except SynapseError as e:
-                logger.info(
-                    "Failed to get PDU %s from %s because %s", event_id, destination, e
-                )
-                continue
-            except NotRetryingDestination as e:
-                logger.info(str(e))
-                continue
-            except FederationDeniedError as e:
-                logger.info(str(e))
-                continue
-            except Exception as e:
-                pdu_attempts[destination] = now
+                except SynapseError as e:
+                    logger.info(
+                        "Failed to get PDU %s from %s because %s",
+                        event_id,
+                        destination,
+                        e,
+                    )
+                    continue
+                except NotRetryingDestination as e:
+                    logger.info(str(e))
+                    continue
+                except FederationDeniedError as e:
+                    logger.info(str(e))
+                    continue
+                except Exception as e:
+                    pdu_attempts[destination] = now
+
+                    logger.info(
+                        "Failed to get PDU %s from %s because %s",
+                        event_id,
+                        destination,
+                        e,
+                    )
+                    continue
 
-                logger.info(
-                    "Failed to get PDU %s from %s because %s", event_id, destination, e
-                )
-                continue
+        if not event:
+            return None
 
-        if signed_pdu:
-            self._get_pdu_cache[event_id] = signed_pdu
+        # `event` now refers to an object stored in `get_pdu_cache`. Our
+        # callers may need to modify the returned object (eg to set
+        # `event.internal_metadata.outlier = true`), so we return a copy
+        # rather than the original object.
+        event_copy = make_event_from_dict(
+            event.get_pdu_json(),
+            event.room_version,
+        )
 
-        return signed_pdu
+        return event_copy
 
     async def get_room_state_ids(
         self, destination: str, room_id: str, event_id: str
diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py
index e4a5b64d10..a5f4ce7c8a 100644
--- a/synapse/handlers/federation_event.py
+++ b/synapse/handlers/federation_event.py
@@ -766,10 +766,24 @@ class FederationEventHandler:
         """
         logger.info("Processing pulled event %s", event)
 
-        # these should not be outliers.
-        assert (
-            not event.internal_metadata.is_outlier()
-        ), "pulled event unexpectedly flagged as outlier"
+        # This function should not be used to persist outliers (use something
+        # else) because this does a bunch of operations that aren't necessary
+        # (extra work; in particular, it makes sure we have all the prev_events
+        # and resolves the state across those prev events). If you happen to run
+        # into a situation where the event you're trying to process/backfill is
+        # marked as an `outlier`, then you should update that spot to return an
+        # `EventBase` copy that doesn't have `outlier` flag set.
+        #
+        # `EventBase` is used to represent both an event we have not yet
+        # persisted, and one that we have persisted and now keep in the cache.
+        # In an ideal world this method would only be called with the first type
+        # of event, but it turns out that's not actually the case and for
+        # example, you could get an event from cache that is marked as an
+        # `outlier` (fix up that spot though).
+        assert not event.internal_metadata.is_outlier(), (
+            "Outlier event passed to _process_pulled_event. "
+            "To persist an event as a non-outlier, make sure to pass in a copy without `event.internal_metadata.outlier = true`."
+        )
 
         event_id = event.event_id
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 156e1bd5ab..1f600f1190 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1346,9 +1346,24 @@ class PersistEventsStore:
             event_id: outlier for event_id, outlier in txn
         }
 
+        logger.debug(
+            "_update_outliers_txn: events=%s have_persisted=%s",
+            [ev.event_id for ev, _ in events_and_contexts],
+            have_persisted,
+        )
+
         to_remove = set()
         for event, context in events_and_contexts:
-            if event.event_id not in have_persisted:
+            outlier_persisted = have_persisted.get(event.event_id)
+            logger.debug(
+                "_update_outliers_txn: event=%s outlier=%s outlier_persisted=%s",
+                event.event_id,
+                event.internal_metadata.is_outlier(),
+                outlier_persisted,
+            )
+
+            # Ignore events which we haven't persisted at all
+            if outlier_persisted is None:
                 continue
 
             to_remove.add(event)
@@ -1358,7 +1373,6 @@ class PersistEventsStore:
                 # was an outlier or not - what we have is at least as good.
                 continue
 
-            outlier_persisted = have_persisted[event.event_id]
             if not event.internal_metadata.is_outlier() and outlier_persisted:
                 # We received a copy of an event that we had already stored as
                 # an outlier in the database. We now have some state at that event
@@ -1369,7 +1383,10 @@ class PersistEventsStore:
                 # events down /sync. In general they will be historical events, so that
                 # doesn't matter too much, but that is not always the case.
 
-                logger.info("Updating state for ex-outlier event %s", event.event_id)
+                logger.info(
+                    "_update_outliers_txn: Updating state for ex-outlier event %s",
+                    event.event_id,
+                )
 
                 # insert into event_to_state_groups.
                 try:
diff --git a/tests/federation/test_federation_client.py b/tests/federation/test_federation_client.py
index cf6b130e4f..50e376f695 100644
--- a/tests/federation/test_federation_client.py
+++ b/tests/federation/test_federation_client.py
@@ -22,6 +22,7 @@ from twisted.python.failure import Failure
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.room_versions import RoomVersions
+from synapse.events import EventBase
 from synapse.server import HomeServer
 from synapse.types import JsonDict
 from synapse.util import Clock
@@ -38,20 +39,24 @@ class FederationClientTest(FederatingHomeserverTestCase):
         self._mock_agent = mock.create_autospec(twisted.web.client.Agent, spec_set=True)
         homeserver.get_federation_http_client().agent = self._mock_agent
 
-    def test_get_room_state(self):
-        creator = f"@creator:{self.OTHER_SERVER_NAME}"
-        test_room_id = "!room_id"
+        # Move clock up to somewhat realistic time so the PDU destination retry
+        # works (`now` needs to be larger than `0 + PDU_RETRY_TIME_MS`).
+        self.reactor.advance(1000000000)
+
+        self.creator = f"@creator:{self.OTHER_SERVER_NAME}"
+        self.test_room_id = "!room_id"
 
+    def test_get_room_state(self):
         # mock up some events to use in the response.
         # In real life, these would have things in `prev_events` and `auth_events`, but that's
         # a bit annoying to mock up, and the code under test doesn't care, so we don't bother.
         create_event_dict = self.add_hashes_and_signatures_from_other_server(
             {
-                "room_id": test_room_id,
+                "room_id": self.test_room_id,
                 "type": "m.room.create",
                 "state_key": "",
-                "sender": creator,
-                "content": {"creator": creator},
+                "sender": self.creator,
+                "content": {"creator": self.creator},
                 "prev_events": [],
                 "auth_events": [],
                 "origin_server_ts": 500,
@@ -59,10 +64,10 @@ class FederationClientTest(FederatingHomeserverTestCase):
         )
         member_event_dict = self.add_hashes_and_signatures_from_other_server(
             {
-                "room_id": test_room_id,
+                "room_id": self.test_room_id,
                 "type": "m.room.member",
-                "sender": creator,
-                "state_key": creator,
+                "sender": self.creator,
+                "state_key": self.creator,
                 "content": {"membership": "join"},
                 "prev_events": [],
                 "auth_events": [],
@@ -71,9 +76,9 @@ class FederationClientTest(FederatingHomeserverTestCase):
         )
         pl_event_dict = self.add_hashes_and_signatures_from_other_server(
             {
-                "room_id": test_room_id,
+                "room_id": self.test_room_id,
                 "type": "m.room.power_levels",
-                "sender": creator,
+                "sender": self.creator,
                 "state_key": "",
                 "content": {},
                 "prev_events": [],
@@ -103,7 +108,7 @@ class FederationClientTest(FederatingHomeserverTestCase):
         state_resp, auth_resp = self.get_success(
             self.hs.get_federation_client().get_room_state(
                 "yet.another.server",
-                test_room_id,
+                self.test_room_id,
                 "event_id",
                 RoomVersions.V9,
             )
@@ -130,6 +135,102 @@ class FederationClientTest(FederatingHomeserverTestCase):
             ["m.room.create", "m.room.member", "m.room.power_levels"],
         )
 
+    def test_get_pdu_returns_nothing_when_event_does_not_exist(self):
+        """No event should be returned when the event does not exist"""
+        remote_pdu = self.get_success(
+            self.hs.get_federation_client().get_pdu(
+                ["yet.another.server"],
+                "event_should_not_exist",
+                RoomVersions.V9,
+            )
+        )
+        self.assertEqual(remote_pdu, None)
+
+    def test_get_pdu(self):
+        """Test to make sure an event is returned by `get_pdu()`"""
+        self._get_pdu_once()
+
+    def test_get_pdu_event_from_cache_is_pristine(self):
+        """Test that modifications made to events returned by `get_pdu()`
+        do not propagate back to to the internal cache (events returned should
+        be a copy).
+        """
+
+        # Get the PDU in the cache
+        remote_pdu = self._get_pdu_once()
+
+        # Modify the the event reference.
+        # This change should not make it back to the `_get_pdu_cache`.
+        remote_pdu.internal_metadata.outlier = True
+
+        # Get the event again. This time it should read it from cache.
+        remote_pdu2 = self.get_success(
+            self.hs.get_federation_client().get_pdu(
+                ["yet.another.server"],
+                remote_pdu.event_id,
+                RoomVersions.V9,
+            )
+        )
+
+        # Sanity check that we are working against the same event
+        self.assertEqual(remote_pdu.event_id, remote_pdu2.event_id)
+
+        # Make sure the event does not include modification from earlier
+        self.assertIsNotNone(remote_pdu2)
+        self.assertEqual(remote_pdu2.internal_metadata.outlier, False)
+
+    def _get_pdu_once(self) -> EventBase:
+        """Retrieve an event via `get_pdu()` and assert that an event was returned.
+        Also used to prime the cache for subsequent test logic.
+        """
+        message_event_dict = self.add_hashes_and_signatures_from_other_server(
+            {
+                "room_id": self.test_room_id,
+                "type": "m.room.message",
+                "sender": self.creator,
+                "state_key": "",
+                "content": {},
+                "prev_events": [],
+                "auth_events": [],
+                "origin_server_ts": 700,
+                "depth": 10,
+            }
+        )
+
+        # mock up the response, and have the agent return it
+        self._mock_agent.request.side_effect = lambda *args, **kwargs: defer.succeed(
+            _mock_response(
+                {
+                    "origin": "yet.another.server",
+                    "origin_server_ts": 900,
+                    "pdus": [
+                        message_event_dict,
+                    ],
+                }
+            )
+        )
+
+        remote_pdu = self.get_success(
+            self.hs.get_federation_client().get_pdu(
+                ["yet.another.server"],
+                "event_id",
+                RoomVersions.V9,
+            )
+        )
+
+        # check the right call got made to the agent
+        self._mock_agent.request.assert_called_once_with(
+            b"GET",
+            b"matrix://yet.another.server/_matrix/federation/v1/event/event_id",
+            headers=mock.ANY,
+            bodyProducer=None,
+        )
+
+        self.assertIsNotNone(remote_pdu)
+        self.assertEqual(remote_pdu.internal_metadata.outlier, False)
+
+        return remote_pdu
+
 
 def _mock_response(resp: JsonDict):
     body = json.dumps(resp).encode("utf-8")