diff --git a/synapse/replication/tcp/redis.py b/synapse/replication/tcp/redis.py
index 2688c1ee8e..d4a9df83b6 100644
--- a/synapse/replication/tcp/redis.py
+++ b/synapse/replication/tcp/redis.py
@@ -263,6 +263,7 @@ class SynapseRedisFactory(txredisapi.RedisFactory):
replyTimeout=replyTimeout,
convertNumbers=convertNumbers,
)
+ self.hs = hs
# Set the homeserver reactor as the clock, if this is not done than
# twisted.internet.protocol.ReconnectingClientFactory.retry will default
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index a5398029ad..d84df7a915 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -163,6 +163,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
def reconnect(self):
self.disconnect()
+ print("RECONNECTING")
# Make a `FakeConnector` to emulate the behavior of `connectTCP. That
# creates an `IConnector`, which is responsible for calling the factory
@@ -193,6 +194,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.connect_any_redis_attempts()
def disconnect(self):
+ print("DISCONNECTING")
for (
client_to_server_transport,
server_to_client_transport,
@@ -296,6 +298,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
client_protocol = client_factory.buildProtocol(None)
server_protocol = self._redis_server.buildProtocol(None)
+ if client_protocol.__class__.__name__ == "RedisSubscriber":
+ print(client_protocol, client_protocol.synapse_handler._presence_handler.hs, client_protocol.synapse_outbound_redis_connection)
+ else:
+ print(client_protocol, client_protocol.factory.hs)
+ print()
client_to_server_transport = FakeTransport(
server_protocol, self.reactor, client_protocol
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 4e9ac3727a..da6ec8a12c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -252,7 +252,8 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
# limit the replication rate from server -> client.
print(len(self._redis_transports))
- print(self._redis_transports)
+ for x in self._redis_transports:
+ print(f"\t{x}")
assert len(self._redis_transports) == 1
for _, repl_transport in self._redis_transports:
assert isinstance(repl_transport, FakeTransport)
|