summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/handlers/test_room_member.py4
-rw-r--r--tests/module_api/test_api.py7
-rw-r--r--tests/replication/_base.py90
-rw-r--r--tests/replication/tcp/test_handler.py4
-rw-r--r--tests/replication/test_sharded_event_persister.py7
5 files changed, 35 insertions, 77 deletions
diff --git a/tests/handlers/test_room_member.py b/tests/handlers/test_room_member.py
index b4e1405aee..1d13ed1e88 100644
--- a/tests/handlers/test_room_member.py
+++ b/tests/handlers/test_room_member.py
@@ -14,7 +14,7 @@ from synapse.server import HomeServer
 from synapse.types import UserID, create_requester
 from synapse.util import Clock
 
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import make_request
 from tests.test_utils import make_awaitable
 from tests.unittest import FederatingHomeserverTestCase, override_config
@@ -216,7 +216,7 @@ class TestJoinsLimitedByPerRoomRateLimiter(FederatingHomeserverTestCase):
     #   - trying to remote-join again.
 
 
-class TestReplicatedJoinsLimitedByPerRoomRateLimiter(RedisMultiWorkerStreamTestCase):
+class TestReplicatedJoinsLimitedByPerRoomRateLimiter(BaseMultiWorkerStreamTestCase):
     servlets = [
         synapse.rest.admin.register_servlets,
         synapse.rest.client.login.register_servlets,
diff --git a/tests/module_api/test_api.py b/tests/module_api/test_api.py
index 106159fa65..02cef6f876 100644
--- a/tests/module_api/test_api.py
+++ b/tests/module_api/test_api.py
@@ -30,7 +30,6 @@ from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.test_utils import simple_async_mock
 from tests.test_utils.event_injection import inject_member_event
 from tests.unittest import HomeserverTestCase, override_config
-from tests.utils import USE_POSTGRES_FOR_TESTS
 
 
 class ModuleApiTestCase(HomeserverTestCase):
@@ -738,11 +737,6 @@ class ModuleApiTestCase(HomeserverTestCase):
 class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
     """For testing ModuleApi functionality in a multi-worker setup"""
 
-    # Testing stream ID replication from the main to worker processes requires postgres
-    # (due to needing `MultiWriterIdGenerator`).
-    if not USE_POSTGRES_FOR_TESTS:
-        skip = "Requires Postgres"
-
     servlets = [
         admin.register_servlets,
         login.register_servlets,
@@ -752,7 +746,6 @@ class ModuleApiWorkerTestCase(BaseMultiWorkerStreamTestCase):
 
     def default_config(self):
         conf = super().default_config()
-        conf["redis"] = {"enabled": "true"}
         conf["stream_writers"] = {"presence": ["presence_writer"]}
         conf["instance_map"] = {
             "presence_writer": {"host": "testserv", "port": 1001},
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 970d5e533b..ce53f808db 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite
 from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.client import ReplicationDataHandler
 from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import (
-    ReplicationStreamProtocolFactory,
+from synapse.replication.tcp.protocol import (
+    ClientReplicationStreamProtocol,
     ServerReplicationStreamProtocol,
 )
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
 from synapse.server import HomeServer
 
 from tests import unittest
@@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
 class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
     """Base class for tests running multiple workers.
 
+    Enables Redis, providing a fake Redis server.
+
     Automatically handle HTTP replication requests from workers to master,
     unlike `BaseStreamTestCase`.
     """
 
+    if not hiredis:
+        skip = "Requires hiredis"
+
+    if not USE_POSTGRES_FOR_TESTS:
+        # Redis replication only takes place on Postgres
+        skip = "Requires Postgres"
+
+    def default_config(self) -> Dict[str, Any]:
+        """
+        Overrides the default config to enable Redis.
+        Even if the test only uses make_worker_hs, the main process needs Redis
+        enabled otherwise it won't create a Fake Redis server to listen on the
+        Redis port and accept fake TCP connections.
+        """
+        base = super().default_config()
+        base["redis"] = {"enabled": True}
+        return base
+
     def setUp(self):
         super().setUp()
 
         # build a replication server
-        self.server_factory = ReplicationStreamProtocolFactory(self.hs)
         self.streamer = self.hs.get_replication_streamer()
 
         # Fake in memory Redis server that servers can connect to.
@@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # handling inbound HTTP requests to that instance.
         self._hs_to_site = {self.hs: self.site}
 
-        if self.hs.config.redis.redis_enabled:
-            # Handle attempts to connect to fake redis server.
-            self.reactor.add_tcp_client_callback(
-                "localhost",
-                6379,
-                self.connect_any_redis_attempts,
-            )
+        # Handle attempts to connect to fake redis server.
+        self.reactor.add_tcp_client_callback(
+            "localhost",
+            6379,
+            self.connect_any_redis_attempts,
+        )
 
-            self.hs.get_replication_command_handler().start_replication(self.hs)
+        self.hs.get_replication_command_handler().start_replication(self.hs)
 
         # When we see a connection attempt to the master replication listener we
         # automatically set up the connection. This is so that tests don't
@@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         store = worker_hs.get_datastores().main
         store.db_pool._db_pool = self.database_pool._db_pool
 
-        # Set up TCP replication between master and the new worker if we don't
-        # have Redis support enabled.
-        if not worker_hs.config.redis.redis_enabled:
-            repl_handler = ReplicationCommandHandler(worker_hs)
-            client = ClientReplicationStreamProtocol(
-                worker_hs,
-                "client",
-                "test",
-                self.clock,
-                repl_handler,
-            )
-            server = self.server_factory.buildProtocol(
-                IPv4Address("TCP", "127.0.0.1", 0)
-            )
-
-            client_transport = FakeTransport(server, self.reactor)
-            client.makeConnection(client_transport)
-
-            server_transport = FakeTransport(client, self.reactor)
-            server.makeConnection(server_transport)
-
         # Set up a resource for the worker
         resource = ReplicationRestResource(worker_hs)
 
@@ -378,8 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             reactor=self.reactor,
         )
 
-        if worker_hs.config.redis.redis_enabled:
-            worker_hs.get_replication_command_handler().start_replication(worker_hs)
+        worker_hs.get_replication_command_handler().start_replication(worker_hs)
 
         return worker_hs
 
@@ -582,27 +578,3 @@ class FakeRedisPubSubProtocol(Protocol):
 
     def connectionLost(self, reason):
         self._server.remove_subscriber(self)
-
-
-class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
-    """
-    A test case that enables Redis, providing a fake Redis server.
-    """
-
-    if not hiredis:
-        skip = "Requires hiredis"
-
-    if not USE_POSTGRES_FOR_TESTS:
-        # Redis replication only takes place on Postgres
-        skip = "Requires Postgres"
-
-    def default_config(self) -> Dict[str, Any]:
-        """
-        Overrides the default config to enable Redis.
-        Even if the test only uses make_worker_hs, the main process needs Redis
-        enabled otherwise it won't create a Fake Redis server to listen on the
-        Redis port and accept fake TCP connections.
-        """
-        base = super().default_config()
-        base["redis"] = {"enabled": True}
-        return base
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index e6a19eafd5..1e299d2d67 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -12,10 +12,10 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
 
 
-class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
+class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
     def test_subscribed_to_enough_redis_channels(self) -> None:
         # The default main process is subscribed to the USER_IP channel.
         self.assertCountEqual(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index a7ca68069e..541d390286 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
 from tests.server import make_request
-from tests.utils import USE_POSTGRES_FOR_TESTS
 
 logger = logging.getLogger(__name__)
 
@@ -28,11 +27,6 @@ logger = logging.getLogger(__name__)
 class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
     """Checks event persisting sharding works"""
 
-    # Event persister sharding requires postgres (due to needing
-    # `MultiWriterIdGenerator`).
-    if not USE_POSTGRES_FOR_TESTS:
-        skip = "Requires Postgres"
-
     servlets = [
         admin.register_servlets_for_client_rest_resource,
         room.register_servlets,
@@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
 
     def default_config(self):
         conf = super().default_config()
-        conf["redis"] = {"enabled": "true"}
         conf["stream_writers"] = {"events": ["worker1", "worker2"]}
         conf["instance_map"] = {
             "worker1": {"host": "testserv", "port": 1001},