diff --git a/synapse/replication/tcp/client.py b/synapse/replication/tcp/client.py
index 322d695bc7..5c2482e40c 100644
--- a/synapse/replication/tcp/client.py
+++ b/synapse/replication/tcp/client.py
@@ -16,6 +16,7 @@
import logging
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from twisted.internet import defer
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IAddress, IConnector
from twisted.internet.protocol import ReconnectingClientFactory
@@ -314,10 +315,21 @@ class ReplicationDataHandler:
self.send_handler.wake_destination(server)
async def wait_for_stream_position(
- self, instance_name: str, stream_name: str, position: int
+ self,
+ instance_name: str,
+ stream_name: str,
+ position: int,
+ raise_on_timeout: bool = True,
) -> None:
"""Wait until this instance has received updates up to and including
the given stream position.
+
+ Args:
+ instance_name
+ stream_name
+ position
+ raise_on_timeout: Whether to raise an exception if we time out
+ waiting for the updates, or if we log an error and return.
"""
if instance_name == self._instance_name:
@@ -345,7 +357,16 @@ class ReplicationDataHandler:
# We measure here to get in flight counts and average waiting time.
with Measure(self._clock, "repl.wait_for_stream_position"):
logger.info("Waiting for repl stream %r to reach %s", stream_name, position)
- await make_deferred_yieldable(deferred)
+ try:
+ await make_deferred_yieldable(deferred)
+ except defer.TimeoutError:
+ logger.error("Timed out waiting for stream %s", stream_name)
+
+ if raise_on_timeout:
+ raise
+
+ return
+
logger.info(
"Finished waiting for repl stream %r to reach %s", stream_name, position
)
|