From 22fc68762ee131229636a20231e03a4c8c525f65 Mon Sep 17 00:00:00 2001 From: Erik Johnston Date: Wed, 4 Mar 2020 16:22:39 +0000 Subject: Move command processing out of transport --- tests/replication/slave/storage/_base.py | 24 +++++++++------- tests/replication/tcp/streams/_base.py | 38 ++++++++++---------------- tests/replication/tcp/streams/test_receipts.py | 1 - 3 files changed, 28 insertions(+), 35 deletions(-) (limited to 'tests') diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py index 2a1e7c7166..f2c0e381c1 100644 --- a/tests/replication/slave/storage/_base.py +++ b/tests/replication/slave/storage/_base.py @@ -15,9 +15,10 @@ from mock import Mock, NonCallableMock -from synapse.replication.tcp.client import ( - ReplicationClientFactory, +from synapse.replication.tcp.client import ReplicationClientFactory +from synapse.replication.tcp.handler import ( ReplicationClientHandler, + WorkerReplicationDataHandler, ) from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory from synapse.storage.database import make_conn @@ -51,16 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase): self.event_id = 0 server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer - - handler_factory = Mock() - self.replication_handler = ReplicationClientHandler(self.slaved_store) - self.replication_handler.factory = handler_factory - - client_factory = ReplicationClientFactory( - self.hs, "client_name", self.replication_handler + self.streamer = hs.get_replication_streamer() + + # We now do some gut wrenching so that we have a client that is based + # off of the slave store rather than the main store. + self.replication_handler = ReplicationClientHandler(self.hs) + self.replication_handler.store = self.slaved_store + self.replication_handler.replication_data_handler = WorkerReplicationDataHandler( + self.slaved_store ) + client_factory = ReplicationClientFactory(self.hs, "client_name") + client_factory.handler = self.replication_handler + server = server_factory.buildProtocol(None) client = client_factory.buildProtocol(None) diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index a755fe2879..2d6e44f625 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -15,7 +15,7 @@ from mock import Mock -from synapse.replication.tcp.commands import ReplicateCommand +from synapse.replication.tcp.handler import ReplicationClientHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory @@ -26,15 +26,20 @@ from tests.server import FakeTransport class BaseStreamTestCase(unittest.HomeserverTestCase): """Base class for tests of the replication streams""" + def make_homeserver(self, reactor, clock): + self.test_handler = Mock(wraps=TestReplicationClientHandler()) + return self.setup_test_homeserver(replication_data_handler=self.test_handler) + def prepare(self, reactor, clock, hs): # build a replication server - server_factory = ReplicationStreamProtocolFactory(self.hs) - self.streamer = server_factory.streamer + server_factory = ReplicationStreamProtocolFactory(hs) + self.streamer = hs.get_replication_streamer() self.server = server_factory.buildProtocol(None) - self.test_handler = Mock(wraps=TestReplicationClientHandler()) + repl_handler = ReplicationClientHandler(hs) + repl_handler.handler = self.test_handler self.client = ClientReplicationStreamProtocol( - hs, "client", "test", clock, self.test_handler, + hs, "client", "test", clock, repl_handler, ) self._client_transport = None @@ -69,14 +74,8 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.streamer.on_notifier_poke() self.pump(0.1) - def replicate_stream(self): - """Make the client end a REPLICATE command to set up a subscription to a stream""" - self.client.send_command(ReplicateCommand()) - - -class TestReplicationClientHandler(object): - """Drop-in for ReplicationClientHandler which just collects RDATA rows""" +class TestReplicationClientHandler: def __init__(self): self.streams = set() self._received_rdata_rows = [] @@ -88,18 +87,9 @@ class TestReplicationClientHandler(object): positions[stream] = max(token, positions.get(stream, 0)) return positions - def get_currently_syncing_users(self): - return [] - - def update_connection(self, connection): - pass - - def finished_connecting(self): - pass - - async def on_position(self, stream_name, token): - """Called when we get new position data.""" - async def on_rdata(self, stream_name, token, rows): for r in rows: self._received_rdata_rows.append((stream_name, token, r)) + + async def on_position(self, stream_name, token): + pass diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index 0ec0825a0e..a0206f7363 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -24,7 +24,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): self.reconnect() # make the client subscribe to the receipts stream - self.replicate_stream() self.test_handler.streams.add("receipts") # tell the master to send a new receipt -- cgit 1.4.1