diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py
index e96ad4ca4e..a755fe2879 100644
--- a/tests/replication/tcp/streams/_base.py
+++ b/tests/replication/tcp/streams/_base.py
@@ -12,6 +12,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
from mock import Mock
from synapse.replication.tcp.commands import ReplicateCommand
@@ -29,19 +30,37 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
# build a replication server
server_factory = ReplicationStreamProtocolFactory(self.hs)
self.streamer = server_factory.streamer
- server = server_factory.buildProtocol(None)
+ self.server = server_factory.buildProtocol(None)
- # build a replication client, with a dummy handler
- handler_factory = Mock()
- self.test_handler = TestReplicationClientHandler()
- self.test_handler.factory = handler_factory
+ self.test_handler = Mock(wraps=TestReplicationClientHandler())
self.client = ClientReplicationStreamProtocol(
- "client", "test", clock, self.test_handler
+ hs, "client", "test", clock, self.test_handler,
)
- # wire them together
- self.client.makeConnection(FakeTransport(server, reactor))
- server.makeConnection(FakeTransport(self.client, reactor))
+ self._client_transport = None
+ self._server_transport = None
+
+ def reconnect(self):
+ if self._client_transport:
+ self.client.close()
+
+ if self._server_transport:
+ self.server.close()
+
+ self._client_transport = FakeTransport(self.server, self.reactor)
+ self.client.makeConnection(self._client_transport)
+
+ self._server_transport = FakeTransport(self.client, self.reactor)
+ self.server.makeConnection(self._server_transport)
+
+ def disconnect(self):
+ if self._client_transport:
+ self._client_transport = None
+ self.client.close()
+
+ if self._server_transport:
+ self._server_transport = None
+ self.server.close()
def replicate(self):
"""Tell the master side of replication that something has happened, and then
@@ -50,19 +69,24 @@ class BaseStreamTestCase(unittest.HomeserverTestCase):
self.streamer.on_notifier_poke()
self.pump(0.1)
- def replicate_stream(self, stream, token="NOW"):
+ def replicate_stream(self):
"""Make the client end a REPLICATE command to set up a subscription to a stream"""
- self.client.send_command(ReplicateCommand(stream, token))
+ self.client.send_command(ReplicateCommand())
class TestReplicationClientHandler(object):
"""Drop-in for ReplicationClientHandler which just collects RDATA rows"""
def __init__(self):
- self.received_rdata_rows = []
+ self.streams = set()
+ self._received_rdata_rows = []
def get_streams_to_replicate(self):
- return {}
+ positions = {s: 0 for s in self.streams}
+ for stream, token, _ in self._received_rdata_rows:
+ if stream in self.streams:
+ positions[stream] = max(token, positions.get(stream, 0))
+ return positions
def get_currently_syncing_users(self):
return []
@@ -73,6 +97,9 @@ class TestReplicationClientHandler(object):
def finished_connecting(self):
pass
+ async def on_position(self, stream_name, token):
+ """Called when we get new position data."""
+
async def on_rdata(self, stream_name, token, rows):
for r in rows:
- self.received_rdata_rows.append((stream_name, token, r))
+ self._received_rdata_rows.append((stream_name, token, r))
|