summary refs log tree commit diff
path: root/tests/replication/test_federation_sender_shard.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/test_federation_sender_shard.py')
-rw-r--r--tests/replication/test_federation_sender_shard.py286
1 files changed, 286 insertions, 0 deletions
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
new file mode 100644
index 0000000000..519a2dc510
--- /dev/null
+++ b/tests/replication/test_federation_sender_shard.py
@@ -0,0 +1,286 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.app.generic_worker import GenericWorkerServer
+from synapse.events.builder import EventBuilderFactory
+from synapse.replication.http import streams
+from synapse.replication.tcp.handler import ReplicationCommandHandler
+from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
+from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.rest.admin import register_servlets_for_client_rest_resource
+from synapse.rest.client.v1 import login, room
+from synapse.types import UserID
+
+from tests import unittest
+from tests.server import FakeTransport
+
+logger = logging.getLogger(__name__)
+
+
+class BaseStreamTestCase(unittest.HomeserverTestCase):
+    """Base class for tests of the replication streams"""
+
+    servlets = [
+        streams.register_servlets,
+    ]
+
+    def prepare(self, reactor, clock, hs):
+        # build a replication server
+        self.server_factory = ReplicationStreamProtocolFactory(hs)
+        self.streamer = hs.get_replication_streamer()
+
+        store = hs.get_datastore()
+        self.database = store.db
+
+        self.reactor.lookups["testserv"] = "1.2.3.4"
+
+    def default_config(self):
+        conf = super().default_config()
+        conf["send_federation"] = False
+        return conf
+
+    def make_worker_hs(self, extra_config={}):
+        config = self._get_worker_hs_config()
+        config.update(extra_config)
+
+        mock_federation_client = Mock(spec=["put_json"])
+        mock_federation_client.put_json.side_effect = lambda *_, **__: defer.succeed({})
+
+        worker_hs = self.setup_test_homeserver(
+            http_client=mock_federation_client,
+            homeserverToUse=GenericWorkerServer,
+            config=config,
+            reactor=self.reactor,
+        )
+
+        store = worker_hs.get_datastore()
+        store.db._db_pool = self.database._db_pool
+
+        repl_handler = ReplicationCommandHandler(worker_hs)
+        client = ClientReplicationStreamProtocol(
+            worker_hs, "client", "test", self.clock, repl_handler,
+        )
+        server = self.server_factory.buildProtocol(None)
+
+        client_transport = FakeTransport(server, self.reactor)
+        client.makeConnection(client_transport)
+
+        server_transport = FakeTransport(client, self.reactor)
+        server.makeConnection(server_transport)
+
+        return worker_hs
+
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_app"] = "synapse.app.federation_sender"
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+        return config
+
+    def replicate(self):
+        """Tell the master side of replication that something has happened, and then
+        wait for the replication to occur.
+        """
+        self.streamer.on_notifier_poke()
+        self.pump()
+
+    def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+        room = self.helper.create_room_as(user, tok=token)
+        store = self.hs.get_datastore()
+        federation = self.hs.get_handlers().federation_handler
+
+        prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
+        room_version = self.get_success(store.get_room_version(room))
+
+        factory = EventBuilderFactory(self.hs)
+        factory.hostname = remote_server
+
+        user_id = UserID("user", remote_server).to_string()
+
+        event_dict = {
+            "type": EventTypes.Member,
+            "state_key": user_id,
+            "content": {"membership": Membership.JOIN},
+            "sender": user_id,
+            "room_id": room,
+        }
+
+        builder = factory.for_room_version(room_version, event_dict)
+        join_event = self.get_success(builder.build(prev_event_ids))
+
+        self.get_success(federation.on_send_join_request(remote_server, join_event))
+        self.replicate()
+
+        return room
+
+
+class FederationSenderTestCase(BaseStreamTestCase):
+    servlets = [
+        login.register_servlets,
+        register_servlets_for_client_rest_resource,
+        room.register_servlets,
+    ]
+
+    def test_send_event_single_sender(self):
+        """Test that using a single federation sender worker correctly sends a
+        new event.
+        """
+        worker_hs = self.make_worker_hs({"send_federation": True})
+        mock_client = worker_hs.get_http_client()
+
+        user = self.register_user("user", "pass")
+        token = self.login("user", "pass")
+
+        room = self.create_room_with_remote_server(user, token)
+
+        mock_client.put_json.reset_mock()
+
+        self.create_and_send_event(room, UserID.from_string(user))
+        self.replicate()
+
+        # Assert that the event was sent out over federation.
+        mock_client.put_json.assert_called()
+        self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
+        self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
+
+    def test_send_event_sharded(self):
+        """Test that using two federation sender workers correctly sends
+        new events.
+        """
+        worker1 = self.make_worker_hs(
+            {
+                "send_federation": True,
+                "worker_name": "sender1",
+                "federation_sender_instances": ["sender1", "sender2"],
+            }
+        )
+        mock_client1 = worker1.get_http_client()
+
+        worker2 = self.make_worker_hs(
+            {
+                "send_federation": True,
+                "worker_name": "sender2",
+                "federation_sender_instances": ["sender1", "sender2"],
+            }
+        )
+        mock_client2 = worker2.get_http_client()
+
+        user = self.register_user("user2", "pass")
+        token = self.login("user2", "pass")
+
+        sent_on_1 = False
+        sent_on_2 = False
+        for i in range(20):
+            server_name = "other_server_%d" % (i,)
+            room = self.create_room_with_remote_server(user, token, server_name)
+            mock_client1.reset_mock()
+            mock_client2.reset_mock()
+
+            self.create_and_send_event(room, UserID.from_string(user))
+            self.replicate()
+
+            if mock_client1.put_json.called:
+                sent_on_1 = True
+                mock_client2.put_json.assert_not_called()
+                self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+                self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("pdus"))
+            elif mock_client2.put_json.called:
+                sent_on_2 = True
+                mock_client1.put_json.assert_not_called()
+                self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+                self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("pdus"))
+            else:
+                raise AssertionError(
+                    "Expected send transaction from one or the other sender"
+                )
+
+            if sent_on_1 and sent_on_2:
+                break
+
+        self.assertTrue(sent_on_1)
+        self.assertTrue(sent_on_2)
+
+    def test_send_typing_sharded(self):
+        """Test that using two federation sender workers correctly sends
+        new typing EDUs.
+        """
+        worker1 = self.make_worker_hs(
+            {
+                "send_federation": True,
+                "worker_name": "sender1",
+                "federation_sender_instances": ["sender1", "sender2"],
+            }
+        )
+        mock_client1 = worker1.get_http_client()
+
+        worker2 = self.make_worker_hs(
+            {
+                "send_federation": True,
+                "worker_name": "sender2",
+                "federation_sender_instances": ["sender1", "sender2"],
+            }
+        )
+        mock_client2 = worker2.get_http_client()
+
+        user = self.register_user("user3", "pass")
+        token = self.login("user3", "pass")
+
+        typing_handler = self.hs.get_typing_handler()
+
+        sent_on_1 = False
+        sent_on_2 = False
+        for i in range(20):
+            server_name = "other_server_%d" % (i,)
+            room = self.create_room_with_remote_server(user, token, server_name)
+            mock_client1.reset_mock()
+            mock_client2.reset_mock()
+
+            self.get_success(
+                typing_handler.started_typing(
+                    target_user=UserID.from_string(user),
+                    auth_user=UserID.from_string(user),
+                    room_id=room,
+                    timeout=20000,
+                )
+            )
+
+            self.replicate()
+
+            if mock_client1.put_json.called:
+                sent_on_1 = True
+                mock_client2.put_json.assert_not_called()
+                self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+                self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("edus"))
+            elif mock_client2.put_json.called:
+                sent_on_2 = True
+                mock_client1.put_json.assert_not_called()
+                self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+                self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("edus"))
+            else:
+                raise AssertionError(
+                    "Expected send transaction from one or the other sender"
+                )
+
+            if sent_on_1 and sent_on_2:
+                break
+
+        self.assertTrue(sent_on_1)
+        self.assertTrue(sent_on_2)