diff options
Diffstat (limited to 'tests/replication')
-rw-r--r-- | tests/replication/_base.py | 59 | ||||
-rw-r--r-- | tests/replication/test_auth.py | 117 | ||||
-rw-r--r-- | tests/replication/test_client_reader_shard.py | 36 | ||||
-rw-r--r-- | tests/replication/test_federation_sender_shard.py | 10 | ||||
-rw-r--r-- | tests/replication/test_multi_media_repo.py | 4 | ||||
-rw-r--r-- | tests/replication/test_pusher_shard.py | 14 | ||||
-rw-r--r-- | tests/replication/test_sharded_event_persister.py | 16 |
7 files changed, 188 insertions, 68 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py index 295c5d58a6..d5dce1f83f 100644 --- a/tests/replication/_base.py +++ b/tests/replication/_base.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple import attr @@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.protocol import Protocol from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel +from twisted.web.resource import Resource from synapse.app.generic_worker import ( GenericWorkerReplicationHandler, @@ -28,7 +29,7 @@ from synapse.app.generic_worker import ( ) from synapse.http.server import JsonResource from synapse.http.site import SynapseRequest, SynapseSite -from synapse.replication.http import ReplicationRestResource, streams +from synapse.replication.http import ReplicationRestResource from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): if not hiredis: skip = "Requires hiredis" - servlets = [ - streams.register_servlets, - ] - def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) @@ -67,7 +64,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): # Make a new HomeServer object for the worker self.reactor.lookups["testserv"] = "1.2.3.4" self.worker_hs = self.setup_test_homeserver( - http_client=None, + federation_http_client=None, homeserver_to_use=GenericWorkerServer, config=self._get_worker_hs_config(), reactor=self.reactor, @@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def create_resource_dict(self) -> Dict[str, Resource]: + d = super().create_resource_dict() + d["/_synapse/replication"] = ReplicationRestResource(self.hs) + return d + def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.generic_worker" @@ -210,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): # Fake in memory Redis server that servers can connect to. self._redis_server = FakeRedisPubSubServer() + # We may have an attempt to connect to redis for the external cache already. + self.connect_any_redis_attempts() + store = self.hs.get_datastore() self.database_pool = store.db_pool @@ -264,7 +269,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): worker_app: Type of worker, e.g. `synapse.app.federation_sender`. extra_config: Any extra config to use for this instances. **kwargs: Options that get passed to `self.setup_test_homeserver`, - useful to e.g. pass some mocks for things like `http_client` + useful to e.g. pass some mocks for things like `federation_http_client` Returns: The new worker HomeServer instance. @@ -399,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase): fake one. """ clients = self.reactor.tcpClients - self.assertEqual(len(clients), 1) - (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) - self.assertEqual(host, "localhost") - self.assertEqual(port, 6379) + while clients: + (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0) + self.assertEqual(host, "localhost") + self.assertEqual(port, 6379) - client_protocol = client_factory.buildProtocol(None) - server_protocol = self._redis_server.buildProtocol(None) + client_protocol = client_factory.buildProtocol(None) + server_protocol = self._redis_server.buildProtocol(None) - client_to_server_transport = FakeTransport( - server_protocol, self.reactor, client_protocol - ) - client_protocol.makeConnection(client_to_server_transport) - - server_to_client_transport = FakeTransport( - client_protocol, self.reactor, server_protocol - ) - server_protocol.makeConnection(server_to_client_transport) + client_to_server_transport = FakeTransport( + server_protocol, self.reactor, client_protocol + ) + client_protocol.makeConnection(client_to_server_transport) - return client_to_server_transport, server_to_client_transport + server_to_client_transport = FakeTransport( + client_protocol, self.reactor, server_protocol + ) + server_protocol.makeConnection(server_to_client_transport) class TestReplicationDataHandler(GenericWorkerReplicationHandler): @@ -622,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol): (channel,) = args self._server.add_subscriber(self) self.send(["subscribe", channel, 1]) + + # Since we use SET/GET to cache things we can safely no-op them. + elif command == b"SET": + self.send("OK") + elif command == b"GET": + self.send(None) else: raise Exception("Unknown command") @@ -643,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol): # We assume bytes are just unicode strings. obj = obj.decode("utf-8") + if obj is None: + return "$-1\r\n" if isinstance(obj, str): return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj) if isinstance(obj, int): diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py new file mode 100644 index 0000000000..f35a5235e1 --- /dev/null +++ b/tests/replication/test_auth.py @@ -0,0 +1,117 @@ +# -*- coding: utf-8 -*- +# Copyright 2020 The Matrix.org Foundation C.I.C. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging + +from synapse.rest.client.v2_alpha import register + +from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.server import FakeChannel, make_request +from tests.unittest import override_config + +logger = logging.getLogger(__name__) + + +class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase): + """Test the authentication of HTTP calls between workers.""" + + servlets = [register.register_servlets] + + def make_homeserver(self, reactor, clock): + config = self.default_config() + # This isn't a real configuration option but is used to provide the main + # homeserver and worker homeserver different options. + main_replication_secret = config.pop("main_replication_secret", None) + if main_replication_secret: + config["worker_replication_secret"] = main_replication_secret + return self.setup_test_homeserver(config=config) + + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.client_reader" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + + return config + + def _test_register(self) -> FakeChannel: + """Run the actual test: + + 1. Create a worker homeserver. + 2. Start registration by providing a user/password. + 3. Complete registration by providing dummy auth (this hits the main synapse). + 4. Return the final request. + + """ + worker_hs = self.make_worker_hs("synapse.app.client_reader") + site = self._hs_to_site[worker_hs] + + channel_1 = make_request( + self.reactor, + site, + "POST", + "register", + {"username": "user", "type": "m.login.password", "password": "bar"}, + ) + self.assertEqual(channel_1.code, 401) + + # Grab the session + session = channel_1.json_body["session"] + + # also complete the dummy auth + return make_request( + self.reactor, + site, + "POST", + "register", + {"auth": {"session": session, "type": "m.login.dummy"}}, + ) + + def test_no_auth(self): + """With no authentication the request should finish. + """ + channel = self._test_register() + self.assertEqual(channel.code, 200) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") + + @override_config({"main_replication_secret": "my-secret"}) + def test_missing_auth(self): + """If the main process expects a secret that is not provided, an error results. + """ + channel = self._test_register() + self.assertEqual(channel.code, 500) + + @override_config( + { + "main_replication_secret": "my-secret", + "worker_replication_secret": "wrong-secret", + } + ) + def test_unauthorized(self): + """If the main process receives the wrong secret, an error results. + """ + channel = self._test_register() + self.assertEqual(channel.code, 500) + + @override_config({"worker_replication_secret": "my-secret"}) + def test_authorized(self): + """The request should finish when the worker provides the authentication header. + """ + channel = self._test_register() + self.assertEqual(channel.code, 200) + + # We're given a registered user. + self.assertEqual(channel.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py index 96801db473..4608b65a0c 100644 --- a/tests/replication/test_client_reader_shard.py +++ b/tests/replication/test_client_reader_shard.py @@ -14,27 +14,19 @@ # limitations under the License. import logging -from synapse.api.constants import LoginType -from synapse.http.site import SynapseRequest from synapse.rest.client.v2_alpha import register from tests.replication._base import BaseMultiWorkerStreamTestCase -from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker -from tests.server import FakeChannel, make_request +from tests.server import make_request logger = logging.getLogger(__name__) class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): - """Base class for tests of the replication streams""" + """Test using one or more client readers for registration.""" servlets = [register.register_servlets] - def prepare(self, reactor, clock, hs): - self.recaptcha_checker = DummyRecaptchaChecker(hs) - auth_handler = hs.get_auth_handler() - auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker - def _get_worker_hs_config(self) -> dict: config = self.default_config() config["worker_app"] = "synapse.app.client_reader" @@ -48,27 +40,27 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): worker_hs = self.make_worker_hs("synapse.app.client_reader") site = self._hs_to_site[worker_hs] - request_1, channel_1 = make_request( + channel_1 = make_request( self.reactor, site, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, - ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_1.code, 401) + ) + self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] # also complete the dummy auth - request_2, channel_2 = make_request( + channel_2 = make_request( self.reactor, site, "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}, - ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_2.code, 200) + ) + self.assertEqual(channel_2.code, 200) # We're given a registered user. self.assertEqual(channel_2.json_body["user_id"], "@user:test") @@ -80,28 +72,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase): worker_hs_2 = self.make_worker_hs("synapse.app.client_reader") site_1 = self._hs_to_site[worker_hs_1] - request_1, channel_1 = make_request( + channel_1 = make_request( self.reactor, site_1, "POST", "register", {"username": "user", "type": "m.login.password", "password": "bar"}, - ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_1.code, 401) + ) + self.assertEqual(channel_1.code, 401) # Grab the session session = channel_1.json_body["session"] # also complete the dummy auth site_2 = self._hs_to_site[worker_hs_2] - request_2, channel_2 = make_request( + channel_2 = make_request( self.reactor, site_2, "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}, - ) # type: SynapseRequest, FakeChannel - self.assertEqual(request_2.code, 200) + ) + self.assertEqual(channel_2.code, 200) # We're given a registered user. self.assertEqual(channel_2.json_body["user_id"], "@user:test") diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 779745ae9d..fffdb742c8 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.federation_sender", {"send_federation": True}, - http_client=mock_client, + federation_http_client=mock_client, ) user = self.register_user("user", "pass") @@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client1, + federation_http_client=mock_client1, ) mock_client2 = Mock(spec=["put_json"]) @@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client2, + federation_http_client=mock_client2, ) user = self.register_user("user2", "pass") @@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender1", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client1, + federation_http_client=mock_client1, ) mock_client2 = Mock(spec=["put_json"]) @@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "sender2", "federation_sender_instances": ["sender1", "sender2"], }, - http_client=mock_client2, + federation_http_client=mock_client2, ) user = self.register_user("user3", "pass") diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py index 48b574ccbe..d1feca961f 100644 --- a/tests/replication/test_multi_media_repo.py +++ b/tests/replication/test_multi_media_repo.py @@ -48,7 +48,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): self.user_id = self.register_user("user", "pass") self.access_token = self.login("user", "pass") - self.reactor.lookups["example.com"] = "127.0.0.2" + self.reactor.lookups["example.com"] = "1.2.3.4" def default_config(self): conf = super().default_config() @@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase): the media which the caller should respond to. """ resource = hs.get_media_repository_resource().children[b"download"] - _, channel = make_request( + channel = make_request( self.reactor, FakeSite(resource), "GET", diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py index 67c27a089f..800ad94a04 100644 --- a/tests/replication/test_pusher_shard.py +++ b/tests/replication/test_pusher_shard.py @@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): device_display_name="pushy push", pushkey="a@example.com", lang=None, - data={"url": "https://push.example.com/push"}, + data={"url": "https://push.example.com/_matrix/push/v1/notify"}, ) ) @@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): self.make_worker_hs( "synapse.app.pusher", {"start_pushers": True}, - proxied_http_client=http_client_mock, + proxied_blacklisted_http_client=http_client_mock, ) event_id = self._create_pusher_and_send_msg("user") @@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): http_client_mock.post_json_get_json.assert_called_once() self.assertEqual( http_client_mock.post_json_get_json.call_args[0][0], - "https://push.example.com/push", + "https://push.example.com/_matrix/push/v1/notify", ) self.assertEqual( event_id, @@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "pusher1", "pusher_instances": ["pusher1", "pusher2"], }, - proxied_http_client=http_client_mock1, + proxied_blacklisted_http_client=http_client_mock1, ) http_client_mock2 = Mock(spec_set=["post_json_get_json"]) @@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): "worker_name": "pusher2", "pusher_instances": ["pusher1", "pusher2"], }, - proxied_http_client=http_client_mock2, + proxied_blacklisted_http_client=http_client_mock2, ) # We choose a user name that we know should go to pusher1. @@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): http_client_mock2.post_json_get_json.assert_not_called() self.assertEqual( http_client_mock1.post_json_get_json.call_args[0][0], - "https://push.example.com/push", + "https://push.example.com/_matrix/push/v1/notify", ) self.assertEqual( event_id, @@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase): http_client_mock2.post_json_get_json.assert_called_once() self.assertEqual( http_client_mock2.post_json_get_json.call_args[0][0], - "https://push.example.com/push", + "https://push.example.com/_matrix/push/v1/notify", ) self.assertEqual( event_id, diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py index 77fc3856d5..8d494ebc03 100644 --- a/tests/replication/test_sharded_event_persister.py +++ b/tests/replication/test_sharded_event_persister.py @@ -180,7 +180,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ) # Do an initial sync so that we're up to date. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token ) next_batch = channel.json_body["next_batch"] @@ -206,7 +206,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Check that syncing still gets the new event, despite the gap in the # stream IDs. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -236,7 +236,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token) first_event_in_room2 = response["event_id"] - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -261,7 +261,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token) self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token) - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -279,7 +279,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Paginating back in the first room should not produce any results, as # no events have happened in it. This tests that we are correctly # filtering results based on the vector clock portion. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -292,7 +292,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): # Paginating back on the second room should produce the first event # again. This tests that pagination isn't completely broken. - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -307,7 +307,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ) # Paginating forwards should give the same results - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", @@ -318,7 +318,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase): ) self.assertListEqual([], channel.json_body["chunk"]) - request, channel = make_request( + channel = make_request( self.reactor, sync_hs_site, "GET", |