diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 295c5d58a6..3379189785 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- servlets = [
- streams.register_servlets,
- ]
-
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -67,7 +64,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
- http_client=None,
+ federation_http_client=None,
homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(),
reactor=self.reactor,
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+ return d
+
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
@@ -264,7 +266,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
- useful to e.g. pass some mocks for things like `http_client`
+ useful to e.g. pass some mocks for things like `federation_http_client`
Returns:
The new worker HomeServer instance.
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
new file mode 100644
index 0000000000..fe9e4d5f9a
--- /dev/null
+++ b/tests/replication/test_auth.py
@@ -0,0 +1,119 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+from typing import Tuple
+
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, make_request
+from tests.unittest import override_config
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
+ """Test the authentication of HTTP calls between workers."""
+
+ servlets = [register.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ # This isn't a real configuration option but is used to provide the main
+ # homeserver and worker homeserver different options.
+ main_replication_secret = config.pop("main_replication_secret", None)
+ if main_replication_secret:
+ config["worker_replication_secret"] = main_replication_secret
+ return self.setup_test_homeserver(config=config)
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.client_reader"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+
+ return config
+
+ def _test_register(self) -> Tuple[SynapseRequest, FakeChannel]:
+ """Run the actual test:
+
+ 1. Create a worker homeserver.
+ 2. Start registration by providing a user/password.
+ 3. Complete registration by providing dummy auth (this hits the main synapse).
+ 4. Return the final request.
+
+ """
+ worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ site = self._hs_to_site[worker_hs]
+
+ request_1, channel_1 = make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ ) # type: SynapseRequest, FakeChannel
+ self.assertEqual(request_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ return make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
+ )
+
+ def test_no_auth(self):
+ """With no authentication the request should finish.
+ """
+ request, channel = self._test_register()
+ self.assertEqual(request.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+ @override_config({"main_replication_secret": "my-secret"})
+ def test_missing_auth(self):
+ """If the main process expects a secret that is not provided, an error results.
+ """
+ request, channel = self._test_register()
+ self.assertEqual(request.code, 500)
+
+ @override_config(
+ {
+ "main_replication_secret": "my-secret",
+ "worker_replication_secret": "wrong-secret",
+ }
+ )
+ def test_unauthorized(self):
+ """If the main process receives the wrong secret, an error results.
+ """
+ request, channel = self._test_register()
+ self.assertEqual(request.code, 500)
+
+ @override_config({"worker_replication_secret": "my-secret"})
+ def test_authorized(self):
+ """The request should finish when the worker provides the authentication header.
+ """
+ request, channel = self._test_register()
+ self.assertEqual(request.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 96801db473..fdaad3d8ad 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,20 @@
# limitations under the License.
import logging
-from synapse.api.constants import LoginType
from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
from tests.server import FakeChannel, make_request
logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
- """Base class for tests of the replication streams"""
+ """Test using one or more client readers for registration."""
servlets = [register.register_servlets]
- def prepare(self, reactor, clock, hs):
- self.recaptcha_checker = DummyRecaptchaChecker(hs)
- auth_handler = hs.get_auth_handler()
- auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
-
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 779745ae9d..fffdb742c8 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.federation_sender",
{"send_federation": True},
- http_client=mock_client,
+ federation_http_client=mock_client,
)
user = self.register_user("user", "pass")
@@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client1,
+ federation_http_client=mock_client1,
)
mock_client2 = Mock(spec=["put_json"])
@@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client2,
+ federation_http_client=mock_client2,
)
user = self.register_user("user2", "pass")
@@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client1,
+ federation_http_client=mock_client1,
)
mock_client2 = Mock(spec=["put_json"])
@@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client2,
+ federation_http_client=mock_client2,
)
user = self.register_user("user3", "pass")
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 48b574ccbe..83afd9fd2f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -48,7 +48,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
- self.reactor.lookups["example.com"] = "127.0.0.2"
+ self.reactor.lookups["example.com"] = "1.2.3.4"
def default_config(self):
conf = super().default_config()
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 67c27a089f..800ad94a04 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "https://push.example.com/push"},
+ data={"url": "https://push.example.com/_matrix/push/v1/notify"},
)
)
@@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.pusher",
{"start_pushers": True},
- proxied_http_client=http_client_mock,
+ proxied_blacklisted_http_client=http_client_mock,
)
event_id = self._create_pusher_and_send_msg("user")
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_http_client=http_client_mock1,
+ proxied_blacklisted_http_client=http_client_mock1,
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_http_client=http_client_mock2,
+ proxied_blacklisted_http_client=http_client_mock2,
)
# We choose a user name that we know should go to pusher1.
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_not_called()
self.assertEqual(
http_client_mock1.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock2.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
|