summary refs log tree commit diff
path: root/tests/replication
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--tests/replication/_base.py90
-rw-r--r--tests/replication/slave/storage/test_account_data.py42
-rw-r--r--tests/replication/tcp/test_handler.py4
-rw-r--r--tests/replication/test_sharded_event_persister.py7
-rw-r--r--tests/storage/test_receipts.py (renamed from tests/replication/slave/storage/test_receipts.py)107
5 files changed, 95 insertions, 155 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py

index 970d5e533b..ce53f808db 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler -from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol -from synapse.replication.tcp.resource import ( - ReplicationStreamProtocolFactory, +from synapse.replication.tcp.protocol import ( + ClientReplicationStreamProtocol, ServerReplicationStreamProtocol, ) +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.server import HomeServer from tests import unittest @@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): """Base class for tests running multiple workers. + Enables Redis, providing a fake Redis server. + Automatically handle HTTP replication requests from workers to master, unlike `BaseStreamTestCase`. """ + if not hiredis: + skip = "Requires hiredis" + + if not USE_POSTGRES_FOR_TESTS: + # Redis replication only takes place on Postgres + skip = "Requires Postgres" + + def default_config(self) -> Dict[str, Any]: + """ + Overrides the default config to enable Redis. + Even if the test only uses make_worker_hs, the main process needs Redis + enabled otherwise it won't create a Fake Redis server to listen on the + Redis port and accept fake TCP connections. + """ + base = super().default_config() + base["redis"] = {"enabled": True} + return base + def setUp(self): super().setUp() # build a replication server - self.server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = self.hs.get_replication_streamer() # Fake in memory Redis server that servers can connect to. @@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # handling inbound HTTP requests to that instance. self._hs_to_site = {self.hs: self.site} - if self.hs.config.redis.redis_enabled: - # Handle attempts to connect to fake redis server. - self.reactor.add_tcp_client_callback( - "localhost", - 6379, - self.connect_any_redis_attempts, - ) + # Handle attempts to connect to fake redis server. + self.reactor.add_tcp_client_callback( + "localhost", + 6379, + self.connect_any_redis_attempts, + ) - self.hs.get_replication_command_handler().start_replication(self.hs) + self.hs.get_replication_command_handler().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't @@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): store = worker_hs.get_datastores().main store.db_pool._db_pool = self.database_pool._db_pool - # Set up TCP replication between master and the new worker if we don't - # have Redis support enabled. - if not worker_hs.config.redis.redis_enabled: - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, - "client", - "test", - self.clock, - repl_handler, - ) - server = self.server_factory.buildProtocol( - IPv4Address("TCP", "127.0.0.1", 0) - ) - - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) - - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) - # Set up a resource for the worker resource = ReplicationRestResource(worker_hs) @@ -378,8 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): reactor=self.reactor, ) - if worker_hs.config.redis.redis_enabled: - worker_hs.get_replication_command_handler().start_replication(worker_hs) + worker_hs.get_replication_command_handler().start_replication(worker_hs) return worker_hs @@ -582,27 +578,3 @@ class FakeRedisPubSubProtocol(Protocol): def connectionLost(self, reason): self._server.remove_subscriber(self) - - -class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase): - """ - A test case that enables Redis, providing a fake Redis server. - """ - - if not hiredis: - skip = "Requires hiredis" - - if not USE_POSTGRES_FOR_TESTS: - # Redis replication only takes place on Postgres - skip = "Requires Postgres" - - def default_config(self) -> Dict[str, Any]: - """ - Overrides the default config to enable Redis. - Even if the test only uses make_worker_hs, the main process needs Redis - enabled otherwise it won't create a Fake Redis server to listen on the - Redis port and accept fake TCP connections. - """ - base = super().default_config() - base["redis"] = {"enabled": True} - return base diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py deleted file mode 100644
index 1524087c43..0000000000 --- a/tests/replication/slave/storage/test_account_data.py +++ /dev/null
@@ -1,42 +0,0 @@ -# Copyright 2016 OpenMarket Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from synapse.replication.slave.storage.account_data import SlavedAccountDataStore - -from ._base import BaseSlavedStoreTestCase - -USER_ID = "@feeling:blue" -TYPE = "my.type" - - -class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase): - - STORE_TYPE = SlavedAccountDataStore - - def test_user_account_data(self): - self.get_success( - self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1}) - ) - self.replicate() - self.check( - "get_global_account_data_by_type_for_user", [USER_ID, TYPE], {"a": 1} - ) - - self.get_success( - self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2}) - ) - self.replicate() - self.check( - "get_global_account_data_by_type_for_user", [USER_ID, TYPE], {"a": 2} - ) diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index e6a19eafd5..1e299d2d67 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py
@@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.replication._base import RedisMultiWorkerStreamTestCase +from tests.replication._base import BaseMultiWorkerStreamTestCase -class ChannelsTestCase(RedisMultiWorkerStreamTestCase): +class ChannelsTestCase(BaseMultiWorkerStreamTestCase): def test_subscribed_to_enough_redis_channels(self) -> None: # The default main process is subscribed to the USER_IP channel. self.assertCountEqual( diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index a7ca68069e..541d390286 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py
@@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request -from tests.utils import USE_POSTGRES_FOR_TESTS logger = logging.getLogger(__name__) @@ -28,11 +27,6 @@ logger = logging.getLogger(__name__) class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): """Checks event persisting sharding works""" - # Event persister sharding requires postgres (due to needing - # `MultiWriterIdGenerator`). - if not USE_POSTGRES_FOR_TESTS: - skip = "Requires Postgres" - servlets = [ admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): def default_config(self): conf = super().default_config() - conf["redis"] = {"enabled": "true"} conf["stream_writers"] = {"events": ["worker1", "worker2"]} conf["instance_map"] = { "worker1": {"host": "testserv", "port": 1001}, diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/storage/test_receipts.py
index 19f57115a1..191c957fb5 100644 --- a/tests/replication/slave/storage/test_receipts.py +++ b/tests/storage/test_receipts.py
@@ -12,24 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from parameterized import parameterized + from synapse.api.constants import ReceiptTypes -from synapse.replication.slave.storage.receipts import SlavedReceiptsStore from synapse.types import UserID, create_requester from tests.test_utils.event_injection import create_event - -from ._base import BaseSlavedStoreTestCase +from tests.unittest import HomeserverTestCase OTHER_USER_ID = "@other:test" OUR_USER_ID = "@our:test" -class SlavedReceiptTestCase(BaseSlavedStoreTestCase): +class ReceiptTestCase(HomeserverTestCase): + def prepare(self, reactor, clock, homeserver) -> None: + super().prepare(reactor, clock, homeserver) - STORE_TYPE = SlavedReceiptsStore + self.store = homeserver.get_datastores().main - def prepare(self, reactor, clock, homeserver): - super().prepare(reactor, clock, homeserver) self.room_creator = homeserver.get_room_creation_handler() self.persist_event_storage_controller = ( self.hs.get_storage_controllers().persistence @@ -85,32 +85,48 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): ) ) - def test_return_empty_with_no_data(self): + def test_return_empty_with_no_data(self) -> None: res = self.get_success( - self.master_store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + self.store.get_receipts_for_user( + OUR_USER_ID, + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, {}) res = self.get_success( - self.master_store.get_receipts_for_user_with_orderings( + self.store.get_receipts_for_user_with_orderings( OUR_USER_ID, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, {}) res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ + ReceiptTypes.READ, + ReceiptTypes.READ_PRIVATE, + ReceiptTypes.UNSTABLE_READ_PRIVATE, + ], ) ) self.assertEqual(res, None) - def test_get_receipts_for_user(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_get_receipts_for_user(self, receipt_type: str) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -121,47 +137,45 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Send public read receipt for the first event self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} ) ) # Send private read receipt for the second event self.get_success( - self.master_store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.store.insert_receipt( + self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( - self.master_store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + self.store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ, receipt_type] ) ) self.assertEqual(res, {self.room_id1: event1_2_id}) # Test we get the older event when we want only public receipt res = self.get_success( - self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) ) self.assertEqual(res, {self.room_id1: event1_1_id}) # Test we get the latest event when we want only the public receipt res = self.get_success( - self.master_store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ_PRIVATE] - ) + self.store.get_receipts_for_user(OUR_USER_ID, [receipt_type]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) # Test receipt updating self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} ) ) res = self.get_success( - self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) + self.store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ]) ) self.assertEqual(res, {self.room_id1: event1_2_id}) @@ -172,18 +186,21 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Test new room is reflected in what the method returns self.get_success( - self.master_store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.store.insert_receipt( + self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( - self.master_store.get_receipts_for_user( - OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE] + self.store.get_receipts_for_user( + OUR_USER_ID, [ReceiptTypes.READ, receipt_type] ) ) self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id}) - def test_get_last_receipt_event_id_for_user(self): + @parameterized.expand( + [ReceiptTypes.READ_PRIVATE, ReceiptTypes.UNSTABLE_READ_PRIVATE] + ) + def test_get_last_receipt_event_id_for_user(self, receipt_type: str) -> None: # Send some events into the first room event1_1_id = self.create_and_send_event( self.room_id1, UserID.from_string(OTHER_USER_ID) @@ -194,30 +211,30 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Send public read receipt for the first event self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {} ) ) # Send private read receipt for the second event self.get_success( - self.master_store.insert_receipt( - self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {} + self.store.insert_receipt( + self.room_id1, receipt_type, OUR_USER_ID, [event1_2_id], {} ) ) # Test we get the latest event when we want both private and public receipts res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ReceiptTypes.READ, receipt_type], ) ) self.assertEqual(res, event1_2_id) # Test we get the older event when we want only public receipt res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] ) ) @@ -225,20 +242,20 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Test we get the latest event when we want only the private receipt res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( - OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE] + self.store.get_last_receipt_event_id_for_user( + OUR_USER_ID, self.room_id1, [receipt_type] ) ) self.assertEqual(res, event1_2_id) # Test receipt updating self.get_success( - self.master_store.insert_receipt( + self.store.insert_receipt( self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {} ) ) res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id1, [ReceiptTypes.READ] ) ) @@ -251,15 +268,15 @@ class SlavedReceiptTestCase(BaseSlavedStoreTestCase): # Test new room is reflected in what the method returns self.get_success( - self.master_store.insert_receipt( - self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {} + self.store.insert_receipt( + self.room_id2, receipt_type, OUR_USER_ID, [event2_1_id], {} ) ) res = self.get_success( - self.master_store.get_last_receipt_event_id_for_user( + self.store.get_last_receipt_event_id_for_user( OUR_USER_ID, self.room_id2, - [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE], + [ReceiptTypes.READ, receipt_type], ) ) self.assertEqual(res, event2_1_id)