summary refs log tree commit diff
path: root/tests/replication/tcp
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-02-06 09:55:00 -0500
committerGitHub <noreply@github.com>2023-02-06 09:55:00 -0500
commit156cd88eefe7db100e5cdba48174c709975b93ca (patch)
treebf4059f81c6ba16439ef6dfa19a4e016057da20d /tests/replication/tcp
parentExpect type stubs from canonicaljson (#14992) (diff)
downloadsynapse-156cd88eefe7db100e5cdba48174c709975b93ca.tar.xz
Add missing type hints to tests.replication. (#14987)
Diffstat (limited to 'tests/replication/tcp')
-rw-r--r--tests/replication/tcp/streams/test_account_data.py4
-rw-r--r--tests/replication/tcp/streams/test_events.py18
-rw-r--r--tests/replication/tcp/streams/test_federation.py2
-rw-r--r--tests/replication/tcp/streams/test_partial_state.py2
-rw-r--r--tests/replication/tcp/streams/test_typing.py33
-rw-r--r--tests/replication/tcp/test_commands.py6
-rw-r--r--tests/replication/tcp/test_remote_server_up.py8
7 files changed, 40 insertions, 33 deletions
diff --git a/tests/replication/tcp/streams/test_account_data.py b/tests/replication/tcp/streams/test_account_data.py
index 50fbff5f32..01df1be047 100644
--- a/tests/replication/tcp/streams/test_account_data.py
+++ b/tests/replication/tcp/streams/test_account_data.py
@@ -21,7 +21,7 @@ from tests.replication._base import BaseStreamTestCase
 
 
 class AccountDataStreamTestCase(BaseStreamTestCase):
-    def test_update_function_room_account_data_limit(self):
+    def test_update_function_room_account_data_limit(self) -> None:
         """Test replication with many room account data updates"""
         store = self.hs.get_datastores().main
 
@@ -67,7 +67,7 @@ class AccountDataStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_global_account_data_limit(self):
+    def test_update_function_global_account_data_limit(self) -> None:
         """Test replication with many global account data updates"""
         store = self.hs.get_datastores().main
 
diff --git a/tests/replication/tcp/streams/test_events.py b/tests/replication/tcp/streams/test_events.py
index 641a94133b..043dbe76af 100644
--- a/tests/replication/tcp/streams/test_events.py
+++ b/tests/replication/tcp/streams/test_events.py
@@ -12,7 +12,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import List, Optional
+from typing import Any, List, Optional
+
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
@@ -25,6 +27,8 @@ from synapse.replication.tcp.streams.events import (
 )
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.replication._base import BaseStreamTestCase
 from tests.test_utils.event_injection import inject_event, inject_member_event
@@ -37,7 +41,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         room.register_servlets,
     ]
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         super().prepare(reactor, clock, hs)
         self.user_id = self.register_user("u1", "pass")
         self.user_tok = self.login("u1", "pass")
@@ -47,7 +51,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
         self.room_id = self.helper.create_room_as(tok=self.user_tok)
         self.test_handler.received_rdata_rows.clear()
 
-    def test_update_function_event_row_limit(self):
+    def test_update_function_event_row_limit(self) -> None:
         """Test replication with many non-state events
 
         Checks that all events are correctly replicated when there are lots of
@@ -102,7 +106,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_update_function_huge_state_change(self):
+    def test_update_function_huge_state_change(self) -> None:
         """Test replication with many state events
 
         Ensures that all events are correctly replicated when there are lots of
@@ -256,7 +260,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
             # "None" indicates the state has been deleted
             self.assertIsNone(sr.event_id)
 
-    def test_update_function_state_row_limit(self):
+    def test_update_function_state_row_limit(self) -> None:
         """Test replication with many state events over several stream ids."""
 
         # we want to generate lots of state changes, but for this test, we want to
@@ -376,7 +380,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
 
         self.assertEqual([], received_rows)
 
-    def test_backwards_stream_id(self):
+    def test_backwards_stream_id(self) -> None:
         """
         Test that RDATA that comes after the current position should be discarded.
         """
@@ -437,7 +441,7 @@ class EventsStreamTestCase(BaseStreamTestCase):
     event_count = 0
 
     def _inject_test_event(
-        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs
+        self, body: Optional[str] = None, sender: Optional[str] = None, **kwargs: Any
     ) -> EventBase:
         if sender is None:
             sender = self.user_id
diff --git a/tests/replication/tcp/streams/test_federation.py b/tests/replication/tcp/streams/test_federation.py
index bcb82c9c80..cdbdfaf057 100644
--- a/tests/replication/tcp/streams/test_federation.py
+++ b/tests/replication/tcp/streams/test_federation.py
@@ -26,7 +26,7 @@ class FederationStreamTestCase(BaseStreamTestCase):
         config["federation_sender_instances"] = ["federation_sender1"]
         return config
 
-    def test_catchup(self):
+    def test_catchup(self) -> None:
         """Basic test of catchup on reconnect
 
         Makes sure that updates sent while we are offline are received later.
diff --git a/tests/replication/tcp/streams/test_partial_state.py b/tests/replication/tcp/streams/test_partial_state.py
index 2c10eab4db..38b5020ce0 100644
--- a/tests/replication/tcp/streams/test_partial_state.py
+++ b/tests/replication/tcp/streams/test_partial_state.py
@@ -23,7 +23,7 @@ class PartialStateStreamsTestCase(BaseMultiWorkerStreamTestCase):
     hijack_auth = True
     user_id = "@bob:test"
 
-    def setUp(self):
+    def setUp(self) -> None:
         super().setUp()
         self.store = self.hs.get_datastores().main
 
diff --git a/tests/replication/tcp/streams/test_typing.py b/tests/replication/tcp/streams/test_typing.py
index 9a229dd23f..68de5d1cc2 100644
--- a/tests/replication/tcp/streams/test_typing.py
+++ b/tests/replication/tcp/streams/test_typing.py
@@ -27,10 +27,11 @@ ROOM_ID_2 = "!foo:blue"
 
 
 class TypingStreamTestCase(BaseStreamTestCase):
-    def _build_replication_data_handler(self):
-        return Mock(wraps=super()._build_replication_data_handler())
+    def _build_replication_data_handler(self) -> Mock:
+        self.mock_handler = Mock(wraps=super()._build_replication_data_handler())
+        return self.mock_handler
 
-    def test_typing(self):
+    def test_typing(self) -> None:
         typing = self.hs.get_typing_handler()
 
         self.reconnect()
@@ -43,8 +44,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
         request = self.handle_http_replication_attempt()
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -54,11 +55,11 @@ class TypingStreamTestCase(BaseStreamTestCase):
         # Now let's disconnect and insert some data.
         self.disconnect()
 
-        self.test_handler.on_rdata.reset_mock()
+        self.mock_handler.on_rdata.reset_mock()
 
         typing._push_update(member=RoomMember(ROOM_ID, USER_ID), typing=False)
 
-        self.test_handler.on_rdata.assert_not_called()
+        self.mock_handler.on_rdata.assert_not_called()
 
         self.reconnect()
         self.pump(0.1)
@@ -71,15 +72,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
         assert request.args is not None
         self.assertEqual(int(request.args[b"from_token"][0]), token)
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
         self.assertEqual(ROOM_ID, row.room_id)
         self.assertEqual([], row.user_ids)
 
-    def test_reset(self):
+    def test_reset(self) -> None:
         """
         Test what happens when a typing stream resets.
 
@@ -98,8 +99,8 @@ class TypingStreamTestCase(BaseStreamTestCase):
         request = self.handle_http_replication_attempt()
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row: TypingStream.TypingStreamRow = rdata_rows[0]
@@ -134,15 +135,15 @@ class TypingStreamTestCase(BaseStreamTestCase):
         self.assert_request_is_get_repl_stream_updates(request, "typing")
 
         # Reset the test code.
-        self.test_handler.on_rdata.reset_mock()
-        self.test_handler.on_rdata.assert_not_called()
+        self.mock_handler.on_rdata.reset_mock()
+        self.mock_handler.on_rdata.assert_not_called()
 
         # Push additional data.
         typing._push_update(member=RoomMember(ROOM_ID_2, USER_ID_2), typing=False)
         self.reactor.advance(0)
 
-        self.test_handler.on_rdata.assert_called_once()
-        stream_name, _, token, rdata_rows = self.test_handler.on_rdata.call_args[0]
+        self.mock_handler.on_rdata.assert_called_once()
+        stream_name, _, token, rdata_rows = self.mock_handler.on_rdata.call_args[0]
         self.assertEqual(stream_name, "typing")
         self.assertEqual(1, len(rdata_rows))
         row = rdata_rows[0]
diff --git a/tests/replication/tcp/test_commands.py b/tests/replication/tcp/test_commands.py
index cca7ebb719..5d6b72b16d 100644
--- a/tests/replication/tcp/test_commands.py
+++ b/tests/replication/tcp/test_commands.py
@@ -21,12 +21,12 @@ from tests.unittest import TestCase
 
 
 class ParseCommandTestCase(TestCase):
-    def test_parse_one_word_command(self):
+    def test_parse_one_word_command(self) -> None:
         line = "REPLICATE"
         cmd = parse_command_from_line(line)
         self.assertIsInstance(cmd, ReplicateCommand)
 
-    def test_parse_rdata(self):
+    def test_parse_rdata(self) -> None:
         line = 'RDATA events master 6287863 ["ev", ["$eventid", "!roomid", "type", null, null, null]]'
         cmd = parse_command_from_line(line)
         assert isinstance(cmd, RdataCommand)
@@ -34,7 +34,7 @@ class ParseCommandTestCase(TestCase):
         self.assertEqual(cmd.instance_name, "master")
         self.assertEqual(cmd.token, 6287863)
 
-    def test_parse_rdata_batch(self):
+    def test_parse_rdata_batch(self) -> None:
         line = 'RDATA presence master batch ["@foo:example.com", "online"]'
         cmd = parse_command_from_line(line)
         assert isinstance(cmd, RdataCommand)
diff --git a/tests/replication/tcp/test_remote_server_up.py b/tests/replication/tcp/test_remote_server_up.py
index 545f11acd1..b75fc05fd5 100644
--- a/tests/replication/tcp/test_remote_server_up.py
+++ b/tests/replication/tcp/test_remote_server_up.py
@@ -16,15 +16,17 @@ from typing import Tuple
 
 from twisted.internet.address import IPv4Address
 from twisted.internet.interfaces import IProtocol
-from twisted.test.proto_helpers import StringTransport
+from twisted.test.proto_helpers import MemoryReactor, StringTransport
 
 from synapse.replication.tcp.resource import ReplicationStreamProtocolFactory
+from synapse.server import HomeServer
+from synapse.util import Clock
 
 from tests.unittest import HomeserverTestCase
 
 
 class RemoteServerUpTestCase(HomeserverTestCase):
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.factory = ReplicationStreamProtocolFactory(hs)
 
     def _make_client(self) -> Tuple[IProtocol, StringTransport]:
@@ -40,7 +42,7 @@ class RemoteServerUpTestCase(HomeserverTestCase):
 
         return proto, transport
 
-    def test_relay(self):
+    def test_relay(self) -> None:
         """Test that Synapse will relay REMOTE_SERVER_UP commands to all
         other connections, but not the one that sent it.
         """