summary refs log tree commit diff
path: root/tests/replication/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r--tests/replication/_base.py90
1 files changed, 31 insertions, 59 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 970d5e533b..ce53f808db 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.client import ReplicationDataHandler
 from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import (
-    ReplicationStreamProtocolFactory,
+from synapse.replication.tcp.protocol import (
+    ClientReplicationStreamProtocol,
     ServerReplicationStreamProtocol,
 )
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.server import HomeServer
 
 from tests import unittest
@@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
     """Base class for tests running multiple workers.
 
+    Enables Redis, providing a fake Redis server.
+
     Automatically handle HTTP replication requests from workers to master,
     unlike `BaseStreamTestCase`.
     """
 
+    if not hiredis:
+        skip = "Requires hiredis"
+
+    if not USE_POSTGRES_FOR_TESTS:
+        # Redis replication only takes place on Postgres
+        skip = "Requires Postgres"
+
+    def default_config(self) -> Dict[str, Any]:
+        """
+        Overrides the default config to enable Redis.
+        Even if the test only uses make_worker_hs, the main process needs Redis
+        enabled otherwise it won't create a Fake Redis server to listen on the
+        Redis port and accept fake TCP connections.
+        """
+        base = super().default_config()
+        base["redis"] = {"enabled": True}
+        return base
+
     def setUp(self):
         super().setUp()
 
         # build a replication server
-        self.server_factory = ReplicationStreamProtocolFactory(self.hs)
         self.streamer = self.hs.get_replication_streamer()
 
         # Fake in memory Redis server that servers can connect to.
@@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # handling inbound HTTP requests to that instance.
         self._hs_to_site = {self.hs: self.site}
 
-        if self.hs.config.redis.redis_enabled:
-            # Handle attempts to connect to fake redis server.
-            self.reactor.add_tcp_client_callback(
-                "localhost",
-                6379,
-                self.connect_any_redis_attempts,
-            )
+        # Handle attempts to connect to fake redis server.
+        self.reactor.add_tcp_client_callback(
+            "localhost",
+            6379,
+            self.connect_any_redis_attempts,
+        )
 
-            self.hs.get_replication_command_handler().start_replication(self.hs)
+        self.hs.get_replication_command_handler().start_replication(self.hs)
 
         # When we see a connection attempt to the master replication listener we
         # automatically set up the connection. This is so that tests don't
@@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         store = worker_hs.get_datastores().main
         store.db_pool._db_pool = self.database_pool._db_pool
 
-        # Set up TCP replication between master and the new worker if we don't
-        # have Redis support enabled.
-        if not worker_hs.config.redis.redis_enabled:
-            repl_handler = ReplicationCommandHandler(worker_hs)
-            client = ClientReplicationStreamProtocol(
-                worker_hs,
-                "client",
-                "test",
-                self.clock,
-                repl_handler,
-            )
-            server = self.server_factory.buildProtocol(
-                IPv4Address("TCP", "127.0.0.1", 0)
-            )
-
-            client_transport = FakeTransport(server, self.reactor)
-            client.makeConnection(client_transport)
-
-            server_transport = FakeTransport(client, self.reactor)
-            server.makeConnection(server_transport)
-
         # Set up a resource for the worker
         resource = ReplicationRestResource(worker_hs)
 
@@ -378,8 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             reactor=self.reactor,
         )
 
-        if worker_hs.config.redis.redis_enabled:
-            worker_hs.get_replication_command_handler().start_replication(worker_hs)
+        worker_hs.get_replication_command_handler().start_replication(worker_hs)
 
         return worker_hs
 
@@ -582,27 +578,3 @@ class FakeRedisPubSubProtocol(Protocol):
 
     def connectionLost(self, reason):
         self._server.remove_subscriber(self)
-
-
-class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
-    """
-    A test case that enables Redis, providing a fake Redis server.
-    """
-
-    if not hiredis:
-        skip = "Requires hiredis"
-
-    if not USE_POSTGRES_FOR_TESTS:
-        # Redis replication only takes place on Postgres
-        skip = "Requires Postgres"
-
-    def default_config(self) -> Dict[str, Any]:
-        """
-        Overrides the default config to enable Redis.
-        Even if the test only uses make_worker_hs, the main process needs Redis
-        enabled otherwise it won't create a Fake Redis server to listen on the
-        Redis port and accept fake TCP connections.
-        """
-        base = super().default_config()
-        base["redis"] = {"enabled": True}
-        return base