summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/14872.misc1
-rw-r--r--synapse/replication/tcp/client.py29
2 files changed, 20 insertions, 10 deletions
diff --git a/changelog.d/14872.misc b/changelog.d/14872.misc
new file mode 100644
index 0000000000..3731d6cbf1
--- /dev/null
+++ b/changelog.d/14872.misc
@@ -0,0 +1 @@
+Fix `wait_for_stream_position` to correctly wait for the right instance to advance its token.
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 5c2482e40c..6e242c5749 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -133,9 +133,9 @@ class ReplicationDataHandler:
         if hs.should_send_federation():
             self.send_handler = FederationSenderHandler(hs)
 
-        # Map from stream to list of deferreds waiting for the stream to
+        # Map from stream and instance to list of deferreds waiting for the stream to
         # arrive at a particular position. The lists are sorted by stream position.
-        self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {}
+        self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {}
 
     async def on_rdata(
         self, stream_name: str, instance_name: str, token: int, rows: list
@@ -270,7 +270,7 @@ class ReplicationDataHandler:
         # Notify any waiting deferreds. The list is ordered by position so we
         # just iterate through the list until we reach a position that is
         # greater than the received row position.
-        waiting_list = self._streams_to_waiters.get(stream_name, [])
+        waiting_list = self._streams_to_waiters.get((stream_name, instance_name), [])
 
         # Index of first item with a position after the current token, i.e we
         # have called all deferreds before this index. If not overwritten by
@@ -279,14 +279,13 @@ class ReplicationDataHandler:
         # `len(list)` works for both cases.
         index_of_first_deferred_not_called = len(waiting_list)
 
+        # We don't fire the deferreds until after we finish iterating over the
+        # list, to avoid the list changing when we fire the deferreds.
+        deferreds_to_callback = []
+
         for idx, (position, deferred) in enumerate(waiting_list):
             if position <= token:
-                try:
-                    with PreserveLoggingContext():
-                        deferred.callback(None)
-                except Exception:
-                    # The deferred has been cancelled or timed out.
-                    pass
+                deferreds_to_callback.append(deferred)
             else:
                 # The list is sorted by position so we don't need to continue
                 # checking any further entries in the list.
@@ -297,6 +296,14 @@ class ReplicationDataHandler:
         # loop. (This maintains the order so no need to resort)
         waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
 
+        for deferred in deferreds_to_callback:
+            try:
+                with PreserveLoggingContext():
+                    deferred.callback(None)
+            except Exception:
+                # The deferred has been cancelled or timed out.
+                pass
+
     async def on_position(
         self, stream_name: str, instance_name: str, token: int
     ) -> None:
@@ -349,7 +356,9 @@ class ReplicationDataHandler:
             deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
         )
 
-        waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
+        waiting_list = self._streams_to_waiters.setdefault(
+            (stream_name, instance_name), []
+        )
 
         waiting_list.append((position, deferred))
         waiting_list.sort(key=lambda t: t[0])