summary refs log tree commit diff
path: root/tests/replication/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/_base.py')
-rw-r--r--tests/replication/_base.py168
1 files changed, 165 insertions, 3 deletions
diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9d4f0bbe44..06575ba0a6 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, List, Optional, Tuple
+from typing import Any, Callable, List, Optional, Tuple
 
 import attr
 
@@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
     GenericWorkerReplicationHandler,
     GenericWorkerServer,
 )
+from synapse.http.server import JsonResource
 from synapse.http.site import SynapseRequest
-from synapse.replication.http import streams
+from synapse.replication.http import ReplicationRestResource, streams
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -35,7 +36,7 @@ from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
-from tests.server import FakeTransport
+from tests.server import FakeTransport, render
 
 logger = logging.getLogger(__name__)
 
@@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(request.method, b"GET")
 
 
+class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
+    """Base class for tests running multiple workers.
+
+    Automatically handle HTTP replication requests from workers to master,
+    unlike `BaseStreamTestCase`.
+    """
+
+    servlets = []  # type: List[Callable[[HomeServer, JsonResource], None]]
+
+    def setUp(self):
+        super().setUp()
+
+        # build a replication server
+        self.server_factory = ReplicationStreamProtocolFactory(self.hs)
+        self.streamer = self.hs.get_replication_streamer()
+
+        store = self.hs.get_datastore()
+        self.database = store.db
+
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+        self._worker_hs_to_resource = {}
+
+        # When we see a connection attempt to the master replication listener we
+        # automatically set up the connection. This is so that tests don't
+        # manually have to go and explicitly set it up each time (plus sometimes
+        # it is impossible to write the handling explicitly in the tests).
+        self.reactor.add_tcp_client_callback(
+            "1.2.3.4", 8765, self._handle_http_replication_attempt
+        )
+
+    def create_test_json_resource(self):
+        """Overrides `HomeserverTestCase.create_test_json_resource`.
+        """
+        # We override this so that it automatically registers all the HTTP
+        # replication servlets, without having to explicitly do that in all
+        # subclassses.
+
+        resource = ReplicationRestResource(self.hs)
+
+        for servlet in self.servlets:
+            servlet(self.hs, resource)
+
+        return resource
+
+    def make_worker_hs(
+        self, worker_app: str, extra_config: dict = {}, **kwargs
+    ) -> HomeServer:
+        """Make a new worker HS instance, correctly connecting replcation
+        stream to the master HS.
+
+        Args:
+            worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
+            extra_config: Any extra config to use for this instances.
+            **kwargs: Options that get passed to `self.setup_test_homeserver`,
+                useful to e.g. pass some mocks for things like `http_client`
+
+        Returns:
+            The new worker HomeServer instance.
+        """
+
+        config = self._get_worker_hs_config()
+        config["worker_app"] = worker_app
+        config.update(extra_config)
+
+        worker_hs = self.setup_test_homeserver(
+            homeserverToUse=GenericWorkerServer,
+            config=config,
+            reactor=self.reactor,
+            **kwargs
+        )
+
+        store = worker_hs.get_datastore()
+        store.db._db_pool = self.database._db_pool
+
+        repl_handler = ReplicationCommandHandler(worker_hs)
+        client = ClientReplicationStreamProtocol(
+            worker_hs, "client", "test", self.clock, repl_handler,
+        )
+        server = self.server_factory.buildProtocol(None)
+
+        client_transport = FakeTransport(server, self.reactor)
+        client.makeConnection(client_transport)
+
+        server_transport = FakeTransport(client, self.reactor)
+        server.makeConnection(server_transport)
+
+        # Set up a resource for the worker
+        resource = ReplicationRestResource(self.hs)
+
+        for servlet in self.servlets:
+            servlet(worker_hs, resource)
+
+        self._worker_hs_to_resource[worker_hs] = resource
+
+        return worker_hs
+
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+        return config
+
+    def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
+        render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+
+    def replicate(self):
+        """Tell the master side of replication that something has happened, and then
+        wait for the replication to occur.
+        """
+        self.streamer.on_notifier_poke()
+        self.pump()
+
+    def _handle_http_replication_attempt(self):
+        """Handles a connection attempt to the master replication HTTP
+        listener.
+        """
+
+        # We should have at least one outbound connection attempt, where the
+        # last is one to the HTTP repication IP/port.
+        clients = self.reactor.tcpClients
+        self.assertGreaterEqual(len(clients), 1)
+        (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+        self.assertEqual(host, "1.2.3.4")
+        self.assertEqual(port, 8765)
+
+        # Set up client side protocol
+        client_protocol = client_factory.buildProtocol(None)
+
+        request_factory = OneShotRequestFactory()
+
+        # Set up the server side protocol
+        channel = _PushHTTPChannel(self.reactor)
+        channel.requestFactory = request_factory
+        channel.site = self.site
+
+        # Connect client to server and vice versa.
+        client_to_server_transport = FakeTransport(
+            channel, self.reactor, client_protocol
+        )
+        client_protocol.makeConnection(client_to_server_transport)
+
+        server_to_client_transport = FakeTransport(
+            client_protocol, self.reactor, channel
+        )
+        channel.makeConnection(server_to_client_transport)
+
+        # Note: at this point we've wired everything up, but we need to return
+        # before the data starts flowing over the connections as this is called
+        # inside `connecTCP` before the connection has been passed back to the
+        # code that requested the TCP connection.
+
+
 class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
 
@@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
             # We need to manually stop the _PullToPushProducer.
             self._pull_to_push_producer.stop()
 
+    def checkPersistence(self, request, version):
+        """Check whether the connection can be re-used
+        """
+        # We hijack this to always say no for ease of wiring stuff up in
+        # `handle_http_replication_attempt`.
+        request.responseHeaders.setRawHeaders(b"connection", [b"close"])
+        return False
+
 
 class _PullToPushProducer:
     """A push producer that wraps a pull producer.