summary refs log tree commit diff
path: root/tests/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/tcp/streams/_base.py')
-rw-r--r--tests/replication/tcp/streams/_base.py34
1 files changed, 12 insertions, 22 deletions
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 83e16cfe3d..7b56d2028d 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
 
 import attr
 
@@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 
-from synapse.app.generic_worker import GenericWorkerServer
+from synapse.app.generic_worker import (
+    GenericWorkerReplicationHandler,
+    GenericWorkerServer,
+)
 from synapse.http.site import SynapseRequest
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.client import ReplicationDataHandler
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
@@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._server_transport = None
 
     def _build_replication_data_handler(self):
-        return TestReplicationDataHandler(self.worker_hs.get_datastore())
+        return TestReplicationDataHandler(self.worker_hs)
 
     def reconnect(self):
         if self._client_transport:
@@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(request.method, b"GET")
 
 
-class TestReplicationDataHandler(ReplicationDataHandler):
+class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
 
-    def __init__(self, store: BaseSlavedStore):
-        super().__init__(store)
-
-        # streams to subscribe to: map from stream id to position
-        self.stream_positions = {}  # type: Dict[str, int]
+    def __init__(self, hs: HomeServer):
+        super().__init__(hs)
 
         # list of received (stream_name, token, row) tuples
         self.received_rdata_rows = []  # type: List[Tuple[str, int, Any]]
 
-    def get_streams_to_replicate(self):
-        return self.stream_positions
-
-    async def on_rdata(self, stream_name, token, rows):
-        await super().on_rdata(stream_name, token, rows)
+    async def on_rdata(self, stream_name, instance_name, token, rows):
+        await super().on_rdata(stream_name, instance_name, token, rows)
         for r in rows:
             self.received_rdata_rows.append((stream_name, token, r))
 
-        if (
-            stream_name in self.stream_positions
-            and token > self.stream_positions[stream_name]
-        ):
-            self.stream_positions[stream_name] = token
-
 
 @attr.s()
 class OneShotRequestFactory: