diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 295c5d58a6..d5dce1f83f 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import Any, Callable, List, Optional, Tuple
+from typing import Any, Callable, Dict, List, Optional, Tuple
import attr
@@ -21,6 +21,7 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
from twisted.internet.protocol import Protocol
from twisted.internet.task import LoopingCall
from twisted.web.http import HTTPChannel
+from twisted.web.resource import Resource
from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
@@ -28,7 +29,7 @@ from synapse.app.generic_worker import (
)
from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest, SynapseSite
-from synapse.replication.http import ReplicationRestResource, streams
+from synapse.replication.http import ReplicationRestResource
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -54,10 +55,6 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
if not hiredis:
skip = "Requires hiredis"
- servlets = [
- streams.register_servlets,
- ]
-
def prepare(self, reactor, clock, hs):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(hs)
@@ -67,7 +64,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Make a new HomeServer object for the worker
self.reactor.lookups["testserv"] = "1.2.3.4"
self.worker_hs = self.setup_test_homeserver(
- http_client=None,
+ federation_http_client=None,
homeserver_to_use=GenericWorkerServer,
config=self._get_worker_hs_config(),
reactor=self.reactor,
@@ -88,6 +85,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self._client_transport = None
self._server_transport = None
+ def create_resource_dict(self) -> Dict[str, Resource]:
+ d = super().create_resource_dict()
+ d["/_synapse/replication"] = ReplicationRestResource(self.hs)
+ return d
+
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.generic_worker"
@@ -210,6 +212,9 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
# Fake in memory Redis server that servers can connect to.
self._redis_server = FakeRedisPubSubServer()
+ # We may have an attempt to connect to redis for the external cache already.
+ self.connect_any_redis_attempts()
+
store = self.hs.get_datastore()
self.database_pool = store.db_pool
@@ -264,7 +269,7 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
extra_config: Any extra config to use for this instances.
**kwargs: Options that get passed to `self.setup_test_homeserver`,
- useful to e.g. pass some mocks for things like `http_client`
+ useful to e.g. pass some mocks for things like `federation_http_client`
Returns:
The new worker HomeServer instance.
@@ -399,25 +404,23 @@ class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
fake one.
"""
clients = self.reactor.tcpClients
- self.assertEqual(len(clients), 1)
- (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
- self.assertEqual(host, "localhost")
- self.assertEqual(port, 6379)
+ while clients:
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop(0)
+ self.assertEqual(host, "localhost")
+ self.assertEqual(port, 6379)
- client_protocol = client_factory.buildProtocol(None)
- server_protocol = self._redis_server.buildProtocol(None)
+ client_protocol = client_factory.buildProtocol(None)
+ server_protocol = self._redis_server.buildProtocol(None)
- client_to_server_transport = FakeTransport(
- server_protocol, self.reactor, client_protocol
- )
- client_protocol.makeConnection(client_to_server_transport)
-
- server_to_client_transport = FakeTransport(
- client_protocol, self.reactor, server_protocol
- )
- server_protocol.makeConnection(server_to_client_transport)
+ client_to_server_transport = FakeTransport(
+ server_protocol, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
- return client_to_server_transport, server_to_client_transport
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, server_protocol
+ )
+ server_protocol.makeConnection(server_to_client_transport)
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
@@ -622,6 +625,12 @@ class FakeRedisPubSubProtocol(Protocol):
(channel,) = args
self._server.add_subscriber(self)
self.send(["subscribe", channel, 1])
+
+ # Since we use SET/GET to cache things we can safely no-op them.
+ elif command == b"SET":
+ self.send("OK")
+ elif command == b"GET":
+ self.send(None)
else:
raise Exception("Unknown command")
@@ -643,6 +652,8 @@ class FakeRedisPubSubProtocol(Protocol):
# We assume bytes are just unicode strings.
obj = obj.decode("utf-8")
+ if obj is None:
+ return "$-1\r\n"
if isinstance(obj, str):
return "${len}\r\n{str}\r\n".format(len=len(obj), str=obj)
if isinstance(obj, int):
diff --git a/tests/replication/test_auth.py b/tests/replication/test_auth.py
new file mode 100644
index 0000000000..f35a5235e1
--- /dev/null
+++ b/tests/replication/test_auth.py
@@ -0,0 +1,117 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.server import FakeChannel, make_request
+from tests.unittest import override_config
+
+logger = logging.getLogger(__name__)
+
+
+class WorkerAuthenticationTestCase(BaseMultiWorkerStreamTestCase):
+ """Test the authentication of HTTP calls between workers."""
+
+ servlets = [register.register_servlets]
+
+ def make_homeserver(self, reactor, clock):
+ config = self.default_config()
+ # This isn't a real configuration option but is used to provide the main
+ # homeserver and worker homeserver different options.
+ main_replication_secret = config.pop("main_replication_secret", None)
+ if main_replication_secret:
+ config["worker_replication_secret"] = main_replication_secret
+ return self.setup_test_homeserver(config=config)
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.client_reader"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+
+ return config
+
+ def _test_register(self) -> FakeChannel:
+ """Run the actual test:
+
+ 1. Create a worker homeserver.
+ 2. Start registration by providing a user/password.
+ 3. Complete registration by providing dummy auth (this hits the main synapse).
+ 4. Return the final request.
+
+ """
+ worker_hs = self.make_worker_hs("synapse.app.client_reader")
+ site = self._hs_to_site[worker_hs]
+
+ channel_1 = make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ )
+ self.assertEqual(channel_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ return make_request(
+ self.reactor,
+ site,
+ "POST",
+ "register",
+ {"auth": {"session": session, "type": "m.login.dummy"}},
+ )
+
+ def test_no_auth(self):
+ """With no authentication the request should finish.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
+
+ @override_config({"main_replication_secret": "my-secret"})
+ def test_missing_auth(self):
+ """If the main process expects a secret that is not provided, an error results.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 500)
+
+ @override_config(
+ {
+ "main_replication_secret": "my-secret",
+ "worker_replication_secret": "wrong-secret",
+ }
+ )
+ def test_unauthorized(self):
+ """If the main process receives the wrong secret, an error results.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 500)
+
+ @override_config({"worker_replication_secret": "my-secret"})
+ def test_authorized(self):
+ """The request should finish when the worker provides the authentication header.
+ """
+ channel = self._test_register()
+ self.assertEqual(channel.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
index 96801db473..4608b65a0c 100644
--- a/tests/replication/test_client_reader_shard.py
+++ b/tests/replication/test_client_reader_shard.py
@@ -14,27 +14,19 @@
# limitations under the License.
import logging
-from synapse.api.constants import LoginType
-from synapse.http.site import SynapseRequest
from synapse.rest.client.v2_alpha import register
from tests.replication._base import BaseMultiWorkerStreamTestCase
-from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
-from tests.server import FakeChannel, make_request
+from tests.server import make_request
logger = logging.getLogger(__name__)
class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
- """Base class for tests of the replication streams"""
+ """Test using one or more client readers for registration."""
servlets = [register.register_servlets]
- def prepare(self, reactor, clock, hs):
- self.recaptcha_checker = DummyRecaptchaChecker(hs)
- auth_handler = hs.get_auth_handler()
- auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
-
def _get_worker_hs_config(self) -> dict:
config = self.default_config()
config["worker_app"] = "synapse.app.client_reader"
@@ -48,27 +40,27 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs = self.make_worker_hs("synapse.app.client_reader")
site = self._hs_to_site[worker_hs]
- request_1, channel_1 = make_request(
+ channel_1 = make_request(
self.reactor,
site,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_1.code, 401)
+ )
+ self.assertEqual(channel_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
- request_2, channel_2 = make_request(
+ channel_2 = make_request(
self.reactor,
site,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_2.code, 200)
+ )
+ self.assertEqual(channel_2.code, 200)
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
@@ -80,28 +72,28 @@ class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
site_1 = self._hs_to_site[worker_hs_1]
- request_1, channel_1 = make_request(
+ channel_1 = make_request(
self.reactor,
site_1,
"POST",
"register",
{"username": "user", "type": "m.login.password", "password": "bar"},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_1.code, 401)
+ )
+ self.assertEqual(channel_1.code, 401)
# Grab the session
session = channel_1.json_body["session"]
# also complete the dummy auth
site_2 = self._hs_to_site[worker_hs_2]
- request_2, channel_2 = make_request(
+ channel_2 = make_request(
self.reactor,
site_2,
"POST",
"register",
{"auth": {"session": session, "type": "m.login.dummy"}},
- ) # type: SynapseRequest, FakeChannel
- self.assertEqual(request_2.code, 200)
+ )
+ self.assertEqual(channel_2.code, 200)
# We're given a registered user.
self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
index 779745ae9d..fffdb742c8 100644
--- a/tests/replication/test_federation_sender_shard.py
+++ b/tests/replication/test_federation_sender_shard.py
@@ -50,7 +50,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.federation_sender",
{"send_federation": True},
- http_client=mock_client,
+ federation_http_client=mock_client,
)
user = self.register_user("user", "pass")
@@ -81,7 +81,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client1,
+ federation_http_client=mock_client1,
)
mock_client2 = Mock(spec=["put_json"])
@@ -93,7 +93,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client2,
+ federation_http_client=mock_client2,
)
user = self.register_user("user2", "pass")
@@ -144,7 +144,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender1",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client1,
+ federation_http_client=mock_client1,
)
mock_client2 = Mock(spec=["put_json"])
@@ -156,7 +156,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "sender2",
"federation_sender_instances": ["sender1", "sender2"],
},
- http_client=mock_client2,
+ federation_http_client=mock_client2,
)
user = self.register_user("user3", "pass")
diff --git a/tests/replication/test_multi_media_repo.py b/tests/replication/test_multi_media_repo.py
index 48b574ccbe..d1feca961f 100644
--- a/tests/replication/test_multi_media_repo.py
+++ b/tests/replication/test_multi_media_repo.py
@@ -48,7 +48,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
self.user_id = self.register_user("user", "pass")
self.access_token = self.login("user", "pass")
- self.reactor.lookups["example.com"] = "127.0.0.2"
+ self.reactor.lookups["example.com"] = "1.2.3.4"
def default_config(self):
conf = super().default_config()
@@ -68,7 +68,7 @@ class MediaRepoShardTestCase(BaseMultiWorkerStreamTestCase):
the media which the caller should respond to.
"""
resource = hs.get_media_repository_resource().children[b"download"]
- _, channel = make_request(
+ channel = make_request(
self.reactor,
FakeSite(resource),
"GET",
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
index 67c27a089f..800ad94a04 100644
--- a/tests/replication/test_pusher_shard.py
+++ b/tests/replication/test_pusher_shard.py
@@ -67,7 +67,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
device_display_name="pushy push",
pushkey="a@example.com",
lang=None,
- data={"url": "https://push.example.com/push"},
+ data={"url": "https://push.example.com/_matrix/push/v1/notify"},
)
)
@@ -98,7 +98,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
self.make_worker_hs(
"synapse.app.pusher",
{"start_pushers": True},
- proxied_http_client=http_client_mock,
+ proxied_blacklisted_http_client=http_client_mock,
)
event_id = self._create_pusher_and_send_msg("user")
@@ -109,7 +109,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -133,7 +133,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher1",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_http_client=http_client_mock1,
+ proxied_blacklisted_http_client=http_client_mock1,
)
http_client_mock2 = Mock(spec_set=["post_json_get_json"])
@@ -148,7 +148,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
"worker_name": "pusher2",
"pusher_instances": ["pusher1", "pusher2"],
},
- proxied_http_client=http_client_mock2,
+ proxied_blacklisted_http_client=http_client_mock2,
)
# We choose a user name that we know should go to pusher1.
@@ -161,7 +161,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_not_called()
self.assertEqual(
http_client_mock1.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
@@ -183,7 +183,7 @@ class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
http_client_mock2.post_json_get_json.assert_called_once()
self.assertEqual(
http_client_mock2.post_json_get_json.call_args[0][0],
- "https://push.example.com/push",
+ "https://push.example.com/_matrix/push/v1/notify",
)
self.assertEqual(
event_id,
diff --git a/tests/replication/test_sharded_event_persister.py b/tests/replication/test_sharded_event_persister.py
index 77fc3856d5..8d494ebc03 100644
--- a/tests/replication/test_sharded_event_persister.py
+++ b/tests/replication/test_sharded_event_persister.py
@@ -180,7 +180,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Do an initial sync so that we're up to date.
- request, channel = make_request(
+ channel = make_request(
self.reactor, sync_hs_site, "GET", "/sync", access_token=access_token
)
next_batch = channel.json_body["next_batch"]
@@ -206,7 +206,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Check that syncing still gets the new event, despite the gap in the
# stream IDs.
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -236,7 +236,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
response = self.helper.send(room_id2, body="Hi!", tok=self.other_access_token)
first_event_in_room2 = response["event_id"]
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -261,7 +261,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
self.helper.send(room_id1, body="Hi again!", tok=self.other_access_token)
self.helper.send(room_id2, body="Hi again!", tok=self.other_access_token)
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -279,7 +279,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back in the first room should not produce any results, as
# no events have happened in it. This tests that we are correctly
# filtering results based on the vector clock portion.
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -292,7 +292,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
# Paginating back on the second room should produce the first event
# again. This tests that pagination isn't completely broken.
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -307,7 +307,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
# Paginating forwards should give the same results
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
@@ -318,7 +318,7 @@ class EventPersisterShardTestCase(BaseMultiWorkerStreamTestCase):
)
self.assertListEqual([], channel.json_body["chunk"])
- request, channel = make_request(
+ channel = make_request(
self.reactor,
sync_hs_site,
"GET",
|