summary refs log tree commit diff
path: root/tests/replication/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r--tests/replication/_base.py69
1 files changed, 33 insertions, 36 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index d5dce1f83f..f6a6aed35e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -79,7 +79,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
         repl_handler = ReplicationCommandHandler(self.worker_hs)
         self.client = ClientReplicationStreamProtocol(
-            self.worker_hs, "client", "test", clock, repl_handler,
+            self.worker_hs,
+            "client",
+            "test",
+            clock,
+            repl_handler,
         )
 
         self._client_transport = None
@@ -228,7 +232,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         if self.hs.config.redis.redis_enabled:
             # Handle attempts to connect to fake redis server.
             self.reactor.add_tcp_client_callback(
-                "localhost", 6379, self.connect_any_redis_attempts,
+                "localhost",
+                6379,
+                self.connect_any_redis_attempts,
             )
 
             self.hs.get_tcp_replication().start_replication(self.hs)
@@ -246,8 +252,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         )
 
     def create_test_resource(self):
-        """Overrides `HomeserverTestCase.create_test_resource`.
-        """
+        """Overrides `HomeserverTestCase.create_test_resource`."""
         # We override this so that it automatically registers all the HTTP
         # replication servlets, without having to explicitly do that in all
         # subclassses.
@@ -296,7 +301,10 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             if instance_loc.host not in self.reactor.lookups:
                 raise Exception(
                     "Host does not have an IP for instance_map[%r].host = %r"
-                    % (instance_name, instance_loc.host,)
+                    % (
+                        instance_name,
+                        instance_loc.host,
+                    )
                 )
 
             self.reactor.add_tcp_client_callback(
@@ -315,7 +323,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         if not worker_hs.config.redis_enabled:
             repl_handler = ReplicationCommandHandler(worker_hs)
             client = ClientReplicationStreamProtocol(
-                worker_hs, "client", "test", self.clock, repl_handler,
+                worker_hs,
+                "client",
+                "test",
+                self.clock,
+                repl_handler,
             )
             server = self.server_factory.buildProtocol(None)
 
@@ -485,8 +497,7 @@ class _PushHTTPChannel(HTTPChannel):
             self._pull_to_push_producer.stop()
 
     def checkPersistence(self, request, version):
-        """Check whether the connection can be re-used
-        """
+        """Check whether the connection can be re-used"""
         # We hijack this to always say no for ease of wiring stuff up in
         # `handle_http_replication_attempt`.
         request.responseHeaders.setRawHeaders(b"connection", [b"close"])
@@ -494,8 +505,7 @@ class _PushHTTPChannel(HTTPChannel):
 
 
 class _PullToPushProducer:
-    """A push producer that wraps a pull producer.
-    """
+    """A push producer that wraps a pull producer."""
 
     def __init__(
         self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
@@ -512,39 +522,33 @@ class _PullToPushProducer:
         self._start_loop()
 
     def _start_loop(self):
-        """Start the looping call to
-        """
+        """Start the looping call to"""
 
         if not self._looping_call:
             # Start a looping call which runs every tick.
             self._looping_call = self._clock.looping_call(self._run_once, 0)
 
     def stop(self):
-        """Stops calling resumeProducing.
-        """
+        """Stops calling resumeProducing."""
         if self._looping_call:
             self._looping_call.stop()
             self._looping_call = None
 
     def pauseProducing(self):
-        """Implements IPushProducer
-        """
+        """Implements IPushProducer"""
         self.stop()
 
     def resumeProducing(self):
-        """Implements IPushProducer
-        """
+        """Implements IPushProducer"""
         self._start_loop()
 
     def stopProducing(self):
-        """Implements IPushProducer
-        """
+        """Implements IPushProducer"""
         self.stop()
         self._producer.stopProducing()
 
     def _run_once(self):
-        """Calls resumeProducing on producer once.
-        """
+        """Calls resumeProducing on producer once."""
 
         try:
             self._producer.resumeProducing()
@@ -559,25 +563,21 @@ class _PullToPushProducer:
 
 
 class FakeRedisPubSubServer:
-    """A fake Redis server for pub/sub.
-    """
+    """A fake Redis server for pub/sub."""
 
     def __init__(self):
         self._subscribers = set()
 
     def add_subscriber(self, conn):
-        """A connection has called SUBSCRIBE
-        """
+        """A connection has called SUBSCRIBE"""
         self._subscribers.add(conn)
 
     def remove_subscriber(self, conn):
-        """A connection has called UNSUBSCRIBE
-        """
+        """A connection has called UNSUBSCRIBE"""
         self._subscribers.discard(conn)
 
     def publish(self, conn, channel, msg) -> int:
-        """A connection want to publish a message to subscribers.
-        """
+        """A connection want to publish a message to subscribers."""
         for sub in self._subscribers:
             sub.send(["message", channel, msg])
 
@@ -588,8 +588,7 @@ class FakeRedisPubSubServer:
 
 
 class FakeRedisPubSubProtocol(Protocol):
-    """A connection from a client talking to the fake Redis server.
-    """
+    """A connection from a client talking to the fake Redis server."""
 
     def __init__(self, server: FakeRedisPubSubServer):
         self._server = server
@@ -613,8 +612,7 @@ class FakeRedisPubSubProtocol(Protocol):
             self.handle_command(msg[0], *msg[1:])
 
     def handle_command(self, command, *args):
-        """Received a Redis command from the client.
-        """
+        """Received a Redis command from the client."""
 
         # We currently only support pub/sub.
         if command == b"PUBLISH":
@@ -635,8 +633,7 @@ class FakeRedisPubSubProtocol(Protocol):
             raise Exception("Unknown command")
 
     def send(self, msg):
-        """Send a message back to the client.
-        """
+        """Send a message back to the client."""
         raw = self.encode(msg).encode("utf-8")
 
         self.transport.write(raw)