diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 9e104fd96a..5ea4a75a9f 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, FrozenSet, List, Optional, Set
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
@@ -29,7 +29,6 @@ from synapse.server import HomeServer
from synapse.types import JsonDict, ReadReceipt
from synapse.util import Clock
-from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase
@@ -43,12 +42,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(spec=["send_transaction"])
+ self.federation_transport_client.send_transaction = AsyncMock()
hs = self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client,
)
- hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
- return_value=make_awaitable({"test", "host2"})
+ hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[assignment]
+ return_value={"test", "host2"}
)
hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment]
@@ -64,7 +64,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts(self) -> None:
mock_send_transaction = self.federation_transport_client.send_transaction
- mock_send_transaction.return_value = make_awaitable({})
+ mock_send_transaction.return_value = {}
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
@@ -104,7 +104,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_thread(self) -> None:
mock_send_transaction = self.federation_transport_client.send_transaction
- mock_send_transaction.return_value = make_awaitable({})
+ mock_send_transaction.return_value = {}
# Create receipts for:
#
@@ -180,7 +180,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
mock_send_transaction = self.federation_transport_client.send_transaction
- mock_send_transaction.return_value = make_awaitable({})
+ mock_send_transaction.return_value = {}
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
@@ -276,6 +276,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.federation_transport_client = Mock(
spec=["send_transaction", "query_user_devices"]
)
+ self.federation_transport_client.send_transaction = AsyncMock()
+ self.federation_transport_client.query_user_devices = AsyncMock()
return self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client,
)
@@ -317,13 +319,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.record_transaction
)
- def record_transaction(
+ async def record_transaction(
self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
- ) -> "defer.Deferred[JsonDict]":
+ ) -> JsonDict:
assert json_cb is not None
data = json_cb()
self.edus.extend(data["edus"])
- return defer.succeed({})
+ return {}
def test_send_device_updates(self) -> None:
"""Basic case: each device update should result in an EDU"""
@@ -354,15 +356,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
- self.federation_transport_client.query_user_devices.return_value = (
- make_awaitable(
- {
- "stream_id": "1",
- "user_id": "@user2:host2",
- "devices": [{"device_id": "D1"}],
- }
- )
- )
+ self.federation_transport_client.query_user_devices.return_value = {
+ "stream_id": "1",
+ "user_id": "@user2:host2",
+ "devices": [{"device_id": "D1"}],
+ }
self.get_success(
self.device_handler.device_list_updater.incoming_device_list_update(
@@ -533,7 +531,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
recovery
"""
mock_send_txn = self.federation_transport_client.send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+ mock_send_txn.side_effect = AssertionError("fail")
# create devices
u1 = self.register_user("user", "pass")
@@ -578,7 +576,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable.
"""
mock_send_txn = self.federation_transport_client.send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+ mock_send_txn.side_effect = AssertionError("fail")
# create devices
u1 = self.register_user("user", "pass")
@@ -636,7 +634,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# now the server goes offline
mock_send_txn = self.federation_transport_client.send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+ mock_send_txn.side_effect = AssertionError("fail")
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
|