summary refs log tree commit diff
path: root/tests/replication
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication')
-rw-r--r--tests/replication/tcp/streams/_base.py34
-rw-r--r--tests/replication/tcp/streams/test_events.py24
-rw-r--r--tests/replication/tcp/streams/test_receipts.py7
-rw-r--r--tests/replication/tcp/streams/test_typing.py7
4 files changed, 32 insertions, 40 deletions
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 83e16cfe3d..7b56d2028d 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
 
 import attr
 
@@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 
-from synapse.app.generic_worker import GenericWorkerServer
+from synapse.app.generic_worker import (
+    GenericWorkerReplicationHandler,
+    GenericWorkerServer,
+)
 from synapse.http.site import SynapseRequest
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.client import ReplicationDataHandler
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
@@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._server_transport = None
 
     def _build_replication_data_handler(self):
-        return TestReplicationDataHandler(self.worker_hs.get_datastore())
+        return TestReplicationDataHandler(self.worker_hs)
 
     def reconnect(self):
         if self._client_transport:
@@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(request.method, b"GET")
 
 
-class TestReplicationDataHandler(ReplicationDataHandler):
+class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
 
-    def __init__(self, store: BaseSlavedStore):
-        super().__init__(store)
-
-        # streams to subscribe to: map from stream id to position
-        self.stream_positions = {}  # type: Dict[str, int]
+    def __init__(self, hs: HomeServer):
+        super().__init__(hs)
 
         # list of received (stream_name, token, row) tuples
         self.received_rdata_rows = []  # type: List[Tuple[str, int, Any]]
 
-    def get_streams_to_replicate(self):
-        return self.stream_positions
-
-    async def on_rdata(self, stream_name, token, rows):
-        await super().on_rdata(stream_name, token, rows)
+    async def on_rdata(self, stream_name, instance_name, token, rows):
+        await super().on_rdata(stream_name, instance_name, token, rows)
         for r in rows:
             self.received_rdata_rows.append((stream_name, token, r))
 
-        if (
-            stream_name in self.stream_positions
-            and token > self.stream_positions[stream_name]
-        ):
-            self.stream_positions[stream_name] = token
-
 
 @attr.s()
 class OneShotRequestFactory:
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 1fa28084f9..8bd67bb9f1 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.user_tok = self.login("u1", "pass")
 
         self.reconnect()
-        self.test_handler.stream_positions["events"] = 0
 
         self.room_id = self.helper.create_room_as(tok=self.user_tok)
         self.test_handler.received_rdata_rows.clear()
@@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
         self.replicate()
 
-        # we should have received all the expected rows in the right order
-        received_rows = self.test_handler.received_rdata_rows
+        # we should have received all the expected rows in the right order (as
+        # well as various cache invalidation updates which we ignore)
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
+
         for event in events:
             stream_name, token, row = received_rows.pop(0)
             self.assertEqual("events", stream_name)
@@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
         self.replicate()
 
-        # now we should have received all the expected rows in the right order.
+        # we should have received all the expected rows in the right order (as
+        # well as various cache invalidation updates which we ignore)
         #
         # we expect:
         #
@@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase):
         #       of the states that got reverted.
         # - two rows for state2
 
-        received_rows = self.test_handler.received_rdata_rows
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
 
         # first check the first two rows, which should be state1
 
@@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.reconnect()
         self.replicate()
 
-        # we should have received all the expected rows in the right order
-
-        received_rows = self.test_handler.received_rdata_rows
+        # we should have received all the expected rows in the right order (as
+        # well as various cache invalidation updates which we ignore)
+        received_rows = [
+            row for row in self.test_handler.received_rdata_rows if row[0] == "events"
+        ]
         self.assertGreaterEqual(len(received_rows), len(events))
         for i in range(NUM_USERS):
             # for each user, we expect the PL event row, followed by state rows for
diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py
index c122b8589c..5853314fd4 100644
--- a/tests/replication/tcp/streams/test_receipts.py
+++ b/tests/replication/tcp/streams/test_receipts.py
@@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
     def test_receipt(self):
         self.reconnect()
 
-        # make the client subscribe to the receipts stream
-        self.test_handler.stream_positions.update({"receipts": 0})
-
         # tell the master to send a new receipt
         self.get_success(
             self.hs.get_datastore().insert_receipt(
@@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
 
         # there should be one RDATA command
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "receipts")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]  # type: ReceiptsStream.ReceiptsStreamRow
@@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase):
 
         # We should now have caught up and get the missing data
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "receipts")
         self.assertEqual(token, 3)
         self.assertEqual(1, len(rdata_rows))
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 4d354a9db8..d25a7b194e 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase):
 
         self.reconnect()
 
-        # make the client subscribe to the typing stream
-        self.test_handler.stream_positions.update({"typing": 0})
-
         typing._push_update(member=RoomMember(room_id, USER_ID), typing=True)
 
         self.reactor.advance(0)
@@ -50,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]  # type: TypingStream.TypingStreamRow
@@ -77,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assertEqual(int(request.args[b"from_token"][0]), token)
 
         self.test_handler.on_rdata.assert_called_once()
-        stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]