diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py
index 129d7cfd93..73a2766baf 100644
--- a/tests/federation/test_complexity.py
+++ b/tests/federation/test_complexity.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
from synapse.api.errors import Codes, SynapseError
from synapse.rest import admin
@@ -20,7 +20,6 @@ from synapse.rest.client import login, room
from synapse.types import JsonDict, UserID, create_requester
from tests import unittest
-from tests.test_utils import make_awaitable
class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
@@ -58,7 +57,7 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
async def get_current_state_event_counts(room_id: str) -> int:
return int(500 * 1.23)
- store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
+ store.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign]
# Get the room complexity again -- make sure it's our artificial value
channel = self.make_signed_federation_request(
@@ -75,9 +74,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
- handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
- return_value=make_awaitable(("", 1))
+ fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign]
+ handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign]
+ return_value=("", 1)
)
d = handler._remote_join(
@@ -106,9 +105,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
- handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
- return_value=make_awaitable(("", 1))
+ fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign]
+ handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign]
+ return_value=("", 1)
)
d = handler._remote_join(
@@ -143,16 +142,16 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
- handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
- return_value=make_awaitable(("", 1))
+ fed_transport.client.get_json = AsyncMock(return_value=None) # type: ignore[method-assign]
+ handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign]
+ return_value=("", 1)
)
# Artificially raise the complexity
async def get_current_state_event_counts(room_id: str) -> int:
return 600
- self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[assignment]
+ self.hs.get_datastores().main.get_current_state_event_counts = get_current_state_event_counts # type: ignore[method-assign]
d = handler._remote_join(
create_requester(u1),
@@ -200,9 +199,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
- handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
- return_value=make_awaitable(("", 1))
+ fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign]
+ handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign]
+ return_value=("", 1)
)
d = handler._remote_join(
@@ -230,9 +229,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase):
fed_transport = self.hs.get_federation_transport_client()
# Mock out some things, because we don't want to test the whole join
- fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) # type: ignore[assignment]
- handler.federation_handler.do_invite_join = Mock( # type: ignore[assignment]
- return_value=make_awaitable(("", 1))
+ fed_transport.client.get_json = AsyncMock(return_value={"v1": 9999}) # type: ignore[method-assign]
+ handler.federation_handler.do_invite_join = AsyncMock( # type: ignore[method-assign]
+ return_value=("", 1)
)
d = handler._remote_join(
diff --git a/tests/federation/test_federation_catch_up.py b/tests/federation/test_federation_catch_up.py
index b290b020a2..75ae740b43 100644
--- a/tests/federation/test_federation_catch_up.py
+++ b/tests/federation/test_federation_catch_up.py
@@ -1,6 +1,6 @@
from typing import Callable, Collection, List, Optional, Tuple
from unittest import mock
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -19,7 +19,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from synapse.util.retryutils import NotRetryingDestination
-from tests.test_utils import event_injection, make_awaitable
+from tests.test_utils import event_injection
from tests.unittest import FederatingHomeserverTestCase
@@ -50,8 +50,8 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
# This mock is crucial for destination_rooms to be populated.
# TODO: this seems to no longer be the case---tests pass with this mock
# commented out.
- state_storage_controller.get_current_hosts_in_room = Mock( # type: ignore[assignment]
- return_value=make_awaitable({"test", "host2"})
+ state_storage_controller.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign]
+ return_value={"test", "host2"}
)
# whenever send_transaction is called, record the pdu data
@@ -436,7 +436,7 @@ class FederationCatchUpTestCases(FederatingHomeserverTestCase):
def wake_destination_track(destination: str) -> None:
woken.add(destination)
- self.federation_sender.wake_destination = wake_destination_track # type: ignore[assignment]
+ self.federation_sender.wake_destination = wake_destination_track # type: ignore[method-assign]
# We wait quite long so that all dests can be woken up, since there is a delay
# between them.
diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py
index 9e104fd96a..caf04b54cb 100644
--- a/tests/federation/test_federation_sender.py
+++ b/tests/federation/test_federation_sender.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, FrozenSet, List, Optional, Set
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from signedjson import key, sign
from signedjson.types import BaseKey, SigningKey
@@ -29,7 +29,6 @@ from synapse.server import HomeServer
from synapse.types import JsonDict, ReadReceipt
from synapse.util import Clock
-from tests.test_utils import make_awaitable
from tests.unittest import HomeserverTestCase
@@ -43,15 +42,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.federation_transport_client = Mock(spec=["send_transaction"])
+ self.federation_transport_client.send_transaction = AsyncMock()
hs = self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client,
)
- hs.get_storage_controllers().state.get_current_hosts_in_room = Mock( # type: ignore[assignment]
- return_value=make_awaitable({"test", "host2"})
+ hs.get_storage_controllers().state.get_current_hosts_in_room = AsyncMock( # type: ignore[method-assign]
+ return_value={"test", "host2"}
)
- hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[assignment]
+ hs.get_storage_controllers().state.get_current_hosts_in_room_or_partial_state_approximation = ( # type: ignore[method-assign]
hs.get_storage_controllers().state.get_current_hosts_in_room
)
@@ -64,7 +64,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts(self) -> None:
mock_send_transaction = self.federation_transport_client.send_transaction
- mock_send_transaction.return_value = make_awaitable({})
+ mock_send_transaction.return_value = {}
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
@@ -75,7 +75,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
thread_id=None,
data={"ts": 1234},
)
- self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
+ self.get_success(sender.send_read_receipt(receipt))
self.pump()
@@ -104,13 +104,16 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
def test_send_receipts_thread(self) -> None:
mock_send_transaction = self.federation_transport_client.send_transaction
- mock_send_transaction.return_value = make_awaitable({})
+ mock_send_transaction.return_value = {}
# Create receipts for:
#
# * The same room / user on multiple threads.
# * A different user in the same room.
sender = self.hs.get_federation_sender()
+ # Hack so that we have a txn in-flight so we batch up read receipts
+ # below
+ sender.wake_destination("host2")
for user, thread in (
("alice", None),
("alice", "thread"),
@@ -125,9 +128,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
thread_id=thread,
data={"ts": 1234},
)
- self.successResultOf(
- defer.ensureDeferred(sender.send_read_receipt(receipt))
- )
+ defer.ensureDeferred(sender.send_read_receipt(receipt))
self.pump()
@@ -180,7 +181,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
"""Send two receipts in quick succession; the second should be flushed, but
only after 20ms"""
mock_send_transaction = self.federation_transport_client.send_transaction
- mock_send_transaction.return_value = make_awaitable({})
+ mock_send_transaction.return_value = {}
sender = self.hs.get_federation_sender()
receipt = ReadReceipt(
@@ -191,7 +192,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase):
thread_id=None,
data={"ts": 1234},
)
- self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt)))
+ self.get_success(sender.send_read_receipt(receipt))
self.pump()
@@ -276,6 +277,8 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.federation_transport_client = Mock(
spec=["send_transaction", "query_user_devices"]
)
+ self.federation_transport_client.send_transaction = AsyncMock()
+ self.federation_transport_client.query_user_devices = AsyncMock()
return self.setup_test_homeserver(
federation_transport_client=self.federation_transport_client,
)
@@ -317,13 +320,13 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.record_transaction
)
- def record_transaction(
+ async def record_transaction(
self, txn: Transaction, json_cb: Optional[Callable[[], JsonDict]] = None
- ) -> "defer.Deferred[JsonDict]":
+ ) -> JsonDict:
assert json_cb is not None
data = json_cb()
self.edus.extend(data["edus"])
- return defer.succeed({})
+ return {}
def test_send_device_updates(self) -> None:
"""Basic case: each device update should result in an EDU"""
@@ -340,7 +343,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
self.reactor.advance(1)
# a second call should produce no new device EDUs
- self.hs.get_federation_sender().send_device_messages("host2")
+ self.get_success(
+ self.hs.get_federation_sender().send_device_messages(["host2"])
+ )
self.assertEqual(self.edus, [])
# a second device
@@ -354,15 +359,11 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# Send the server a device list EDU for the other user, this will cause
# it to try and resync the device lists.
- self.federation_transport_client.query_user_devices.return_value = (
- make_awaitable(
- {
- "stream_id": "1",
- "user_id": "@user2:host2",
- "devices": [{"device_id": "D1"}],
- }
- )
- )
+ self.federation_transport_client.query_user_devices.return_value = {
+ "stream_id": "1",
+ "user_id": "@user2:host2",
+ "devices": [{"device_id": "D1"}],
+ }
self.get_success(
self.device_handler.device_list_updater.incoming_device_list_update(
@@ -533,7 +534,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
recovery
"""
mock_send_txn = self.federation_transport_client.send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+ mock_send_txn.side_effect = AssertionError("fail")
# create devices
u1 = self.register_user("user", "pass")
@@ -552,7 +553,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# recover the server
mock_send_txn.side_effect = self.record_transaction
- self.hs.get_federation_sender().send_device_messages("host2")
+ self.get_success(
+ self.hs.get_federation_sender().send_device_messages(["host2"])
+ )
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -578,7 +581,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
This case tests the behaviour when the server has never been reachable.
"""
mock_send_txn = self.federation_transport_client.send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+ mock_send_txn.side_effect = AssertionError("fail")
# create devices
u1 = self.register_user("user", "pass")
@@ -603,7 +606,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# recover the server
mock_send_txn.side_effect = self.record_transaction
- self.hs.get_federation_sender().send_device_messages("host2")
+ self.get_success(
+ self.hs.get_federation_sender().send_device_messages(["host2"])
+ )
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
@@ -636,7 +641,7 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# now the server goes offline
mock_send_txn = self.federation_transport_client.send_transaction
- mock_send_txn.side_effect = lambda t, cb: defer.fail(AssertionError("fail"))
+ mock_send_txn.side_effect = AssertionError("fail")
self.login("user", "pass", device_id="D2")
self.login("user", "pass", device_id="D3")
@@ -658,7 +663,9 @@ class FederationSenderDevicesTestCases(HomeserverTestCase):
# recover the server
mock_send_txn.side_effect = self.record_transaction
- self.hs.get_federation_sender().send_device_messages("host2")
+ self.get_success(
+ self.hs.get_federation_sender().send_device_messages(["host2"])
+ )
# We queue up device list updates to be sent over federation, so we
# advance to clear the queue.
diff --git a/tests/federation/transport/test_knocking.py b/tests/federation/transport/test_knocking.py
index 70209ab090..3f42f79f26 100644
--- a/tests/federation/transport/test_knocking.py
+++ b/tests/federation/transport/test_knocking.py
@@ -218,7 +218,7 @@ class FederationKnockingTestCase(
) -> EventBase:
return pdu
- homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[assignment]
+ homeserver.get_federation_server()._check_sigs_and_hash = ( # type: ignore[method-assign]
approve_all_signature_checking
)
@@ -229,7 +229,7 @@ class FederationKnockingTestCase(
) -> None:
pass
- homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[assignment]
+ homeserver.get_federation_event_handler()._check_event_auth = _check_event_auth # type: ignore[method-assign]
return super().prepare(reactor, clock, homeserver)
|