summary refs log tree commit diff
path: root/tests/federation/test_federation_sender.py
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-08-24 19:38:46 -0400
committerGitHub <noreply@github.com>2023-08-24 19:38:46 -0400
commitdaf11e26efc210dccaef029422431a7d2803dd8a (patch)
tree7e42649714a38217d25eecf0a089592632dc989c /tests/federation/test_federation_sender.py
parent Document `exclude_rooms_fom_sync` configuration option (#16178) (diff)
downloadsynapse-daf11e26efc210dccaef029422431a7d2803dd8a.tar.xz
Replace make_awaitable with AsyncMock (#16179)
Python 3.8 provides a native AsyncMock, we can replace the
homegrown version we have.
Diffstat (limited to 'tests/federation/test_federation_sender.py')
-rw-r--r--tests/federation/test_federation_sender.py42
1 files changed, 20 insertions, 22 deletions
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 9e104fd96a..5ea4a75a9f 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Callable, FrozenSet, List, Optional, Set
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from signedjson import key, sign
 from signedjson.types import BaseKey, SigningKey
@@ -29,7 +29,6 @@ from synapse.server import HomeServer
 from synapse.types import JsonDict, ReadReceipt
 from synapse.util import Clock
 
-from tests.test_utils import make_awaitable
 from tests.unittest import HomeserverTestCase
 
 
@@ -43,12 +42,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.federation_transport_client = Mock(spec=["send_transaction"])
+        self.federation_transport_client.send_transaction = AsyncMock()
         hs = self.setup_test_homeserver(
             federation_transport_client=self.federation_transport_client,
         )
 
-        hs.get_storage_controllers().state.get_current_hosts_in_room = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable({"test", "host2"})
+        hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock(  # type: ignore[assignment]
+            return_value={"test", "host2"}
         )
 
         hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = (  # type: ignore[assignment]
@@ -64,7 +64,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def test_send_receipts(self) -> None:
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
@@ -104,7 +104,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
 
     def test_send_receipts_thread(self) -> None:
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         # Create receipts for:
         #
@@ -180,7 +180,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
         """Send two receipts in quick succession; the second should be flushed, but
         only after 20ms"""
         mock_send_transaction = self.federation_transport_client.send_transaction
-        mock_send_transaction.return_value = make_awaitable({})
+        mock_send_transaction.return_value = {}
 
         sender = self.hs.get_federation_sender()
         receipt = ReadReceipt(
@@ -276,6 +276,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         self.federation_transport_client = Mock(
             spec=["send_transaction", "query_user_devices"]
         )
+        self.federation_transport_client.send_transaction = AsyncMock()
+        self.federation_transport_client.query_user_devices = AsyncMock()
         return self.setup_test_homeserver(
             federation_transport_client=self.federation_transport_client,
         )
@@ -317,13 +319,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
             self.record_transaction
         )
 
-    def record_transaction(
+    async def record_transaction(
         self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
-    ) -> "defer.Deferred[JsonDict]":
+    ) -> JsonDict:
         assert json_cb is not None
         data = json_cb()
         self.edus.extend(data["edus"])
-        return defer.succeed({})
+        return {}
 
     def test_send_device_updates(self) -> None:
         """Basic case: each device update should result in an EDU"""
@@ -354,15 +356,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # Send the server a device list EDU for the other user, this will cause
         # it to try and resync the device lists.
-        self.federation_transport_client.query_user_devices.return_value = (
-            make_awaitable(
-                {
-                    "stream_id": "1",
-                    "user_id": "@user2:host2",
-                    "devices": [{"device_id": "D1"}],
-                }
-            )
-        )
+        self.federation_transport_client.query_user_devices.return_value = {
+            "stream_id": "1",
+            "user_id": "@user2:host2",
+            "devices": [{"device_id": "D1"}],
+        }
 
         self.get_success(
             self.device_handler.device_list_updater.incoming_device_list_update(
@@ -533,7 +531,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         recovery
         """
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         # create devices
         u1 = self.register_user("user", "pass")
@@ -578,7 +576,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
         This case tests the behaviour when the server has never been reachable.
         """
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         # create devices
         u1 = self.register_user("user", "pass")
@@ -636,7 +634,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
 
         # now the server goes offline
         mock_send_txn = self.federation_transport_client.send_transaction
-        mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+        mock_send_txn.side_effect = AssertionError("fail")
 
         self.login("user", "pass", device_id="D2")
         self.login("user", "pass", device_id="D3")