summary refs log tree commit diff
path: root/synapse/replication/tcp/client.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/replication/tcp/client.py')
-rw-r--r--synapse/replication/tcp/client.py25
1 files changed, 23 insertions, 2 deletions
diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 322d695bc7..5c2482e40c 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,6 +16,7 @@
 import logging
 from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
 
+from twisted.internet import defer
 from twisted.internet.defer import Deferred
 from twisted.internet.interfaces import IAddress, IConnector
 from twisted.internet.protocol import ReconnectingClientFactory
@@ -314,10 +315,21 @@ class ReplicationDataHandler:
             self.send_handler.wake_destination(server)
 
     async def wait_for_stream_position(
-        self, instance_name: str, stream_name: str, position: int
+        self,
+        instance_name: str,
+        stream_name: str,
+        position: int,
+        raise_on_timeout: bool = True,
     ) -> None:
         """Wait until this instance has received updates up to and including
         the given stream position.
+
+        Args:
+            instance_name
+            stream_name
+            position
+            raise_on_timeout: Whether to raise an exception if we time out
+                waiting for the updates, or if we log an error and return.
         """
 
         if instance_name == self._instance_name:
@@ -345,7 +357,16 @@ class ReplicationDataHandler:
         # We measure here to get in flight counts and average waiting time.
         with Measure(self._clock, "repl.wait_for_stream_position"):
             logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
-            await make_deferred_yieldable(deferred)
+            try:
+                await make_deferred_yieldable(deferred)
+            except defer.TimeoutError:
+                logger.error("Timed out waiting for stream %s", stream_name)
+
+                if raise_on_timeout:
+                    raise
+
+                return
+
             logger.info(
                 "Finished waiting for repl stream %r to reach %s", stream_name, position
             )