diff --git a/changelog.d/14844.misc b/changelog.d/14844.misc
new file mode 100644
index 0000000000..30ce866304
--- /dev/null
+++ b/changelog.d/14844.misc
@@ -0,0 +1 @@
+Add check to avoid starting duplicate partial state syncs.
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index eca75f1108..e386f77de6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -27,6 +27,7 @@ from typing import (
Iterable,
List,
Optional,
+ Set,
Tuple,
Union,
)
@@ -171,12 +172,23 @@ class FederationHandler:
self.third_party_event_rules = hs.get_third_party_event_rules()
+ # Tracks running partial state syncs by room ID.
+ # Partial state syncs currently only run on the main process, so it's okay to
+ # track them in-memory for now.
+ self._active_partial_state_syncs: Set[str] = set()
+ # Tracks partial state syncs we may want to restart.
+ # A dictionary mapping room IDs to (initial destination, other destinations)
+ # tuples.
+ self._partial_state_syncs_maybe_needing_restart: Dict[
+ str, Tuple[Optional[str], Collection[str]]
+ ] = {}
+
# if this is the main process, fire off a background process to resume
# any partial-state-resync operations which were in flight when we
# were shut down.
if not hs.config.worker.worker_app:
run_as_background_process(
- "resume_sync_partial_state_room", self._resume_sync_partial_state_room
+ "resume_sync_partial_state_room", self._resume_partial_state_room_sync
)
@trace
@@ -679,9 +691,7 @@ class FederationHandler:
if ret.partial_state:
# Kick off the process of asynchronously fetching the state for this
# room.
- run_as_background_process(
- desc="sync_partial_state_room",
- func=self._sync_partial_state_room,
+ self._start_partial_state_room_sync(
initial_destination=origin,
other_destinations=ret.servers_in_room,
room_id=room_id,
@@ -1660,20 +1670,100 @@ class FederationHandler:
# well.
return None
- async def _resume_sync_partial_state_room(self) -> None:
+ async def _resume_partial_state_room_sync(self) -> None:
"""Resumes resyncing of all partial-state rooms after a restart."""
assert not self.config.worker.worker_app
partial_state_rooms = await self.store.get_partial_state_room_resync_info()
for room_id, resync_info in partial_state_rooms.items():
- run_as_background_process(
- desc="sync_partial_state_room",
- func=self._sync_partial_state_room,
+ self._start_partial_state_room_sync(
initial_destination=resync_info.joined_via,
other_destinations=resync_info.servers_in_room,
room_id=room_id,
)
+ def _start_partial_state_room_sync(
+ self,
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
+ room_id: str,
+ ) -> None:
+ """Starts the background process to resync the state of a partial state room,
+ if it is not already running.
+
+ Args:
+ initial_destination: the initial homeserver to pull the state from
+ other_destinations: other homeservers to try to pull the state from, if
+ `initial_destination` is unavailable
+ room_id: room to be resynced
+ """
+
+ async def _sync_partial_state_room_wrapper() -> None:
+ if room_id in self._active_partial_state_syncs:
+ # Another local user has joined the room while there is already a
+ # partial state sync running. This implies that there is a new join
+ # event to un-partial state. We might find ourselves in one of a few
+ # scenarios:
+ # 1. There is an existing partial state sync. The partial state sync
+ # un-partial states the new join event before completing and all is
+ # well.
+ # 2. Before the latest join, the homeserver was no longer in the room
+ # and there is an existing partial state sync from our previous
+ # membership of the room. The partial state sync may have:
+ # a) succeeded, but not yet terminated. The room will not be
+ # un-partial stated again unless we restart the partial state
+ # sync.
+ # b) failed, because we were no longer in the room and remote
+ # homeservers were refusing our requests, but not yet
+ # terminated. After the latest join, remote homeservers may
+ # start answering our requests again, so we should restart the
+ # partial state sync.
+ # In the cases where we would want to restart the partial state sync,
+ # the room would have the partial state flag when the partial state sync
+ # terminates.
+ self._partial_state_syncs_maybe_needing_restart[room_id] = (
+ initial_destination,
+ other_destinations,
+ )
+ return
+
+ self._active_partial_state_syncs.add(room_id)
+
+ try:
+ await self._sync_partial_state_room(
+ initial_destination=initial_destination,
+ other_destinations=other_destinations,
+ room_id=room_id,
+ )
+ finally:
+ # Read the room's partial state flag while we still hold the claim to
+ # being the active partial state sync (so that another partial state
+ # sync can't come along and mess with it under us).
+ # Normally, the partial state flag will be gone. If it isn't, then we
+ # may find ourselves in scenario 2a or 2b as described in the comment
+ # above, where we want to restart the partial state sync.
+ is_still_partial_state_room = await self.store.is_partial_state_room(
+ room_id
+ )
+ self._active_partial_state_syncs.remove(room_id)
+
+ if room_id in self._partial_state_syncs_maybe_needing_restart:
+ (
+ restart_initial_destination,
+ restart_other_destinations,
+ ) = self._partial_state_syncs_maybe_needing_restart.pop(room_id)
+
+ if is_still_partial_state_room:
+ self._start_partial_state_room_sync(
+ initial_destination=restart_initial_destination,
+ other_destinations=restart_other_destinations,
+ room_id=room_id,
+ )
+
+ run_as_background_process(
+ desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper
+ )
+
async def _sync_partial_state_room(
self,
initial_destination: Optional[str],
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index cedbb9fafc..c1558c40c3 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import cast
+from typing import Collection, Optional, cast
from unittest import TestCase
from unittest.mock import Mock, patch
+from twisted.internet.defer import Deferred
from twisted.test.proto_helpers import MemoryReactor
from synapse.api.constants import EventTypes
@@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
f"Stale partial-stated room flag left over for {room_id} after a"
f" failed do_invite_join!",
)
+
+ def test_duplicate_partial_state_room_syncs(self) -> None:
+ """
+ Tests that concurrent partial state syncs are not started for the same room.
+ """
+ is_partial_state = True
+ end_sync: "Deferred[None]" = Deferred()
+
+ async def is_partial_state_room(room_id: str) -> bool:
+ return is_partial_state
+
+ async def sync_partial_state_room(
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
+ room_id: str,
+ ) -> None:
+ nonlocal end_sync
+ try:
+ await end_sync
+ finally:
+ end_sync = Deferred()
+
+ mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+ mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+ fed_handler = self.hs.get_federation_handler()
+ store = self.hs.get_datastores().main
+
+ with patch.object(
+ fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+ ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+ # Start the partial state sync.
+ fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # Try to start another partial state sync.
+ # Nothing should happen.
+ fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # End the partial state sync
+ is_partial_state = False
+ end_sync.callback(None)
+
+ # The partial state sync should not be restarted.
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # The next attempt to start the partial state sync should work.
+ is_partial_state = True
+ fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+ def test_partial_state_room_sync_restart(self) -> None:
+ """
+ Tests that partial state syncs are restarted when a second partial state sync
+ was deduplicated and the first partial state sync fails.
+ """
+ is_partial_state = True
+ end_sync: "Deferred[None]" = Deferred()
+
+ async def is_partial_state_room(room_id: str) -> bool:
+ return is_partial_state
+
+ async def sync_partial_state_room(
+ initial_destination: Optional[str],
+ other_destinations: Collection[str],
+ room_id: str,
+ ) -> None:
+ nonlocal end_sync
+ try:
+ await end_sync
+ finally:
+ end_sync = Deferred()
+
+ mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+ mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+ fed_handler = self.hs.get_federation_handler()
+ store = self.hs.get_datastores().main
+
+ with patch.object(
+ fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+ ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+ # Start the partial state sync.
+ fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # Fail the partial state sync.
+ # The partial state sync should not be restarted.
+ end_sync.errback(Exception("Failed to request /state_ids"))
+ self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+ # Start the partial state sync again.
+ fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+ # Deduplicate another partial state sync.
+ fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+ self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+ # Fail the partial state sync.
+ # It should restart with the latest parameters.
+ end_sync.errback(Exception("Failed to request /state_ids"))
+ self.assertEqual(mock_sync_partial_state_room.call_count, 3)
+ mock_sync_partial_state_room.assert_called_with(
+ initial_destination="hs3",
+ other_destinations=["hs2"],
+ room_id="room_id",
+ )
|