diff --git a/tests/replication/_base.py b/tests/replication/_base.py
index 9d4f0bbe44..ae60874ec3 100644
--- a/tests/replication/_base.py
+++ b/tests/replication/_base.py
@@ -14,7 +14,7 @@
# limitations under the License.
import logging
-from typing import Any, List, Optional, Tuple
+from typing import Any, Callable, List, Optional, Tuple
import attr
@@ -26,8 +26,9 @@ from synapse.app.generic_worker import (
GenericWorkerReplicationHandler,
GenericWorkerServer,
)
+from synapse.http.server import JsonResource
from synapse.http.site import SynapseRequest
-from synapse.replication.http import streams
+from synapse.replication.http import ReplicationRestResource, streams
from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -35,7 +36,7 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.server import FakeTransport
+from tests.server import FakeTransport, render
logger = logging.getLogger(__name__)
@@ -64,7 +65,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# Since we use sqlite in memory databases we need to make sure the
# databases objects are the same.
- self.worker_hs.get_datastore().db = hs.get_datastore().db
+ self.worker_hs.get_datastore().db_pool = hs.get_datastore().db_pool
self.test_handler = self._build_replication_data_handler()
self.worker_hs.replication_data_handler = self.test_handler
@@ -180,6 +181,159 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.assertEqual(request.method, b"GET")
+class BaseMultiWorkerStreamTestCase(unittest.HomeserverTestCase):
+ """Base class for tests running multiple workers.
+
+ Automatically handle HTTP replication requests from workers to master,
+ unlike `BaseStreamTestCase`.
+ """
+
+ servlets = [] # type: List[Callable[[HomeServer, JsonResource], None]]
+
+ def setUp(self):
+ super().setUp()
+
+ # build a replication server
+ self.server_factory = ReplicationStreamProtocolFactory(self.hs)
+ self.streamer = self.hs.get_replication_streamer()
+
+ store = self.hs.get_datastore()
+ self.database_pool = store.db_pool
+
+ self.reactor.lookups["testserv"] = "1.2.3.4"
+
+ self._worker_hs_to_resource = {}
+
+ # When we see a connection attempt to the master replication listener we
+ # automatically set up the connection. This is so that tests don't
+ # manually have to go and explicitly set it up each time (plus sometimes
+ # it is impossible to write the handling explicitly in the tests).
+ self.reactor.add_tcp_client_callback(
+ "1.2.3.4", 8765, self._handle_http_replication_attempt
+ )
+
+ def create_test_json_resource(self):
+ """Overrides `HomeserverTestCase.create_test_json_resource`.
+ """
+ # We override this so that it automatically registers all the HTTP
+ # replication servlets, without having to explicitly do that in all
+ # subclassses.
+
+ resource = ReplicationRestResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(self.hs, resource)
+
+ return resource
+
+ def make_worker_hs(
+ self, worker_app: str, extra_config: dict = {}, **kwargs
+ ) -> HomeServer:
+ """Make a new worker HS instance, correctly connecting replcation
+ stream to the master HS.
+
+ Args:
+ worker_app: Type of worker, e.g. `synapse.app.federation_sender`.
+ extra_config: Any extra config to use for this instances.
+ **kwargs: Options that get passed to `self.setup_test_homeserver`,
+ useful to e.g. pass some mocks for things like `http_client`
+
+ Returns:
+ The new worker HomeServer instance.
+ """
+
+ config = self._get_worker_hs_config()
+ config["worker_app"] = worker_app
+ config.update(extra_config)
+
+ worker_hs = self.setup_test_homeserver(
+ homeserverToUse=GenericWorkerServer,
+ config=config,
+ reactor=self.reactor,
+ **kwargs
+ )
+
+ store = worker_hs.get_datastore()
+ store.db_pool._db_pool = self.database_pool._db_pool
+
+ repl_handler = ReplicationCommandHandler(worker_hs)
+ client = ClientReplicationStreamProtocol(
+ worker_hs, "client", "test", self.clock, repl_handler,
+ )
+ server = self.server_factory.buildProtocol(None)
+
+ client_transport = FakeTransport(server, self.reactor)
+ client.makeConnection(client_transport)
+
+ server_transport = FakeTransport(client, self.reactor)
+ server.makeConnection(server_transport)
+
+ # Set up a resource for the worker
+ resource = ReplicationRestResource(self.hs)
+
+ for servlet in self.servlets:
+ servlet(worker_hs, resource)
+
+ self._worker_hs_to_resource[worker_hs] = resource
+
+ return worker_hs
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def render_on_worker(self, worker_hs: HomeServer, request: SynapseRequest):
+ render(request, self._worker_hs_to_resource[worker_hs], self.reactor)
+
+ def replicate(self):
+ """Tell the master side of replication that something has happened, and then
+ wait for the replication to occur.
+ """
+ self.streamer.on_notifier_poke()
+ self.pump()
+
+ def _handle_http_replication_attempt(self):
+ """Handles a connection attempt to the master replication HTTP
+ listener.
+ """
+
+ # We should have at least one outbound connection attempt, where the
+ # last is one to the HTTP repication IP/port.
+ clients = self.reactor.tcpClients
+ self.assertGreaterEqual(len(clients), 1)
+ (host, port, client_factory, _timeout, _bindAddress) = clients.pop()
+ self.assertEqual(host, "1.2.3.4")
+ self.assertEqual(port, 8765)
+
+ # Set up client side protocol
+ client_protocol = client_factory.buildProtocol(None)
+
+ request_factory = OneShotRequestFactory()
+
+ # Set up the server side protocol
+ channel = _PushHTTPChannel(self.reactor)
+ channel.requestFactory = request_factory
+ channel.site = self.site
+
+ # Connect client to server and vice versa.
+ client_to_server_transport = FakeTransport(
+ channel, self.reactor, client_protocol
+ )
+ client_protocol.makeConnection(client_to_server_transport)
+
+ server_to_client_transport = FakeTransport(
+ client_protocol, self.reactor, channel
+ )
+ channel.makeConnection(server_to_client_transport)
+
+ # Note: at this point we've wired everything up, but we need to return
+ # before the data starts flowing over the connections as this is called
+ # inside `connecTCP` before the connection has been passed back to the
+ # code that requested the TCP connection.
+
+
class TestReplicationDataHandler(GenericWorkerReplicationHandler):
"""Drop-in for ReplicationDataHandler which just collects RDATA rows"""
@@ -241,6 +395,14 @@ class _PushHTTPChannel(HTTPChannel):
# We need to manually stop the _PullToPushProducer.
self._pull_to_push_producer.stop()
+ def checkPersistence(self, request, version):
+ """Check whether the connection can be re-used
+ """
+ # We hijack this to always say no for ease of wiring stuff up in
+ # `handle_http_replication_attempt`.
+ request.responseHeaders.setRawHeaders(b"connection", [b"close"])
+ return False
+
class _PullToPushProducer:
"""A push producer that wraps a pull producer.
diff --git a/tests/replication/slave/storage/test_events.py b/tests/replication/slave/storage/test_events.py
index 1a88c7fb80..561258a356 100644
--- a/tests/replication/slave/storage/test_events.py
+++ b/tests/replication/slave/storage/test_events.py
@@ -160,7 +160,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 0},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 0},
)
self.persist(
@@ -173,7 +173,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 0, "notify_count": 1},
+ {"highlight_count": 0, "unread_count": 0, "notify_count": 1},
)
self.persist(
@@ -188,7 +188,7 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
self.check(
"get_unread_event_push_actions_by_room_for_user",
[ROOM_ID, USER_ID_2, event1.event_id],
- {"highlight_count": 1, "notify_count": 2},
+ {"highlight_count": 1, "unread_count": 0, "notify_count": 2},
)
def test_get_rooms_for_user_with_stream_ordering(self):
@@ -366,7 +366,11 @@ class SlavedEventStoreTestCase(BaseSlavedStoreTestCase):
state_handler = self.hs.get_state_handler()
context = self.get_success(state_handler.compute_event_context(event))
- self.master_store.add_push_actions_to_staging(
- event.event_id, {user_id: actions for user_id, actions in push_actions}
+ self.get_success(
+ self.master_store.add_push_actions_to_staging(
+ event.event_id,
+ {user_id: actions for user_id, actions in push_actions},
+ False,
+ )
)
return event, context
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 51bf0ef4e9..c9998e88e6 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -17,6 +17,7 @@ from typing import List, Optional
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
+from synapse.replication.tcp.commands import RdataCommand
from synapse.replication.tcp.streams._base import _STREAM_UPDATE_TARGET_ROW_COUNT
from synapse.replication.tcp.streams.events import (
EventsStreamCurrentStateRow,
@@ -66,11 +67,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
# also one state event
state_event = self._inject_state_event()
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -123,7 +119,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
OTHER_USER = "@other_user:localhost"
# have the user join
- inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+ self.get_success(
+ inject_member_event(self.hs, self.room_id, OTHER_USER, Membership.JOIN)
+ )
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
@@ -161,24 +159,21 @@ class EventsStreamTestCase(BaseStreamTestCase):
# roll back all the state by de-modding the user
prev_events = fork_point
pls["users"][OTHER_USER] = 0
- pl_event = inject_event(
- self.hs,
- prev_event_ids=prev_events,
- type=EventTypes.PowerLevels,
- state_key="",
- sender=self.user_id,
- room_id=self.room_id,
- content=pls,
+ pl_event = self.get_success(
+ inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
)
# one more bit of state that doesn't get rolled back
state2 = self._inject_state_event()
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -277,7 +272,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
# have the users join
for u in user_ids:
- inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+ self.get_success(
+ inject_member_event(self.hs, self.room_id, u, Membership.JOIN)
+ )
# Update existing power levels with mod at PL50
pls = self.helper.get_state(
@@ -315,23 +312,20 @@ class EventsStreamTestCase(BaseStreamTestCase):
pl_events = []
for u in user_ids:
pls["users"][u] = 0
- e = inject_event(
- self.hs,
- prev_event_ids=prev_events,
- type=EventTypes.PowerLevels,
- state_key="",
- sender=self.user_id,
- room_id=self.room_id,
- content=pls,
+ e = self.get_success(
+ inject_event(
+ self.hs,
+ prev_event_ids=prev_events,
+ type=EventTypes.PowerLevels,
+ state_key="",
+ sender=self.user_id,
+ room_id=self.room_id,
+ content=pls,
+ )
)
prev_events = [e.event_id]
pl_events.append(e)
- # tell the notifier to catch up to avoid duplicate rows.
- # workaround for https://github.com/matrix-org/synapse/issues/7360
- # FIXME remove this when the above is fixed
- self.replicate()
-
# check we're testing what we think we are: no rows should yet have been
# received
self.assertEqual([], self.test_handler.received_rdata_rows)
@@ -378,6 +372,64 @@ class EventsStreamTestCase(BaseStreamTestCase):
self.assertEqual([], received_rows)
+ def test_backwards_stream_id(self):
+ """
+ Test that RDATA that comes after the current position should be discarded.
+ """
+ # disconnect, so that we can stack up some changes
+ self.disconnect()
+
+ # Generate an events. We inject them using inject_event so that they are
+ # not send out over replication until we call self.replicate().
+ event = self._inject_test_event()
+
+ # check we're testing what we think we are: no rows should yet have been
+ # received
+ self.assertEqual([], self.test_handler.received_rdata_rows)
+
+ # now reconnect to pull the updates
+ self.reconnect()
+ self.replicate()
+
+ # We should have received the expected single row (as well as various
+ # cache invalidation updates which we ignore).
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+
+ # There should be a single received row.
+ self.assertEqual(len(received_rows), 1)
+
+ stream_name, token, row = received_rows[0]
+ self.assertEqual("events", stream_name)
+ self.assertIsInstance(row, EventsStreamRow)
+ self.assertEqual(row.type, "ev")
+ self.assertIsInstance(row.data, EventsStreamEventRow)
+ self.assertEqual(row.data.event_id, event.event_id)
+
+ # Reset the data.
+ self.test_handler.received_rdata_rows = []
+
+ # Save the current token for later.
+ worker_events_stream = self.worker_hs.get_replication_streams()["events"]
+ prev_token = worker_events_stream.current_token("master")
+
+ # Manually send an old RDATA command, which should get dropped. This
+ # re-uses the row from above, but with an earlier stream token.
+ self.hs.get_tcp_replication().send_command(
+ RdataCommand("events", "master", 1, row)
+ )
+
+ # No updates have been received (because it was discard as old).
+ received_rows = [
+ row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+ ]
+ self.assertEqual(len(received_rows), 0)
+
+ # Ensure the stream has not gone backwards.
+ current_token = worker_events_stream.current_token("master")
+ self.assertGreaterEqual(current_token, prev_token)
+
event_count = 0
def _inject_test_event(
@@ -390,13 +442,15 @@ class EventsStreamTestCase(BaseStreamTestCase):
body = "event %i" % (self.event_count,)
self.event_count += 1
- return inject_event(
- self.hs,
- room_id=self.room_id,
- sender=sender,
- type="test_event",
- content={"body": body},
- **kwargs
+ return self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_event",
+ content={"body": body},
+ **kwargs
+ )
)
def _inject_state_event(
@@ -415,11 +469,13 @@ class EventsStreamTestCase(BaseStreamTestCase):
if body is None:
body = "state event %s" % (state_key,)
- return inject_event(
- self.hs,
- room_id=self.room_id,
- sender=sender,
- type="test_state_event",
- state_key=state_key,
- content={"body": body},
+ return self.get_success(
+ inject_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=sender,
+ type="test_state_event",
+ state_key=state_key,
+ content={"body": body},
+ )
)
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index fd62b26356..5acfb3e53e 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -16,10 +16,15 @@ from mock import Mock
from synapse.handlers.typing import RoomMember
from synapse.replication.tcp.streams import TypingStream
+from synapse.util.caches.stream_change_cache import StreamChangeCache
from tests.replication._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
+USER_ID_2 = "@da-ba-dee:blue"
+
+ROOM_ID = "!bar:blue"
+ROOM_ID_2 = "!foo:blue"
class TypingStreamTestCase(BaseStreamTestCase):
@@ -29,11 +34,9 @@ class TypingStreamTestCase(BaseStreamTestCase):
def test_typing(self):
typing = self.hs.get_typing_handler()
- room_id = "!bar:blue"
-
self.reconnect()
- typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
self.reactor.advance(0)
@@ -46,7 +49,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0] # type: TypingStream.TypingStreamRow
- self.assertEqual(room_id, row.room_id)
+ self.assertEqual(ROOM_ID, row.room_id)
self.assertEqual([USER_ID], row.user_ids)
# Now let's disconnect and insert some data.
@@ -54,7 +57,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.test_handler.on_rdata.reset_mock()
- typing._push_update(member=RoomMember(room_id, USER_ID), typing=False)
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
self.test_handler.on_rdata.assert_not_called()
@@ -73,5 +76,78 @@ class TypingStreamTestCase(BaseStreamTestCase):
self.assertEqual(stream_name, "typing")
self.assertEqual(1, len(rdata_rows))
row = rdata_rows[0]
- self.assertEqual(room_id, row.room_id)
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual([], row.user_ids)
+
+ def test_reset(self):
+ """
+ Test what happens when a typing stream resets.
+
+ This is emulated by jumping the stream ahead, then reconnecting (which
+ sends the proper position and RDATA).
+ """
+ typing = self.hs.get_typing_handler()
+
+ self.reconnect()
+
+ typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=True)
+
+ self.reactor.advance(0)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0] # type: TypingStream.TypingStreamRow
+ self.assertEqual(ROOM_ID, row.room_id)
+ self.assertEqual([USER_ID], row.user_ids)
+
+ # Push the stream forward a bunch so it can be reset.
+ for i in range(100):
+ typing._push_update(
+ member=RoomMember(ROOM_ID, "@test%s:blue" % i), typing=True
+ )
+ self.reactor.advance(0)
+
+ # Disconnect.
+ self.disconnect()
+
+ # Reset the typing handler
+ self.hs.get_replication_streams()["typing"].last_token = 0
+ self.hs.get_tcp_replication()._streams["typing"].last_token = 0
+ typing._latest_room_serial = 0
+ typing._typing_stream_change_cache = StreamChangeCache(
+ "TypingStreamChangeCache", typing._latest_room_serial
+ )
+ typing._reset()
+
+ # Reconnect.
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now see an attempt to connect to the master
+ request = self.handle_http_replication_attempt()
+ self.assert_request_is_get_repl_stream_updates(request, "typing")
+
+ # Reset the test code.
+ self.test_handler.on_rdata.reset_mock()
+ self.test_handler.on_rdata.assert_not_called()
+
+ # Push additional data.
+ typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
+ self.reactor.advance(0)
+
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "typing")
+ self.assertEqual(1, len(rdata_rows))
+ row = rdata_rows[0]
+ self.assertEqual(ROOM_ID_2, row.room_id)
self.assertEqual([], row.user_ids)
+
+ # The token should have been reset.
+ self.assertEqual(token, 1)
diff --git a/tests/replication/test_client_reader_shard.py b/tests/replication/test_client_reader_shard.py
new file mode 100644
index 0000000000..86c03fd89c
--- /dev/null
+++ b/tests/replication/test_client_reader_shard.py
@@ -0,0 +1,96 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from synapse.api.constants import LoginType
+from synapse.http.site import SynapseRequest
+from synapse.rest.client.v2_alpha import register
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.rest.client.v2_alpha.test_auth import DummyRecaptchaChecker
+from tests.server import FakeChannel
+
+logger = logging.getLogger(__name__)
+
+
+class ClientReaderTestCase(BaseMultiWorkerStreamTestCase):
+ """Base class for tests of the replication streams"""
+
+ servlets = [register.register_servlets]
+
+ def prepare(self, reactor, clock, hs):
+ self.recaptcha_checker = DummyRecaptchaChecker(hs)
+ auth_handler = hs.get_auth_handler()
+ auth_handler.checkers[LoginType.RECAPTCHA] = self.recaptcha_checker
+
+ def _get_worker_hs_config(self) -> dict:
+ config = self.default_config()
+ config["worker_app"] = "synapse.app.client_reader"
+ config["worker_replication_host"] = "testserv"
+ config["worker_replication_http_port"] = "8765"
+ return config
+
+ def test_register_single_worker(self):
+ """Test that registration works when using a single client reader worker.
+ """
+ worker_hs = self.make_worker_hs("synapse.app.client_reader")
+
+ request_1, channel_1 = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs, request_1)
+ self.assertEqual(request_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ request_2, channel_2 = self.make_request(
+ "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs, request_2)
+ self.assertEqual(request_2.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel_2.json_body["user_id"], "@user:test")
+
+ def test_register_multi_worker(self):
+ """Test that registration works when using multiple client reader workers.
+ """
+ worker_hs_1 = self.make_worker_hs("synapse.app.client_reader")
+ worker_hs_2 = self.make_worker_hs("synapse.app.client_reader")
+
+ request_1, channel_1 = self.make_request(
+ "POST",
+ "register",
+ {"username": "user", "type": "m.login.password", "password": "bar"},
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs_1, request_1)
+ self.assertEqual(request_1.code, 401)
+
+ # Grab the session
+ session = channel_1.json_body["session"]
+
+ # also complete the dummy auth
+ request_2, channel_2 = self.make_request(
+ "POST", "register", {"auth": {"session": session, "type": "m.login.dummy"}}
+ ) # type: SynapseRequest, FakeChannel
+ self.render_on_worker(worker_hs_2, request_2)
+ self.assertEqual(request_2.code, 200)
+
+ # We're given a registered user.
+ self.assertEqual(channel_2.json_body["user_id"], "@user:test")
diff --git a/tests/replication/test_federation_ack.py b/tests/replication/test_federation_ack.py
index 5448d9f0dc..23be1167a3 100644
--- a/tests/replication/test_federation_ack.py
+++ b/tests/replication/test_federation_ack.py
@@ -32,6 +32,7 @@ class FederationAckTestCase(HomeserverTestCase):
def make_homeserver(self, reactor, clock):
hs = self.setup_test_homeserver(homeserverToUse=GenericWorkerServer)
+
return hs
def test_federation_ack_sent(self):
diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py
new file mode 100644
index 0000000000..8b4982ecb1
--- /dev/null
+++ b/tests/replication/test_federation_sender_shard.py
@@ -0,0 +1,234 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from synapse.api.constants import EventTypes, Membership
+from synapse.events.builder import EventBuilderFactory
+from synapse.rest.admin import register_servlets_for_client_rest_resource
+from synapse.rest.client.v1 import login, room
+from synapse.types import UserID, create_requester
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+from tests.test_utils import make_awaitable
+
+logger = logging.getLogger(__name__)
+
+
+class FederationSenderTestCase(BaseMultiWorkerStreamTestCase):
+ servlets = [
+ login.register_servlets,
+ register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ ]
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["send_federation"] = False
+ return conf
+
+ def test_send_event_single_sender(self):
+ """Test that using a single federation sender worker correctly sends a
+ new event.
+ """
+ mock_client = Mock(spec=["put_json"])
+ mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({})
+
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {"send_federation": True},
+ http_client=mock_client,
+ )
+
+ user = self.register_user("user", "pass")
+ token = self.login("user", "pass")
+
+ room = self.create_room_with_remote_server(user, token)
+
+ mock_client.put_json.reset_mock()
+
+ self.create_and_send_event(room, UserID.from_string(user))
+ self.replicate()
+
+ # Assert that the event was sent out over federation.
+ mock_client.put_json.assert_called()
+ self.assertEqual(mock_client.put_json.call_args[0][0], "other_server")
+ self.assertTrue(mock_client.put_json.call_args[1]["data"].get("pdus"))
+
+ def test_send_event_sharded(self):
+ """Test that using two federation sender workers correctly sends
+ new events.
+ """
+ mock_client1 = Mock(spec=["put_json"])
+ mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender1",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client1,
+ )
+
+ mock_client2 = Mock(spec=["put_json"])
+ mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender2",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client2,
+ )
+
+ user = self.register_user("user2", "pass")
+ token = self.login("user2", "pass")
+
+ sent_on_1 = False
+ sent_on_2 = False
+ for i in range(20):
+ server_name = "other_server_%d" % (i,)
+ room = self.create_room_with_remote_server(user, token, server_name)
+ mock_client1.reset_mock() # type: ignore[attr-defined]
+ mock_client2.reset_mock() # type: ignore[attr-defined]
+
+ self.create_and_send_event(room, UserID.from_string(user))
+ self.replicate()
+
+ if mock_client1.put_json.called:
+ sent_on_1 = True
+ mock_client2.put_json.assert_not_called()
+ self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("pdus"))
+ elif mock_client2.put_json.called:
+ sent_on_2 = True
+ mock_client1.put_json.assert_not_called()
+ self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("pdus"))
+ else:
+ raise AssertionError(
+ "Expected send transaction from one or the other sender"
+ )
+
+ if sent_on_1 and sent_on_2:
+ break
+
+ self.assertTrue(sent_on_1)
+ self.assertTrue(sent_on_2)
+
+ def test_send_typing_sharded(self):
+ """Test that using two federation sender workers correctly sends
+ new typing EDUs.
+ """
+ mock_client1 = Mock(spec=["put_json"])
+ mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender1",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client1,
+ )
+
+ mock_client2 = Mock(spec=["put_json"])
+ mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({})
+ self.make_worker_hs(
+ "synapse.app.federation_sender",
+ {
+ "send_federation": True,
+ "worker_name": "sender2",
+ "federation_sender_instances": ["sender1", "sender2"],
+ },
+ http_client=mock_client2,
+ )
+
+ user = self.register_user("user3", "pass")
+ token = self.login("user3", "pass")
+
+ typing_handler = self.hs.get_typing_handler()
+
+ sent_on_1 = False
+ sent_on_2 = False
+ for i in range(20):
+ server_name = "other_server_%d" % (i,)
+ room = self.create_room_with_remote_server(user, token, server_name)
+ mock_client1.reset_mock() # type: ignore[attr-defined]
+ mock_client2.reset_mock() # type: ignore[attr-defined]
+
+ self.get_success(
+ typing_handler.started_typing(
+ target_user=UserID.from_string(user),
+ requester=create_requester(user),
+ room_id=room,
+ timeout=20000,
+ )
+ )
+
+ self.replicate()
+
+ if mock_client1.put_json.called:
+ sent_on_1 = True
+ mock_client2.put_json.assert_not_called()
+ self.assertEqual(mock_client1.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client1.put_json.call_args[1]["data"].get("edus"))
+ elif mock_client2.put_json.called:
+ sent_on_2 = True
+ mock_client1.put_json.assert_not_called()
+ self.assertEqual(mock_client2.put_json.call_args[0][0], server_name)
+ self.assertTrue(mock_client2.put_json.call_args[1]["data"].get("edus"))
+ else:
+ raise AssertionError(
+ "Expected send transaction from one or the other sender"
+ )
+
+ if sent_on_1 and sent_on_2:
+ break
+
+ self.assertTrue(sent_on_1)
+ self.assertTrue(sent_on_2)
+
+ def create_room_with_remote_server(self, user, token, remote_server="other_server"):
+ room = self.helper.create_room_as(user, tok=token)
+ store = self.hs.get_datastore()
+ federation = self.hs.get_handlers().federation_handler
+
+ prev_event_ids = self.get_success(store.get_latest_event_ids_in_room(room))
+ room_version = self.get_success(store.get_room_version(room))
+
+ factory = EventBuilderFactory(self.hs)
+ factory.hostname = remote_server
+
+ user_id = UserID("user", remote_server).to_string()
+
+ event_dict = {
+ "type": EventTypes.Member,
+ "state_key": user_id,
+ "content": {"membership": Membership.JOIN},
+ "sender": user_id,
+ "room_id": room,
+ }
+
+ builder = factory.for_room_version(room_version, event_dict)
+ join_event = self.get_success(builder.build(prev_event_ids))
+
+ self.get_success(federation.on_send_join_request(remote_server, join_event))
+ self.replicate()
+
+ return room
diff --git a/tests/replication/test_pusher_shard.py b/tests/replication/test_pusher_shard.py
new file mode 100644
index 0000000000..2bdc6edbb1
--- /dev/null
+++ b/tests/replication/test_pusher_shard.py
@@ -0,0 +1,193 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from mock import Mock
+
+from twisted.internet import defer
+
+from synapse.rest import admin
+from synapse.rest.client.v1 import login, room
+
+from tests.replication._base import BaseMultiWorkerStreamTestCase
+
+logger = logging.getLogger(__name__)
+
+
+class PusherShardTestCase(BaseMultiWorkerStreamTestCase):
+ """Checks pusher sharding works
+ """
+
+ servlets = [
+ admin.register_servlets_for_client_rest_resource,
+ room.register_servlets,
+ login.register_servlets,
+ ]
+
+ def prepare(self, reactor, clock, hs):
+ # Register a user who sends a message that we'll get notified about
+ self.other_user_id = self.register_user("otheruser", "pass")
+ self.other_access_token = self.login("otheruser", "pass")
+
+ def default_config(self):
+ conf = super().default_config()
+ conf["start_pushers"] = False
+ return conf
+
+ def _create_pusher_and_send_msg(self, localpart):
+ # Create a user that will get push notifications
+ user_id = self.register_user(localpart, "pass")
+ access_token = self.login(localpart, "pass")
+
+ # Register a pusher
+ user_dict = self.get_success(
+ self.hs.get_datastore().get_user_by_access_token(access_token)
+ )
+ token_id = user_dict["token_id"]
+
+ self.get_success(
+ self.hs.get_pusherpool().add_pusher(
+ user_id=user_id,
+ access_token=token_id,
+ kind="http",
+ app_id="m.http",
+ app_display_name="HTTP Push Notifications",
+ device_display_name="pushy push",
+ pushkey="a@example.com",
+ lang=None,
+ data={"url": "https://push.example.com/push"},
+ )
+ )
+
+ self.pump()
+
+ # Create a room
+ room = self.helper.create_room_as(user_id, tok=access_token)
+
+ # The other user joins
+ self.helper.join(
+ room=room, user=self.other_user_id, tok=self.other_access_token
+ )
+
+ # The other user sends some messages
+ response = self.helper.send(room, body="Hi!", tok=self.other_access_token)
+ event_id = response["event_id"]
+
+ return event_id
+
+ def test_send_push_single_worker(self):
+ """Test that registration works when using a pusher worker.
+ """
+ http_client_mock = Mock(spec_set=["post_json_get_json"])
+ http_client_mock.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {"start_pushers": True},
+ proxied_http_client=http_client_mock,
+ )
+
+ event_id = self._create_pusher_and_send_msg("user")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock.post_json_get_json.assert_called_once()
+ self.assertEqual(
+ http_client_mock.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
+
+ def test_send_push_multiple_workers(self):
+ """Test that registration works when using sharded pusher workers.
+ """
+ http_client_mock1 = Mock(spec_set=["post_json_get_json"])
+ http_client_mock1.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {
+ "start_pushers": True,
+ "worker_name": "pusher1",
+ "pusher_instances": ["pusher1", "pusher2"],
+ },
+ proxied_http_client=http_client_mock1,
+ )
+
+ http_client_mock2 = Mock(spec_set=["post_json_get_json"])
+ http_client_mock2.post_json_get_json.side_effect = lambda *_, **__: defer.succeed(
+ {}
+ )
+
+ self.make_worker_hs(
+ "synapse.app.pusher",
+ {
+ "start_pushers": True,
+ "worker_name": "pusher2",
+ "pusher_instances": ["pusher1", "pusher2"],
+ },
+ proxied_http_client=http_client_mock2,
+ )
+
+ # We choose a user name that we know should go to pusher1.
+ event_id = self._create_pusher_and_send_msg("user2")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock1.post_json_get_json.assert_called_once()
+ http_client_mock2.post_json_get_json.assert_not_called()
+ self.assertEqual(
+ http_client_mock1.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock1.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
+
+ http_client_mock1.post_json_get_json.reset_mock()
+ http_client_mock2.post_json_get_json.reset_mock()
+
+ # Now we choose a user name that we know should go to pusher2.
+ event_id = self._create_pusher_and_send_msg("user4")
+
+ # Advance time a bit, so the pusher will register something has happened
+ self.pump()
+
+ http_client_mock1.post_json_get_json.assert_not_called()
+ http_client_mock2.post_json_get_json.assert_called_once()
+ self.assertEqual(
+ http_client_mock2.post_json_get_json.call_args[0][0],
+ "https://push.example.com/push",
+ )
+ self.assertEqual(
+ event_id,
+ http_client_mock2.post_json_get_json.call_args[0][1]["notification"][
+ "event_id"
+ ],
+ )
|