diff options
Diffstat (limited to 'tests/replication')
-rw-r--r-- | tests/replication/_base.py | 90 | ||||
-rw-r--r-- | tests/replication/tcp/test_handler.py | 4 | ||||
-rw-r--r-- | tests/replication/test_sharded_event_persister.py | 7 |
3 files changed, 33 insertions, 68 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 970d5e533b..ce53f808db 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler -from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol -from synapse.replication.tcp.resource import ( - ReplicationStreamProtocolFactory, +from synapse.replication.tcp.protocol import ( + ClientReplicationStreamProtocol, ServerReplicationStreamProtocol, ) +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.server import HomeServer from tests import unittest @@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): """Base class for tests running multiple workers. + Enables Redis, providing a fake Redis server. + Automatically handle HTTP replication requests from workers to master, unlike `BaseStreamTestCase`. """ + if not hiredis: + skip = "Requires hiredis" + + if not USE_POSTGRES_FOR_TESTS: + # Redis replication only takes place on Postgres + skip = "Requires Postgres" + + def default_config(self) -> Dict[str, Any]: + """ + Overrides the default config to enable Redis. + Even if the test only uses make_worker_hs, the main process needs Redis + enabled otherwise it won't create a Fake Redis server to listen on the + Redis port and accept fake TCP connections. + """ + base = super().default_config() + base["redis"] = {"enabled": True} + return base + def setUp(self): super().setUp() # build a replication server - self.server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = self.hs.get_replication_streamer() # Fake in memory Redis server that servers can connect to. @@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # handling inbound HTTP requests to that instance. self._hs_to_site = {self.hs: self.site} - if self.hs.config.redis.redis_enabled: - # Handle attempts to connect to fake redis server. - self.reactor.add_tcp_client_callback( - "localhost", - 6379, - self.connect_any_redis_attempts, - ) + # Handle attempts to connect to fake redis server. + self.reactor.add_tcp_client_callback( + "localhost", + 6379, + self.connect_any_redis_attempts, + ) - self.hs.get_replication_command_handler().start_replication(self.hs) + self.hs.get_replication_command_handler().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't @@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): store = worker_hs.get_datastores().main store.db_pool._db_pool = self.database_pool._db_pool - # Set up TCP replication between master and the new worker if we don't - # have Redis support enabled. - if not worker_hs.config.redis.redis_enabled: - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, - "client", - "test", - self.clock, - repl_handler, - ) - server = self.server_factory.buildProtocol( - IPv4Address("TCP", "127.0.0.1", 0) - ) - - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) - - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) - # Set up a resource for the worker resource = ReplicationRestResource(worker_hs) @@ -378,8 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): reactor=self.reactor, ) - if worker_hs.config.redis.redis_enabled: - worker_hs.get_replication_command_handler().start_replication(worker_hs) + worker_hs.get_replication_command_handler().start_replication(worker_hs) return worker_hs @@ -582,27 +578,3 @@ class FakeRedisPubSubProtocol(Protocol): def connectionLost(self, reason): self._server.remove_subscriber(self) - - -class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase): - """ - A test case that enables Redis, providing a fake Redis server. - """ - - if not hiredis: - skip = "Requires hiredis" - - if not USE_POSTGRES_FOR_TESTS: - # Redis replication only takes place on Postgres - skip = "Requires Postgres" - - def default_config(self) -> Dict[str, Any]: - """ - Overrides the default config to enable Redis. - Even if the test only uses make_worker_hs, the main process needs Redis - enabled otherwise it won't create a Fake Redis server to listen on the - Redis port and accept fake TCP connections. - """ - base = super().default_config() - base["redis"] = {"enabled": True} - return base diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py index e6a19eafd5..1e299d2d67 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.replication._base import RedisMultiWorkerStreamTestCase +from tests.replication._base import BaseMultiWorkerStreamTestCase -class ChannelsTestCase(RedisMultiWorkerStreamTestCase): +class ChannelsTestCase(BaseMultiWorkerStreamTestCase): def test_subscribed_to_enough_redis_channels(self) -> None: # The default main process is subscribed to the USER_IP channel. self.assertCountEqual( diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index a7ca68069e..541d390286 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request -from tests.utils import USE_POSTGRES_FOR_TESTS logger = logging.getLogger(__name__) @@ -28,11 +27,6 @@ logger = logging.getLogger(__name__) class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): """Checks event persisting sharding works""" - # Event persister sharding requires postgres (due to needing - # `MultiWriterIdGenerator`). - if not USE_POSTGRES_FOR_TESTS: - skip = "Requires Postgres" - servlets = [ admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): def default_config(self): conf = super().default_config() - conf["redis"] = {"enabled": "true"} conf["stream_writers"] = {"events": ["worker1", "worker2"]} conf["instance_map"] = { "worker1": {"host": "testserv", "port": 1001}, |