diff options
author | Patrick Cloke <patrickc@matrix.org> | 2022-03-10 14:37:50 -0500 |
---|---|---|
committer | Patrick Cloke <patrickc@matrix.org> | 2022-03-11 10:35:22 -0500 |
commit | 829139c3d595605591ea5897a60542c6ef386ed8 (patch) | |
tree | 976caabad974fbee93bb21c1762f21a0a3e2eab0 | |
parent | More robust-ness against dying connections. (diff) | |
download | synapse-829139c3d595605591ea5897a60542c6ef386ed8.tar.xz |
Attempt to re-connect better.
-rw-r--r-- | tests/replication/_base.py | 91 |
1 files changed, 76 insertions, 15 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py index f35a28de94..a5398029ad 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -15,6 +15,7 @@ import logging from typing import Any, Dict, List, Optional, Tuple from twisted.internet.protocol import Protocol +from twisted.python.failure import Failure from twisted.web.resource import Resource from synapse.app.generic_worker import GenericWorkerServer @@ -34,6 +35,55 @@ except ImportError: logger = logging.getLogger(__name__) +class FakeOutboundConnector: + """ + A fake connector class, reconnects. + """ + + def __init__(self, hs: HomeServer): + self._hs = hs + + def stopConnecting(self): + pass + + def connect(self): + # Restart replication. + from synapse.replication.tcp.redis import lazyConnection + + handler = self._hs.get_outbound_redis_connection() + + reactor = self._hs.get_reactor() + reactor.connectTCP( + self._hs.config.redis.redis_host, + self._hs.config.redis.redis_port, + handler._factory, + timeout=30, + bindAddress=None, + ) + + def getDestination(self): + return "blah" + + +class FakeReplicationHandlerConnector: + """ + A fake connector class, reconnects. + """ + + def __init__(self, hs: HomeServer): + self._hs = hs + + def stopConnecting(self): + pass + + def connect(self): + # Restart replication. + self._hs.get_replication_command_handler().start_replication(self._hs) + + def getDestination(self): + return "blah" + + class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" @@ -114,21 +164,31 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): def reconnect(self): self.disconnect() - # TODO: The following fail as nothing has called on - # `clientConnectionLost` on the factories. I can't figure out *what* is - # meant to call them though. The `txredisapi.HiRedisProtocol` doesn't - # seem to do it, but I don't know if it's *meant* to. - # - # (...time passes...) - # - # After some spelunking it appears that `connectTCP` creates an - # `IConnector`, which is responsible for calling the factory + # Make a `FakeConnector` to emulate the behavior of `connectTCP. That + # creates an `IConnector`, which is responsible for calling the factory # `clientConnectionLost`. The reconnecting factory then calls # `IConnector.connect` to attempt a reconnection. The transport is meant - # to call `connectionLost` on the `IConnector`. So I *think* we need to - # make a `FakeConnector` and pass that to `FakeTransport`? - self.hs.get_replication_command_handler()._factory.retry() - self.worker_hs.get_replication_command_handler()._factory.retry() + # to call `connectionLost` on the `IConnector`. + # + # Most of that is bypassed by directly calling `retry` on the factory, + # which schedules a `connect()` call on the connector. + timeouts = [] + for hs in (self.hs, self.worker_hs): + hs_factory_outbound = hs.get_outbound_redis_connection()._factory + hs_factory_outbound.clientConnectionLost( + FakeOutboundConnector(hs), Failure(RuntimeError("")) + ) + timeouts.append(hs_factory_outbound.delay) + + hs_factory = hs.get_replication_command_handler()._factory + hs_factory.clientConnectionLost( + FakeReplicationHandlerConnector(hs), + Failure(RuntimeError("")), + ) + timeouts.append(hs_factory.delay) + + # Wait for the reconnects to happen. + self.pump(max(timeouts) + 1) self.connect_any_redis_attempts() @@ -137,8 +197,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): client_to_server_transport, server_to_client_transport, ) in self._redis_transports: - client_to_server_transport.loseConnection() - server_to_client_transport.loseConnection() + client_to_server_transport.abortConnection() + server_to_client_transport.abortConnection() + self._redis_transports = [] def replicate(self): """Tell the master side of replication that something has happened, and then |