diff --git a/tests/replication/slave/storage/_base.py b/tests/replication/slave/storage/_base.py
index 2a1e7c7166..395c7d0306 100644
--- a/tests/replication/slave/storage/_base.py
+++ b/tests/replication/slave/storage/_base.py
@@ -16,9 +16,10 @@
from mock import Mock, NonCallableMock
from synapse.replication.tcp.client import (
- ReplicationClientFactory,
- ReplicationClientHandler,
+ DirectTcpReplicationClientFactory,
+ ReplicationDataHandler,
)
+from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
from synapse.storage.database import make_conn
@@ -51,15 +52,19 @@ class BaseSlavedStoreTestCase(unittest.HomeserverTestCase):
self.event_id = 0
server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = server_factory.streamer
+ self.streamer = hs.get_replication_streamer()
- handler_factory = Mock()
- self.replication_handler = ReplicationClientHandler(self.slaved_store)
- self.replication_handler.factory = handler_factory
+ # We now do some gut wrenching so that we have a client that is based
+ # off of the slave store rather than the main store.
+ self.replication_handler = ReplicationCommandHandler(self.hs)
+ self.replication_handler._replication_data_handler = ReplicationDataHandler(
+ self.slaved_store
+ )
- client_factory = ReplicationClientFactory(
+ client_factory = DirectTcpReplicationClientFactory(
self.hs, "client_name", self.replication_handler
)
+ client_factory.handler = self.replication_handler
server = server_factory.buildProtocol(None)
client = client_factory.buildProtocol(None)
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index e96ad4ca4e..32238fe79a 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -12,9 +12,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from mock import Mock
-from synapse.replication.tcp.commands import ReplicateCommand
+from synapse.replication.tcp.handler import ReplicationCommandHandler
from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
@@ -25,23 +26,46 @@ from tests.server import FakeTransport
class BaseStreamTestCase(unittest.HomeserverTestCase):
"""Base class for tests of the replication streams"""
+ def make_homeserver(self, reactor, clock):
+ self.test_handler = Mock(wraps=TestReplicationDataHandler())
+ return self.setup_test_homeserver(replication_data_handler=self.test_handler)
+
def prepare(self, reactor, clock, hs):
# build a replication server
- server_factory = ReplicationStreamProtocolFactory(self.hs)
- self.streamer = server_factory.streamer
- server = server_factory.buildProtocol(None)
-
- # build a replication client, with a dummy handler
- handler_factory = Mock()
- self.test_handler = TestReplicationClientHandler()
- self.test_handler.factory = handler_factory
+ server_factory = ReplicationStreamProtocolFactory(hs)
+ self.streamer = hs.get_replication_streamer()
+ self.server = server_factory.buildProtocol(None)
+
+ repl_handler = ReplicationCommandHandler(hs)
+ repl_handler.handler = self.test_handler
self.client = ClientReplicationStreamProtocol(
- "client", "test", clock, self.test_handler
+ hs, "client", "test", clock, repl_handler,
)
- # wire them together
- self.client.makeConnection(FakeTransport(server, reactor))
- server.makeConnection(FakeTransport(self.client, reactor))
+ self._client_transport = None
+ self._server_transport = None
+
+ def reconnect(self):
+ if self._client_transport:
+ self.client.close()
+
+ if self._server_transport:
+ self.server.close()
+
+ self._client_transport = FakeTransport(self.server, self.reactor)
+ self.client.makeConnection(self._client_transport)
+
+ self._server_transport = FakeTransport(self.client, self.reactor)
+ self.server.makeConnection(self._server_transport)
+
+ def disconnect(self):
+ if self._client_transport:
+ self._client_transport = None
+ self.client.close()
+
+ if self._server_transport:
+ self._server_transport = None
+ self.server.close()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -50,29 +74,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump(0.1)
- def replicate_stream(self, stream, token="NOW"):
- """Make the client end a REPLICATE command to set up a subscription to a stream"""
- self.client.send_command(ReplicateCommand(stream, token))
-
-class TestReplicationClientHandler(object):
- """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
+class TestReplicationDataHandler:
+ """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
def __init__(self):
- self.received_rdata_rows = []
+ self.streams = set()
+ self._received_rdata_rows = []
def get_streams_to_replicate(self):
- return {}
-
- def get_currently_syncing_users(self):
- return []
-
- def update_connection(self, connection):
- pass
-
- def finished_connecting(self):
- pass
+ positions = {s: 0 for s in self.streams}
+ for stream, token, _ in self._received_rdata_rows:
+ if stream in self.streams:
+ positions[stream] = max(token, positions.get(stream, 0))
+ return positions
async def on_rdata(self, stream_name, token, rows):
for r in rows:
- self.received_rdata_rows.append((stream_name, token, r))
+ self._received_rdata_rows.append((stream_name, token, r))
+
+ async def on_position(self, stream_name, token):
+ pass
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index d5a99f6caa..a0206f7363 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -12,35 +12,68 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from synapse.replication.tcp.streams._base import ReceiptsStreamRow
+from synapse.replication.tcp.streams._base import ReceiptsStream
from tests.replication.tcp.streams._base import BaseStreamTestCase
USER_ID = "@feeling:blue"
-ROOM_ID = "!room:blue"
-EVENT_ID = "$event:blue"
class ReceiptsStreamTestCase(BaseStreamTestCase):
def test_receipt(self):
+ self.reconnect()
+
# make the client subscribe to the receipts stream
- self.replicate_stream("receipts", "NOW")
+ self.test_handler.streams.add("receipts")
# tell the master to send a new receipt
self.get_success(
self.hs.get_datastore().insert_receipt(
- ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
+ "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
)
)
self.replicate()
# there should be one RDATA command
- rdata_rows = self.test_handler.received_rdata_rows
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
self.assertEqual(1, len(rdata_rows))
- self.assertEqual(rdata_rows[0][0], "receipts")
- row = rdata_rows[0][2] # type: ReceiptsStreamRow
- self.assertEqual(ROOM_ID, row.room_id)
+ row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ self.assertEqual("!room:blue", row.room_id)
self.assertEqual("m.read", row.receipt_type)
self.assertEqual(USER_ID, row.user_id)
- self.assertEqual(EVENT_ID, row.event_id)
+ self.assertEqual("$event:blue", row.event_id)
self.assertEqual({"a": 1}, row.data)
+
+ # Now let's disconnect and insert some data.
+ self.disconnect()
+
+ self.test_handler.on_rdata.reset_mock()
+
+ self.get_success(
+ self.hs.get_datastore().insert_receipt(
+ "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
+ )
+ )
+ self.replicate()
+
+ # Nothing should have happened as we are disconnected
+ self.test_handler.on_rdata.assert_not_called()
+
+ self.reconnect()
+ self.pump(0.1)
+
+ # We should now have caught up and get the missing data
+ self.test_handler.on_rdata.assert_called_once()
+ stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+ self.assertEqual(stream_name, "receipts")
+ self.assertEqual(token, 3)
+ self.assertEqual(1, len(rdata_rows))
+
+ row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow
+ self.assertEqual("!room2:blue", row.room_id)
+ self.assertEqual("m.read", row.receipt_type)
+ self.assertEqual(USER_ID, row.user_id)
+ self.assertEqual("$event2:foo", row.event_id)
+ self.assertEqual({"a": 2}, row.data)
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
new file mode 100644
index 0000000000..3cbcb513cc
--- /dev/null
+++ b/tests/replication/tcp/test_commands.py
@@ -0,0 +1,42 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from synapse.replication.tcp.commands import (
+ RdataCommand,
+ ReplicateCommand,
+ parse_command_from_line,
+)
+
+from tests.unittest import TestCase
+
+
+class ParseCommandTestCase(TestCase):
+ def test_parse_one_word_command(self):
+ line = "REPLICATE"
+ cmd = parse_command_from_line(line)
+ self.assertIsInstance(cmd, ReplicateCommand)
+
+ def test_parse_rdata(self):
+ line = 'RDATA events 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
+ cmd = parse_command_from_line(line)
+ self.assertIsInstance(cmd, RdataCommand)
+ self.assertEqual(cmd.stream_name, "events")
+ self.assertEqual(cmd.token, 6287863)
+
+ def test_parse_rdata_batch(self):
+ line = 'RDATA presence batch ["@foo:example.com", "online"]'
+ cmd = parse_command_from_line(line)
+ self.assertIsInstance(cmd, RdataCommand)
+ self.assertEqual(cmd.stream_name, "presence")
+ self.assertIsNone(cmd.token)
|