summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/14844.misc1
-rw-r--r--synapse/handlers/federation.py106
-rw-r--r--tests/handlers/test_federation.py112
3 files changed, 210 insertions, 9 deletions
diff --git a/changelog.d/14844.misc b/changelog.d/14844.misc
new file mode 100644
index 0000000000..30ce866304
--- /dev/null
+++ b/changelog.d/14844.misc
@@ -0,0 +1 @@
+Add check to avoid starting duplicate partial state syncs.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index eca75f1108..e386f77de6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -27,6 +27,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Set,
     Tuple,
     Union,
 )
@@ -171,12 +172,23 @@ class FederationHandler:
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
+        # Tracks running partial state syncs by room ID.
+        # Partial state syncs currently only run on the main process, so it's okay to
+        # track them in-memory for now.
+        self._active_partial_state_syncs: Set[str] = set()
+        # Tracks partial state syncs we may want to restart.
+        # A dictionary mapping room IDs to (initial destination, other destinations)
+        # tuples.
+        self._partial_state_syncs_maybe_needing_restart: Dict[
+            str, Tuple[Optional[str], Collection[str]]
+        ] = {}
+
         # if this is the main process, fire off a background process to resume
         # any partial-state-resync operations which were in flight when we
         # were shut down.
         if not hs.config.worker.worker_app:
             run_as_background_process(
-                "resume_sync_partial_state_room", self._resume_sync_partial_state_room
+                "resume_sync_partial_state_room", self._resume_partial_state_room_sync
             )
 
     @trace
@@ -679,9 +691,7 @@ class FederationHandler:
                 if ret.partial_state:
                     # Kick off the process of asynchronously fetching the state for this
                     # room.
-                    run_as_background_process(
-                        desc="sync_partial_state_room",
-                        func=self._sync_partial_state_room,
+                    self._start_partial_state_room_sync(
                         initial_destination=origin,
                         other_destinations=ret.servers_in_room,
                         room_id=room_id,
@@ -1660,20 +1670,100 @@ class FederationHandler:
         # well.
         return None
 
-    async def _resume_sync_partial_state_room(self) -> None:
+    async def _resume_partial_state_room_sync(self) -> None:
         """Resumes resyncing of all partial-state rooms after a restart."""
         assert not self.config.worker.worker_app
 
         partial_state_rooms = await self.store.get_partial_state_room_resync_info()
         for room_id, resync_info in partial_state_rooms.items():
-            run_as_background_process(
-                desc="sync_partial_state_room",
-                func=self._sync_partial_state_room,
+            self._start_partial_state_room_sync(
                 initial_destination=resync_info.joined_via,
                 other_destinations=resync_info.servers_in_room,
                 room_id=room_id,
             )
 
+    def _start_partial_state_room_sync(
+        self,
+        initial_destination: Optional[str],
+        other_destinations: Collection[str],
+        room_id: str,
+    ) -> None:
+        """Starts the background process to resync the state of a partial state room,
+        if it is not already running.
+
+        Args:
+            initial_destination: the initial homeserver to pull the state from
+            other_destinations: other homeservers to try to pull the state from, if
+                `initial_destination` is unavailable
+            room_id: room to be resynced
+        """
+
+        async def _sync_partial_state_room_wrapper() -> None:
+            if room_id in self._active_partial_state_syncs:
+                # Another local user has joined the room while there is already a
+                # partial state sync running. This implies that there is a new join
+                # event to un-partial state. We might find ourselves in one of a few
+                # scenarios:
+                #  1. There is an existing partial state sync. The partial state sync
+                #     un-partial states the new join event before completing and all is
+                #     well.
+                #  2. Before the latest join, the homeserver was no longer in the room
+                #     and there is an existing partial state sync from our previous
+                #     membership of the room. The partial state sync may have:
+                #      a) succeeded, but not yet terminated. The room will not be
+                #         un-partial stated again unless we restart the partial state
+                #         sync.
+                #      b) failed, because we were no longer in the room and remote
+                #         homeservers were refusing our requests, but not yet
+                #         terminated. After the latest join, remote homeservers may
+                #         start answering our requests again, so we should restart the
+                #         partial state sync.
+                # In the cases where we would want to restart the partial state sync,
+                # the room would have the partial state flag when the partial state sync
+                # terminates.
+                self._partial_state_syncs_maybe_needing_restart[room_id] = (
+                    initial_destination,
+                    other_destinations,
+                )
+                return
+
+            self._active_partial_state_syncs.add(room_id)
+
+            try:
+                await self._sync_partial_state_room(
+                    initial_destination=initial_destination,
+                    other_destinations=other_destinations,
+                    room_id=room_id,
+                )
+            finally:
+                # Read the room's partial state flag while we still hold the claim to
+                # being the active partial state sync (so that another partial state
+                # sync can't come along and mess with it under us).
+                # Normally, the partial state flag will be gone. If it isn't, then we
+                # may find ourselves in scenario 2a or 2b as described in the comment
+                # above, where we want to restart the partial state sync.
+                is_still_partial_state_room = await self.store.is_partial_state_room(
+                    room_id
+                )
+                self._active_partial_state_syncs.remove(room_id)
+
+                if room_id in self._partial_state_syncs_maybe_needing_restart:
+                    (
+                        restart_initial_destination,
+                        restart_other_destinations,
+                    ) = self._partial_state_syncs_maybe_needing_restart.pop(room_id)
+
+                    if is_still_partial_state_room:
+                        self._start_partial_state_room_sync(
+                            initial_destination=restart_initial_destination,
+                            other_destinations=restart_other_destinations,
+                            room_id=room_id,
+                        )
+
+        run_as_background_process(
+            desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper
+        )
+
     async def _sync_partial_state_room(
         self,
         initial_destination: Optional[str],
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index cedbb9fafc..c1558c40c3 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,10 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import cast
+from typing import Collection, Optional, cast
 from unittest import TestCase
 from unittest.mock import Mock, patch
 
+from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes
@@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             f"Stale partial-stated room flag left over for {room_id} after a"
             f" failed do_invite_join!",
         )
+
+    def test_duplicate_partial_state_room_syncs(self) -> None:
+        """
+        Tests that concurrent partial state syncs are not started for the same room.
+        """
+        is_partial_state = True
+        end_sync: "Deferred[None]" = Deferred()
+
+        async def is_partial_state_room(room_id: str) -> bool:
+            return is_partial_state
+
+        async def sync_partial_state_room(
+            initial_destination: Optional[str],
+            other_destinations: Collection[str],
+            room_id: str,
+        ) -> None:
+            nonlocal end_sync
+            try:
+                await end_sync
+            finally:
+                end_sync = Deferred()
+
+        mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+        mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+        fed_handler = self.hs.get_federation_handler()
+        store = self.hs.get_datastores().main
+
+        with patch.object(
+            fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+        ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+            # Start the partial state sync.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Try to start another partial state sync.
+            # Nothing should happen.
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # End the partial state sync
+            is_partial_state = False
+            end_sync.callback(None)
+
+            # The partial state sync should not be restarted.
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # The next attempt to start the partial state sync should work.
+            is_partial_state = True
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+    def test_partial_state_room_sync_restart(self) -> None:
+        """
+        Tests that partial state syncs are restarted when a second partial state sync
+        was deduplicated and the first partial state sync fails.
+        """
+        is_partial_state = True
+        end_sync: "Deferred[None]" = Deferred()
+
+        async def is_partial_state_room(room_id: str) -> bool:
+            return is_partial_state
+
+        async def sync_partial_state_room(
+            initial_destination: Optional[str],
+            other_destinations: Collection[str],
+            room_id: str,
+        ) -> None:
+            nonlocal end_sync
+            try:
+                await end_sync
+            finally:
+                end_sync = Deferred()
+
+        mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+        mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+        fed_handler = self.hs.get_federation_handler()
+        store = self.hs.get_datastores().main
+
+        with patch.object(
+            fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+        ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+            # Start the partial state sync.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Fail the partial state sync.
+            # The partial state sync should not be restarted.
+            end_sync.errback(Exception("Failed to request /state_ids"))
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Start the partial state sync again.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+            # Deduplicate another partial state sync.
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+            # Fail the partial state sync.
+            # It should restart with the latest parameters.
+            end_sync.errback(Exception("Failed to request /state_ids"))
+            self.assertEqual(mock_sync_partial_state_room.call_count, 3)
+            mock_sync_partial_state_room.assert_called_with(
+                initial_destination="hs3",
+                other_destinations=["hs2"],
+                room_id="room_id",
+            )