summary refs log tree commit diff
path: root/tests/replication
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication')
-rw-r--r--tests/replication/_base.py69
-rw-r--r--tests/replication/slave/storage/test_events.py5
-rw-r--r--tests/replication/tcp/streams/test_account_data.py6
-rw-r--r--tests/replication/tcp/streams/test_events.py13
-rw-r--r--tests/replication/tcp/test_remote_server_up.py3
-rw-r--r--tests/replication/test_auth.py12
-rw-r--r--tests/replication/test_client_reader_shard.py6
-rw-r--r--tests/replication/test_multi_media_repo.py12
-rw-r--r--tests/replication/test_pusher_shard.py21
-rw-r--r--tests/replication/test_sharded_event_persister.py25
10 files changed, 83 insertions, 89 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index d5dce1f83f..f6a6aed35e 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -79,7 +79,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 
         repl_handler = ReplicationCommandHandler(self.worker_hs)
         self.client = ClientReplicationStreamProtocol(
-            self.worker_hs, "client", "test", clock, repl_handler,
+            self.worker_hs,
+            "client",
+            "test",
+            clock,
+            repl_handler,
         )
 
         self._client_transport = None
@@ -228,7 +232,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         if self.hs.config.redis.redis_enabled:
             # Handle attempts to connect to fake redis server.
             self.reactor.add_tcp_client_callback(
-                "localhost", 6379, self.connect_any_redis_attempts,
+                "localhost",
+                6379,
+                self.connect_any_redis_attempts,
             )
 
             self.hs.get_tcp_replication().start_replication(self.hs)
@@ -246,8 +252,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         )
 
     def create_test_resource(self):
-        """Overrides `HomeserverTestCase.create_test_resource`.
-        """
+        """Overrides `HomeserverTestCase.create_test_resource`."""
         # We override this so that it automatically registers all the HTTP
         # replication servlets, without having to explicitly do that in all
         # subclassses.
@@ -296,7 +301,10 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             if instance_loc.host not in self.reactor.lookups:
                 raise Exception(
                     "Host does not have an IP for instance_map[%r].host = %r"
-                    % (instance_name, instance_loc.host,)
+                    % (
+                        instance_name,
+                        instance_loc.host,
+                    )
                 )
 
             self.reactor.add_tcp_client_callback(
@@ -315,7 +323,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         if not worker_hs.config.redis_enabled:
             repl_handler = ReplicationCommandHandler(worker_hs)
             client = ClientReplicationStreamProtocol(
-                worker_hs, "client", "test", self.clock, repl_handler,
+                worker_hs,
+                "client",
+                "test",
+                self.clock,
+                repl_handler,
             )
             server = self.server_factory.buildProtocol(None)
 
@@ -485,8 +497,7 @@ class _PushHTTPChannel(HTTPChannel):
             self._pull_to_push_producer.stop()
 
     def checkPersistence(self, request, version):
-        """Check whether the connection can be re-used
-        """
+        """Check whether the connection can be re-used"""
         # We hijack this to always say no for ease of wiring stuff up in
         # `handle_http_replication_attempt`.
         request.responseHeaders.setRawHeaders(b"connection", [b"close"])
@@ -494,8 +505,7 @@ class _PushHTTPChannel(HTTPChannel):
 
 
 class _PullToPushProducer:
-    """A push producer that wraps a pull producer.
-    """
+    """A push producer that wraps a pull producer."""
 
     def __init__(
         self, reactor: IReactorTime, producer: IPullProducer, consumer: IConsumer
@@ -512,39 +522,33 @@ class _PullToPushProducer:
         self._start_loop()
 
     def _start_loop(self):
-        """Start the looping call to
-        """
+        """Start the looping call to"""
 
         if not self._looping_call:
             # Start a looping call which runs every tick.
             self._looping_call = self._clock.looping_call(self._run_once, 0)
 
     def stop(self):
-        """Stops calling resumeProducing.
-        """
+        """Stops calling resumeProducing."""
         if self._looping_call:
             self._looping_call.stop()
             self._looping_call = None
 
     def pauseProducing(self):
-        """Implements IPushProducer
-        """
+        """Implements IPushProducer"""
         self.stop()
 
     def resumeProducing(self):
-        """Implements IPushProducer
-        """
+        """Implements IPushProducer"""
         self._start_loop()
 
     def stopProducing(self):
-        """Implements IPushProducer
-        """
+        """Implements IPushProducer"""
         self.stop()
         self._producer.stopProducing()
 
     def _run_once(self):
-        """Calls resumeProducing on producer once.
-        """
+        """Calls resumeProducing on producer once."""
 
         try:
             self._producer.resumeProducing()
@@ -559,25 +563,21 @@ class _PullToPushProducer:
 
 
 class FakeRedisPubSubServer:
-    """A fake Redis server for pub/sub.
-    """
+    """A fake Redis server for pub/sub."""
 
     def __init__(self):
         self._subscribers = set()
 
     def add_subscriber(self, conn):
-        """A connection has called SUBSCRIBE
-        """
+        """A connection has called SUBSCRIBE"""
         self._subscribers.add(conn)
 
     def remove_subscriber(self, conn):
-        """A connection has called UNSUBSCRIBE
-        """
+        """A connection has called UNSUBSCRIBE"""
         self._subscribers.discard(conn)
 
     def publish(self, conn, channel, msg) -> int:
-        """A connection want to publish a message to subscribers.
-        """
+        """A connection want to publish a message to subscribers."""
         for sub in self._subscribers:
             sub.send(["message", channel, msg])
 
@@ -588,8 +588,7 @@ class FakeRedisPubSubServer:
 
 
 class FakeRedisPubSubProtocol(Protocol):
-    """A connection from a client talking to the fake Redis server.
-    """
+    """A connection from a client talking to the fake Redis server."""
 
     def __init__(self, server: FakeRedisPubSubServer):
         self._server = server
@@ -613,8 +612,7 @@ class FakeRedisPubSubProtocol(Protocol):
             self.handle_command(msg[0], *msg[1:])
 
     def handle_command(self, command, *args):
-        """Received a Redis command from the client.
-        """
+        """Received a Redis command from the client."""
 
         # We currently only support pub/sub.
         if command == b"PUBLISH":
@@ -635,8 +633,7 @@ class FakeRedisPubSubProtocol(Protocol):
             raise Exception("Unknown command")
 
     def send(self, msg):
-        """Send a message back to the client.
-        """
+        """Send a message back to the client."""
         raw = self.encode(msg).encode("utf-8")
 
         self.transport.write(raw)
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index c0ee1cfbd6..0ceb0f935c 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -66,7 +66,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
 
         self.get_success(
             self.master_store.store_room(
-                ROOM_ID, USER_ID, is_public=False, room_version=RoomVersions.V1,
+                ROOM_ID,
+                USER_ID,
+                is_public=False,
+                room_version=RoomVersions.V1,
             )
         )
 
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 6a5116dd2a..153634d4ee 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -23,8 +23,7 @@ from tests.replication._base import BaseStreamTestCase
 
 class AccountDataStreamTestCase(BaseStreamTestCase):
     def test_update_function_room_account_data_limit(self):
-        """Test replication with many room account data updates
-        """
+        """Test replication with many room account data updates"""
         store = self.hs.get_datastore()
 
         # generate lots of account data updates
@@ -70,8 +69,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
         self.assertEqual([], received_rows)
 
     def test_update_function_global_account_data_limit(self):
-        """Test replication with many global account data updates
-        """
+        """Test replication with many global account data updates"""
         store = self.hs.get_datastore()
 
         # generate lots of account data updates
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index bad0df08cf..77856fc304 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -129,7 +129,10 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
         pls["users"][OTHER_USER] = 50
         self.helper.send_state(
-            self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+            self.room_id,
+            EventTypes.PowerLevels,
+            pls,
+            tok=self.user_tok,
         )
 
         # this is the point in the DAG where we make a fork
@@ -255,8 +258,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             self.assertIsNone(sr.event_id)
 
     def test_update_function_state_row_limit(self):
-        """Test replication with many state events over several stream ids.
-        """
+        """Test replication with many state events over several stream ids."""
 
         # we want to generate lots of state changes, but for this test, we want to
         # spread out the state changes over a few stream IDs.
@@ -282,7 +284,10 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
         pls["users"].update({u: 50 for u in user_ids})
         self.helper.send_state(
-            self.room_id, EventTypes.PowerLevels, pls, tok=self.user_tok,
+            self.room_id,
+            EventTypes.PowerLevels,
+            pls,
+            tok=self.user_tok,
         )
 
         # this is the point in the DAG where we make a fork
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index d1c15caeb0..1fe9d5b4d0 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -28,8 +28,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
         self.factory = ReplicationStreamProtocolFactory(hs)
 
     def _make_client(self) -> Tuple[IProtocol, StringTransport]:
-        """Create a new direct TCP replication connection
-        """
+        """Create a new direct TCP replication connection"""
 
         proto = self.factory.buildProtocol(("127.0.0.1", 0))
         transport = StringTransport()
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
index f35a5235e1..f8fd8a843c 100644
--- a/tests/replication/test_auth.py
+++ b/tests/replication/test_auth.py
@@ -79,8 +79,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
         )
 
     def test_no_auth(self):
-        """With no authentication the request should finish.
-        """
+        """With no authentication the request should finish."""
         channel = self._test_register()
         self.assertEqual(channel.code, 200)
 
@@ -89,8 +88,7 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
 
     @override_config({"main_replication_secret": "my-secret"})
     def test_missing_auth(self):
-        """If the main process expects a secret that is not provided, an error results.
-        """
+        """If the main process expects a secret that is not provided, an error results."""
         channel = self._test_register()
         self.assertEqual(channel.code, 500)
 
@@ -101,15 +99,13 @@ class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
         }
     )
     def test_unauthorized(self):
-        """If the main process receives the wrong secret, an error results.
-        """
+        """If the main process receives the wrong secret, an error results."""
         channel = self._test_register()
         self.assertEqual(channel.code, 500)
 
     @override_config({"worker_replication_secret": "my-secret"})
     def test_authorized(self):
-        """The request should finish when the worker provides the authentication header.
-        """
+        """The request should finish when the worker provides the authentication header."""
         channel = self._test_register()
         self.assertEqual(channel.code, 200)
 
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 4608b65a0c..5da1d5dc4d 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -35,8 +35,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         return config
 
     def test_register_single_worker(self):
-        """Test that registration works when using a single client reader worker.
-        """
+        """Test that registration works when using a single client reader worker."""
         worker_hs = self.make_worker_hs("synapse.app.client_reader")
         site = self._hs_to_site[worker_hs]
 
@@ -66,8 +65,7 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(channel_2.json_body["user_id"], "@user:test")
 
     def test_register_multi_worker(self):
-        """Test that registration works when using multiple client reader workers.
-        """
+        """Test that registration works when using multiple client reader workers."""
         worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
         worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
 
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index d1feca961f..7ff11cde10 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -36,8 +36,7 @@ test_server_connection_factory = None
 
 
 class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
-    """Checks running multiple media repos work correctly.
-    """
+    """Checks running multiple media repos work correctly."""
 
     servlets = [
         admin.register_servlets_for_client_rest_resource,
@@ -124,8 +123,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         return channel, request
 
     def test_basic(self):
-        """Test basic fetching of remote media from a single worker.
-        """
+        """Test basic fetching of remote media from a single worker."""
         hs1 = self.make_worker_hs("synapse.app.generic_worker")
 
         channel, request = self._get_media_req(hs1, "example.com:443", "ABC123")
@@ -223,16 +221,14 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.assertEqual(start_count + 3, self._count_remote_thumbnails())
 
     def _count_remote_media(self) -> int:
-        """Count the number of files in our remote media directory.
-        """
+        """Count the number of files in our remote media directory."""
         path = os.path.join(
             self.hs.get_media_repository().primary_base_path, "remote_content"
         )
         return sum(len(files) for _, _, files in os.walk(path))
 
     def _count_remote_thumbnails(self) -> int:
-        """Count the number of files in our remote thumbnails directory.
-        """
+        """Count the number of files in our remote thumbnails directory."""
         path = os.path.join(
             self.hs.get_media_repository().primary_base_path, "remote_thumbnail"
         )
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 800ad94a04..f118fe32af 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -27,8 +27,7 @@ logger = logging.getLogger(__name__)
 
 
 class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
-    """Checks pusher sharding works
-    """
+    """Checks pusher sharding works"""
 
     servlets = [
         admin.register_servlets_for_client_rest_resource,
@@ -88,11 +87,10 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         return event_id
 
     def test_send_push_single_worker(self):
-        """Test that registration works when using a pusher worker.
-        """
+        """Test that registration works when using a pusher worker."""
         http_client_mock = Mock(spec_set=["post_json_get_json"])
-        http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
-            {}
+        http_client_mock.post_json_get_json.side_effect = (
+            lambda *_, **__: defer.succeed({})
         )
 
         self.make_worker_hs(
@@ -119,11 +117,10 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
     def test_send_push_multiple_workers(self):
-        """Test that registration works when using sharded pusher workers.
-        """
+        """Test that registration works when using sharded pusher workers."""
         http_client_mock1 = Mock(spec_set=["post_json_get_json"])
-        http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
-            {}
+        http_client_mock1.post_json_get_json.side_effect = (
+            lambda *_, **__: defer.succeed({})
         )
 
         self.make_worker_hs(
@@ -137,8 +134,8 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         http_client_mock2 = Mock(spec_set=["post_json_get_json"])
-        http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
-            {}
+        http_client_mock2.post_json_get_json.side_effect = (
+            lambda *_, **__: defer.succeed({})
         )
 
         self.make_worker_hs(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 8d494ebc03..c9b773fbd2 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -29,8 +29,7 @@ logger = logging.getLogger(__name__)
 
 
 class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
-    """Checks event persisting sharding works
-    """
+    """Checks event persisting sharding works"""
 
     # Event persister sharding requires postgres (due to needing
     # `MutliWriterIdGenerator`).
@@ -63,8 +62,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         return conf
 
     def _create_room(self, room_id: str, user_id: str, tok: str):
-        """Create a room with given room_id
-        """
+        """Create a room with given room_id"""
 
         # We control the room ID generation by patching out the
         # `_generate_room_id` method
@@ -91,11 +89,13 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         """
 
         self.make_worker_hs(
-            "synapse.app.generic_worker", {"worker_name": "worker1"},
+            "synapse.app.generic_worker",
+            {"worker_name": "worker1"},
         )
 
         self.make_worker_hs(
-            "synapse.app.generic_worker", {"worker_name": "worker2"},
+            "synapse.app.generic_worker",
+            {"worker_name": "worker2"},
         )
 
         persisted_on_1 = False
@@ -139,15 +139,18 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         """
 
         self.make_worker_hs(
-            "synapse.app.generic_worker", {"worker_name": "worker1"},
+            "synapse.app.generic_worker",
+            {"worker_name": "worker1"},
         )
 
         worker_hs2 = self.make_worker_hs(
-            "synapse.app.generic_worker", {"worker_name": "worker2"},
+            "synapse.app.generic_worker",
+            {"worker_name": "worker2"},
         )
 
         sync_hs = self.make_worker_hs(
-            "synapse.app.generic_worker", {"worker_name": "sync"},
+            "synapse.app.generic_worker",
+            {"worker_name": "sync"},
         )
         sync_hs_site = self._hs_to_site[sync_hs]
 
@@ -323,7 +326,9 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
             sync_hs_site,
             "GET",
             "/rooms/{}/messages?from={}&to={}&dir=f".format(
-                room_id2, vector_clock_token, prev_batch2,
+                room_id2,
+                vector_clock_token,
+                prev_batch2,
             ),
             access_token=access_token,
         )