summary refs log tree commit diff
path: root/tests/replication/tcp/streams
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/tcp/streams')
-rw-r--r--tests/replication/tcp/streams/_base.py55
-rw-r--r--tests/replication/tcp/streams/test_receipts.py52
2 files changed, 84 insertions, 23 deletions
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index e96ad4ca4e..a755fe2879 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 from mock import Mock
 
 from synapse.replication.tcp.commands import ReplicateCommand
@@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(self.hs)
         self.streamer = server_factory.streamer
-        server = server_factory.buildProtocol(None)
+        self.server = server_factory.buildProtocol(None)
 
-        # build a replication client, with a dummy handler
-        handler_factory = Mock()
-        self.test_handler = TestReplicationClientHandler()
-        self.test_handler.factory = handler_factory
+        self.test_handler = Mock(wraps=TestReplicationClientHandler())
         self.client = ClientReplicationStreamProtocol(
-            "client", "test", clock, self.test_handler
+            hs, "client", "test", clock, self.test_handler,
         )
 
-        # wire them together
-        self.client.makeConnection(FakeTransport(server, reactor))
-        server.makeConnection(FakeTransport(self.client, reactor))
+        self._client_transport = None
+        self._server_transport = None
+
+    def reconnect(self):
+        if self._client_transport:
+            self.client.close()
+
+        if self._server_transport:
+            self.server.close()
+
+        self._client_transport = FakeTransport(self.server, self.reactor)
+        self.client.makeConnection(self._client_transport)
+
+        self._server_transport = FakeTransport(self.client, self.reactor)
+        self.server.makeConnection(self._server_transport)
+
+    def disconnect(self):
+        if self._client_transport:
+            self._client_transport = None
+            self.client.close()
+
+        if self._server_transport:
+            self._server_transport = None
+            self.server.close()
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
@@ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.streamer.on_notifier_poke()
         self.pump(0.1)
 
-    def replicate_stream(self, stream, token="NOW"):
+    def replicate_stream(self):
         """Make the client end a REPLICATE command to set up a subscription to a stream"""
-        self.client.send_command(ReplicateCommand(stream, token))
+        self.client.send_command(ReplicateCommand())
 
 
 class TestReplicationClientHandler(object):
     """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
 
     def __init__(self):
-        self.received_rdata_rows = []
+        self.streams = set()
+        self._received_rdata_rows = []
 
     def get_streams_to_replicate(self):
-        return {}
+        positions = {s: 0 for s in self.streams}
+        for stream, token, _ in self._received_rdata_rows:
+            if stream in self.streams:
+                positions[stream] = max(token, positions.get(stream, 0))
+        return positions
 
     def get_currently_syncing_users(self):
         return []
@@ -73,6 +97,9 @@ class TestReplicationClientHandler(object):
     def finished_connecting(self):
         pass
 
+    async def on_position(self, stream_name, token):
+        """Called when we get new position data."""
+
     async def on_rdata(self, stream_name, token, rows):
         for r in rows:
-            self.received_rdata_rows.append((stream_name, token, r))
+            self._received_rdata_rows.append((stream_name, token, r))
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index fa2493cad6..0ec0825a0e 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -17,30 +17,64 @@ from synapse.replication.tcp.streams._base import ReceiptsStream
 from tests.replication.tcp.streams._base import BaseStreamTestCase
 
 USER_ID = "@feeling:blue"
-ROOM_ID = "!room:blue"
-EVENT_ID = "$event:blue"
 
 
 class ReceiptsStreamTestCase(BaseStreamTestCase):
     def test_receipt(self):
+        self.reconnect()
+
         # make the client subscribe to the receipts stream
-        self.replicate_stream("receipts", "NOW")
+        self.replicate_stream()
+        self.test_handler.streams.add("receipts")
 
         # tell the master to send a new receipt
         self.get_success(
             self.hs.get_datastore().insert_receipt(
-                ROOM_ID, "m.read", USER_ID, [EVENT_ID], {"a": 1}
+                "!room:blue", "m.read", USER_ID, ["$event:blue"], {"a": 1}
             )
         )
         self.replicate()
 
         # there should be one RDATA command
-        rdata_rows = self.test_handler.received_rdata_rows
+        self.test_handler.on_rdata.assert_called_once()
+        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.assertEqual(stream_name, "receipts")
         self.assertEqual(1, len(rdata_rows))
-        self.assertEqual(rdata_rows[0][0], "receipts")
-        row = rdata_rows[0][2]  # type: ReceiptsStream.ReceiptsStreamRow
-        self.assertEqual(ROOM_ID, row.room_id)
+        row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
+        self.assertEqual("!room:blue", row.room_id)
         self.assertEqual("m.read", row.receipt_type)
         self.assertEqual(USER_ID, row.user_id)
-        self.assertEqual(EVENT_ID, row.event_id)
+        self.assertEqual("$event:blue", row.event_id)
         self.assertEqual({"a": 1}, row.data)
+
+        # Now let's disconnect and insert some data.
+        self.disconnect()
+
+        self.test_handler.on_rdata.reset_mock()
+
+        self.get_success(
+            self.hs.get_datastore().insert_receipt(
+                "!room2:blue", "m.read", USER_ID, ["$event2:foo"], {"a": 2}
+            )
+        )
+        self.replicate()
+
+        # Nothing should have happened as we are disconnected
+        self.test_handler.on_rdata.assert_not_called()
+
+        self.reconnect()
+        self.pump(0.1)
+
+        # We should now have caught up and get the missing data
+        self.test_handler.on_rdata.assert_called_once()
+        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.assertEqual(stream_name, "receipts")
+        self.assertEqual(token, 3)
+        self.assertEqual(1, len(rdata_rows))
+
+        row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
+        self.assertEqual("!room2:blue", row.room_id)
+        self.assertEqual("m.read", row.receipt_type)
+        self.assertEqual(USER_ID, row.user_id)
+        self.assertEqual("$event2:foo", row.event_id)
+        self.assertEqual({"a": 2}, row.data)