diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py
index d5b0640e7a..481db9a687 100644
--- a/tests/rest/client/test_account_data.py
+++ b/tests/rest/client/test_account_data.py
@@ -11,13 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
from synapse.rest import admin
from synapse.rest.client import account_data, login, room
from tests import unittest
-from tests.test_utils import make_awaitable
class AccountDataTestCase(unittest.HomeserverTestCase):
@@ -32,7 +31,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
"""Tests that the on_account_data_updated module callback is called correctly when
a user's account data changes.
"""
- mocked_callback = Mock(return_value=make_awaitable(None))
+ mocked_callback = AsyncMock(return_value=None)
self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
mocked_callback
)
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index e12098102b..66b387cea3 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from http import HTTPStatus
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -23,7 +23,6 @@ from synapse.types import UserID
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import make_awaitable
class PresenceTestCase(unittest.HomeserverTestCase):
@@ -36,7 +35,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
self.presence_handler = Mock(spec=PresenceHandler)
- self.presence_handler.set_state.return_value = make_awaitable(None)
+ self.presence_handler.set_state = AsyncMock(return_value=None)
hs = self.setup_test_homeserver(
"red",
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 9bfe913e45..d3f6191996 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,7 +15,7 @@
import urllib.parse
from typing import Any, Callable, Dict, List, Optional, Tuple
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
from twisted.test.proto_helpers import MemoryReactor
@@ -28,7 +28,6 @@ from synapse.util import Clock
from tests import unittest
from tests.server import FakeChannel
-from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import inject_event
from tests.unittest import override_config
@@ -264,7 +263,8 @@ class RelationsTestCase(BaseRelationsTestCase):
# Disable the validation to pretend this came over federation.
with patch(
"synapse.handlers.message.EventCreationHandler._validate_event_relation",
- new=lambda self, event: make_awaitable(None),
+ new_callable=AsyncMock,
+ return_value=None,
):
# Generate a various relations from a different room.
self.get_success(
@@ -1300,7 +1300,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
# not an event the Client-Server API will allow..
with patch(
"synapse.handlers.message.EventCreationHandler._validate_event_relation",
- new=lambda self, event: make_awaitable(None),
+ new_callable=AsyncMock,
+ return_value=None,
):
# Create a sub-thread off the thread, which is not allowed.
self._send_relation(
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 88e579dc39..53182459e4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -20,7 +20,7 @@
import json
from http import HTTPStatus
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
-from unittest.mock import Mock, call, patch
+from unittest.mock import AsyncMock, Mock, call, patch
from urllib import parse as urlparse
from parameterized import param, parameterized
@@ -52,7 +52,6 @@ from synapse.util.stringutils import random_string
from tests import unittest
from tests.http.server._base import make_request_with_cancellation_test
from tests.storage.test_stream import PaginationTestCase
-from tests.test_utils import make_awaitable
from tests.test_utils.event_injection import create_event
from tests.unittest import override_config
@@ -70,8 +69,8 @@ class RoomBase(unittest.HomeserverTestCase):
)
self.hs.get_federation_handler = Mock() # type: ignore[assignment]
- self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
- return_value=make_awaitable(None)
+ self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock(
+ return_value=None
)
async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
@@ -2375,7 +2374,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
]
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
- return self.setup_test_homeserver(federation_client=Mock())
+ return self.setup_test_homeserver(federation_client=AsyncMock())
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.register_user("user", "pass")
@@ -2385,7 +2384,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
def test_simple(self) -> None:
"Simple test for searching rooms over federation"
- self.federation_client.get_public_rooms.return_value = make_awaitable({}) # type: ignore[attr-defined]
+ self.federation_client.get_public_rooms.return_value = {} # type: ignore[attr-defined]
search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
@@ -2413,7 +2412,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
# with a 404, when using search filters.
self.federation_client.get_public_rooms.side_effect = ( # type: ignore[attr-defined]
HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""),
- make_awaitable({}),
+ {},
)
search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
@@ -3413,17 +3412,17 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
- make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
+ make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
- self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
- return_value=make_awaitable(None),
+ self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[assignment]
+ return_value=None,
)
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
# allow everything for now.
# `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker.
- mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None)
+ mock = AsyncMock(return_value=True, spec=lambda *x: None)
self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
mock
)
@@ -3451,7 +3450,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Now change the return value of the callback to deny any invite and test that
# we can't send the invite.
- mock.return_value = make_awaitable(False)
+ mock.return_value = False
channel = self.make_request(
method="POST",
path="/rooms/" + self.room_id + "/invite",
@@ -3477,18 +3476,18 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Mock a few functions to prevent the test from failing due to failing to talk to
# a remote IS. We keep the mock for make_and_store_3pid_invite around so we
# can check its call_count later on during the test.
- make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
+ make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock # type: ignore[assignment]
- self.hs.get_identity_handler().lookup_3pid = Mock( # type: ignore[assignment]
- return_value=make_awaitable(None),
+ self.hs.get_identity_handler().lookup_3pid = AsyncMock( # type: ignore[assignment]
+ return_value=None,
)
# Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
# allow everything for now.
# `spec` argument is needed for this function mock to have `__qualname__`, which
# is needed for `Measure` metrics buried in SpamChecker.
- mock = Mock(
- return_value=make_awaitable(synapse.module_api.NOT_SPAM),
+ mock = AsyncMock(
+ return_value=synapse.module_api.NOT_SPAM,
spec=lambda *x: None,
)
self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
@@ -3519,7 +3518,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
# Now change the return value of the callback to deny any invite and test that
# we can't send the invite. We pick an arbitrary error code to be able to check
# that the same code has been returned
- mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
+ mock.return_value = Codes.CONSENT_NOT_GIVEN
channel = self.make_request(
method="POST",
path="/rooms/" + self.room_id + "/invite",
@@ -3538,7 +3537,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
make_invite_mock.assert_called_once()
# Run variant with `Tuple[Codes, dict]`.
- mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"}))
+ mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"})
channel = self.make_request(
method="POST",
path="/rooms/" + self.room_id + "/invite",
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e5ba5a9706..da37fcb045 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -13,7 +13,7 @@
# limitations under the License.
import threading
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from twisted.test.proto_helpers import MemoryReactor
@@ -33,7 +33,6 @@ from synapse.util import Clock
from synapse.util.frozenutils import unfreeze
from tests import unittest
-from tests.test_utils import make_awaitable
if TYPE_CHECKING:
from synapse.module_api import ModuleApi
@@ -477,7 +476,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
def test_on_new_event(self) -> None:
"""Test that the on_new_event callback is called on new events"""
- on_new_event = Mock(make_awaitable(None))
+ on_new_event = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append(
on_new_event
)
@@ -580,7 +579,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
# Register a mock callback.
- m = Mock(return_value=make_awaitable(None))
+ m = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
m
)
@@ -641,7 +640,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
# Register a mock callback.
- m = Mock(return_value=make_awaitable(None))
+ m = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
m
)
@@ -682,7 +681,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing a user's deactivation.
"""
# Register a mocked callback.
- deactivation_mock = Mock(return_value=make_awaitable(None))
+ deactivation_mock = AsyncMock(return_value=None)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_user_deactivation_status_changed_callbacks.append(
deactivation_mock,
@@ -690,7 +689,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
# Also register a mocked callback for profile updates, to check that the
# deactivation code calls it in a way that let modules know the user is being
# deactivated.
- profile_mock = Mock(return_value=make_awaitable(None))
+ profile_mock = AsyncMock(return_value=None)
self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
profile_mock,
)
@@ -740,7 +739,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
well as a reactivation.
"""
# Register a mock callback.
- m = Mock(return_value=make_awaitable(None))
+ m = AsyncMock(return_value=None)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
@@ -794,7 +793,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing a user's deactivation.
"""
# Register a mocked callback.
- deactivation_mock = Mock(return_value=make_awaitable(False))
+ deactivation_mock = AsyncMock(return_value=False)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_deactivate_user_callbacks.append(
deactivation_mock,
@@ -840,7 +839,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing a user's deactivation triggered by a server admin.
"""
# Register a mocked callback.
- deactivation_mock = Mock(return_value=make_awaitable(False))
+ deactivation_mock = AsyncMock(return_value=False)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_deactivate_user_callbacks.append(
deactivation_mock,
@@ -879,7 +878,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
correctly when processing an admin's shutdown room request.
"""
# Register a mocked callback.
- shutdown_mock = Mock(return_value=make_awaitable(False))
+ shutdown_mock = AsyncMock(return_value=False)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._check_can_shutdown_room_callbacks.append(
shutdown_mock,
@@ -915,7 +914,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
associating a 3PID to an account.
"""
# Register a mocked callback.
- threepid_bind_mock = Mock(return_value=make_awaitable(None))
+ threepid_bind_mock = AsyncMock(return_value=None)
third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
@@ -957,11 +956,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
just before associating and removing a 3PID to/from an account.
"""
# Pretend to be a Synapse module and register both callbacks as mocks.
- on_add_user_third_party_identifier_callback_mock = Mock(
- return_value=make_awaitable(None)
- )
- on_remove_user_third_party_identifier_callback_mock = Mock(
- return_value=make_awaitable(None)
+ on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None)
+ on_remove_user_third_party_identifier_callback_mock = AsyncMock(
+ return_value=None
)
self.hs.get_module_api().register_third_party_rules_callbacks(
on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
@@ -1021,8 +1018,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
when a user is deactivated and their third-party ID associations are deleted.
"""
# Pretend to be a Synapse module and register both callbacks as mocks.
- on_remove_user_third_party_identifier_callback_mock = Mock(
- return_value=make_awaitable(None)
+ on_remove_user_third_party_identifier_callback_mock = AsyncMock(
+ return_value=None
)
self.hs.get_module_api().register_third_party_rules_callbacks(
on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index d8dc56261a..951a3cbc43 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -14,7 +14,7 @@
from http import HTTPStatus
from typing import Any, Generator, Tuple, cast
-from unittest.mock import Mock, call
+from unittest.mock import AsyncMock, Mock, call
from twisted.internet import defer, reactor as _reactor
@@ -24,7 +24,6 @@ from synapse.types import ISynapseReactor, JsonDict
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import make_awaitable
from tests.utils import MockClock
reactor = cast(ISynapseReactor, _reactor)
@@ -53,7 +52,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_executes_given_function(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
- cb = Mock(return_value=make_awaitable(self.mock_http_response))
+ cb = AsyncMock(return_value=self.mock_http_response)
res = yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
)
@@ -64,7 +63,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
def test_deduplicates_based_on_key(
self,
) -> Generator["defer.Deferred[Any]", object, None]:
- cb = Mock(return_value=make_awaitable(self.mock_http_response))
+ cb = AsyncMock(return_value=self.mock_http_response)
for i in range(3): # invoke multiple times
res = yield self.cache.fetch_or_execute_request(
self.mock_request,
@@ -168,7 +167,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
- cb = Mock(return_value=make_awaitable(self.mock_http_response))
+ cb = AsyncMock(return_value=self.mock_http_response)
yield self.cache.fetch_or_execute_request(
self.mock_request, self.mock_requester, cb, "an arg"
)
|