summary refs log tree commit diff
path: root/tests/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/tcp/streams/_base.py')
-rw-r--r--tests/replication/tcp/streams/_base.py54
1 files changed, 25 insertions, 29 deletions
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index 83e16cfe3d..9d4f0bbe44 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 import logging
-from typing import Any, Dict, List, Optional, Tuple
+from typing import Any, List, Optional, Tuple
 
 import attr
 
@@ -22,13 +22,16 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime
 from twisted.internet.task import LoopingCall
 from twisted.web.http import HTTPChannel
 
-from synapse.app.generic_worker import GenericWorkerServer
+from synapse.app.generic_worker import (
+    GenericWorkerReplicationHandler,
+    GenericWorkerServer,
+)
 from synapse.http.site import SynapseRequest
-from synapse.replication.slave.storage._base import BaseSlavedStore
-from synapse.replication.tcp.client import ReplicationDataHandler
+from synapse.replication.http import streams
 from synapse.replication.tcp.handler import ReplicationCommandHandler
 from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
@@ -40,6 +43,10 @@ logger = logging.getLogger(__name__)
 class BaseStreamTestCase(unittest.HomeserverTestCase):
     """Base class for tests of the replication streams"""
 
+    servlets = [
+        streams.register_servlets,
+    ]
+
     def prepare(self, reactor, clock, hs):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(hs)
@@ -47,17 +54,11 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.server = server_factory.buildProtocol(None)
 
         # Make a new HomeServer object for the worker
-        config = self.default_config()
-        config["worker_app"] = "synapse.app.generic_worker"
-        config["worker_replication_host"] = "testserv"
-        config["worker_replication_http_port"] = "8765"
-
         self.reactor.lookups["testserv"] = "1.2.3.4"
-
         self.worker_hs = self.setup_test_homeserver(
             http_client=None,
             homeserverToUse=GenericWorkerServer,
-            config=config,
+            config=self._get_worker_hs_config(),
             reactor=self.reactor,
         )
 
@@ -76,8 +77,15 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self._client_transport = None
         self._server_transport = None
 
+    def _get_worker_hs_config(self) -> dict:
+        config = self.default_config()
+        config["worker_app"] = "synapse.app.generic_worker"
+        config["worker_replication_host"] = "testserv"
+        config["worker_replication_http_port"] = "8765"
+        return config
+
     def _build_replication_data_handler(self):
-        return TestReplicationDataHandler(self.worker_hs.get_datastore())
+        return TestReplicationDataHandler(self.worker_hs)
 
     def reconnect(self):
         if self._client_transport:
@@ -172,32 +180,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.assertEqual(request.method, b"GET")
 
 
-class TestReplicationDataHandler(ReplicationDataHandler):
+class TestReplicationDataHandler(GenericWorkerReplicationHandler):
     """Drop-in for ReplicationDataHandler which just collects RDATA rows"""
 
-    def __init__(self, store: BaseSlavedStore):
-        super().__init__(store)
-
-        # streams to subscribe to: map from stream id to position
-        self.stream_positions = {}  # type: Dict[str, int]
+    def __init__(self, hs: HomeServer):
+        super().__init__(hs)
 
         # list of received (stream_name, token, row) tuples
         self.received_rdata_rows = []  # type: List[Tuple[str, int, Any]]
 
-    def get_streams_to_replicate(self):
-        return self.stream_positions
-
-    async def on_rdata(self, stream_name, token, rows):
-        await super().on_rdata(stream_name, token, rows)
+    async def on_rdata(self, stream_name, instance_name, token, rows):
+        await super().on_rdata(stream_name, instance_name, token, rows)
         for r in rows:
             self.received_rdata_rows.append((stream_name, token, r))
 
-        if (
-            stream_name in self.stream_positions
-            and token > self.stream_positions[stream_name]
-        ):
-            self.stream_positions[stream_name] = token
-
 
 @attr.s()
 class OneShotRequestFactory: