diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 2a1e7c7166..8902a5ab69 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -17,8 +17,9 @@ from mock import Mock, NonCallableMock
from synapse.replication.tcp.client import (
ReplicationClientFactory,
- ReplicationClientHandler,
+ ReplicationDataHandler,
)
+from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.storage.database import make_conn
@@ -51,15 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = server_factory.streamer
+ self.streamer = hs.get_replication_streamer()
- handler_factory = Mock()
- self.replication_handler = ReplicationClientHandler(self.slaved_store)
- self.replication_handler.factory = handler_factory
+ # We now do some gut wrenching so that we have a client that is based
+ # off of the slave store rather than the main store.
+ self.replication_handler = ReplicationCommandHandler(self.hs)
+ self.replication_handler._replication_data_handler = ReplicationDataHandler(
+ self.slaved_store
+ )
client_factory = ReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
+ client_factory.handler = self.replication_handler
server = server_factory.buildProtocol(None)
client = client_factory.buildProtocol(None)
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index a755fe2879..32238fe79a 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -15,7 +15,7 @@
from mock import Mock
-from synapse.replication.tcp.commands import ReplicateCommand
+from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -26,15 +26,20 @@ from tests.server import FakeTransport
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ def make_homeserver(self, reactor, clock):
+ self.test_handler = Mock(wraps=TestReplicationDataHandler())
+ return self.setup_test_homeserver(replication_data_handler=self.test_handler)
+
def prepare(self, reactor, clock, hs):
# build a replication server
- server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = server_factory.streamer
+ server_factory = ReplicationStreamProtocolFactory(hs)
+ self.streamer = hs.get_replication_streamer()
self.server = server_factory.buildProtocol(None)
- self.test_handler = Mock(wraps=TestReplicationClientHandler())
+ repl_handler = ReplicationCommandHandler(hs)
+ repl_handler.handler = self.test_handler
self.client = ClientReplicationStreamProtocol(
- hs, "client", "test", clock, self.test_handler,
+ hs, "client", "test", clock, repl_handler,
)
self._client_transport = None
@@ -69,13 +74,9 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump(0.1)
- def replicate_stream(self):
- """Make the client end a REPLICATE command to set up a subscription to a stream"""
- self.client.send_command(ReplicateCommand())
-
-class TestReplicationClientHandler(object):
- """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
+class TestReplicationDataHandler:
+ """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self):
self.streams = set()
@@ -88,18 +89,9 @@ class TestReplicationClientHandler(object):
positions[stream] = max(token, positions.get(stream, 0))
return positions
- def get_currently_syncing_users(self):
- return []
-
- def update_connection(self, connection):
- pass
-
- def finished_connecting(self):
- pass
-
- async def on_position(self, stream_name, token):
- """Called when we get new position data."""
-
async def on_rdata(self, stream_name, token, rows):
for r in rows:
self._received_rdata_rows.append((stream_name, token, r))
+
+ async def on_position(self, stream_name, token):
+ pass
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index 0ec0825a0e..a0206f7363 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -24,7 +24,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
self.reconnect()
# make the client subscribe to the receipts stream
- self.replicate_stream()
self.test_handler.streams.add("receipts")
# tell the master to send a new receipt
|