diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9d4f0bbe44..06575ba0a6 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, List, Optional, Tuple
+from typing import Any, Callable, List, Optional, Tuple
import attr
@@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
+from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
-from synapse.replication.http import streams
+from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -35,7 +36,7 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeTransport
+from tests.server import FakeTransport, render
logger = logging.getLogger(__name__)
@@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
+class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
+ """Base class for tests running multiple workers.
+
+ Automatically handle HTTP replication requests from workers to master,
+ unlike `BaseStreamTestCase`.
+ """
+
+ servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
+
+ def setUp(self):
+ super().setUp()
+
+ # build a replication server
+ self.server_factory = ReplicationStreamProtocolFactory(self.hs)
+ self.streamer = self.hs.get_replication_streamer()
+
+ store = self.hs.get_datastore()
+ self.database = store.db
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ self._worker_hs_to_resource = {}
+
+ # When we see a connection attempt to the master replication listener we
+ # automatically set up the connection. This is so that tests don't
+ # manually have to go and explicitly set it up each time (plus sometimes
+ # it is impossible to write the handling explicitly in the tests).
+ self.reactor.add_tcp_client_callback(
+ "1.2.3.4", 8765, self._handle_http_replication_attempt
+ )
+
+ def create_test_json_resource(self):
+ """Overrides `HomeserverTestCase.create_test_json_resource`.
+ """
+ # We override this so that it automatically registers all the HTTP
+ # replication servlets, without having to explicitly do that in all
+ # subclassses.
+
+ resource = ReplicationRestResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(self.hs, resource)
+
+ return resource
+
+ def make_worker_hs(
+ self, worker_app: str, extra_config: dict = {}, **kwargs
+ ) -> HomeServer:
+ """Make a new worker HS instance, correctly connecting replcation
+ stream to the master HS.
+
+ Args:
+ worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
+ extra_config: Any extra config to use for this instances.
+ **kwargs: Options that get passed to `self.setup_test_homeserver`,
+ useful to e.g. pass some mocks for things like `http_client`
+
+ Returns:
+ The new worker HomeServer instance.
+ """
+
+ config = self._get_worker_hs_config()
+ config["worker_app"] = worker_app
+ config.update(extra_config)
+
+ worker_hs = self.setup_test_homeserver(
+ homeserverToUse=GenericWorkerServer,
+ config=config,
+ reactor=self.reactor,
+ **kwargs
+ )
+
+ store = worker_hs.get_datastore()
+ store.db._db_pool = self.database._db_pool
+
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
+
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
+
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
+
+ # Set up a resource for the worker
+ resource = ReplicationRestResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(worker_hs, resource)
+
+ self._worker_hs_to_resource[worker_hs] = resource
+
+ return worker_hs
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
+ render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+
+ def replicate(self):
+ """Tell the master side of replication that something has happened, and then
+ wait for the replication to occur.
+ """
+ self.streamer.on_notifier_poke()
+ self.pump()
+
+ def _handle_http_replication_attempt(self):
+ """Handles a connection attempt to the master replication HTTP
+ listener.
+ """
+
+ # We should have at least one outbound connection attempt, where the
+ # last is one to the HTTP repication IP/port.
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8765)
+
+ # Set up client side protocol
+ client_protocol = client_factory.buildProtocol(None)
+
+ request_factory = OneShotRequestFactory()
+
+ # Set up the server side protocol
+ channel = _PushHTTPChannel(self.reactor)
+ channel.requestFactory = request_factory
+ channel.site = self.site
+
+ # Connect client to server and vice versa.
+ client_to_server_transport = FakeTransport(
+ channel, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, channel
+ )
+ channel.makeConnection(server_to_client_transport)
+
+ # Note: at this point we've wired everything up, but we need to return
+ # before the data starts flowing over the connections as this is called
+ # inside `connecTCP` before the connection has been passed back to the
+ # code that requested the TCP connection.
+
+
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
+ def checkPersistence(self, request, version):
+ """Check whether the connection can be re-used
+ """
+ # We hijack this to always say no for ease of wiring stuff up in
+ # `handle_http_replication_attempt`.
+ request.responseHeaders.setRawHeaders(b"connection", [b"close"])
+ return False
+
class _PullToPushProducer:
"""A push producer that wraps a pull producer.
|