summary refs log tree commit diff
path: root/tests/federation/test_federation_sender.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/federation/test_federation_sender.py')
-rw-r--r--tests/federation/test_federation_sender.py74
1 files changed, 41 insertions, 33 deletions
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 8692d8190f..ddeffe1ad5 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -11,18 +11,22 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from typing import Optional
+from typing import Callable, FrozenSet, List, Optional, Set
 from unittest.mock import Mock
 
 from signedjson import key, sign
 from signedjson.types import BaseKey, SigningKey
 
 from twisted.internet import defer
+from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EduTypes, RoomEncryptionAlgorithms
+from synapse.federation.units import Transaction
 from synapse.rest import admin
 from synapse.rest.client import login
+from synapse.server import HomeServer
 from synapse.types import JsonDict, ReadReceipt
+from synapse.util import Clock
 
 from tests.test_utils import make_awaitable
 from tests.unittest import HomeserverTestCase
@@ -36,16 +40,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
     re-enabled for the main process.
     """
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         hs = self.setup_test_homeserver(
             federation_transport_client=Mock(spec=["send_transaction"]),
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(
+        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
             return_value=make_awaitable({"test", "host2"})
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (
+        hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (  # type: ignore[assignment]
             hs.get_storage_controllers().state.get_current_hosts_in_room
         )
 
@@ -56,7 +60,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         config["federation_sender_instances"] = None
         return config
 
-    def test_send_receipts(self):
+    def test_send_receipts(self) -> None:
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
@@ -98,7 +102,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
             ],
         )
 
-    def test_send_receipts_thread(self):
+    def test_send_receipts_thread(self) -> None:
         mock_send_transaction = (
             self.hs.get_federation_transport_client().send_transaction
         )
@@ -174,7 +178,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
             ],
         )
 
-    def test_send_receipts_with_backoff(self):
+    def test_send_receipts_with_backoff(self) -> None:
         """Send two receipts in quick succession; the second should be flushed, but
         only after 20ms"""
         mock_send_transaction = (
@@ -272,51 +276,55 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         login.register_servlets,
     ]
 
-    def make_homeserver(self, reactor, clock):
+    def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         return self.setup_test_homeserver(
             federation_transport_client=Mock(
                 spec=["send_transaction", "query_user_devices"]
             ),
         )
 
-    def default_config(self):
+    def default_config(self) -> JsonDict:
         c = super().default_config()
         # Enable federation sending on the main process.
         c["federation_sender_instances"] = None
         return c
 
-    def prepare(self, reactor, clock, hs):
+    def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         test_room_id = "!room:host1"
 
         # stub out `get_rooms_for_user` and `get_current_hosts_in_room` so that the
         # server thinks the user shares a room with `@user2:host2`
-        def get_rooms_for_user(user_id):
-            return defer.succeed({test_room_id})
+        def get_rooms_for_user(user_id: str) -> "defer.Deferred[FrozenSet[str]]":
+            return defer.succeed(frozenset({test_room_id}))
 
-        hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user
+        hs.get_datastores().main.get_rooms_for_user = get_rooms_for_user  # type: ignore[assignment]
 
-        async def get_current_hosts_in_room(room_id):
+        async def get_current_hosts_in_room(room_id: str) -> Set[str]:
             if room_id == test_room_id:
-                return ["host2"]
-
-            # TODO: We should fail the test when we encounter an unxpected room ID.
-            # We can't just use `self.fail(...)` here because the app code is greedy
-            # with `Exception` and will catch it before the test can see it.
+                return {"host2"}
+            else:
+                # TODO: We should fail the test when we encounter an unxpected room ID.
+                # We can't just use `self.fail(...)` here because the app code is greedy
+                # with `Exception` and will catch it before the test can see it.
+                return set()
 
-        hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room
+        hs.get_datastores().main.get_current_hosts_in_room = get_current_hosts_in_room  # type: ignore[assignment]
 
         # whenever send_transaction is called, record the edu data
-        self.edus = []
+        self.edus: List[JsonDict] = []
         self.hs.get_federation_transport_client().send_transaction.side_effect = (
             self.record_transaction
         )
 
-    def record_transaction(self, txn, json_cb):
+    def record_transaction(
+        self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
+    ) -> "defer.Deferred[JsonDict]":
+        assert json_cb is not None
         data = json_cb()
         self.edus.extend(data["edus"])
         return defer.succeed({})
 
-    def test_send_device_updates(self):
+    def test_send_device_updates(self) -> None:
         """Basic case: each device update should result in an EDU"""
         # create a device
         u1 = self.register_user("user", "pass")
@@ -340,7 +348,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.assertEqual(len(self.edus), 1)
         self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
 
-    def test_dont_send_device_updates_for_remote_users(self):
+    def test_dont_send_device_updates_for_remote_users(self) -> None:
         """Check that we don't send device updates for remote users"""
 
         # Send the server a device list EDU for the other user, this will cause
@@ -379,7 +387,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         )
         self.assertIn("D1", devices)
 
-    def test_upload_signatures(self):
+    def test_upload_signatures(self) -> None:
         """Uploading signatures on some devices should produce updates for that user"""
 
         e2e_handler = self.hs.get_e2e_keys_handler()
@@ -391,7 +399,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # expect two edus
         self.assertEqual(len(self.edus), 2)
-        stream_id = None
+        stream_id: Optional[int] = None
         stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D1", stream_id)
         stream_id = self.check_device_update_edu(self.edus.pop(0), u1, "D2", stream_id)
 
@@ -473,13 +481,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
             c = edu["content"]
             if stream_id is not None:
-                self.assertEqual(c["prev_id"], [stream_id])
+                self.assertEqual(c["prev_id"], [stream_id])  # type: ignore[unreachable]
                 self.assertGreaterEqual(c["stream_id"], stream_id)
             stream_id = c["stream_id"]
         devices = {edu["content"]["device_id"] for edu in self.edus}
         self.assertEqual({"D1", "D2"}, devices)
 
-    def test_delete_devices(self):
+    def test_delete_devices(self) -> None:
         """If devices are deleted, that should result in EDUs too"""
 
         # create devices
@@ -521,7 +529,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         devices = {edu["content"]["device_id"] for edu in self.edus}
         self.assertEqual({"D1", "D2", "D3"}, devices)
 
-    def test_unreachable_server(self):
+    def test_unreachable_server(self) -> None:
         """If the destination server is unreachable, all the updates should get sent on
         recovery
         """
@@ -555,7 +563,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # for each device, there should be a single update
         self.assertEqual(len(self.edus), 3)
-        stream_id = None
+        stream_id: Optional[int] = None
         for edu in self.edus:
             self.assertEqual(edu["edu_type"], EduTypes.DEVICE_LIST_UPDATE)
             c = edu["content"]
@@ -566,7 +574,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         devices = {edu["content"]["device_id"] for edu in self.edus}
         self.assertEqual({"D1", "D2", "D3"}, devices)
 
-    def test_prune_outbound_device_pokes1(self):
+    def test_prune_outbound_device_pokes1(self) -> None:
         """If a destination is unreachable, and the updates are pruned, we should get
         a single update.
 
@@ -615,7 +623,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         # synapse uses an empty prev_id list to indicate "needs a full resync".
         self.assertEqual(c["prev_id"], [])
 
-    def test_prune_outbound_device_pokes2(self):
+    def test_prune_outbound_device_pokes2(self) -> None:
         """If a destination is unreachable, and the updates are pruned, we should get
         a single update.
 
@@ -741,7 +749,7 @@ def encode_pubkey(sk: SigningKey) -> str:
     return key.encode_verify_key_base64(key.get_verify_key(sk))
 
 
-def build_device_dict(user_id: str, device_id: str, sk: SigningKey):
+def build_device_dict(user_id: str, device_id: str, sk: SigningKey) -> JsonDict:
     """Build a dict representing the given device"""
     return {
         "user_id": user_id,