summary refs log tree commit diff
path: root/tests/replication
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication')
-rw-r--r--tests/replication/_base.py99
-rw-r--r--tests/replication/http/test__base.py3
-rw-r--r--tests/replication/slave/storage/test_events.py33
-rw-r--r--tests/replication/tcp/streams/test_receipts.py15
-rw-r--r--tests/replication/tcp/test_handler.py4
-rw-r--r--tests/replication/test_module_cache_invalidation.py79
-rw-r--r--tests/replication/test_multi_media_repo.py14
-rw-r--r--tests/replication/test_pusher_shard.py2
-rw-r--r--tests/replication/test_sharded_event_persister.py7
9 files changed, 165 insertions, 91 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py

index 970d5e533b..3029a16dda 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py
@@ -24,11 +24,11 @@ from synapse.http.site import SynapseRequest, SynapseSite from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler -from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol -from synapse.replication.tcp.resource import ( - ReplicationStreamProtocolFactory, +from synapse.replication.tcp.protocol import ( + ClientReplicationStreamProtocol, ServerReplicationStreamProtocol, ) +from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.server import HomeServer from tests import unittest @@ -220,15 +220,34 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): """Base class for tests running multiple workers. + Enables Redis, providing a fake Redis server. + Automatically handle HTTP replication requests from workers to master, unlike `BaseStreamTestCase`. """ + if not hiredis: + skip = "Requires hiredis" + + if not USE_POSTGRES_FOR_TESTS: + # Redis replication only takes place on Postgres + skip = "Requires Postgres" + + def default_config(self) -> Dict[str, Any]: + """ + Overrides the default config to enable Redis. + Even if the test only uses make_worker_hs, the main process needs Redis + enabled otherwise it won't create a Fake Redis server to listen on the + Redis port and accept fake TCP connections. + """ + base = super().default_config() + base["redis"] = {"enabled": True} + return base + def setUp(self): super().setUp() # build a replication server - self.server_factory = ReplicationStreamProtocolFactory(self.hs) self.streamer = self.hs.get_replication_streamer() # Fake in memory Redis server that servers can connect to. @@ -247,15 +266,14 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # handling inbound HTTP requests to that instance. self._hs_to_site = {self.hs: self.site} - if self.hs.config.redis.redis_enabled: - # Handle attempts to connect to fake redis server. - self.reactor.add_tcp_client_callback( - "localhost", - 6379, - self.connect_any_redis_attempts, - ) + # Handle attempts to connect to fake redis server. + self.reactor.add_tcp_client_callback( + "localhost", + 6379, + self.connect_any_redis_attempts, + ) - self.hs.get_replication_command_handler().start_replication(self.hs) + self.hs.get_replication_command_handler().start_replication(self.hs) # When we see a connection attempt to the master replication listener we # automatically set up the connection. This is so that tests don't @@ -339,27 +357,6 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): store = worker_hs.get_datastores().main store.db_pool._db_pool = self.database_pool._db_pool - # Set up TCP replication between master and the new worker if we don't - # have Redis support enabled. - if not worker_hs.config.redis.redis_enabled: - repl_handler = ReplicationCommandHandler(worker_hs) - client = ClientReplicationStreamProtocol( - worker_hs, - "client", - "test", - self.clock, - repl_handler, - ) - server = self.server_factory.buildProtocol( - IPv4Address("TCP", "127.0.0.1", 0) - ) - - client_transport = FakeTransport(server, self.reactor) - client.makeConnection(client_transport) - - server_transport = FakeTransport(client, self.reactor) - server.makeConnection(server_transport) - # Set up a resource for the worker resource = ReplicationRestResource(worker_hs) @@ -374,12 +371,11 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): config=worker_hs.config.server.listeners[0], resource=resource, server_version_string="1", - max_request_body_size=4096, + max_request_body_size=8192, reactor=self.reactor, ) - if worker_hs.config.redis.redis_enabled: - worker_hs.get_replication_command_handler().start_replication(worker_hs) + worker_hs.get_replication_command_handler().start_replication(worker_hs) return worker_hs @@ -546,8 +542,13 @@ class FakeRedisPubSubProtocol(Protocol): self.send("OK") elif command == b"GET": self.send(None) + + # Connection keep-alives. + elif command == b"PING": + self.send("PONG") + else: - raise Exception("Unknown command") + raise Exception(f"Unknown command: {command}") def send(self, msg): """Send a message back to the client.""" @@ -582,27 +583,3 @@ class FakeRedisPubSubProtocol(Protocol): def connectionLost(self, reason): self._server.remove_subscriber(self) - - -class RedisMultiWorkerStreamTestCase(BaseMultiWorkerStreamTestCase): - """ - A test case that enables Redis, providing a fake Redis server. - """ - - if not hiredis: - skip = "Requires hiredis" - - if not USE_POSTGRES_FOR_TESTS: - # Redis replication only takes place on Postgres - skip = "Requires Postgres" - - def default_config(self) -> Dict[str, Any]: - """ - Overrides the default config to enable Redis. - Even if the test only uses make_worker_hs, the main process needs Redis - enabled otherwise it won't create a Fake Redis server to listen on the - Redis port and accept fake TCP connections. - """ - base = super().default_config() - base["redis"] = {"enabled": True} - return base diff --git a/tests/replication/http/test__base.py b/tests/replication/http/test__base.py
index 822a957c3a..936ab4504a 100644 --- a/tests/replication/http/test__base.py +++ b/tests/replication/http/test__base.py
@@ -18,11 +18,12 @@ from typing import Tuple from twisted.web.server import Request from synapse.api.errors import Codes -from synapse.http.server import JsonResource, cancellable +from synapse.http.server import JsonResource from synapse.replication.http import REPLICATION_PREFIX from synapse.replication.http._base import ReplicationEndpoint from synapse.server import HomeServer from synapse.types import JsonDict +from synapse.util.cancellation import cancellable from tests import unittest from tests.http.server._base import test_disconnect diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 531a0db2d0..dce71f7334 100644 --- a/tests/replication/slave/storage/test_events.py +++ b/tests/replication/slave/storage/test_events.py
@@ -21,8 +21,11 @@ from synapse.api.constants import ReceiptTypes from synapse.api.room_versions import RoomVersions from synapse.events import FrozenEvent, _EventInternalMetadata, make_event_from_dict from synapse.handlers.room import RoomEventSource -from synapse.replication.slave.storage.events import SlavedEventStore -from synapse.storage.databases.main.event_push_actions import NotifCounts +from synapse.storage.databases.main.event_push_actions import ( + NotifCounts, + RoomNotifCounts, +) +from synapse.storage.databases.main.events_worker import EventsWorkerStore from synapse.storage.roommember import GetRoomsForUserWithStreamOrdering, RoomsForUser from synapse.types import PersistedEventPosition @@ -55,9 +58,9 @@ def patch__eq__(cls): return unpatch -class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): +class EventsWorkerStoreTestCase(BaseSlavedStoreTestCase): - STORE_TYPE = SlavedEventStore + STORE_TYPE = EventsWorkerStore def setUp(self): # Patch up the equality operator for events so that we can check @@ -140,6 +143,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.persist(type="m.room.create", key="", creator=USER_ID) self.check("get_invited_rooms_for_local_user", [USER_ID_2], []) event = self.persist(type="m.room.member", key=USER_ID_2, membership="invite") + assert event.internal_metadata.stream_ordering is not None self.replicate() @@ -171,14 +175,16 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): if send_receipt: self.get_success( self.master_store.insert_receipt( - ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], {} + ROOM_ID, ReceiptTypes.READ, USER_ID_2, [event1.event_id], None, {} ) ) self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=0), + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=0), {} + ), ) self.persist( @@ -191,7 +197,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=0, unread_count=0, notify_count=1), + RoomNotifCounts( + NotifCounts(highlight_count=0, unread_count=0, notify_count=1), {} + ), ) self.persist( @@ -206,7 +214,9 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): self.check( "get_unread_event_push_actions_by_room_for_user", [ROOM_ID, USER_ID_2], - NotifCounts(highlight_count=1, unread_count=0, notify_count=2), + RoomNotifCounts( + NotifCounts(highlight_count=1, unread_count=0, notify_count=2), {} + ), ) def test_get_rooms_for_user_with_stream_ordering(self): @@ -221,6 +231,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): j2 = self.persist( type="m.room.member", sender=USER_ID_2, key=USER_ID_2, membership="join" ) + assert j2.internal_metadata.stream_ordering is not None self.replicate() expected_pos = PersistedEventPosition( @@ -278,6 +289,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): ) ) self.replicate() + assert j2.internal_metadata.stream_ordering is not None event_source = RoomEventSource(self.hs) event_source.store = self.slaved_store @@ -327,10 +339,10 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event_id = 0 - def persist(self, backfill=False, **kwargs): + def persist(self, backfill=False, **kwargs) -> FrozenEvent: """ Returns: - synapse.events.FrozenEvent: The event that was persisted. + The event that was persisted. """ event, context = self.build_event(**kwargs) @@ -404,6 +416,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase): event.event_id, {user_id: actions for user_id, actions in push_actions}, False, + "main", ) ) return event, context diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index eb00117845..ede6d0c118 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py
@@ -33,7 +33,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # tell the master to send a new receipt self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1} + "!room:blue", + "m.read", + USER_ID, + ["$event:blue"], + thread_id=None, + data={"a": 1}, ) ) self.replicate() @@ -48,6 +53,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.assertEqual("m.read", row.receipt_type) self.assertEqual(USER_ID, row.user_id) self.assertEqual("$event:blue", row.event_id) + self.assertIsNone(row.thread_id) self.assertEqual({"a": 1}, row.data) # Now let's disconnect and insert some data. @@ -57,7 +63,12 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.get_success( self.hs.get_datastores().main.insert_receipt( - "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2} + "!room2:blue", + "m.read", + USER_ID, + ["$event2:foo"], + thread_id=None, + data={"a": 2}, ) ) self.replicate() diff --git a/tests/replication/tcp/test_handler.py b/tests/replication/tcp/test_handler.py
index e6a19eafd5..1e299d2d67 100644 --- a/tests/replication/tcp/test_handler.py +++ b/tests/replication/tcp/test_handler.py
@@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. -from tests.replication._base import RedisMultiWorkerStreamTestCase +from tests.replication._base import BaseMultiWorkerStreamTestCase -class ChannelsTestCase(RedisMultiWorkerStreamTestCase): +class ChannelsTestCase(BaseMultiWorkerStreamTestCase): def test_subscribed_to_enough_redis_channels(self) -> None: # The default main process is subscribed to the USER_IP channel. self.assertCountEqual( diff --git a/tests/replication/test_module_cache_invalidation.py b/tests/replication/test_module_cache_invalidation.py new file mode 100644
index 0000000000..b93cae67d3 --- /dev/null +++ b/tests/replication/test_module_cache_invalidation.py
@@ -0,0 +1,79 @@ +# Copyright 2022 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +import synapse +from synapse.module_api import cached + +from tests.replication._base import BaseMultiWorkerStreamTestCase + +logger = logging.getLogger(__name__) + +FIRST_VALUE = "one" +SECOND_VALUE = "two" + +KEY = "mykey" + + +class TestCache: + current_value = FIRST_VALUE + + @cached() + async def cached_function(self, user_id: str) -> str: + return self.current_value + + +class ModuleCacheInvalidationTestCase(BaseMultiWorkerStreamTestCase): + servlets = [ + synapse.rest.admin.register_servlets, + ] + + def test_module_cache_full_invalidation(self): + main_cache = TestCache() + self.hs.get_module_api().register_cached_function(main_cache.cached_function) + + worker_hs = self.make_worker_hs("synapse.app.generic_worker") + + worker_cache = TestCache() + worker_hs.get_module_api().register_cached_function( + worker_cache.cached_function + ) + + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual( + FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) + + main_cache.current_value = SECOND_VALUE + worker_cache.current_value = SECOND_VALUE + # No invalidation yet, should return the cached value on both the main process and the worker + self.assertEqual(FIRST_VALUE, self.get_success(main_cache.cached_function(KEY))) + self.assertEqual( + FIRST_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) + + # Full invalidation on the main process, should be replicated on the worker that + # should returned the updated value too + self.get_success( + self.hs.get_module_api().invalidate_cache( + main_cache.cached_function, (KEY,) + ) + ) + + self.assertEqual( + SECOND_VALUE, self.get_success(main_cache.cached_function(KEY)) + ) + self.assertEqual( + SECOND_VALUE, self.get_success(worker_cache.cached_function(KEY)) + ) diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 13aa5eb51a..96cdf2c45b 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py
@@ -15,8 +15,9 @@ import logging import os from typing import Optional, Tuple +from twisted.internet.interfaces import IOpenSSLServerConnectionCreator from twisted.internet.protocol import Factory -from twisted.protocols.tls import TLSMemoryBIOFactory +from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.web.http import HTTPChannel from twisted.web.server import Request @@ -102,7 +103,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): ) # fish the test server back out of the server-side TLS protocol. - http_server = server_tls_protocol.wrappedProtocol + http_server: HTTPChannel = server_tls_protocol.wrappedProtocol # type: ignore[assignment] # give the reactor a pump to get the TLS juices flowing. self.reactor.pump((0.1,)) @@ -238,16 +239,15 @@ def get_connection_factory(): return test_server_connection_factory -def _build_test_server(connection_creator): +def _build_test_server( + connection_creator: IOpenSSLServerConnectionCreator, +) -> TLSMemoryBIOProtocol: """Construct a test server This builds an HTTP channel, wrapped with a TLSMemoryBIOProtocol Args: - connection_creator (IOpenSSLServerConnectionCreator): thing to build - SSL connections - sanlist (list[bytes]): list of the SAN entries for the cert returned - by the server + connection_creator: thing to build SSL connections Returns: TLSMemoryBIOProtocol diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 8f4f6688ce..59fea93e49 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py
@@ -55,7 +55,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): token_id = user_dict.token_id self.get_success( - self.hs.get_pusherpool().add_pusher( + self.hs.get_pusherpool().add_or_update_pusher( user_id=user_id, access_token=token_id, kind="http", diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index a7ca68069e..541d390286 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py
@@ -20,7 +20,6 @@ from synapse.storage.util.id_generators import MultiWriterIdGenerator from tests.replication._base import BaseMultiWorkerStreamTestCase from tests.server import make_request -from tests.utils import USE_POSTGRES_FOR_TESTS logger = logging.getLogger(__name__) @@ -28,11 +27,6 @@ logger = logging.getLogger(__name__) class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): """Checks event persisting sharding works""" - # Event persister sharding requires postgres (due to needing - # `MultiWriterIdGenerator`). - if not USE_POSTGRES_FOR_TESTS: - skip = "Requires Postgres" - servlets = [ admin.register_servlets_for_client_rest_resource, room.register_servlets, @@ -50,7 +44,6 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): def default_config(self): conf = super().default_config() - conf["redis"] = {"enabled": "true"} conf["stream_writers"] = {"events": ["worker1", "worker2"]} conf["instance_map"] = { "worker1": {"host": "testserv", "port": 1001},