diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 5c2482e40c..6e242c5749 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -133,9 +133,9 @@ class ReplicationDataHandler:
if hs.should_send_federation():
self.send_handler = FederationSenderHandler(hs)
- # Map from stream to list of deferreds waiting for the stream to
+ # Map from stream and instance to list of deferreds waiting for the stream to
# arrive at a particular position. The lists are sorted by stream position.
- self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {}
+ self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {}
async def on_rdata(
self, stream_name: str, instance_name: str, token: int, rows: list
@@ -270,7 +270,7 @@ class ReplicationDataHandler:
# Notify any waiting deferreds. The list is ordered by position so we
# just iterate through the list until we reach a position that is
# greater than the received row position.
- waiting_list = self._streams_to_waiters.get(stream_name, [])
+ waiting_list = self._streams_to_waiters.get((stream_name, instance_name), [])
# Index of first item with a position after the current token, i.e we
# have called all deferreds before this index. If not overwritten by
@@ -279,14 +279,13 @@ class ReplicationDataHandler:
# `len(list)` works for both cases.
index_of_first_deferred_not_called = len(waiting_list)
+ # We don't fire the deferreds until after we finish iterating over the
+ # list, to avoid the list changing when we fire the deferreds.
+ deferreds_to_callback = []
+
for idx, (position, deferred) in enumerate(waiting_list):
if position <= token:
- try:
- with PreserveLoggingContext():
- deferred.callback(None)
- except Exception:
- # The deferred has been cancelled or timed out.
- pass
+ deferreds_to_callback.append(deferred)
else:
# The list is sorted by position so we don't need to continue
# checking any further entries in the list.
@@ -297,6 +296,14 @@ class ReplicationDataHandler:
# loop. (This maintains the order so no need to resort)
waiting_list[:] = waiting_list[index_of_first_deferred_not_called:]
+ for deferred in deferreds_to_callback:
+ try:
+ with PreserveLoggingContext():
+ deferred.callback(None)
+ except Exception:
+ # The deferred has been cancelled or timed out.
+ pass
+
async def on_position(
self, stream_name: str, instance_name: str, token: int
) -> None:
@@ -349,7 +356,9 @@ class ReplicationDataHandler:
deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor
)
- waiting_list = self._streams_to_waiters.setdefault(stream_name, [])
+ waiting_list = self._streams_to_waiters.setdefault(
+ (stream_name, instance_name), []
+ )
waiting_list.append((position, deferred))
waiting_list.sort(key=lambda t: t[0])
|