summary refs log tree commit diff
path: root/tests/replication/tcp/streams
diff options
context:
space:
mode:
Diffstat (limited to 'tests/replication/tcp/streams')
-rw-r--r--tests/replication/tcp/streams/test_account_data.py4
-rw-r--r--tests/replication/tcp/streams/test_events.py26
-rw-r--r--tests/replication/tcp/streams/test_federation.py2
-rw-r--r--tests/replication/tcp/streams/test_partial_state.py4
-rw-r--r--tests/replication/tcp/streams/test_typing.py37
5 files changed, 40 insertions, 33 deletions
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 50fbff5f32..01df1be047 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
 
 
 class AccountDataStreamTestCase(BaseStreamTestCase):
-    def test_update_function_room_account_data_limit(self):
+    def test_update_function_room_account_data_limit(self) -> None:
         """Test replication with many room account data updates"""
         store = self.hs.get_datastores().main
 
@@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_global_account_data_limit(self):
+    def test_update_function_global_account_data_limit(self) -> None:
         """Test replication with many global account data updates"""
         store = self.hs.get_datastores().main
 
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 641a94133b..65ef4bb160 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List, Optional
+from typing import Any, List, Optional, Sequence
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
@@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
 )
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.replication._base import BaseStreamTestCase
 from tests.test_utils.event_injection import inject_event, inject_member_event
@@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
         self.user_id = self.register_user("u1", "pass")
         self.user_tok = self.login("u1", "pass")
@@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.room_id = self.helper.create_room_as(tok=self.user_tok)
         self.test_handler.received_rdata_rows.clear()
 
-    def test_update_function_event_row_limit(self):
+    def test_update_function_event_row_limit(self) -> None:
         """Test replication with many non-state events
 
         Checks that all events are correctly replicated when there are lots of
@@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_huge_state_change(self):
+    def test_update_function_huge_state_change(self) -> None:
         """Test replication with many state events
 
         Ensures that all events are correctly replicated when there are lots of
@@ -135,7 +139,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: List[str] = self.get_success(
+        fork_point: Sequence[str] = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -164,7 +168,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         pl_event = self.get_success(
             inject_event(
                 self.hs,
-                prev_event_ids=prev_events,
+                prev_event_ids=list(prev_events),
                 type=EventTypes.PowerLevels,
                 state_key="",
                 sender=self.user_id,
@@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             # "None" indicates the state has been deleted
             self.assertIsNone(sr.event_id)
 
-    def test_update_function_state_row_limit(self):
+    def test_update_function_state_row_limit(self) -> None:
         """Test replication with many state events over several stream ids."""
 
         # we want to generate lots of state changes, but for this test, we want to
@@ -290,7 +294,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         )
 
         # this is the point in the DAG where we make a fork
-        fork_point: List[str] = self.get_success(
+        fork_point: Sequence[str] = self.get_success(
             self.hs.get_datastores().main.get_latest_event_ids_in_room(self.room_id)
         )
 
@@ -319,7 +323,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             e = self.get_success(
                 inject_event(
                     self.hs,
-                    prev_event_ids=prev_events,
+                    prev_event_ids=list(prev_events),
                     type=EventTypes.PowerLevels,
                     state_key="",
                     sender=self.user_id,
@@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_backwards_stream_id(self):
+    def test_backwards_stream_id(self) -> None:
         """
         Test that RDATA that comes after the current position should be discarded.
         """
@@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
     event_count = 0
 
     def _inject_test_event(
-        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
+        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
     ) -> EventBase:
         if sender is None:
             sender = self.user_id
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index bcb82c9c80..cdbdfaf057 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
         config["federation_sender_instances"] = ["federation_sender1"]
         return config
 
-    def test_catchup(self):
+    def test_catchup(self) -> None:
         """Basic test of catchup on reconnect
 
         Makes sure that updates sent while we are offline are received later.
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 2c10eab4db..452ac85069 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
     hijack_auth = True
     user_id = "@bob:test"
 
-    def setUp(self):
+    def setUp(self) -> None:
         super().setUp()
         self.store = self.hs.get_datastores().main
 
@@ -37,7 +37,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
         room_id = self.helper.create_room_as("@bob:test")
         # Mark the room as partial-stated.
         self.get_success(
-            self.store.store_partial_state_room(room_id, ["serv1", "serv2"], 0, "serv1")
+            self.store.store_partial_state_room(room_id, {"serv1", "serv2"}, 0, "serv1")
         )
 
         worker = self.make_worker_hs("synapse.app.generic_worker")
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 9a229dd23f..5a38ac831f 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 from unittest.mock import Mock
 
-from synapse.handlers.typing import RoomMember
+from synapse.handlers.typing import RoomMember, TypingWriterHandler
 from synapse.replication.tcp.streams import TypingStream
 from synapse.util.caches.stream_change_cache import StreamChangeCache
 
@@ -27,11 +27,13 @@ ROOM_ID_2 = "!foo:blue"
 
 
 class TypingStreamTestCase(BaseStreamTestCase):
-    def _build_replication_data_handler(self):
-        return Mock(wraps=super()._build_replication_data_handler())
+    def _build_replication_data_handler(self) -> Mock:
+        self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
+        return self.mock_handler
 
-    def test_typing(self):
+    def test_typing(self) -> None:
         typing = self.hs.get_typing_handler()
+        assert isinstance(typing, TypingWriterHandler)
 
         self.reconnect()
 
@@ -43,8 +45,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
         request = self.handle_http_replication_attempt()
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -54,11 +56,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
         # Now let's disconnect and insert some data.
         self.disconnect()
 
-        self.test_handler.on_rdata.reset_mock()
+        self.mock_handler.on_rdata.reset_mock()
 
         typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
 
-        self.test_handler.on_rdata.assert_not_called()
+        self.mock_handler.on_rdata.assert_not_called()
 
         self.reconnect()
         self.pump(0.1)
@@ -71,15 +73,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
         assert request.args is not None
         self.assertEqual(int(request.args[b"from_token"][0]), token)
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([], row.user_ids)
 
-    def test_reset(self):
+    def test_reset(self) -> None:
         """
         Test what happens when a typing stream resets.
 
@@ -87,6 +89,7 @@ class TypingStreamTestCase(BaseStreamTestCase):
         sends the proper position and RDATA).
         """
         typing = self.hs.get_typing_handler()
+        assert isinstance(typing, TypingWriterHandler)
 
         self.reconnect()
 
@@ -98,8 +101,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
         request = self.handle_http_replication_attempt()
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -134,15 +137,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
         # Reset the test code.
-        self.test_handler.on_rdata.reset_mock()
-        self.test_handler.on_rdata.assert_not_called()
+        self.mock_handler.on_rdata.reset_mock()
+        self.mock_handler.on_rdata.assert_not_called()
 
         # Push additional data.
         typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
         self.reactor.advance(0)
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]