summary refs log tree commit diff
path: root/tests/federation
diff options
context:
space:
mode:
Diffstat (limited to 'tests/federation')
-rw-r--r--tests/federation/test_complexity.py33
-rw-r--r--tests/federation/test_federation_catch_up.py8
-rw-r--r--tests/federation/test_federation_sender.py42
3 files changed, 40 insertions, 43 deletions
diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 129d7cfd93..5b58fb13b5 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from synapse.api.errors import Codes, SynapseError
 from synapse.rest import admin
@@ -20,7 +20,6 @@ from synapse.rest.client import login, room
 from synapse.types import JsonDict, UserID, create_requester
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -75,9 +74,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[assignment]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[assignment]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
@@ -106,9 +105,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[assignment]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[assignment]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
@@ -143,9 +142,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value=None)  # type: ignore[assignment]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[assignment]
+            return_value=("", 1)
         )
 
         # Artificially raise the complexity
@@ -200,9 +199,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[assignment]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[assignment]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
@@ -230,9 +229,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
         fed_transport = self.hs.get_federation_transport_client()
 
         # Mock out some things, because we don't want to test the whole join
-        fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999}))  # type: ignore[assignment]
-        handler.federation_handler.do_invite_join = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(("", 1))
+        fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999})  # type: ignore[assignment]
+        handler.federation_handler.do_invite_join = AsyncMock(  # type: ignore[assignment]
+            return_value=("", 1)
         )
 
         d = handler._remote_join(
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index b290b020a2..40318aa1b6 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,6 +1,6 @@
 from typing import Callable, Collection, List, Optional, Tuple
 from unittest import mock
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -19,7 +19,7 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 from synapse.util.retryutils import NotRetryingDestination
 
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
 from tests.unittest import FederatingHomeserverTestCase
 
 
@@ -50,8 +50,8 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
         # This mock is crucial for destination_rooms to be populated.
         # TODO: this seems to no longer be the case---tests pass with this mock
         # commented out.
-        state_storage_controller.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"test", "host2"})
+        state_storage_controller.get_current_hosts_in_room = AsyncMock(  # type: ignore[assignment]
+            return_value={"test", "host2"}
         )
 
         # whenever send_transaction is called, record the pdu data
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 9e104fd96a..5ea4a75a9f 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Callable, FrozenSet, List, Optional, Set
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from signedjson import key, sign
 from signedjson.types import BaseKey, SigningKey
@@ -29,7 +29,6 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict, ReadReceipt
 from synapse.util import Clock
 
-from tests.test_utils import make_awaitable
 from tests.unittest import HomeserverTestCase
 
 
@@ -43,12 +42,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.federation_transport_client = Mock(spec=["send_transaction"])
+        self.federation_transport_client.send_transaction = AsyncMock()
         hs = self.setup_test_homeserver(
             federation_transport_client=self.federation_transport_client,
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"test", "host2"})
+        hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock(  # type: ignore[assignment]
+            return_value={"test", "host2"}
         )
 
         hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (  # type: ignore[assignment]
@@ -64,7 +64,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def test_send_receipts(self) -> None:
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
@@ -104,7 +104,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def test_send_receipts_thread(self) -> None:
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         # Create receipts for:
         #
@@ -180,7 +180,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         """Send two receipts in quick succession; the second should be flushed, but
         only after 20ms"""
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
@@ -276,6 +276,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.federation_transport_client = Mock(
             spec=["send_transaction", "query_user_devices"]
         )
+        self.federation_transport_client.send_transaction = AsyncMock()
+        self.federation_transport_client.query_user_devices = AsyncMock()
         return self.setup_test_homeserver(
             federation_transport_client=self.federation_transport_client,
         )
@@ -317,13 +319,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.record_transaction
         )
 
-    def record_transaction(
+    async def record_transaction(
         self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
-    ) -> "defer.Deferred[JsonDict]":
+    ) -> JsonDict:
         assert json_cb is not None
         data = json_cb()
         self.edus.extend(data["edus"])
-        return defer.succeed({})
+        return {}
 
     def test_send_device_updates(self) -> None:
         """Basic case: each device update should result in an EDU"""
@@ -354,15 +356,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # Send the server a device list EDU for the other user, this will cause
         # it to try and resync the device lists.
-        self.federation_transport_client.query_user_devices.return_value = (
-            make_awaitable(
-                {
-                    "stream_id": "1",
-                    "user_id": "@user2:host2",
-                    "devices": [{"device_id": "D1"}],
-                }
-            )
-        )
+        self.federation_transport_client.query_user_devices.return_value = {
+            "stream_id": "1",
+            "user_id": "@user2:host2",
+            "devices": [{"device_id": "D1"}],
+        }
 
         self.get_success(
             self.device_handler.device_list_updater.incoming_device_list_update(
@@ -533,7 +531,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         recovery
         """
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         # create devices
         u1 = self.register_user("user", "pass")
@@ -578,7 +576,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         This case tests the behaviour when the server has never been reachable.
         """
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         # create devices
         u1 = self.register_user("user", "pass")
@@ -636,7 +634,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # now the server goes offline
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         self.login("user", "pass", device_id="D2")
         self.login("user", "pass", device_id="D3")