diff options
Diffstat (limited to 'synapse/replication/tcp')
-rw-r--r-- | synapse/replication/tcp/client.py | 29 |
1 files changed, 19 insertions, 10 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py index 5c2482e40c..6e242c5749 100644 --- a/synapse/replication/tcp/client.py +++ b/synapse/replication/tcp/client.py @@ -133,9 +133,9 @@ class ReplicationDataHandler: if hs.should_send_federation(): self.send_handler = FederationSenderHandler(hs) - # Map from stream to list of deferreds waiting for the stream to + # Map from stream and instance to list of deferreds waiting for the stream to # arrive at a particular position. The lists are sorted by stream position. - self._streams_to_waiters: Dict[str, List[Tuple[int, Deferred]]] = {} + self._streams_to_waiters: Dict[Tuple[str, str], List[Tuple[int, Deferred]]] = {} async def on_rdata( self, stream_name: str, instance_name: str, token: int, rows: list @@ -270,7 +270,7 @@ class ReplicationDataHandler: # Notify any waiting deferreds. The list is ordered by position so we # just iterate through the list until we reach a position that is # greater than the received row position. - waiting_list = self._streams_to_waiters.get(stream_name, []) + waiting_list = self._streams_to_waiters.get((stream_name, instance_name), []) # Index of first item with a position after the current token, i.e we # have called all deferreds before this index. If not overwritten by @@ -279,14 +279,13 @@ class ReplicationDataHandler: # `len(list)` works for both cases. index_of_first_deferred_not_called = len(waiting_list) + # We don't fire the deferreds until after we finish iterating over the + # list, to avoid the list changing when we fire the deferreds. + deferreds_to_callback = [] + for idx, (position, deferred) in enumerate(waiting_list): if position <= token: - try: - with PreserveLoggingContext(): - deferred.callback(None) - except Exception: - # The deferred has been cancelled or timed out. - pass + deferreds_to_callback.append(deferred) else: # The list is sorted by position so we don't need to continue # checking any further entries in the list. @@ -297,6 +296,14 @@ class ReplicationDataHandler: # loop. (This maintains the order so no need to resort) waiting_list[:] = waiting_list[index_of_first_deferred_not_called:] + for deferred in deferreds_to_callback: + try: + with PreserveLoggingContext(): + deferred.callback(None) + except Exception: + # The deferred has been cancelled or timed out. + pass + async def on_position( self, stream_name: str, instance_name: str, token: int ) -> None: @@ -349,7 +356,9 @@ class ReplicationDataHandler: deferred, _WAIT_FOR_REPLICATION_TIMEOUT_SECONDS, self._reactor ) - waiting_list = self._streams_to_waiters.setdefault(stream_name, []) + waiting_list = self._streams_to_waiters.setdefault( + (stream_name, instance_name), [] + ) waiting_list.append((position, deferred)) waiting_list.sort(key=lambda t: t[0]) |