summary refs log tree commit diff
path: root/tests/federation/test_federation_catch_up.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/federation/test_federation_catch_up.py')
-rw-r--r--tests/federation/test_federation_catch_up.py52
1 files changed, 31 insertions, 21 deletions
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index b8fee72898..a986b15f0a 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,13 +1,17 @@
-from typing import List, Tuple
+from typing import Callable, List, Optional, Tuple
 from unittest.mock import Mock
 
+from twisted.test.proto_helpers import MemoryReactor
+
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
 from synapse.federation.sender import PerDestinationQueue, TransactionManager
-from synapse.federation.units import Edu
+from synapse.federation.units import Edu, Transaction
 from synapse.rest import admin
 from synapse.rest.client import login, room
+from synapse.server import HomeServer
 from synapse.types import JsonDict
+from synapse.util import Clock
 from synapse.util.retryutils import NotRetryingDestination
 
 from tests.test_utils import event_injection, make_awaitable
@@ -28,23 +32,25 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         return self.setup_test_homeserver(
             federation_transport_client=Mock(spec=["send_transaction"]),
         )
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         # stub out get_current_hosts_in_room
-        state_handler = hs.get_state_handler()
+        state_storage_controller = hs.get_storage_controllers().state
 
         # This mock is crucial for destination_rooms to be populated.
-        state_handler.get_current_hosts_in_room = Mock(
-            return_value=make_awaitable(["test", "host2"])
+        # TODO: this seems to no longer be the case---tests pass with this mock
+        # commented out.
+        state_storage_controller.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
+            return_value=make_awaitable({"test", "host2"})
         )
 
         # whenever send_transaction is called, record the pdu data
-        self.pdus = []
-        self.failed_pdus = []
+        self.pdus: List[JsonDict] = []
+        self.failed_pdus: List[JsonDict] = []
         self.is_online = True
         self.hs.get_federation_transport_client().send_transaction.side_effect = (
             self.record_transaction
@@ -55,8 +61,13 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         config["federation_sender_instances"] = None
         return config
 
-    async def record_transaction(self, txn, json_cb):
-        if self.is_online:
+    async def record_transaction(
+        self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]]
+    ) -> JsonDict:
+        if json_cb is None:
+            # The tests seem to expect that this method raises in this situation.
+            raise Exception("Blank json_cb")
+        elif self.is_online:
             data = json_cb()
             self.pdus.extend(data["pdus"])
             return {}
@@ -92,7 +103,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         )[0]
         return {"event_id": event_id, "stream_ordering": stream_ordering}
 
-    def test_catch_up_destination_rooms_tracking(self):
+    def test_catch_up_destination_rooms_tracking(self) -> None:
         """
         Tests that we populate the `destination_rooms` table as needed.
         """
@@ -117,7 +128,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         self.assertEqual(row_2["event_id"], event_id_2)
         self.assertEqual(row_1["stream_ordering"], row_2["stream_ordering"] - 1)
 
-    def test_catch_up_last_successful_stream_ordering_tracking(self):
+    def test_catch_up_last_successful_stream_ordering_tracking(self) -> None:
         """
         Tests that we populate the `destination_rooms` table as needed.
         """
@@ -174,7 +185,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
             "Send succeeded but not marked as last_successful_stream_ordering",
         )
 
-    def test_catch_up_from_blank_state(self):
+    def test_catch_up_from_blank_state(self) -> None:
         """
         Runs an overall test of federation catch-up from scratch.
         Further tests will focus on more narrow aspects and edge-cases, but I
@@ -261,16 +272,15 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
             destination_tm: str,
             pending_pdus: List[EventBase],
             _pending_edus: List[Edu],
-        ) -> bool:
+        ) -> None:
             assert destination == destination_tm
             results_list.extend(pending_pdus)
-            return True  # success!
 
-        transaction_manager.send_new_transaction = fake_send
+        transaction_manager.send_new_transaction = fake_send  # type: ignore[assignment]
 
         return per_dest_queue, results_list
 
-    def test_catch_up_loop(self):
+    def test_catch_up_loop(self) -> None:
         """
         Tests the behaviour of _catch_up_transmission_loop.
         """
@@ -334,7 +344,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
             event_5.internal_metadata.stream_ordering,
         )
 
-    def test_catch_up_on_synapse_startup(self):
+    def test_catch_up_on_synapse_startup(self) -> None:
         """
         Tests the behaviour of get_catch_up_outstanding_destinations and
             _wake_destinations_needing_catchup.
@@ -412,7 +422,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # patch wake_destination to just count the destinations instead
         woken = []
 
-        def wake_destination_track(destination):
+        def wake_destination_track(destination: str) -> None:
             woken.append(destination)
 
         self.hs.get_federation_sender().wake_destination = wake_destination_track
@@ -432,7 +442,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # - all destinations are woken exactly once; they appear once in woken.
         self.assertCountEqual(woken, server_names[:-1])
 
-    def test_not_latest_event(self):
+    def test_not_latest_event(self) -> None:
         """Test that we send the latest event in the room even if its not ours."""
 
         per_dest_queue, sent_pdus = self.make_fake_destination_queue()