summary refs log tree commit diff
path: root/tests/replication/tcp/streams/_base.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/tcp/streams/_base.py')
-rw-r--r--tests/replication/tcp/streams/_base.py55
1 files changed, 41 insertions, 14 deletions
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index e96ad4ca4e..a755fe2879 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -12,6 +12,7 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+
 from mock import Mock
 
 from synapse.replication.tcp.commands import ReplicateCommand
@@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         # build a replication server
         server_factory = ReplicationStreamProtocolFactory(self.hs)
         self.streamer = server_factory.streamer
-        server = server_factory.buildProtocol(None)
+        self.server = server_factory.buildProtocol(None)
 
-        # build a replication client, with a dummy handler
-        handler_factory = Mock()
-        self.test_handler = TestReplicationClientHandler()
-        self.test_handler.factory = handler_factory
+        self.test_handler = Mock(wraps=TestReplicationClientHandler())
         self.client = ClientReplicationStreamProtocol(
-            "client", "test", clock, self.test_handler
+            hs, "client", "test", clock, self.test_handler,
         )
 
-        # wire them together
-        self.client.makeConnection(FakeTransport(server, reactor))
-        server.makeConnection(FakeTransport(self.client, reactor))
+        self._client_transport = None
+        self._server_transport = None
+
+    def reconnect(self):
+        if self._client_transport:
+            self.client.close()
+
+        if self._server_transport:
+            self.server.close()
+
+        self._client_transport = FakeTransport(self.server, self.reactor)
+        self.client.makeConnection(self._client_transport)
+
+        self._server_transport = FakeTransport(self.client, self.reactor)
+        self.server.makeConnection(self._server_transport)
+
+    def disconnect(self):
+        if self._client_transport:
+            self._client_transport = None
+            self.client.close()
+
+        if self._server_transport:
+            self._server_transport = None
+            self.server.close()
 
     def replicate(self):
         """Tell the master side of replication that something has happened, and then
@@ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
         self.streamer.on_notifier_poke()
         self.pump(0.1)
 
-    def replicate_stream(self, stream, token="NOW"):
+    def replicate_stream(self):
         """Make the client end a REPLICATE command to set up a subscription to a stream"""
-        self.client.send_command(ReplicateCommand(stream, token))
+        self.client.send_command(ReplicateCommand())
 
 
 class TestReplicationClientHandler(object):
     """Drop-in for ReplicationClientHandler which just collects RDATA rows"""
 
     def __init__(self):
-        self.received_rdata_rows = []
+        self.streams = set()
+        self._received_rdata_rows = []
 
     def get_streams_to_replicate(self):
-        return {}
+        positions = {s: 0 for s in self.streams}
+        for stream, token, _ in self._received_rdata_rows:
+            if stream in self.streams:
+                positions[stream] = max(token, positions.get(stream, 0))
+        return positions
 
     def get_currently_syncing_users(self):
         return []
@@ -73,6 +97,9 @@ class TestReplicationClientHandler(object):
     def finished_connecting(self):
         pass
 
+    async def on_position(self, stream_name, token):
+        """Called when we get new position data."""
+
     async def on_rdata(self, stream_name, token, rows):
         for r in rows:
-            self.received_rdata_rows.append((stream_name, token, r))
+            self._received_rdata_rows.append((stream_name, token, r))