summary refs log tree commit diff
path: root/tests/replication/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r--tests/replication/_base.py110
1 files changed, 58 insertions, 52 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py

index 3379189785..f6a6aed35e 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -79,7 +79,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): repl_handler = ReplicationCommandHandler(self.worker_hs) self.client = ClientReplicationStreamProtocol( - self.worker_hs, "client", "test", clock, repl_handler, + self.worker_hs, + "client", + "test", + clock, + repl_handler, ) self._client_transport = None @@ -212,6 +216,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Fake in memory Redis server that servers can connect to. self._redis_server = FakeRedisPubSubServer() + # We may have an attempt to connect to redis for the external cache already. + self.connect_any_redis_attempts() + store = self.hs.get_datastore() self.database_pool = store.db_pool @@ -225,7 +232,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): if self.hs.config.redis.redis_enabled: # Handle attempts to connect to fake redis server. self.reactor.add_tcp_client_callback( - "localhost", 6379, self.connect_any_redis_attempts, + "localhost", + 6379, + self.connect_any_redis_attempts, ) self.hs.get_tcp_replication().start_replication(self.hs) @@ -243,8 +252,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): ) def create_test_resource(self): - """Overrides `HomeserverTestCase.create_test_resource`. - """ + """Overrides `HomeserverTestCase.create_test_resource`.""" # We override this so that it automatically registers all the HTTP # replication servlets, without having to explicitly do that in all # subclassses. @@ -293,7 +301,10 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): if instance_loc.host not in self.reactor.lookups: raise Exception( "Host does not have an IP for instance_map[%r].host = %r" - % (instance_name, instance_loc.host,) + % ( + instance_name, + instance_loc.host, + ) ) self.reactor.add_tcp_client_callback( @@ -312,7 +323,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): if not worker_hs.config.redis_enabled: repl_handler = ReplicationCommandHandler(worker_hs) client = ClientReplicationStreamProtocol( - worker_hs, "client", "test", self.clock, repl_handler, + worker_hs, + "client", + "test", + self.clock, + repl_handler, ) server = self.server_factory.buildProtocol(None) @@ -401,25 +416,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): fake one. """ clients = self.reactor.tcpClients - self.assertEqual(len(clients), 1) - (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, "localhost") - self.assertEqual(port, 6379) + while clients: + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) - client_protocol = client_factory.buildProtocol(None) - server_protocol = self._redis_server.buildProtocol(None) + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) - client_to_server_transport = FakeTransport( - server_protocol, self.reactor, client_protocol - ) - client_protocol.makeConnection(client_to_server_transport) - - server_to_client_transport = FakeTransport( - client_protocol, self.reactor, server_protocol - ) - server_protocol.makeConnection(server_to_client_transport) + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) - return client_to_server_transport, server_to_client_transport + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) class TestReplicationDataHandler(GenericWorkerReplicationHandler): @@ -484,8 +497,7 @@ class _PushHTTPChannel(HTTPChannel): self._pull_to_push_producer.stop() def checkPersistence(self, request, version): - """Check whether the connection can be re-used - """ + """Check whether the connection can be re-used""" # We hijack this to always say no for ease of wiring stuff up in # `handle_http_replication_attempt`. request.responseHeaders.setRawHeaders(b"connection", [b"close"]) @@ -493,8 +505,7 @@ class _PushHTTPChannel(HTTPChannel): class _PullToPushProducer: - """A push producer that wraps a pull producer. - """ + """A push producer that wraps a pull producer.""" def __init__( self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer @@ -511,39 +522,33 @@ class _PullToPushProducer: self._start_loop() def _start_loop(self): - """Start the looping call to - """ + """Start the looping call to""" if not self._looping_call: # Start a looping call which runs every tick. self._looping_call = self._clock.looping_call(self._run_once, 0) def stop(self): - """Stops calling resumeProducing. - """ + """Stops calling resumeProducing.""" if self._looping_call: self._looping_call.stop() self._looping_call = None def pauseProducing(self): - """Implements IPushProducer - """ + """Implements IPushProducer""" self.stop() def resumeProducing(self): - """Implements IPushProducer - """ + """Implements IPushProducer""" self._start_loop() def stopProducing(self): - """Implements IPushProducer - """ + """Implements IPushProducer""" self.stop() self._producer.stopProducing() def _run_once(self): - """Calls resumeProducing on producer once. - """ + """Calls resumeProducing on producer once.""" try: self._producer.resumeProducing() @@ -558,25 +563,21 @@ class _PullToPushProducer: class FakeRedisPubSubServer: - """A fake Redis server for pub/sub. - """ + """A fake Redis server for pub/sub.""" def __init__(self): self._subscribers = set() def add_subscriber(self, conn): - """A connection has called SUBSCRIBE - """ + """A connection has called SUBSCRIBE""" self._subscribers.add(conn) def remove_subscriber(self, conn): - """A connection has called UNSUBSCRIBE - """ + """A connection has called UNSUBSCRIBE""" self._subscribers.discard(conn) def publish(self, conn, channel, msg) -> int: - """A connection want to publish a message to subscribers. - """ + """A connection want to publish a message to subscribers.""" for sub in self._subscribers: sub.send(["message", channel, msg]) @@ -587,8 +588,7 @@ class FakeRedisPubSubServer: class FakeRedisPubSubProtocol(Protocol): - """A connection from a client talking to the fake Redis server. - """ + """A connection from a client talking to the fake Redis server.""" def __init__(self, server: FakeRedisPubSubServer): self._server = server @@ -612,8 +612,7 @@ class FakeRedisPubSubProtocol(Protocol): self.handle_command(msg[0], *msg[1:]) def handle_command(self, command, *args): - """Received a Redis command from the client. - """ + """Received a Redis command from the client.""" # We currently only support pub/sub. if command == b"PUBLISH": @@ -624,12 +623,17 @@ class FakeRedisPubSubProtocol(Protocol): (channel,) = args self._server.add_subscriber(self) self.send(["subscribe", channel, 1]) + + # Since we use SET/GET to cache things we can safely no-op them. + elif command == b"SET": + self.send("OK") + elif command == b"GET": + self.send(None) else: raise Exception("Unknown command") def send(self, msg): - """Send a message back to the client. - """ + """Send a message back to the client.""" raw = self.encode(msg).encode("utf-8") self.transport.write(raw) @@ -645,6 +649,8 @@ class FakeRedisPubSubProtocol(Protocol): # We assume bytes are just unicode strings. obj = obj.decode("utf-8") + if obj is None: + return "$-1\r\n" if isinstance(obj, str): return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) if isinstance(obj, int):