summary refs log tree commit diff
path: root/tests/replication
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication')
-rw-r--r--tests/replication/_base.py59
-rw-r--r--tests/replication/test_auth.py117
-rw-r--r--tests/replication/test_client_reader_shard.py36
-rw-r--r--tests/replication/test_federation_sender_shard.py10
-rw-r--r--tests/replication/test_multi_media_repo.py4
-rw-r--r--tests/replication/test_pusher_shard.py14
-rw-r--r--tests/replication/test_sharded_event_persister.py16
7 files changed, 188 insertions, 68 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 295c5d58a6..d5dce1f83f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
 
 import attr
 
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.protocol import Protocol
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
 
 from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
 )
 from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
     if not hiredis:
         skip = "Requires hiredis"
 
-    servlets = [
-        streams.register_servlets,
-    ]
-
     def prepare(self, reactor, clock, hs):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
@@ -67,7 +64,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # Make a new HomeServer object for the worker
         self.reactor.lookups["testserv"] = "1.2.3.4"
         self.worker_hs = self.setup_test_homeserver(
-            http_client=None,
+            federation_http_client=None,
             homeserver_to_use=GenericWorkerServer,
             config=self._get_worker_hs_config(),
             reactor=self.reactor,
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._client_transport = None
         self._server_transport = None
 
+    def create_resource_dict(self) -> Dict[str, Resource]:
+        d = super().create_resource_dict()
+        d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+        return d
+
     def _get_worker_hs_config(self) -> dict:
         config = self.default_config()
         config["worker_app"] = "synapse.app.generic_worker"
@@ -210,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         # Fake in memory Redis server that servers can connect to.
         self._redis_server = FakeRedisPubSubServer()
 
+        # We may have an attempt to connect to redis for the external cache already.
+        self.connect_any_redis_attempts()
+
         store = self.hs.get_datastore()
         self.database_pool = store.db_pool
 
@@ -264,7 +269,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
             worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
             extra_config: Any extra config to use for this instances.
             **kwargs: Options that get passed to `self.setup_test_homeserver`,
-                useful to e.g. pass some mocks for things like `http_client`
+                useful to e.g. pass some mocks for things like `federation_http_client`
 
         Returns:
             The new worker HomeServer instance.
@@ -399,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
         fake one.
         """
         clients = self.reactor.tcpClients
-        self.assertEqual(len(clients), 1)
-        (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
-        self.assertEqual(host, "localhost")
-        self.assertEqual(port, 6379)
+        while clients:
+            (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+            self.assertEqual(host, "localhost")
+            self.assertEqual(port, 6379)
 
-        client_protocol = client_factory.buildProtocol(None)
-        server_protocol = self._redis_server.buildProtocol(None)
+            client_protocol = client_factory.buildProtocol(None)
+            server_protocol = self._redis_server.buildProtocol(None)
 
-        client_to_server_transport = FakeTransport(
-            server_protocol, self.reactor, client_protocol
-        )
-        client_protocol.makeConnection(client_to_server_transport)
-
-        server_to_client_transport = FakeTransport(
-            client_protocol, self.reactor, server_protocol
-        )
-        server_protocol.makeConnection(server_to_client_transport)
+            client_to_server_transport = FakeTransport(
+                server_protocol, self.reactor, client_protocol
+            )
+            client_protocol.makeConnection(client_to_server_transport)
 
-        return client_to_server_transport, server_to_client_transport
+            server_to_client_transport = FakeTransport(
+                client_protocol, self.reactor, server_protocol
+            )
+            server_protocol.makeConnection(server_to_client_transport)
 
 
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -622,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
             (channel,) = args
             self._server.add_subscriber(self)
             self.send(["subscribe", channel, 1])
+
+        # Since we use SET/GET to cache things we can safely no-op them.
+        elif command == b"SET":
+            self.send("OK")
+        elif command == b"GET":
+            self.send(None)
         else:
             raise Exception("Unknown command")
 
@@ -643,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
             # We assume bytes are just unicode strings.
             obj = obj.decode("utf-8")
 
+        if obj is None:
+            return "$-1\r\n"
         if isinstance(obj, str):
             return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
         if isinstance(obj, int):
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
new file mode 100644
index 0000000000..f35a5235e1
--- /dev/null
+++ b/tests/replication/test_auth.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, make_request
+from tests.unittest import override_config
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
+    """Test the authentication of HTTP calls between workers."""
+
+    servlets = [register.register_servlets]
+
+    def make_homeserver(self, reactor, clock):
+        config = self.default_config()
+        # This isn't a real configuration option but is used to provide the main
+        # homeserver and worker homeserver different options.
+        main_replication_secret = config.pop("main_replication_secret", None)
+        if main_replication_secret:
+            config["worker_replication_secret"] = main_replication_secret
+        return self.setup_test_homeserver(config=config)
+
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_app"] = "synapse.app.client_reader"
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+
+        return config
+
+    def _test_register(self) -> FakeChannel:
+        """Run the actual test:
+
+        1. Create a worker homeserver.
+        2. Start registration by providing a user/password.
+        3. Complete registration by providing dummy auth (this hits the main synapse).
+        4. Return the final request.
+
+        """
+        worker_hs = self.make_worker_hs("synapse.app.client_reader")
+        site = self._hs_to_site[worker_hs]
+
+        channel_1 = make_request(
+            self.reactor,
+            site,
+            "POST",
+            "register",
+            {"username": "user", "type": "m.login.password", "password": "bar"},
+        )
+        self.assertEqual(channel_1.code, 401)
+
+        # Grab the session
+        session = channel_1.json_body["session"]
+
+        # also complete the dummy auth
+        return make_request(
+            self.reactor,
+            site,
+            "POST",
+            "register",
+            {"auth": {"session": session, "type": "m.login.dummy"}},
+        )
+
+    def test_no_auth(self):
+        """With no authentication the request should finish.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+    @override_config({"main_replication_secret": "my-secret"})
+    def test_missing_auth(self):
+        """If the main process expects a secret that is not provided, an error results.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 500)
+
+    @override_config(
+        {
+            "main_replication_secret": "my-secret",
+            "worker_replication_secret": "wrong-secret",
+        }
+    )
+    def test_unauthorized(self):
+        """If the main process receives the wrong secret, an error results.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 500)
+
+    @override_config({"worker_replication_secret": "my-secret"})
+    def test_authorized(self):
+        """The request should finish when the worker provides the authentication header.
+        """
+        channel = self._test_register()
+        self.assertEqual(channel.code, 200)
+
+        # We're given a registered user.
+        self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 96801db473..4608b65a0c 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,19 @@
 # limitations under the License.
 import logging
 
-from synapse.api.constants import LoginType
-from synapse.http.site import SynapseRequest
 from synapse.rest.client.v2_alpha import register
 
 from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel, make_request
+from tests.server import make_request
 
 logger = logging.getLogger(__name__)
 
 
 class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
-    """Base class for tests of the replication streams"""
+    """Test using one or more client readers for registration."""
 
     servlets = [register.register_servlets]
 
-    def prepare(self, reactor, clock, hs):
-        self.recaptcha_checker = DummyRecaptchaChecker(hs)
-        auth_handler = hs.get_auth_handler()
-        auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
-
     def _get_worker_hs_config(self) -> dict:
         config = self.default_config()
         config["worker_app"] = "synapse.app.client_reader"
@@ -48,27 +40,27 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         worker_hs = self.make_worker_hs("synapse.app.client_reader")
         site = self._hs_to_site[worker_hs]
 
-        request_1, channel_1 = make_request(
+        channel_1 = make_request(
             self.reactor,
             site,
             "POST",
             "register",
             {"username": "user", "type": "m.login.password", "password": "bar"},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_1.code, 401)
+        )
+        self.assertEqual(channel_1.code, 401)
 
         # Grab the session
         session = channel_1.json_body["session"]
 
         # also complete the dummy auth
-        request_2, channel_2 = make_request(
+        channel_2 = make_request(
             self.reactor,
             site,
             "POST",
             "register",
             {"auth": {"session": session, "type": "m.login.dummy"}},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_2.code, 200)
+        )
+        self.assertEqual(channel_2.code, 200)
 
         # We're given a registered user.
         self.assertEqual(channel_2.json_body["user_id"], "@user:test")
@@ -80,28 +72,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
         worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
 
         site_1 = self._hs_to_site[worker_hs_1]
-        request_1, channel_1 = make_request(
+        channel_1 = make_request(
             self.reactor,
             site_1,
             "POST",
             "register",
             {"username": "user", "type": "m.login.password", "password": "bar"},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_1.code, 401)
+        )
+        self.assertEqual(channel_1.code, 401)
 
         # Grab the session
         session = channel_1.json_body["session"]
 
         # also complete the dummy auth
         site_2 = self._hs_to_site[worker_hs_2]
-        request_2, channel_2 = make_request(
+        channel_2 = make_request(
             self.reactor,
             site_2,
             "POST",
             "register",
             {"auth": {"session": session, "type": "m.login.dummy"}},
-        )  # type: SynapseRequest, FakeChannel
-        self.assertEqual(request_2.code, 200)
+        )
+        self.assertEqual(channel_2.code, 200)
 
         # We're given a registered user.
         self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 779745ae9d..fffdb742c8 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
         self.make_worker_hs(
             "synapse.app.federation_sender",
             {"send_federation": True},
-            http_client=mock_client,
+            federation_http_client=mock_client,
         )
 
         user = self.register_user("user", "pass")
@@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender1",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client1,
+            federation_http_client=mock_client1,
         )
 
         mock_client2 = Mock(spec=["put_json"])
@@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender2",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client2,
+            federation_http_client=mock_client2,
         )
 
         user = self.register_user("user2", "pass")
@@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender1",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client1,
+            federation_http_client=mock_client1,
         )
 
         mock_client2 = Mock(spec=["put_json"])
@@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "sender2",
                 "federation_sender_instances": ["sender1", "sender2"],
             },
-            http_client=mock_client2,
+            federation_http_client=mock_client2,
         )
 
         user = self.register_user("user3", "pass")
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 48b574ccbe..d1feca961f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -48,7 +48,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
         self.user_id = self.register_user("user", "pass")
         self.access_token = self.login("user", "pass")
 
-        self.reactor.lookups["example.com"] = "127.0.0.2"
+        self.reactor.lookups["example.com"] = "1.2.3.4"
 
     def default_config(self):
         conf = super().default_config()
@@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
             the media which the caller should respond to.
         """
         resource = hs.get_media_repository_resource().children[b"download"]
-        _, channel = make_request(
+        channel = make_request(
             self.reactor,
             FakeSite(resource),
             "GET",
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 67c27a089f..800ad94a04 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 device_display_name="pushy push",
                 pushkey="a@example.com",
                 lang=None,
-                data={"url": "https://push.example.com/push"},
+                data={"url": "https://push.example.com/_matrix/push/v1/notify"},
             )
         )
 
@@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         self.make_worker_hs(
             "synapse.app.pusher",
             {"start_pushers": True},
-            proxied_http_client=http_client_mock,
+            proxied_blacklisted_http_client=http_client_mock,
         )
 
         event_id = self._create_pusher_and_send_msg("user")
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock.post_json_get_json.assert_called_once()
         self.assertEqual(
             http_client_mock.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
@@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "pusher1",
                 "pusher_instances": ["pusher1", "pusher2"],
             },
-            proxied_http_client=http_client_mock1,
+            proxied_blacklisted_http_client=http_client_mock1,
         )
 
         http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
                 "worker_name": "pusher2",
                 "pusher_instances": ["pusher1", "pusher2"],
             },
-            proxied_http_client=http_client_mock2,
+            proxied_blacklisted_http_client=http_client_mock2,
         )
 
         # We choose a user name that we know should go to pusher1.
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock2.post_json_get_json.assert_not_called()
         self.assertEqual(
             http_client_mock1.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
         http_client_mock2.post_json_get_json.assert_called_once()
         self.assertEqual(
             http_client_mock2.post_json_get_json.call_args[0][0],
-            "https://push.example.com/push",
+            "https://push.example.com/_matrix/push/v1/notify",
         )
         self.assertEqual(
             event_id,
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 77fc3856d5..8d494ebc03 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -180,7 +180,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         # Do an initial sync so that we're up to date.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
         )
         next_batch = channel.json_body["next_batch"]
@@ -206,7 +206,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
 
         # Check that syncing still gets the new event, despite the gap in the
         # stream IDs.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -236,7 +236,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
         first_event_in_room2 = response["event_id"]
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -261,7 +261,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
         self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -279,7 +279,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         # Paginating back in the first room should not produce any results, as
         # no events have happened in it. This tests that we are correctly
         # filtering results based on the vector clock portion.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -292,7 +292,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
 
         # Paginating back on the second room should produce the first event
         # again. This tests that pagination isn't completely broken.
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -307,7 +307,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
 
         # Paginating forwards should give the same results
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",
@@ -318,7 +318,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
         )
         self.assertListEqual([], channel.json_body["chunk"])
 
-        request, channel = make_request(
+        channel = make_request(
             self.reactor,
             sync_hs_site,
             "GET",