diff options
Diffstat (limited to 'tests/replication/tcp')
-rw-r--r-- | tests/replication/tcp/streams/_base.py | 34 | ||||
-rw-r--r-- | tests/replication/tcp/streams/test_events.py | 24 | ||||
-rw-r--r-- | tests/replication/tcp/streams/test_receipts.py | 7 | ||||
-rw-r--r-- | tests/replication/tcp/streams/test_typing.py | 7 |
4 files changed, 32 insertions, 40 deletions
diff --git a/tests/replication/tcp/streams/_base.py b/tests/replication/tcp/streams/_base.py index 83e16cfe3d..7b56d2028d 100644 --- a/tests/replication/tcp/streams/_base.py +++ b/tests/replication/tcp/streams/_base.py @@ -14,7 +14,7 @@ # limitations under the License. import logging -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, List, Optional, Tuple import attr @@ -22,13 +22,15 @@ from twisted.internet.interfaces import IConsumer, IPullProducer, IReactorTime from twisted.internet.task import LoopingCall from twisted.web.http import HTTPChannel -from synapse.app.generic_worker import GenericWorkerServer +from synapse.app.generic_worker import ( + GenericWorkerReplicationHandler, + GenericWorkerServer, +) from synapse.http.site import SynapseRequest -from synapse.replication.slave.storage._base import BaseSlavedStore -from synapse.replication.tcp.client import ReplicationDataHandler from synapse.replication.tcp.handler import ReplicationCommandHandler from synapse.replication.tcp.protocol import ClientReplicationStreamProtocol from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory +from synapse.server import HomeServer from synapse.util import Clock from tests import unittest @@ -77,7 +79,7 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self._server_transport = None def _build_replication_data_handler(self): - return TestReplicationDataHandler(self.worker_hs.get_datastore()) + return TestReplicationDataHandler(self.worker_hs) def reconnect(self): if self._client_transport: @@ -172,32 +174,20 @@ class BaseStreamTestCase(unittest.HomeserverTestCase): self.assertEqual(request.method, b"GET") -class TestReplicationDataHandler(ReplicationDataHandler): +class TestReplicationDataHandler(GenericWorkerReplicationHandler): """Drop-in for ReplicationDataHandler which just collects RDATA rows""" - def __init__(self, store: BaseSlavedStore): - super().__init__(store) - - # streams to subscribe to: map from stream id to position - self.stream_positions = {} # type: Dict[str, int] + def __init__(self, hs: HomeServer): + super().__init__(hs) # list of received (stream_name, token, row) tuples self.received_rdata_rows = [] # type: List[Tuple[str, int, Any]] - def get_streams_to_replicate(self): - return self.stream_positions - - async def on_rdata(self, stream_name, token, rows): - await super().on_rdata(stream_name, token, rows) + async def on_rdata(self, stream_name, instance_name, token, rows): + await super().on_rdata(stream_name, instance_name, token, rows) for r in rows: self.received_rdata_rows.append((stream_name, token, r)) - if ( - stream_name in self.stream_positions - and token > self.stream_positions[stream_name] - ): - self.stream_positions[stream_name] = token - @attr.s() class OneShotRequestFactory: diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 1fa28084f9..8bd67bb9f1 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -43,7 +43,6 @@ class EventsStreamTestCase(BaseStreamTestCase): self.user_tok = self.login("u1", "pass") self.reconnect() - self.test_handler.stream_positions["events"] = 0 self.room_id = self.helper.create_room_as(tok=self.user_tok) self.test_handler.received_rdata_rows.clear() @@ -80,8 +79,12 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - received_rows = self.test_handler.received_rdata_rows + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] + for event in events: stream_name, token, row = received_rows.pop(0) self.assertEqual("events", stream_name) @@ -184,7 +187,8 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # now we should have received all the expected rows in the right order. + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) # # we expect: # @@ -193,7 +197,9 @@ class EventsStreamTestCase(BaseStreamTestCase): # of the states that got reverted. # - two rows for state2 - received_rows = self.test_handler.received_rdata_rows + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] # first check the first two rows, which should be state1 @@ -334,9 +340,11 @@ class EventsStreamTestCase(BaseStreamTestCase): self.reconnect() self.replicate() - # we should have received all the expected rows in the right order - - received_rows = self.test_handler.received_rdata_rows + # we should have received all the expected rows in the right order (as + # well as various cache invalidation updates which we ignore) + received_rows = [ + row for row in self.test_handler.received_rdata_rows if row[0] == "events" + ] self.assertGreaterEqual(len(received_rows), len(events)) for i in range(NUM_USERS): # for each user, we expect the PL event row, followed by state rows for diff --git a/tests/replication/tcp/streams/test_receipts.py b/tests/replication/tcp/streams/test_receipts.py index c122b8589c..5853314fd4 100644 --- a/tests/replication/tcp/streams/test_receipts.py +++ b/tests/replication/tcp/streams/test_receipts.py @@ -31,9 +31,6 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): def test_receipt(self): self.reconnect() - # make the client subscribe to the receipts stream - self.test_handler.stream_positions.update({"receipts": 0}) - # tell the master to send a new receipt self.get_success( self.hs.get_datastore().insert_receipt( @@ -44,7 +41,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # there should be one RDATA command self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: ReceiptsStream.ReceiptsStreamRow @@ -74,7 +71,7 @@ class ReceiptsStreamTestCase(BaseStreamTestCase): # We should now have caught up and get the missing data self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "receipts") self.assertEqual(token, 3) self.assertEqual(1, len(rdata_rows)) diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 4d354a9db8..d25a7b194e 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -38,9 +38,6 @@ class TypingStreamTestCase(BaseStreamTestCase): self.reconnect() - # make the client subscribe to the typing stream - self.test_handler.stream_positions.update({"typing": 0}) - typing._push_update(member=RoomMember(room_id, USER_ID), typing=True) self.reactor.advance(0) @@ -50,7 +47,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assert_request_is_get_repl_stream_updates(request, "typing") self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] # type: TypingStream.TypingStreamRow @@ -77,7 +74,7 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assertEqual(int(request.args[b"from_token"][0]), token) self.test_handler.on_rdata.assert_called_once() - stream_name, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] |