diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 970d5e533b..ce53f808db 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite
from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.client import ReplicationDataHandler
from synapse.replication.tcp.handler import ReplicationCommandHandler
-from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
-from synapse.replication.tcp.resource import (
- ReplicationStreamProtocolFactory,
+from synapse.replication.tcp.protocol import (
+ ClientReplicationStreamProtocol,
ServerReplicationStreamProtocol,
)
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.server import HomeServer
from tests import unittest
@@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests running multiple workers.
+ Enables Redis, providing a fake Redis server.
+
Automatically handle HTTP replication requests from workers to master,
unlike `BaseStreamTestCase`.
"""
+ if not hiredis:
+ skip = "Requires hiredis"
+
+ if not USE_POSTGRES_FOR_TESTS:
+ # Redis replication only takes place on Postgres
+ skip = "Requires Postgres"
+
+ def default_config(self) -> Dict[str, Any]:
+ """
+ Overrides the default config to enable Redis.
+ Even if the test only uses make_worker_hs, the main process needs Redis
+ enabled otherwise it won't create a Fake Redis server to listen on the
+ Redis port and accept fake TCP connections.
+ """
+ base = super().default_config()
+ base["redis"] = {"enabled": True}
+ return base
+
def setUp(self):
super().setUp()
# build a replication server
- self.server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = self.hs.get_replication_streamer()
# Fake in memory Redis server that servers can connect to.
@@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# handling inbound HTTP requests to that instance.
self._hs_to_site = {self.hs: self.site}
- if self.hs.config.redis.redis_enabled:
- # Handle attempts to connect to fake redis server.
- self.reactor.add_tcp_client_callback(
- "localhost",
- 6379,
- self.connect_any_redis_attempts,
- )
+ # Handle attempts to connect to fake redis server.
+ self.reactor.add_tcp_client_callback(
+ "localhost",
+ 6379,
+ self.connect_any_redis_attempts,
+ )
- self.hs.get_replication_command_handler().start_replication(self.hs)
+ self.hs.get_replication_command_handler().start_replication(self.hs)
# When we see a connection attempt to the master replication listener we
# automatically set up the connection. This is so that tests don't
@@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
store = worker_hs.get_datastores().main
store.db_pool._db_pool = self.database_pool._db_pool
- # Set up TCP replication between master and the new worker if we don't
- # have Redis support enabled.
- if not worker_hs.config.redis.redis_enabled:
- repl_handler = ReplicationCommandHandler(worker_hs)
- client = ClientReplicationStreamProtocol(
- worker_hs,
- "client",
- "test",
- self.clock,
- repl_handler,
- )
- server = self.server_factory.buildProtocol(
- IPv4Address("TCP", "127.0.0.1", 0)
- )
-
- client_transport = FakeTransport(server, self.reactor)
- client.makeConnection(client_transport)
-
- server_transport = FakeTransport(client, self.reactor)
- server.makeConnection(server_transport)
-
# Set up a resource for the worker
resource = ReplicationRestResource(worker_hs)
@@ -378,8 +375,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
reactor=self.reactor,
)
- if worker_hs.config.redis.redis_enabled:
- worker_hs.get_replication_command_handler().start_replication(worker_hs)
+ worker_hs.get_replication_command_handler().start_replication(worker_hs)
return worker_hs
@@ -582,27 +578,3 @@ class FakeRedisPubSubProtocol(Protocol):
def connectionLost(self, reason):
self._server.remove_subscriber(self)
-
-
-class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase):
- """
- A test case that enables Redis, providing a fake Redis server.
- """
-
- if not hiredis:
- skip = "Requires hiredis"
-
- if not USE_POSTGRES_FOR_TESTS:
- # Redis replication only takes place on Postgres
- skip = "Requires Postgres"
-
- def default_config(self) -> Dict[str, Any]:
- """
- Overrides the default config to enable Redis.
- Even if the test only uses make_worker_hs, the main process needs Redis
- enabled otherwise it won't create a Fake Redis server to listen on the
- Redis port and accept fake TCP connections.
- """
- base = super().default_config()
- base["redis"] = {"enabled": True}
- return base
diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index a5ab093a27..936ab4504a 100644
--- a/tests/replication/http/test__base.py
+++ b/tests/replication/http/test__base.py
@@ -18,14 +18,15 @@ from typing import Tuple
from twisted.web.server import Request
from synapse.api.errors import Codes
-from synapse.http.server import JsonResource, cancellable
+from synapse.http.server import JsonResource
from synapse.replication.http import REPLICATION_PREFIX
from synapse.replication.http._base import ReplicationEndpoint
from synapse.server import HomeServer
from synapse.types import JsonDict
+from synapse.util.cancellation import cancellable
from tests import unittest
-from tests.http.server._base import EndpointCancellationTestHelperMixin
+from tests.http.server._base import test_disconnect
class CancellableReplicationEndpoint(ReplicationEndpoint):
@@ -69,9 +70,7 @@ class UncancellableReplicationEndpoint(ReplicationEndpoint):
return HTTPStatus.OK, {"result": True}
-class ReplicationEndpointCancellationTestCase(
- unittest.HomeserverTestCase, EndpointCancellationTestHelperMixin
-):
+class ReplicationEndpointCancellationTestCase(unittest.HomeserverTestCase):
"""Tests for `ReplicationEndpoint` cancellation."""
def create_test_resource(self):
@@ -87,7 +86,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers with the `@cancellable` flag can be cancelled."""
path = f"{REPLICATION_PREFIX}/{CancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=True,
@@ -98,7 +97,7 @@ class ReplicationEndpointCancellationTestCase(
"""Test that handlers without the `@cancellable` flag cannot be cancelled."""
path = f"{REPLICATION_PREFIX}/{UncancellableReplicationEndpoint.NAME}/"
channel = self.make_request("POST", path, await_result=False)
- self._test_disconnect(
+ test_disconnect(
self.reactor,
channel,
expect_cancellation=False,
diff --git a/tests/replication/slave/storage/test_account_data.py b/tests/replication/slave/storage/test_account_data.py
deleted file mode 100644
index 1524087c43..0000000000
--- a/tests/replication/slave/storage/test_account_data.py
+++ /dev/null
@@ -1,42 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.replication.slave.storage.account_data import SlavedAccountDataStore
-
-from ._base import BaseSlavedStoreTestCase
-
-USER_ID = "@feeling:blue"
-TYPE = "my.type"
-
-
-class SlavedAccountDataStoreTestCase(BaseSlavedStoreTestCase):
-
- STORE_TYPE = SlavedAccountDataStore
-
- def test_user_account_data(self):
- self.get_success(
- self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 1})
- )
- self.replicate()
- self.check(
- "get_global_account_data_by_type_for_user", [USER_ID, TYPE], {"a": 1}
- )
-
- self.get_success(
- self.master_store.add_account_data_for_user(USER_ID, TYPE, {"a": 2})
- )
- self.replicate()
- self.check(
- "get_global_account_data_by_type_for_user", [USER_ID, TYPE], {"a": 2}
- )
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 6d3d4afe52..531a0db2d0 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -15,7 +15,9 @@ import logging
from typing import Iterable, Optional
from canonicaljson import encode_canonical_json
+from parameterized import parameterized
+from synapse.api.constants import ReceiptTypes
from synapse.api.room_versions import RoomVersions
from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict
from synapse.handlers.room import RoomEventSource
@@ -156,17 +158,26 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
],
)
- def test_push_actions_for_user(self):
+ @parameterized.expand([(True,), (False,)])
+ def test_push_actions_for_user(self, send_receipt: bool):
self.persist(type="m.room.create", key="", creator=USER_ID)
- self.persist(type="m.room.join", key=USER_ID, membership="join")
+ self.persist(type="m.room.member", key=USER_ID, membership="join")
self.persist(
- type="m.room.join", sender=USER_ID, key=USER_ID_2, membership="join"
+ type="m.room.member", sender=USER_ID, key=USER_ID_2, membership="join"
)
event1 = self.persist(type="m.room.message", msgtype="m.text", body="hello")
self.replicate()
+
+ if send_receipt:
+ self.get_success(
+ self.master_store.insert_receipt(
+ ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {}
+ )
+ )
+
self.check(
"get_unread_event_push_actions_by_room_for_user",
- [ROOM_ID, USER_ID_2, event1.event_id],
+ [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=0),
)
@@ -179,7 +190,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
- [ROOM_ID, USER_ID_2, event1.event_id],
+ [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=0, unread_count=0, notify_count=1),
)
@@ -194,7 +205,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.replicate()
self.check(
"get_unread_event_push_actions_by_room_for_user",
- [ROOM_ID, USER_ID_2, event1.event_id],
+ [ROOM_ID, USER_ID_2],
NotifCounts(highlight_count=1, unread_count=0, notify_count=2),
)
diff --git a/tests/replication/slave/storage/test_receipts.py b/tests/replication/slave/storage/test_receipts.py
deleted file mode 100644
index 19f57115a1..0000000000
--- a/tests/replication/slave/storage/test_receipts.py
+++ /dev/null
@@ -1,265 +0,0 @@
-# Copyright 2016 OpenMarket Ltd
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-from synapse.api.constants import ReceiptTypes
-from synapse.replication.slave.storage.receipts import SlavedReceiptsStore
-from synapse.types import UserID, create_requester
-
-from tests.test_utils.event_injection import create_event
-
-from ._base import BaseSlavedStoreTestCase
-
-OTHER_USER_ID = "@other:test"
-OUR_USER_ID = "@our:test"
-
-
-class SlavedReceiptTestCase(BaseSlavedStoreTestCase):
-
- STORE_TYPE = SlavedReceiptsStore
-
- def prepare(self, reactor, clock, homeserver):
- super().prepare(reactor, clock, homeserver)
- self.room_creator = homeserver.get_room_creation_handler()
- self.persist_event_storage_controller = (
- self.hs.get_storage_controllers().persistence
- )
-
- # Create a test user
- self.ourUser = UserID.from_string(OUR_USER_ID)
- self.ourRequester = create_requester(self.ourUser)
-
- # Create a second test user
- self.otherUser = UserID.from_string(OTHER_USER_ID)
- self.otherRequester = create_requester(self.otherUser)
-
- # Create a test room
- info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
- self.room_id1 = info["room_id"]
-
- # Create a second test room
- info, _ = self.get_success(self.room_creator.create_room(self.ourRequester, {}))
- self.room_id2 = info["room_id"]
-
- # Join the second user to the first room
- memberEvent, memberEventContext = self.get_success(
- create_event(
- self.hs,
- room_id=self.room_id1,
- type="m.room.member",
- sender=self.otherRequester.user.to_string(),
- state_key=self.otherRequester.user.to_string(),
- content={"membership": "join"},
- )
- )
- self.get_success(
- self.persist_event_storage_controller.persist_event(
- memberEvent, memberEventContext
- )
- )
-
- # Join the second user to the second room
- memberEvent, memberEventContext = self.get_success(
- create_event(
- self.hs,
- room_id=self.room_id2,
- type="m.room.member",
- sender=self.otherRequester.user.to_string(),
- state_key=self.otherRequester.user.to_string(),
- content={"membership": "join"},
- )
- )
- self.get_success(
- self.persist_event_storage_controller.persist_event(
- memberEvent, memberEventContext
- )
- )
-
- def test_return_empty_with_no_data(self):
- res = self.get_success(
- self.master_store.get_receipts_for_user(
- OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
- )
- )
- self.assertEqual(res, {})
-
- res = self.get_success(
- self.master_store.get_receipts_for_user_with_orderings(
- OUR_USER_ID,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
- )
- )
- self.assertEqual(res, {})
-
- res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
- self.room_id1,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
- )
- )
- self.assertEqual(res, None)
-
- def test_get_receipts_for_user(self):
- # Send some events into the first room
- event1_1_id = self.create_and_send_event(
- self.room_id1, UserID.from_string(OTHER_USER_ID)
- )
- event1_2_id = self.create_and_send_event(
- self.room_id1, UserID.from_string(OTHER_USER_ID)
- )
-
- # Send public read receipt for the first event
- self.get_success(
- self.master_store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
- )
- )
- # Send private read receipt for the second event
- self.get_success(
- self.master_store.insert_receipt(
- self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
- )
- )
-
- # Test we get the latest event when we want both private and public receipts
- res = self.get_success(
- self.master_store.get_receipts_for_user(
- OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
- )
- )
- self.assertEqual(res, {self.room_id1: event1_2_id})
-
- # Test we get the older event when we want only public receipt
- res = self.get_success(
- self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ])
- )
- self.assertEqual(res, {self.room_id1: event1_1_id})
-
- # Test we get the latest event when we want only the public receipt
- res = self.get_success(
- self.master_store.get_receipts_for_user(
- OUR_USER_ID, [ReceiptTypes.READ_PRIVATE]
- )
- )
- self.assertEqual(res, {self.room_id1: event1_2_id})
-
- # Test receipt updating
- self.get_success(
- self.master_store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
- )
- )
- res = self.get_success(
- self.master_store.get_receipts_for_user(OUR_USER_ID, [ReceiptTypes.READ])
- )
- self.assertEqual(res, {self.room_id1: event1_2_id})
-
- # Send some events into the second room
- event2_1_id = self.create_and_send_event(
- self.room_id2, UserID.from_string(OTHER_USER_ID)
- )
-
- # Test new room is reflected in what the method returns
- self.get_success(
- self.master_store.insert_receipt(
- self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
- )
- )
- res = self.get_success(
- self.master_store.get_receipts_for_user(
- OUR_USER_ID, [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE]
- )
- )
- self.assertEqual(res, {self.room_id1: event1_2_id, self.room_id2: event2_1_id})
-
- def test_get_last_receipt_event_id_for_user(self):
- # Send some events into the first room
- event1_1_id = self.create_and_send_event(
- self.room_id1, UserID.from_string(OTHER_USER_ID)
- )
- event1_2_id = self.create_and_send_event(
- self.room_id1, UserID.from_string(OTHER_USER_ID)
- )
-
- # Send public read receipt for the first event
- self.get_success(
- self.master_store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_1_id], {}
- )
- )
- # Send private read receipt for the second event
- self.get_success(
- self.master_store.insert_receipt(
- self.room_id1, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event1_2_id], {}
- )
- )
-
- # Test we get the latest event when we want both private and public receipts
- res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
- self.room_id1,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
- )
- )
- self.assertEqual(res, event1_2_id)
-
- # Test we get the older event when we want only public receipt
- res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
- )
- )
- self.assertEqual(res, event1_1_id)
-
- # Test we get the latest event when we want only the private receipt
- res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ_PRIVATE]
- )
- )
- self.assertEqual(res, event1_2_id)
-
- # Test receipt updating
- self.get_success(
- self.master_store.insert_receipt(
- self.room_id1, ReceiptTypes.READ, OUR_USER_ID, [event1_2_id], {}
- )
- )
- res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
- OUR_USER_ID, self.room_id1, [ReceiptTypes.READ]
- )
- )
- self.assertEqual(res, event1_2_id)
-
- # Send some events into the second room
- event2_1_id = self.create_and_send_event(
- self.room_id2, UserID.from_string(OTHER_USER_ID)
- )
-
- # Test new room is reflected in what the method returns
- self.get_success(
- self.master_store.insert_receipt(
- self.room_id2, ReceiptTypes.READ_PRIVATE, OUR_USER_ID, [event2_1_id], {}
- )
- )
- res = self.get_success(
- self.master_store.get_last_receipt_event_id_for_user(
- OUR_USER_ID,
- self.room_id2,
- [ReceiptTypes.READ, ReceiptTypes.READ_PRIVATE],
- )
- )
- self.assertEqual(res, event2_1_id)
diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index e6a19eafd5..1e299d2d67 100644
--- a/tests/replication/tcp/test_handler.py
+++ b/tests/replication/tcp/test_handler.py
@@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from tests.replication._base import RedisMultiWorkerStreamTestCase
+from tests.replication._base import BaseMultiWorkerStreamTestCase
-class ChannelsTestCase(RedisMultiWorkerStreamTestCase):
+class ChannelsTestCase(BaseMultiWorkerStreamTestCase):
def test_subscribed_to_enough_redis_channels(self) -> None:
# The default main process is subscribed to the USER_IP channel.
self.assertCountEqual(
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index a7ca68069e..541d390286 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator
from tests.replication._base import BaseMultiWorkerStreamTestCase
from tests.server import make_request
-from tests.utils import USE_POSTGRES_FOR_TESTS
logger = logging.getLogger(__name__)
@@ -28,11 +27,6 @@ logger = logging.getLogger(__name__)
class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
"""Checks event persisting sharding works"""
- # Event persister sharding requires postgres (due to needing
- # `MultiWriterIdGenerator`).
- if not USE_POSTGRES_FOR_TESTS:
- skip = "Requires Postgres"
-
servlets = [
admin.register_servlets_for_client_rest_resource,
room.register_servlets,
@@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
def default_config(self):
conf = super().default_config()
- conf["redis"] = {"enabled": "true"}
conf["stream_writers"] = {"events": ["worker1", "worker2"]}
conf["instance_map"] = {
"worker1": {"host": "testserv", "port": 1001},
|