diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 3379189785..f6a6aed35e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -79,7 +79,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
repl_handler = ReplicationCommandHandler(self.worker_hs)
self.client = ClientReplicationStreamProtocol(
- self.worker_hs, "client", "test", clock, repl_handler,
+ self.worker_hs,
+ "client",
+ "test",
+ clock,
+ repl_handler,
)
self._client_transport = None
@@ -212,6 +216,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()
+ # We may have an attempt to connect to redis for the external cache already.
+ self.connect_any_redis_attempts()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
@@ -225,7 +232,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if self.hs.config.redis.redis_enabled:
# Handle attempts to connect to fake redis server.
self.reactor.add_tcp_client_callback(
- "localhost", 6379, self.connect_any_redis_attempts,
+ "localhost",
+ 6379,
+ self.connect_any_redis_attempts,
)
self.hs.get_tcp_replication().start_replication(self.hs)
@@ -243,8 +252,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
)
def create_test_resource(self):
- """Overrides `HomeserverTestCase.create_test_resource`.
- """
+ """Overrides `HomeserverTestCase.create_test_resource`."""
# We override this so that it automatically registers all the HTTP
# replication servlets, without having to explicitly do that in all
# subclassses.
@@ -293,7 +301,10 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if instance_loc.host not in self.reactor.lookups:
raise Exception(
"Host does not have an IP for instance_map[%r].host = %r"
- % (instance_name, instance_loc.host,)
+ % (
+ instance_name,
+ instance_loc.host,
+ )
)
self.reactor.add_tcp_client_callback(
@@ -312,7 +323,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
if not worker_hs.config.redis_enabled:
repl_handler = ReplicationCommandHandler(worker_hs)
client = ClientReplicationStreamProtocol(
- worker_hs, "client", "test", self.clock, repl_handler,
+ worker_hs,
+ "client",
+ "test",
+ self.clock,
+ repl_handler,
)
server = self.server_factory.buildProtocol(None)
@@ -401,25 +416,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
fake one.
"""
clients = self.reactor.tcpClients
- self.assertEqual(len(clients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, "localhost")
- self.assertEqual(port, 6379)
+ while clients:
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 6379)
- client_protocol = client_factory.buildProtocol(None)
- server_protocol = self._redis_server.buildProtocol(None)
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
- client_to_server_transport = FakeTransport(
- server_protocol, self.reactor, client_protocol
- )
- client_protocol.makeConnection(client_to_server_transport)
-
- server_to_client_transport = FakeTransport(
- client_protocol, self.reactor, server_protocol
- )
- server_protocol.makeConnection(server_to_client_transport)
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
- return client_to_server_transport, server_to_client_transport
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -484,8 +497,7 @@ class _PushHTTPChannel(HTTPChannel):
self._pull_to_push_producer.stop()
def checkPersistence(self, request, version):
- """Check whether the connection can be re-used
- """
+ """Check whether the connection can be re-used"""
# We hijack this to always say no for ease of wiring stuff up in
# `handle_http_replication_attempt`.
request.responseHeaders.setRawHeaders(b"connection", [b"close"])
@@ -493,8 +505,7 @@ class _PushHTTPChannel(HTTPChannel):
class _PullToPushProducer:
- """A push producer that wraps a pull producer.
- """
+ """A push producer that wraps a pull producer."""
def __init__(
self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
@@ -511,39 +522,33 @@ class _PullToPushProducer:
self._start_loop()
def _start_loop(self):
- """Start the looping call to
- """
+ """Start the looping call to"""
if not self._looping_call:
# Start a looping call which runs every tick.
self._looping_call = self._clock.looping_call(self._run_once, 0)
def stop(self):
- """Stops calling resumeProducing.
- """
+ """Stops calling resumeProducing."""
if self._looping_call:
self._looping_call.stop()
self._looping_call = None
def pauseProducing(self):
- """Implements IPushProducer
- """
+ """Implements IPushProducer"""
self.stop()
def resumeProducing(self):
- """Implements IPushProducer
- """
+ """Implements IPushProducer"""
self._start_loop()
def stopProducing(self):
- """Implements IPushProducer
- """
+ """Implements IPushProducer"""
self.stop()
self._producer.stopProducing()
def _run_once(self):
- """Calls resumeProducing on producer once.
- """
+ """Calls resumeProducing on producer once."""
try:
self._producer.resumeProducing()
@@ -558,25 +563,21 @@ class _PullToPushProducer:
class FakeRedisPubSubServer:
- """A fake Redis server for pub/sub.
- """
+ """A fake Redis server for pub/sub."""
def __init__(self):
self._subscribers = set()
def add_subscriber(self, conn):
- """A connection has called SUBSCRIBE
- """
+ """A connection has called SUBSCRIBE"""
self._subscribers.add(conn)
def remove_subscriber(self, conn):
- """A connection has called UNSUBSCRIBE
- """
+ """A connection has called UNSUBSCRIBE"""
self._subscribers.discard(conn)
def publish(self, conn, channel, msg) -> int:
- """A connection want to publish a message to subscribers.
- """
+ """A connection want to publish a message to subscribers."""
for sub in self._subscribers:
sub.send(["message", channel, msg])
@@ -587,8 +588,7 @@ class FakeRedisPubSubServer:
class FakeRedisPubSubProtocol(Protocol):
- """A connection from a client talking to the fake Redis server.
- """
+ """A connection from a client talking to the fake Redis server."""
def __init__(self, server: FakeRedisPubSubServer):
self._server = server
@@ -612,8 +612,7 @@ class FakeRedisPubSubProtocol(Protocol):
self.handle_command(msg[0], *msg[1:])
def handle_command(self, command, *args):
- """Received a Redis command from the client.
- """
+ """Received a Redis command from the client."""
# We currently only support pub/sub.
if command == b"PUBLISH":
@@ -624,12 +623,17 @@ class FakeRedisPubSubProtocol(Protocol):
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
+
+ # Since we use SET/GET to cache things we can safely no-op them.
+ elif command == b"SET":
+ self.send("OK")
+ elif command == b"GET":
+ self.send(None)
else:
raise Exception("Unknown command")
def send(self, msg):
- """Send a message back to the client.
- """
+ """Send a message back to the client."""
raw = self.encode(msg).encode("utf-8")
self.transport.write(raw)
@@ -645,6 +649,8 @@ class FakeRedisPubSubProtocol(Protocol):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")
+ if obj is None:
+ return "$-1\r\n"
if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):
|