summary refs log tree commit diff
path: root/tests/replication/tcp/streams
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2020-05-06 11:43:11 +0100
committerRichard van der Hoff <richard@matrix.org>2020-05-06 11:43:11 +0100
commit30a19daa02cf61d2a811efb0fd619ff48ce6bcf6 (patch)
treec430aa536e76a2329831d04ae7d8d761ed2ec188 /tests/replication/tcp/streams
parentuse an upsert to update device_lists_outbound_last_success (diff)
parentFix typing annotations in synapse/federation (#7382) (diff)
downloadsynapse-30a19daa02cf61d2a811efb0fd619ff48ce6bcf6.tar.xz
Merge branch 'develop' into rav/upsert_for_device_list
Diffstat (limited to 'tests/replication/tcp/streams')
-rw-r--r--tests/replication/tcp/streams/_base.py54
-rw-r--r--tests/replication/tcp/streams/test_events.py24
-rw-r--r--tests/replication/tcp/streams/test_federation.py81
-rw-r--r--tests/replication/tcp/streams/test_receipts.py7
-rw-r--r--tests/replication/tcp/streams/test_typing.py12
5 files changed, 126 insertions, 52 deletions
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py

index 83e16cfe3d..9d4f0bbe44 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py
@@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple import attr @@ -22,13 +22,16 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel -from synapse.app.generic_worker import GenericWorkerServer +from synapse.app.generic_worker import ( + GenericWorkerReplicationHandler, + GenericWorkerServer, +) from synapse.http.site import SynapseRequest -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.client import ReplicationDataHandler +from synapse.replication.http import streams from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.server import HomeServer from synapse.util import Clock from tests import unittest @@ -40,6 +43,10 @@ logger = logging.getLogger(__name__) class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + servlets = [ + streams.register_servlets, + ] + def prepare(self, reactor, clock, hs): # build a replication server server_factory = ReplicationStreamProtocolFactory(hs) @@ -47,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.server = server_factory.buildProtocol(None) # Make a new HomeServer object for the worker - config = self.default_config() - config["worker_app"] = "synapse.app.generic_worker" - config["worker_replication_host"] = "testserv" - config["worker_replication_http_port"] = "8765" - self.reactor.lookups["testserv"] = "1.2.3.4" - self.worker_hs = self.setup_test_homeserver( http_client=None, homeserverToUse=GenericWorkerServer, - config=config, + config=self._get_worker_hs_config(), reactor=self.reactor, ) @@ -76,8 +77,15 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._client_transport = None self._server_transport = None + def _get_worker_hs_config(self) -> dict: + config = self.default_config() + config["worker_app"] = "synapse.app.generic_worker" + config["worker_replication_host"] = "testserv" + config["worker_replication_http_port"] = "8765" + return config + def _build_replication_data_handler(self): - return TestReplicationDataHandler(self.worker_hs.get_datastore()) + return TestReplicationDataHandler(self.worker_hs) def reconnect(self): if self._client_transport: @@ -172,32 +180,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(request.method, b"GET") -class TestReplicationDataHandler(ReplicationDataHandler): +class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self, store: BaseSlavedStore): - super().__init__(store) - - # streams to subscribe to: map from stream id to position - self.stream_positions = {} # type: Dict[str, int] + def __init__(self, hs: HomeServer): + super().__init__(hs) # list of received (stream_name, token, row) tuples self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] - def get_streams_to_replicate(self): - return self.stream_positions - - async def on_rdata(self, stream_name, token, rows): - await super().on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) - if ( - stream_name in self.stream_positions - and token > self.stream_positions[stream_name] - ): - self.stream_positions[stream_name] = token - @attr.s() class OneShotRequestFactory: diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 1fa28084f9..8bd67bb9f1 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py
@@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase): self.user_tok = self.login("u1", "pass") self.reconnect() - self.test_handler.stream_positions["events"] = 0 self.room_id = self.helper.create_room_as(tok=self.user_tok) self.test_handler.received_rdata_rows.clear() @@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - received_rows = self.test_handler.received_rdata_rows + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + for event in events: stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) @@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # now we should have received all the expected rows in the right order. + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) # # we expect: # @@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase): # of the states that got reverted. # - two rows for state2 - received_rows = self.test_handler.received_rdata_rows + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] # first check the first two rows, which should be state1 @@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - - received_rows = self.test_handler.received_rdata_rows + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] self.assertGreaterEqual(len(received_rows), len(events)) for i in range(NUM_USERS): # for each user, we expect the PL event row, followed by state rows for diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py new file mode 100644
index 0000000000..eea4565da3 --- /dev/null +++ b/tests/replication/tcp/streams/test_federation.py
@@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# Copyright 2019 New Vector Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from synapse.federation.send_queue import EduRow +from synapse.replication.tcp.streams.federation import FederationStream + +from tests.replication.tcp.streams._base import BaseStreamTestCase + + +class FederationStreamTestCase(BaseStreamTestCase): + def _get_worker_hs_config(self) -> dict: + # enable federation sending on the worker + config = super()._get_worker_hs_config() + # TODO: make it so we don't need both of these + config["send_federation"] = True + config["worker_app"] = "synapse.app.federation_sender" + return config + + def test_catchup(self): + """Basic test of catchup on reconnect + + Makes sure that updates sent while we are offline are received later. + """ + fed_sender = self.hs.get_federation_sender() + received_rows = self.test_handler.received_rdata_rows + + fed_sender.build_and_send_edu("testdest", "m.test_edu", {"a": "b"}) + + self.reconnect() + self.reactor.advance(0) + + # check we're testing what we think we are: no rows should yet have been + # received + self.assertEqual(received_rows, []) + + # We should now see an attempt to connect to the master + request = self.handle_http_replication_attempt() + self.assert_request_is_get_repl_stream_updates(request, "federation") + + # we should have received an update row + stream_name, token, row = received_rows.pop() + self.assertEqual(stream_name, "federation") + self.assertIsInstance(row, FederationStream.FederationStreamRow) + self.assertEqual(row.type, EduRow.TypeId) + edurow = EduRow.from_data(row.data) + self.assertEqual(edurow.edu.edu_type, "m.test_edu") + self.assertEqual(edurow.edu.origin, self.hs.hostname) + self.assertEqual(edurow.edu.destination, "testdest") + self.assertEqual(edurow.edu.content, {"a": "b"}) + + self.assertEqual(received_rows, []) + + # additional updates should be transferred without an HTTP hit + fed_sender.build_and_send_edu("testdest", "m.test1", {"c": "d"}) + self.reactor.advance(0) + # there should be no http hit + self.assertEqual(len(self.reactor.tcpClients), 0) + # ... but we should have a row + self.assertEqual(len(received_rows), 1) + + stream_name, token, row = received_rows.pop() + self.assertEqual(stream_name, "federation") + self.assertIsInstance(row, FederationStream.FederationStreamRow) + self.assertEqual(row.type, EduRow.TypeId) + edurow = EduRow.from_data(row.data) + self.assertEqual(edurow.edu.edu_type, "m.test1") + self.assertEqual(edurow.edu.origin, self.hs.hostname) + self.assertEqual(edurow.edu.destination, "testdest") + self.assertEqual(edurow.edu.content, {"c": "d"}) diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index c122b8589c..5853314fd4 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py
@@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): def test_receipt(self): self.reconnect() - # make the client subscribe to the receipts stream - self.test_handler.stream_positions.update({"receipts": 0}) - # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( @@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # there should be one RDATA command self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow @@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # We should now have caught up and get the missing data self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 4d354a9db8..125c63dab5 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py
@@ -15,7 +15,6 @@ from mock import Mock from synapse.handlers.typing import RoomMember -from synapse.replication.http import streams from synapse.replication.tcp.streams import TypingStream from tests.replication.tcp.streams._base import BaseStreamTestCase @@ -24,10 +23,6 @@ USER_ID = "@feeling:blue" class TypingStreamTestCase(BaseStreamTestCase): - servlets = [ - streams.register_servlets, - ] - def _build_replication_data_handler(self): return Mock(wraps=super()._build_replication_data_handler()) @@ -38,9 +33,6 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the typing stream - self.test_handler.stream_positions.update({"typing": 0}) - typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) self.reactor.advance(0) @@ -50,7 +42,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assert_request_is_get_repl_stream_updates(request, "typing") self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: TypingStream.TypingStreamRow @@ -77,7 +69,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(int(request.args[b"from_token"][0]), token) self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0]