summary refs log tree commit diff
diff options
context:
space:
mode:
authorPatrick Cloke <patrickc@matrix.org>2022-03-10 14:37:50 -0500
committerPatrick Cloke <patrickc@matrix.org>2022-03-11 10:35:22 -0500
commit829139c3d595605591ea5897a60542c6ef386ed8 (patch)
tree976caabad974fbee93bb21c1762f21a0a3e2eab0
parentMore robust-ness against dying connections. (diff)
downloadsynapse-829139c3d595605591ea5897a60542c6ef386ed8.tar.xz
Attempt to re-connect better.
-rw-r--r--tests/replication/_base.py91
1 files changed, 76 insertions, 15 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index f35a28de94..a5398029ad 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -15,6 +15,7 @@ import logging
 from typing import Any, Dict, List, Optional, Tuple
 
 from twisted.internet.protocol import Protocol
+from twisted.python.failure import Failure
 from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import GenericWorkerServer
@@ -34,6 +35,55 @@ except ImportError:
 logger = logging.getLogger(__name__)
 
 
+class FakeOutboundConnector:
+    """
+    A fake connector class, reconnects.
+    """
+
+    def __init__(self, hs: HomeServer):
+        self._hs = hs
+
+    def stopConnecting(self):
+        pass
+
+    def connect(self):
+        # Restart replication.
+        from synapse.replication.tcp.redis import lazyConnection
+
+        handler = self._hs.get_outbound_redis_connection()
+
+        reactor = self._hs.get_reactor()
+        reactor.connectTCP(
+            self._hs.config.redis.redis_host,
+            self._hs.config.redis.redis_port,
+            handler._factory,
+            timeout=30,
+            bindAddress=None,
+        )
+
+    def getDestination(self):
+        return "blah"
+
+
+class FakeReplicationHandlerConnector:
+    """
+    A fake connector class, reconnects.
+    """
+
+    def __init__(self, hs: HomeServer):
+        self._hs = hs
+
+    def stopConnecting(self):
+        pass
+
+    def connect(self):
+        # Restart replication.
+        self._hs.get_replication_command_handler().start_replication(self._hs)
+
+    def getDestination(self):
+        return "blah"
+
+
 class BaseStreamTestCase(unittest.HomeserverTestCase):
     """Base class for tests of the replication streams"""
 
@@ -114,21 +164,31 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
     def reconnect(self):
         self.disconnect()
 
-        # TODO: The following fail as nothing has called on
-        # `clientConnectionLost` on the factories. I can't figure out *what* is
-        # meant to call them though. The `txredisapi.HiRedisProtocol` doesn't
-        # seem to do it, but I don't know if it's *meant* to.
-        #
-        # (...time passes...)
-        #
-        # After some spelunking it appears that `connectTCP` creates an
-        # `IConnector`, which is responsible for calling the factory
+        # Make a `FakeConnector` to emulate the behavior of `connectTCP. That
+        # creates an `IConnector`, which is responsible for calling the factory
         # `clientConnectionLost`. The reconnecting factory then calls
         # `IConnector.connect` to attempt a reconnection. The transport is meant
-        # to call `connectionLost` on the `IConnector`. So I *think* we need to
-        # make a `FakeConnector` and pass that to `FakeTransport`?
-        self.hs.get_replication_command_handler()._factory.retry()
-        self.worker_hs.get_replication_command_handler()._factory.retry()
+        # to call `connectionLost` on the `IConnector`.
+        #
+        # Most of that is bypassed by directly calling `retry` on the factory,
+        # which schedules a `connect()` call on the connector.
+        timeouts = []
+        for hs in (self.hs, self.worker_hs):
+            hs_factory_outbound = hs.get_outbound_redis_connection()._factory
+            hs_factory_outbound.clientConnectionLost(
+                FakeOutboundConnector(hs), Failure(RuntimeError(""))
+            )
+            timeouts.append(hs_factory_outbound.delay)
+
+            hs_factory = hs.get_replication_command_handler()._factory
+            hs_factory.clientConnectionLost(
+                FakeReplicationHandlerConnector(hs),
+                Failure(RuntimeError("")),
+            )
+            timeouts.append(hs_factory.delay)
+
+        # Wait for the reconnects to happen.
+        self.pump(max(timeouts) + 1)
 
         self.connect_any_redis_attempts()
 
@@ -137,8 +197,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
             client_to_server_transport,
             server_to_client_transport,
         ) in self._redis_transports:
-            client_to_server_transport.loseConnection()
-            server_to_client_transport.loseConnection()
+            client_to_server_transport.abortConnection()
+            server_to_client_transport.abortConnection()
+        self._redis_transports = []
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then