diff options
Diffstat (limited to 'tests/replication/tcp/streams')
-rw-r--r-- | tests/replication/tcp/streams/test_account_data.py | 4 | ||||
-rw-r--r-- | tests/replication/tcp/streams/test_events.py | 26 | ||||
-rw-r--r-- | tests/replication/tcp/streams/test_federation.py | 2 | ||||
-rw-r--r-- | tests/replication/tcp/streams/test_partial_state.py | 4 | ||||
-rw-r--r-- | tests/replication/tcp/streams/test_typing.py | 37 |
5 files changed, 40 insertions, 33 deletions
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py index 50fbff5f32..01df1be047 100644 --- a/tests/replication/tcp/streams/test_account_data.py +++ b/tests/replication/tcp/streams/test_account_data.py @@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase class AccountDataStreamTestCase(BaseStreamTestCase): - def test_update_function_room_account_data_limit(self): + def test_update_function_room_account_data_limit(self) -> None: """Test replication with many room account data updates""" store = self.hs.get_datastores().main @@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase): self.assertEqual([], received_rows) - def test_update_function_global_account_data_limit(self): + def test_update_function_global_account_data_limit(self) -> None: """Test replication with many global account data updates""" store = self.hs.get_datastores().main diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py index 641a94133b..65ef4bb160 100644 --- a/tests/replication/tcp/streams/test_events.py +++ b/tests/replication/tcp/streams/test_events.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional +from typing import Any, List, Optional, Sequence + +from twisted.test.proto_helpers import MemoryReactor from synapse.api.constants import EventTypes, Membership from synapse.events import EventBase @@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import ( ) from synapse.rest import admin from synapse.rest.client import login, room +from synapse.server import HomeServer +from synapse.util import Clock from tests.replication._base import BaseStreamTestCase from tests.test_utils.event_injection import inject_event, inject_member_event @@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase): room.register_servlets, ] - def prepare(self, reactor, clock, hs): + def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None: super().prepare(reactor, clock, hs) self.user_id = self.register_user("u1", "pass") self.user_tok = self.login("u1", "pass") @@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.room_id = self.helper.create_room_as(tok=self.user_tok) self.test_handler.received_rdata_rows.clear() - def test_update_function_event_row_limit(self): + def test_update_function_event_row_limit(self) -> None: """Test replication with many non-state events Checks that all events are correctly replicated when there are lots of @@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual([], received_rows) - def test_update_function_huge_state_change(self): + def test_update_function_huge_state_change(self) -> None: """Test replication with many state events Ensures that all events are correctly replicated when there are lots of @@ -135,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point: List[str] = self.get_success( + fork_point: Sequence[str] = self.get_success( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) @@ -164,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase): pl_event = self.get_success( inject_event( self.hs, - prev_event_ids=prev_events, + prev_event_ids=list(prev_events), type=EventTypes.PowerLevels, state_key="", sender=self.user_id, @@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase): # "None" indicates the state has been deleted self.assertIsNone(sr.event_id) - def test_update_function_state_row_limit(self): + def test_update_function_state_row_limit(self) -> None: """Test replication with many state events over several stream ids.""" # we want to generate lots of state changes, but for this test, we want to @@ -290,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase): ) # this is the point in the DAG where we make a fork - fork_point: List[str] = self.get_success( + fork_point: Sequence[str] = self.get_success( self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id) ) @@ -319,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase): e = self.get_success( inject_event( self.hs, - prev_event_ids=prev_events, + prev_event_ids=list(prev_events), type=EventTypes.PowerLevels, state_key="", sender=self.user_id, @@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase): self.assertEqual([], received_rows) - def test_backwards_stream_id(self): + def test_backwards_stream_id(self) -> None: """ Test that RDATA that comes after the current position should be discarded. """ @@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase): event_count = 0 def _inject_test_event( - self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs + self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any ) -> EventBase: if sender is None: sender = self.user_id diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py index bcb82c9c80..cdbdfaf057 100644 --- a/tests/replication/tcp/streams/test_federation.py +++ b/tests/replication/tcp/streams/test_federation.py @@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase): config["federation_sender_instances"] = ["federation_sender1"] return config - def test_catchup(self): + def test_catchup(self) -> None: """Basic test of catchup on reconnect Makes sure that updates sent while we are offline are received later. diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py index 2c10eab4db..452ac85069 100644 --- a/tests/replication/tcp/streams/test_partial_state.py +++ b/tests/replication/tcp/streams/test_partial_state.py @@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase): hijack_auth = True user_id = "@bob:test" - def setUp(self): + def setUp(self) -> None: super().setUp() self.store = self.hs.get_datastores().main @@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase): room_id = self.helper.create_room_as("@bob:test") # Mark the room as partial-stated. self.get_success( - self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1") + self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1") ) worker = self.make_worker_hs("synapse.app.generic_worker") diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py index 9a229dd23f..5a38ac831f 100644 --- a/tests/replication/tcp/streams/test_typing.py +++ b/tests/replication/tcp/streams/test_typing.py @@ -13,7 +13,7 @@ # limitations under the License. from unittest.mock import Mock -from synapse.handlers.typing import RoomMember +from synapse.handlers.typing import RoomMember, TypingWriterHandler from synapse.replication.tcp.streams import TypingStream from synapse.util.caches.stream_change_cache import StreamChangeCache @@ -27,11 +27,13 @@ ROOM_ID_2 = "!foo:blue" class TypingStreamTestCase(BaseStreamTestCase): - def _build_replication_data_handler(self): - return Mock(wraps=super()._build_replication_data_handler()) + def _build_replication_data_handler(self) -> Mock: + self.mock_handler = Mock(wraps=super()._build_replication_data_handler()) + return self.mock_handler - def test_typing(self): + def test_typing(self) -> None: typing = self.hs.get_typing_handler() + assert isinstance(typing, TypingWriterHandler) self.reconnect() @@ -43,8 +45,8 @@ class TypingStreamTestCase(BaseStreamTestCase): request = self.handle_http_replication_attempt() self.assert_request_is_get_repl_stream_updates(request, "typing") - self.test_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.mock_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row: TypingStream.TypingStreamRow = rdata_rows[0] @@ -54,11 +56,11 @@ class TypingStreamTestCase(BaseStreamTestCase): # Now let's disconnect and insert some data. self.disconnect() - self.test_handler.on_rdata.reset_mock() + self.mock_handler.on_rdata.reset_mock() typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False) - self.test_handler.on_rdata.assert_not_called() + self.mock_handler.on_rdata.assert_not_called() self.reconnect() self.pump(0.1) @@ -71,15 +73,15 @@ class TypingStreamTestCase(BaseStreamTestCase): assert request.args is not None self.assertEqual(int(request.args[b"from_token"][0]), token) - self.test_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.mock_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] self.assertEqual(ROOM_ID, row.room_id) self.assertEqual([], row.user_ids) - def test_reset(self): + def test_reset(self) -> None: """ Test what happens when a typing stream resets. @@ -87,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase): sends the proper position and RDATA). """ typing = self.hs.get_typing_handler() + assert isinstance(typing, TypingWriterHandler) self.reconnect() @@ -98,8 +101,8 @@ class TypingStreamTestCase(BaseStreamTestCase): request = self.handle_http_replication_attempt() self.assert_request_is_get_repl_stream_updates(request, "typing") - self.test_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.mock_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row: TypingStream.TypingStreamRow = rdata_rows[0] @@ -134,15 +137,15 @@ class TypingStreamTestCase(BaseStreamTestCase): self.assert_request_is_get_repl_stream_updates(request, "typing") # Reset the test code. - self.test_handler.on_rdata.reset_mock() - self.test_handler.on_rdata.assert_not_called() + self.mock_handler.on_rdata.reset_mock() + self.mock_handler.on_rdata.assert_not_called() # Push additional data. typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False) self.reactor.advance(0) - self.test_handler.on_rdata.assert_called_once() - stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0] + self.mock_handler.on_rdata.assert_called_once() + stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0] self.assertEqual(stream_name, "typing") self.assertEqual(1, len(rdata_rows)) row = rdata_rows[0] |