summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2023-01-20 16:06:52 +0000
committerErik Johnston <erik@matrix.org>2023-01-20 16:06:52 +0000
commitbc8136dd81b049b0ec7934d5b901e897a0740147 (patch)
tree894f71640642a5bf444d475bbb7831cc512d9b13
parentNewsfile (diff)
parentNewsfile (diff)
downloadsynapse-erikj/fix_wait_for_stream.tar.xz
Merge branch 'erikj/repl_notifieri' into erikj/fix_wait_for_stream github/erikj/fix_wait_for_stream erikj/fix_wait_for_stream
-rw-r--r--changelog.d/14844.misc1
-rw-r--r--changelog.d/14875.docker1
-rw-r--r--changelog.d/14877.misc1
-rw-r--r--docker/Dockerfile84
-rwxr-xr-xsynapse/_scripts/synapse_port_db.py4
-rw-r--r--synapse/handlers/federation.py106
-rw-r--r--synapse/notifier.py31
-rw-r--r--synapse/server.py6
-rw-r--r--synapse/storage/databases/main/account_data.py2
-rw-r--r--synapse/storage/databases/main/cache.py1
-rw-r--r--synapse/storage/databases/main/deviceinbox.py3
-rw-r--r--synapse/storage/databases/main/devices.py1
-rw-r--r--synapse/storage/databases/main/end_to_end_keys.py5
-rw-r--r--synapse/storage/databases/main/events_worker.py10
-rw-r--r--synapse/storage/databases/main/presence.py3
-rw-r--r--synapse/storage/databases/main/push_rule.py1
-rw-r--r--synapse/storage/databases/main/pusher.py1
-rw-r--r--synapse/storage/databases/main/receipts.py2
-rw-r--r--synapse/storage/databases/main/room.py6
-rw-r--r--synapse/storage/util/id_generators.py26
-rw-r--r--tests/handlers/test_federation.py112
-rw-r--r--tests/module_api/test_api.py3
-rw-r--r--tests/replication/tcp/test_handler.py23
-rw-r--r--tests/storage/test_id_generators.py4
24 files changed, 357 insertions, 80 deletions
diff --git a/changelog.d/14844.misc b/changelog.d/14844.misc
new file mode 100644
index 0000000000..30ce866304
--- /dev/null
+++ b/changelog.d/14844.misc
@@ -0,0 +1 @@
+Add check to avoid starting duplicate partial state syncs.
diff --git a/changelog.d/14875.docker b/changelog.d/14875.docker
new file mode 100644
index 0000000000..584fc10470
--- /dev/null
+++ b/changelog.d/14875.docker
@@ -0,0 +1 @@
+Bump default Python version in the Dockerfile from 3.9 to 3.11.
diff --git a/changelog.d/14877.misc b/changelog.d/14877.misc
new file mode 100644
index 0000000000..4e9c3fa33f
--- /dev/null
+++ b/changelog.d/14877.misc
@@ -0,0 +1 @@
+Always notify replication when a stream advances automatically.
diff --git a/docker/Dockerfile b/docker/Dockerfile
index b2ec005917..a85fd3d691 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -20,7 +20,7 @@
 # `poetry export | pip install -r /dev/stdin`, but beware: we have experienced bugs in
 # in `poetry export` in the past.
 
-ARG PYTHON_VERSION=3.9
+ARG PYTHON_VERSION=3.11
 
 ###
 ### Stage 0: generate requirements.txt
@@ -34,11 +34,11 @@ FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye as requirements
 # Here we use it to set up a cache for apt (and below for pip), to improve
 # rebuild speeds on slow connections.
 RUN \
-   --mount=type=cache,target=/var/cache/apt,sharing=locked \
-   --mount=type=cache,target=/var/lib/apt,sharing=locked \
-    apt-get update -qq && apt-get install -yqq \
-      build-essential git libffi-dev libssl-dev \
-    && rm -rf /var/lib/apt/lists/*
+  --mount=type=cache,target=/var/cache/apt,sharing=locked \
+  --mount=type=cache,target=/var/lib/apt,sharing=locked \
+  apt-get update -qq && apt-get install -yqq \
+  build-essential git libffi-dev libssl-dev \
+  && rm -rf /var/lib/apt/lists/*
 
 # We install poetry in its own build stage to avoid its dependencies conflicting with
 # synapse's dependencies.
@@ -64,9 +64,9 @@ ARG TEST_ONLY_IGNORE_POETRY_LOCKFILE
 # Otherwise, just create an empty requirements file so that the Dockerfile can
 # proceed.
 RUN if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
-    /root/.local/bin/poetry export --extras all -o /synapse/requirements.txt ${TEST_ONLY_SKIP_DEP_HASH_VERIFICATION:+--without-hashes}; \
+  /root/.local/bin/poetry export --extras all -o /synapse/requirements.txt ${TEST_ONLY_SKIP_DEP_HASH_VERIFICATION:+--without-hashes}; \
   else \
-    touch /synapse/requirements.txt; \
+  touch /synapse/requirements.txt; \
   fi
 
 ###
@@ -76,24 +76,24 @@ FROM docker.io/python:${PYTHON_VERSION}-slim-bullseye as builder
 
 # install the OS build deps
 RUN \
-   --mount=type=cache,target=/var/cache/apt,sharing=locked \
-   --mount=type=cache,target=/var/lib/apt,sharing=locked \
- apt-get update -qq && apt-get install -yqq \
-    build-essential \
-    libffi-dev \
-    libjpeg-dev \
-    libpq-dev \
-    libssl-dev \
-    libwebp-dev \
-    libxml++2.6-dev \
-    libxslt1-dev \
-    openssl \
-    zlib1g-dev \
-    git \
-    curl \
-    libicu-dev \
-    pkg-config \
-    && rm -rf /var/lib/apt/lists/*
+  --mount=type=cache,target=/var/cache/apt,sharing=locked \
+  --mount=type=cache,target=/var/lib/apt,sharing=locked \
+  apt-get update -qq && apt-get install -yqq \
+  build-essential \
+  libffi-dev \
+  libjpeg-dev \
+  libpq-dev \
+  libssl-dev \
+  libwebp-dev \
+  libxml++2.6-dev \
+  libxslt1-dev \
+  openssl \
+  zlib1g-dev \
+  git \
+  curl \
+  libicu-dev \
+  pkg-config \
+  && rm -rf /var/lib/apt/lists/*
 
 
 # Install rust and ensure its in the PATH
@@ -134,9 +134,9 @@ ARG TEST_ONLY_IGNORE_POETRY_LOCKFILE
 RUN --mount=type=cache,target=/synapse/target,sharing=locked \
   --mount=type=cache,target=${CARGO_HOME}/registry,sharing=locked \
   if [ -z "$TEST_ONLY_IGNORE_POETRY_LOCKFILE" ]; then \
-    pip install --prefix="/install" --no-deps --no-warn-script-location /synapse[all]; \
+  pip install --prefix="/install" --no-deps --no-warn-script-location /synapse[all]; \
   else \
-    pip install --prefix="/install" --no-warn-script-location /synapse[all]; \
+  pip install --prefix="/install" --no-warn-script-location /synapse[all]; \
   fi
 
 ###
@@ -151,20 +151,20 @@ LABEL org.opencontainers.image.source='https://github.com/matrix-org/synapse.git
 LABEL org.opencontainers.image.licenses='Apache-2.0'
 
 RUN \
-   --mount=type=cache,target=/var/cache/apt,sharing=locked \
-   --mount=type=cache,target=/var/lib/apt,sharing=locked \
+  --mount=type=cache,target=/var/cache/apt,sharing=locked \
+  --mount=type=cache,target=/var/lib/apt,sharing=locked \
   apt-get update -qq && apt-get install -yqq \
-    curl \
-    gosu \
-    libjpeg62-turbo \
-    libpq5 \
-    libwebp6 \
-    xmlsec1 \
-    libjemalloc2 \
-    libicu67 \
-    libssl-dev \
-    openssl \
-    && rm -rf /var/lib/apt/lists/*
+  curl \
+  gosu \
+  libjpeg62-turbo \
+  libpq5 \
+  libwebp6 \
+  xmlsec1 \
+  libjemalloc2 \
+  libicu67 \
+  libssl-dev \
+  openssl \
+  && rm -rf /var/lib/apt/lists/*
 
 COPY --from=builder /install /usr/local
 COPY ./docker/start.py /start.py
@@ -175,4 +175,4 @@ EXPOSE 8008/tcp 8009/tcp 8448/tcp
 ENTRYPOINT ["/start.py"]
 
 HEALTHCHECK --start-period=5s --interval=15s --timeout=5s \
-    CMD curl -fSs http://localhost:8008/health || exit 1
+  CMD curl -fSs http://localhost:8008/health || exit 1
diff --git a/synapse/_scripts/synapse_port_db.py b/synapse/_scripts/synapse_port_db.py
index c463b60b26..5e137dbbf7 100755
--- a/synapse/_scripts/synapse_port_db.py
+++ b/synapse/_scripts/synapse_port_db.py
@@ -51,6 +51,7 @@ from synapse.logging.context import (
     make_deferred_yieldable,
     run_in_background,
 )
+from synapse.notifier import ReplicationNotifier
 from synapse.storage.database import DatabasePool, LoggingTransaction, make_conn
 from synapse.storage.databases.main import PushRuleStore
 from synapse.storage.databases.main.account_data import AccountDataWorkerStore
@@ -260,6 +261,9 @@ class MockHomeserver:
     def should_send_federation(self) -> bool:
         return False
 
+    def get_replication_notifier(self) -> ReplicationNotifier:
+        return ReplicationNotifier()
+
 
 class Porter:
     def __init__(
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index eca75f1108..e386f77de6 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -27,6 +27,7 @@ from typing import (
     Iterable,
     List,
     Optional,
+    Set,
     Tuple,
     Union,
 )
@@ -171,12 +172,23 @@ class FederationHandler:
 
         self.third_party_event_rules = hs.get_third_party_event_rules()
 
+        # Tracks running partial state syncs by room ID.
+        # Partial state syncs currently only run on the main process, so it's okay to
+        # track them in-memory for now.
+        self._active_partial_state_syncs: Set[str] = set()
+        # Tracks partial state syncs we may want to restart.
+        # A dictionary mapping room IDs to (initial destination, other destinations)
+        # tuples.
+        self._partial_state_syncs_maybe_needing_restart: Dict[
+            str, Tuple[Optional[str], Collection[str]]
+        ] = {}
+
         # if this is the main process, fire off a background process to resume
         # any partial-state-resync operations which were in flight when we
         # were shut down.
         if not hs.config.worker.worker_app:
             run_as_background_process(
-                "resume_sync_partial_state_room", self._resume_sync_partial_state_room
+                "resume_sync_partial_state_room", self._resume_partial_state_room_sync
             )
 
     @trace
@@ -679,9 +691,7 @@ class FederationHandler:
                 if ret.partial_state:
                     # Kick off the process of asynchronously fetching the state for this
                     # room.
-                    run_as_background_process(
-                        desc="sync_partial_state_room",
-                        func=self._sync_partial_state_room,
+                    self._start_partial_state_room_sync(
                         initial_destination=origin,
                         other_destinations=ret.servers_in_room,
                         room_id=room_id,
@@ -1660,20 +1670,100 @@ class FederationHandler:
         # well.
         return None
 
-    async def _resume_sync_partial_state_room(self) -> None:
+    async def _resume_partial_state_room_sync(self) -> None:
         """Resumes resyncing of all partial-state rooms after a restart."""
         assert not self.config.worker.worker_app
 
         partial_state_rooms = await self.store.get_partial_state_room_resync_info()
         for room_id, resync_info in partial_state_rooms.items():
-            run_as_background_process(
-                desc="sync_partial_state_room",
-                func=self._sync_partial_state_room,
+            self._start_partial_state_room_sync(
                 initial_destination=resync_info.joined_via,
                 other_destinations=resync_info.servers_in_room,
                 room_id=room_id,
             )
 
+    def _start_partial_state_room_sync(
+        self,
+        initial_destination: Optional[str],
+        other_destinations: Collection[str],
+        room_id: str,
+    ) -> None:
+        """Starts the background process to resync the state of a partial state room,
+        if it is not already running.
+
+        Args:
+            initial_destination: the initial homeserver to pull the state from
+            other_destinations: other homeservers to try to pull the state from, if
+                `initial_destination` is unavailable
+            room_id: room to be resynced
+        """
+
+        async def _sync_partial_state_room_wrapper() -> None:
+            if room_id in self._active_partial_state_syncs:
+                # Another local user has joined the room while there is already a
+                # partial state sync running. This implies that there is a new join
+                # event to un-partial state. We might find ourselves in one of a few
+                # scenarios:
+                #  1. There is an existing partial state sync. The partial state sync
+                #     un-partial states the new join event before completing and all is
+                #     well.
+                #  2. Before the latest join, the homeserver was no longer in the room
+                #     and there is an existing partial state sync from our previous
+                #     membership of the room. The partial state sync may have:
+                #      a) succeeded, but not yet terminated. The room will not be
+                #         un-partial stated again unless we restart the partial state
+                #         sync.
+                #      b) failed, because we were no longer in the room and remote
+                #         homeservers were refusing our requests, but not yet
+                #         terminated. After the latest join, remote homeservers may
+                #         start answering our requests again, so we should restart the
+                #         partial state sync.
+                # In the cases where we would want to restart the partial state sync,
+                # the room would have the partial state flag when the partial state sync
+                # terminates.
+                self._partial_state_syncs_maybe_needing_restart[room_id] = (
+                    initial_destination,
+                    other_destinations,
+                )
+                return
+
+            self._active_partial_state_syncs.add(room_id)
+
+            try:
+                await self._sync_partial_state_room(
+                    initial_destination=initial_destination,
+                    other_destinations=other_destinations,
+                    room_id=room_id,
+                )
+            finally:
+                # Read the room's partial state flag while we still hold the claim to
+                # being the active partial state sync (so that another partial state
+                # sync can't come along and mess with it under us).
+                # Normally, the partial state flag will be gone. If it isn't, then we
+                # may find ourselves in scenario 2a or 2b as described in the comment
+                # above, where we want to restart the partial state sync.
+                is_still_partial_state_room = await self.store.is_partial_state_room(
+                    room_id
+                )
+                self._active_partial_state_syncs.remove(room_id)
+
+                if room_id in self._partial_state_syncs_maybe_needing_restart:
+                    (
+                        restart_initial_destination,
+                        restart_other_destinations,
+                    ) = self._partial_state_syncs_maybe_needing_restart.pop(room_id)
+
+                    if is_still_partial_state_room:
+                        self._start_partial_state_room_sync(
+                            initial_destination=restart_initial_destination,
+                            other_destinations=restart_other_destinations,
+                            room_id=room_id,
+                        )
+
+        run_as_background_process(
+            desc="sync_partial_state_room", func=_sync_partial_state_room_wrapper
+        )
+
     async def _sync_partial_state_room(
         self,
         initial_destination: Optional[str],
diff --git a/synapse/notifier.py b/synapse/notifier.py
index 26b97cf766..28f0d4a25a 100644
--- a/synapse/notifier.py
+++ b/synapse/notifier.py
@@ -226,8 +226,7 @@ class Notifier:
         self.store = hs.get_datastores().main
         self.pending_new_room_events: List[_PendingRoomEventEntry] = []
 
-        # Called when there are new things to stream over replication
-        self.replication_callbacks: List[Callable[[], None]] = []
+        self._replication_notifier = hs.get_replication_notifier()
         self._new_join_in_room_callbacks: List[Callable[[str, str], None]] = []
 
         self._federation_client = hs.get_federation_http_client()
@@ -279,7 +278,7 @@ class Notifier:
         it needs to do any asynchronous work, a background thread should be started and
         wrapped with run_as_background_process.
         """
-        self.replication_callbacks.append(cb)
+        self._replication_notifier.add_replication_callback(cb)
 
     def add_new_join_in_room_callback(self, cb: Callable[[str, str], None]) -> None:
         """Add a callback that will be called when a user joins a room.
@@ -741,8 +740,7 @@ class Notifier:
 
     def notify_replication(self) -> None:
         """Notify the any replication listeners that there's a new event"""
-        for cb in self.replication_callbacks:
-            cb()
+        self._replication_notifier.notify_replication()
 
     def notify_user_joined_room(self, event_id: str, room_id: str) -> None:
         for cb in self._new_join_in_room_callbacks:
@@ -759,3 +757,26 @@ class Notifier:
         # Tell the federation client about the fact the server is back up, so
         # that any in flight requests can be immediately retried.
         self._federation_client.wake_destination(server)
+
+
+@attr.s(auto_attribs=True)
+class ReplicationNotifier:
+    """Tracks callbacks for things that need to know about stream changes.
+
+    This is separate from the notifier to avoid circular dependencies.
+    """
+
+    _replication_callbacks: List[Callable[[], None]] = attr.Factory(list)
+
+    def add_replication_callback(self, cb: Callable[[], None]) -> None:
+        """Add a callback that will be called when some new data is available.
+        Callback is not given any arguments. It should *not* return a Deferred - if
+        it needs to do any asynchronous work, a background thread should be started and
+        wrapped with run_as_background_process.
+        """
+        self._replication_callbacks.append(cb)
+
+    def notify_replication(self) -> None:
+        """Notify the any replication listeners that there's a new event"""
+        for cb in self._replication_callbacks:
+            cb()
diff --git a/synapse/server.py b/synapse/server.py
index f4ab94c4f3..9d6d268f49 100644
--- a/synapse/server.py
+++ b/synapse/server.py
@@ -107,7 +107,7 @@ from synapse.http.client import InsecureInterceptableContextFactory, SimpleHttpC
 from synapse.http.matrixfederationclient import MatrixFederationHttpClient
 from synapse.metrics.common_usage_metrics import CommonUsageMetricsManager
 from synapse.module_api import ModuleApi
-from synapse.notifier import Notifier
+from synapse.notifier import Notifier, ReplicationNotifier
 from synapse.push.bulk_push_rule_evaluator import BulkPushRuleEvaluator
 from synapse.push.pusherpool import PusherPool
 from synapse.replication.tcp.client import ReplicationDataHandler
@@ -390,6 +390,10 @@ class HomeServer(metaclass=abc.ABCMeta):
         return Notifier(self)
 
     @cache_in_self
+    def get_replication_notifier(self) -> ReplicationNotifier:
+        return ReplicationNotifier()
+
+    @cache_in_self
     def get_auth(self) -> Auth:
         return Auth(self)
 
diff --git a/synapse/storage/databases/main/account_data.py b/synapse/storage/databases/main/account_data.py
index 881d7089db..8a359d7eb8 100644
--- a/synapse/storage/databases/main/account_data.py
+++ b/synapse/storage/databases/main/account_data.py
@@ -75,6 +75,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             self._account_data_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="account_data",
                 instance_name=self._instance_name,
                 tables=[
@@ -95,6 +96,7 @@ class AccountDataWorkerStore(PushRulesWorkerStore, CacheInvalidationWorkerStore)
             # SQLite).
             self._account_data_id_gen = StreamIdGenerator(
                 db_conn,
+                hs.get_replication_notifier(),
                 "room_account_data",
                 "stream_id",
                 extra_tables=[("room_tags_revisions", "stream_id")],
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2179a8bf59..5b66431691 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -75,6 +75,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
             self._cache_id_gen = MultiWriterIdGenerator(
                 db_conn,
                 database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="caches",
                 instance_name=hs.get_instance_name(),
                 tables=[
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index 713be91c5d..8e61aba454 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -91,6 +91,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 MultiWriterIdGenerator(
                     db_conn=db_conn,
                     db=database,
+                    notifier=hs.get_replication_notifier(),
                     stream_name="to_device",
                     instance_name=self._instance_name,
                     tables=[("device_inbox", "instance_name", "stream_id")],
@@ -101,7 +102,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         else:
             self._can_write_to_device = True
             self._device_inbox_id_gen = StreamIdGenerator(
-                db_conn, "device_inbox", "stream_id"
+                db_conn, hs.get_replication_notifier(), "device_inbox", "stream_id"
             )
 
         max_device_inbox_id = self._device_inbox_id_gen.get_current_token()
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index cd186c8472..903606fb46 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -92,6 +92,7 @@ class DeviceWorkerStore(RoomMemberWorkerStore, EndToEndKeyWorkerStore):
         # class below that is used on the main process.
         self._device_list_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
             db_conn,
+            hs.get_replication_notifier(),
             "device_lists_stream",
             "stream_id",
             extra_tables=[
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index 4c691642e2..c4ac6c33ba 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -1181,7 +1181,10 @@ class EndToEndKeyStore(EndToEndKeyWorkerStore, SQLBaseStore):
         super().__init__(database, db_conn, hs)
 
         self._cross_signing_id_gen = StreamIdGenerator(
-            db_conn, "e2e_cross_signing_keys", "stream_id"
+            db_conn,
+            hs.get_replication_notifier(),
+            "e2e_cross_signing_keys",
+            "stream_id",
         )
 
     async def set_e2e_device_keys(
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index d150fa8a94..d8a8bcafb6 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -191,6 +191,7 @@ class EventsWorkerStore(SQLBaseStore):
             self._stream_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="events",
                 instance_name=hs.get_instance_name(),
                 tables=[("events", "instance_name", "stream_ordering")],
@@ -200,6 +201,7 @@ class EventsWorkerStore(SQLBaseStore):
             self._backfill_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="backfill",
                 instance_name=hs.get_instance_name(),
                 tables=[("events", "instance_name", "stream_ordering")],
@@ -217,12 +219,14 @@ class EventsWorkerStore(SQLBaseStore):
             # SQLite).
             self._stream_id_gen = StreamIdGenerator(
                 db_conn,
+                hs.get_replication_notifier(),
                 "events",
                 "stream_ordering",
                 is_writer=hs.get_instance_name() in hs.config.worker.writers.events,
             )
             self._backfill_id_gen = StreamIdGenerator(
                 db_conn,
+                hs.get_replication_notifier(),
                 "events",
                 "stream_ordering",
                 step=-1,
@@ -300,6 +304,7 @@ class EventsWorkerStore(SQLBaseStore):
             self._un_partial_stated_events_stream_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="un_partial_stated_event_stream",
                 instance_name=hs.get_instance_name(),
                 tables=[
@@ -311,7 +316,10 @@ class EventsWorkerStore(SQLBaseStore):
             )
         else:
             self._un_partial_stated_events_stream_id_gen = StreamIdGenerator(
-                db_conn, "un_partial_stated_event_stream", "stream_id"
+                db_conn,
+                hs.get_replication_notifier(),
+                "un_partial_stated_event_stream",
+                "stream_id",
             )
 
     def get_un_partial_stated_events_token(self) -> int:
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index 7b60815043..beb210f8ee 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -77,6 +77,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
             self._presence_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="presence_stream",
                 instance_name=self._instance_name,
                 tables=[("presence_stream", "instance_name", "stream_id")],
@@ -85,7 +86,7 @@ class PresenceStore(PresenceBackgroundUpdateStore, CacheInvalidationWorkerStore)
             )
         else:
             self._presence_id_gen = StreamIdGenerator(
-                db_conn, "presence_stream", "stream_id"
+                db_conn, hs.get_replication_notifier(), "presence_stream", "stream_id"
             )
 
         self.hs = hs
diff --git a/synapse/storage/databases/main/push_rule.py b/synapse/storage/databases/main/push_rule.py
index 03182887d1..14ca167b34 100644
--- a/synapse/storage/databases/main/push_rule.py
+++ b/synapse/storage/databases/main/push_rule.py
@@ -118,6 +118,7 @@ class PushRulesWorkerStore(
         # class below that is used on the main process.
         self._push_rules_stream_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
             db_conn,
+            hs.get_replication_notifier(),
             "push_rules_stream",
             "stream_id",
             is_writer=hs.config.worker.worker_app is None,
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index 7f24a3b6ec..df53e726e6 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -62,6 +62,7 @@ class PusherWorkerStore(SQLBaseStore):
         # class below that is used on the main process.
         self._pushers_id_gen: AbstractStreamIdTracker = StreamIdGenerator(
             db_conn,
+            hs.get_replication_notifier(),
             "pushers",
             "id",
             extra_tables=[("deleted_pushers", "stream_id")],
diff --git a/synapse/storage/databases/main/receipts.py b/synapse/storage/databases/main/receipts.py
index 86f5bce5f0..3468f354e6 100644
--- a/synapse/storage/databases/main/receipts.py
+++ b/synapse/storage/databases/main/receipts.py
@@ -73,6 +73,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             self._receipts_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="receipts",
                 instance_name=self._instance_name,
                 tables=[("receipts_linearized", "instance_name", "stream_id")],
@@ -91,6 +92,7 @@ class ReceiptsWorkerStore(SQLBaseStore):
             # SQLite).
             self._receipts_id_gen = StreamIdGenerator(
                 db_conn,
+                hs.get_replication_notifier(),
                 "receipts_linearized",
                 "stream_id",
                 is_writer=hs.get_instance_name() in hs.config.worker.writers.receipts,
diff --git a/synapse/storage/databases/main/room.py b/synapse/storage/databases/main/room.py
index 78906a5e1d..7264a33cd4 100644
--- a/synapse/storage/databases/main/room.py
+++ b/synapse/storage/databases/main/room.py
@@ -126,6 +126,7 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             self._un_partial_stated_rooms_stream_id_gen = MultiWriterIdGenerator(
                 db_conn=db_conn,
                 db=database,
+                notifier=hs.get_replication_notifier(),
                 stream_name="un_partial_stated_room_stream",
                 instance_name=self._instance_name,
                 tables=[
@@ -137,7 +138,10 @@ class RoomWorkerStore(CacheInvalidationWorkerStore):
             )
         else:
             self._un_partial_stated_rooms_stream_id_gen = StreamIdGenerator(
-                db_conn, "un_partial_stated_room_stream", "stream_id"
+                db_conn,
+                hs.get_replication_notifier(),
+                "un_partial_stated_room_stream",
+                "stream_id",
             )
 
     async def store_room(
diff --git a/synapse/storage/util/id_generators.py b/synapse/storage/util/id_generators.py
index 8670ffbfa3..9adff3f4f5 100644
--- a/synapse/storage/util/id_generators.py
+++ b/synapse/storage/util/id_generators.py
@@ -20,6 +20,7 @@ from collections import OrderedDict
 from contextlib import contextmanager
 from types import TracebackType
 from typing import (
+    TYPE_CHECKING,
     AsyncContextManager,
     ContextManager,
     Dict,
@@ -49,6 +50,9 @@ from synapse.storage.database import (
 from synapse.storage.types import Cursor
 from synapse.storage.util.sequence import PostgresSequenceGenerator
 
+if TYPE_CHECKING:
+    from synapse.notifier import ReplicationNotifier
+
 logger = logging.getLogger(__name__)
 
 
@@ -182,6 +186,7 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
     def __init__(
         self,
         db_conn: LoggingDatabaseConnection,
+        notifier: "ReplicationNotifier",
         table: str,
         column: str,
         extra_tables: Iterable[Tuple[str, str]] = (),
@@ -205,6 +210,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
         # The key and values are the same, but we never look at the values.
         self._unfinished_ids: OrderedDict[int, int] = OrderedDict()
 
+        self._notifier = notifier
+
     def advance(self, instance_name: str, new_id: int) -> None:
         # Advance should never be called on a writer instance, only over replication
         if self._is_writer:
@@ -227,6 +234,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
                 with self._lock:
                     self._unfinished_ids.pop(next_id)
 
+                self._notifier.notify_replication()
+
         return _AsyncCtxManagerWrapper(manager())
 
     def get_next_mult(self, n: int) -> AsyncContextManager[Sequence[int]]:
@@ -250,6 +259,8 @@ class StreamIdGenerator(AbstractStreamIdGenerator):
                     for next_id in next_ids:
                         self._unfinished_ids.pop(next_id)
 
+                self._notifier.notify_replication()
+
         return _AsyncCtxManagerWrapper(manager())
 
     def get_current_token(self) -> int:
@@ -296,6 +307,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         self,
         db_conn: LoggingDatabaseConnection,
         db: DatabasePool,
+        notifier: "ReplicationNotifier",
         stream_name: str,
         instance_name: str,
         tables: List[Tuple[str, str, str]],
@@ -304,6 +316,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         positive: bool = True,
     ) -> None:
         self._db = db
+        self._notifier = notifier
         self._stream_name = stream_name
         self._instance_name = instance_name
         self._positive = positive
@@ -535,7 +548,9 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
         # Cast safety: the second argument to _MultiWriterCtxManager, multiple_ids,
         # controls the return type. If `None` or omitted, the context manager yields
         # a single integer stream_id; otherwise it yields a list of stream_ids.
-        return cast(AsyncContextManager[int], _MultiWriterCtxManager(self))
+        return cast(
+            AsyncContextManager[int], _MultiWriterCtxManager(self, self._notifier)
+        )
 
     def get_next_mult(self, n: int) -> AsyncContextManager[List[int]]:
         # If we have a list of instances that are allowed to write to this
@@ -544,7 +559,10 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
             raise Exception("Tried to allocate stream ID on non-writer")
 
         # Cast safety: see get_next.
-        return cast(AsyncContextManager[List[int]], _MultiWriterCtxManager(self, n))
+        return cast(
+            AsyncContextManager[List[int]],
+            _MultiWriterCtxManager(self, self._notifier, n),
+        )
 
     def get_next_txn(self, txn: LoggingTransaction) -> int:
         """
@@ -563,6 +581,7 @@ class MultiWriterIdGenerator(AbstractStreamIdGenerator):
 
         txn.call_after(self._mark_id_as_finished, next_id)
         txn.call_on_exception(self._mark_id_as_finished, next_id)
+        txn.call_after(self._notifier.notify_replication)
 
         # Update the `stream_positions` table with newly updated stream
         # ID (unless self._writers is not set in which case we don't
@@ -787,6 +806,7 @@ class _MultiWriterCtxManager:
     """Async context manager returned by MultiWriterIdGenerator"""
 
     id_gen: MultiWriterIdGenerator
+    notifier: "ReplicationNotifier"
     multiple_ids: Optional[int] = None
     stream_ids: List[int] = attr.Factory(list)
 
@@ -814,6 +834,8 @@ class _MultiWriterCtxManager:
         for i in self.stream_ids:
             self.id_gen._mark_id_as_finished(i)
 
+        self.notifier.notify_replication()
+
         if exc_type is not None:
             return False
 
diff --git a/tests/handlers/test_federation.py b/tests/handlers/test_federation.py
index cedbb9fafc..c1558c40c3 100644
--- a/tests/handlers/test_federation.py
+++ b/tests/handlers/test_federation.py
@@ -12,10 +12,11 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import cast
+from typing import Collection, Optional, cast
 from unittest import TestCase
 from unittest.mock import Mock, patch
 
+from twisted.internet.defer import Deferred
 from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes
@@ -679,3 +680,112 @@ class PartialJoinTestCase(unittest.FederatingHomeserverTestCase):
             f"Stale partial-stated room flag left over for {room_id} after a"
             f" failed do_invite_join!",
         )
+
+    def test_duplicate_partial_state_room_syncs(self) -> None:
+        """
+        Tests that concurrent partial state syncs are not started for the same room.
+        """
+        is_partial_state = True
+        end_sync: "Deferred[None]" = Deferred()
+
+        async def is_partial_state_room(room_id: str) -> bool:
+            return is_partial_state
+
+        async def sync_partial_state_room(
+            initial_destination: Optional[str],
+            other_destinations: Collection[str],
+            room_id: str,
+        ) -> None:
+            nonlocal end_sync
+            try:
+                await end_sync
+            finally:
+                end_sync = Deferred()
+
+        mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+        mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+        fed_handler = self.hs.get_federation_handler()
+        store = self.hs.get_datastores().main
+
+        with patch.object(
+            fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+        ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+            # Start the partial state sync.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Try to start another partial state sync.
+            # Nothing should happen.
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # End the partial state sync
+            is_partial_state = False
+            end_sync.callback(None)
+
+            # The partial state sync should not be restarted.
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # The next attempt to start the partial state sync should work.
+            is_partial_state = True
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+    def test_partial_state_room_sync_restart(self) -> None:
+        """
+        Tests that partial state syncs are restarted when a second partial state sync
+        was deduplicated and the first partial state sync fails.
+        """
+        is_partial_state = True
+        end_sync: "Deferred[None]" = Deferred()
+
+        async def is_partial_state_room(room_id: str) -> bool:
+            return is_partial_state
+
+        async def sync_partial_state_room(
+            initial_destination: Optional[str],
+            other_destinations: Collection[str],
+            room_id: str,
+        ) -> None:
+            nonlocal end_sync
+            try:
+                await end_sync
+            finally:
+                end_sync = Deferred()
+
+        mock_is_partial_state_room = Mock(side_effect=is_partial_state_room)
+        mock_sync_partial_state_room = Mock(side_effect=sync_partial_state_room)
+
+        fed_handler = self.hs.get_federation_handler()
+        store = self.hs.get_datastores().main
+
+        with patch.object(
+            fed_handler, "_sync_partial_state_room", mock_sync_partial_state_room
+        ), patch.object(store, "is_partial_state_room", mock_is_partial_state_room):
+            # Start the partial state sync.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Fail the partial state sync.
+            # The partial state sync should not be restarted.
+            end_sync.errback(Exception("Failed to request /state_ids"))
+            self.assertEqual(mock_sync_partial_state_room.call_count, 1)
+
+            # Start the partial state sync again.
+            fed_handler._start_partial_state_room_sync("hs1", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+            # Deduplicate another partial state sync.
+            fed_handler._start_partial_state_room_sync("hs3", ["hs2"], "room_id")
+            self.assertEqual(mock_sync_partial_state_room.call_count, 2)
+
+            # Fail the partial state sync.
+            # It should restart with the latest parameters.
+            end_sync.errback(Exception("Failed to request /state_ids"))
+            self.assertEqual(mock_sync_partial_state_room.call_count, 3)
+            mock_sync_partial_state_room.assert_called_with(
+                initial_destination="hs3",
+                other_destinations=["hs2"],
+                room_id="room_id",
+            )
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 9919938e80..8f88c0117d 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -404,6 +404,9 @@ class ModuleApiTestCase(HomeserverTestCase):
             self.module_api.send_local_online_presence_to([remote_user_id])
         )
 
+        # We don't always send out federation immediately, so we advance the clock.
+        self.reactor.advance(1000)
+
         # Check that a presence update was sent as part of a federation transaction
         found_update = False
         calls = (
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index 555922409d..6e4055cc21 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -14,7 +14,7 @@
 
 from twisted.internet import defer
 
-from synapse.replication.tcp.commands import PositionCommand, RdataCommand
+from synapse.replication.tcp.commands import PositionCommand
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 
@@ -111,20 +111,14 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
         next_token = self.get_success(ctx.__aenter__())
         self.get_success(ctx.__aexit__(None, None, None))
 
-        cmd_handler.send_command(
-            RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
-        )
-        self.replicate()
-
         self.get_success(
             data_handler.wait_for_stream_position("worker1", "caches", next_token)
         )
 
-        # `wait_for_stream_position` should only return once master receives an
-        # RDATA from the worker
-        ctx = cache_id_gen.get_next()
-        next_token = self.get_success(ctx.__aenter__())
-        self.get_success(ctx.__aexit__(None, None, None))
+        # `wait_for_stream_position` should only return once master receives a
+        # notification that `next_token` has persisted.
+        ctx_worker1 = cache_id_gen.get_next()
+        next_token = self.get_success(ctx_worker1.__aenter__())
 
         d = defer.ensureDeferred(
             data_handler.wait_for_stream_position("worker1", "caches", next_token)
@@ -142,10 +136,7 @@ class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
         )
         self.assertFalse(d.called)
 
-        # ... but receiving the RDATA should
-        cmd_handler.send_command(
-            RdataCommand("caches", "worker1", next_token, ("func_name", [], 0))
-        )
-        self.replicate()
+        # ... but worker1 finishing (and so sending an update) should.
+        self.get_success(ctx_worker1.__aexit__(None, None, None))
 
         self.assertTrue(d.called)
diff --git a/tests/storage/test_id_generators.py b/tests/storage/test_id_generators.py
index ff9691c518..9174fb0964 100644
--- a/tests/storage/test_id_generators.py
+++ b/tests/storage/test_id_generators.py
@@ -52,6 +52,7 @@ class StreamIdGeneratorTestCase(HomeserverTestCase):
         def _create(conn: LoggingDatabaseConnection) -> StreamIdGenerator:
             return StreamIdGenerator(
                 db_conn=conn,
+                notifier=self.hs.get_replication_notifier(),
                 table="foobar",
                 column="stream_id",
             )
@@ -196,6 +197,7 @@ class MultiWriterIdGeneratorTestCase(HomeserverTestCase):
             return MultiWriterIdGenerator(
                 conn,
                 self.db_pool,
+                notifier=self.hs.get_replication_notifier(),
                 stream_name="test_stream",
                 instance_name=instance_name,
                 tables=[("foobar", "instance_name", "stream_id")],
@@ -630,6 +632,7 @@ class BackwardsMultiWriterIdGeneratorTestCase(HomeserverTestCase):
             return MultiWriterIdGenerator(
                 conn,
                 self.db_pool,
+                notifier=self.hs.get_replication_notifier(),
                 stream_name="test_stream",
                 instance_name=instance_name,
                 tables=[("foobar", "instance_name", "stream_id")],
@@ -766,6 +769,7 @@ class MultiTableMultiWriterIdGeneratorTestCase(HomeserverTestCase):
             return MultiWriterIdGenerator(
                 conn,
                 self.db_pool,
+                notifier=self.hs.get_replication_notifier(),
                 stream_name="test_stream",
                 instance_name=instance_name,
                 tables=[