summary refs log tree commit diff
path: root/tests/rest
diff options
context:
space:
mode:
authorPatrick Cloke <clokep@users.noreply.github.com>2023-08-24 19:38:46 -0400
committerGitHub <noreply@github.com>2023-08-24 19:38:46 -0400
commitdaf11e26efc210dccaef029422431a7d2803dd8a (patch)
tree7e42649714a38217d25eecf0a089592632dc989c /tests/rest
parent Document `exclude_rooms_fom_sync` configuration option (#16178) (diff)
downloadsynapse-daf11e26efc210dccaef029422431a7d2803dd8a.tar.xz
Replace make_awaitable with AsyncMock (#16179)
Python 3.8 provides a native AsyncMock, we can replace the
homegrown version we have.
Diffstat (limited to 'tests/rest')
-rw-r--r--tests/rest/admin/test_user.py16
-rw-r--r--tests/rest/client/test_account_data.py5
-rw-r--r--tests/rest/client/test_presence.py5
-rw-r--r--tests/rest/client/test_relations.py9
-rw-r--r--tests/rest/client/test_rooms.py37
-rw-r--r--tests/rest/client/test_third_party_rules.py35
-rw-r--r--tests/rest/client/test_transactions.py9
7 files changed, 55 insertions, 61 deletions
diff --git a/tests/rest/admin/test_user.py b/tests/rest/admin/test_user.py
index feb81844ae..339a41c7e1 100644
--- a/tests/rest/admin/test_user.py
+++ b/tests/rest/admin/test_user.py
@@ -18,7 +18,7 @@ import os
 import urllib.parse
 from binascii import unhexlify
 from typing import List, Optional
-from unittest.mock import Mock, patch
+from unittest.mock import AsyncMock, Mock, patch
 
 from parameterized import parameterized, parameterized_class
 
@@ -45,7 +45,7 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeSite, make_request
-from tests.test_utils import SMALL_PNG, make_awaitable
+from tests.test_utils import SMALL_PNG
 from tests.unittest import override_config
 
 
@@ -419,8 +419,8 @@ class UserRegisterTestCase(unittest.HomeserverTestCase):
         store = self.hs.get_datastores().main
 
         # Set monthly active users to the limit
-        store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -1834,8 +1834,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
             )
 
         # Set monthly active users to the limit
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
@@ -1871,8 +1871,8 @@ class UserRestTestCase(unittest.HomeserverTestCase):
         handler = self.hs.get_registration_handler()
 
         # Set monthly active users to the limit
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(self.hs.config.server.max_mau_value)
+        self.store.get_monthly_active_count = AsyncMock(
+            return_value=self.hs.config.server.max_mau_value
         )
         # Check that the blocking of monthly active users is working as expected
         # The registration of a new user fails due to the limit
diff --git a/tests/rest/client/test_account_data.py b/tests/rest/client/test_account_data.py
index d5b0640e7a..481db9a687 100644
--- a/tests/rest/client/test_account_data.py
+++ b/tests/rest/client/test_account_data.py
@@ -11,13 +11,12 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from synapse.rest import admin
 from synapse.rest.client import account_data, login, room
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class AccountDataTestCase(unittest.HomeserverTestCase):
@@ -32,7 +31,7 @@ class AccountDataTestCase(unittest.HomeserverTestCase):
         """Tests that the on_account_data_updated module callback is called correctly when
         a user's account data changes.
         """
-        mocked_callback = Mock(return_value=make_awaitable(None))
+        mocked_callback = AsyncMock(return_value=None)
         self.hs.get_account_data_handler()._on_account_data_updated_callbacks.append(
             mocked_callback
         )
diff --git a/tests/rest/client/test_presence.py b/tests/rest/client/test_presence.py
index e12098102b..66b387cea3 100644
--- a/tests/rest/client/test_presence.py
+++ b/tests/rest/client/test_presence.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from http import HTTPStatus
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -23,7 +23,6 @@ from synapse.types import UserID
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class PresenceTestCase(unittest.HomeserverTestCase):
@@ -36,7 +35,7 @@ class PresenceTestCase(unittest.HomeserverTestCase):
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
         self.presence_handler = Mock(spec=PresenceHandler)
-        self.presence_handler.set_state.return_value = make_awaitable(None)
+        self.presence_handler.set_state = AsyncMock(return_value=None)
 
         hs = self.setup_test_homeserver(
             "red",
diff --git a/tests/rest/client/test_relations.py b/tests/rest/client/test_relations.py
index 9bfe913e45..d3f6191996 100644
--- a/tests/rest/client/test_relations.py
+++ b/tests/rest/client/test_relations.py
@@ -15,7 +15,7 @@
 
 import urllib.parse
 from typing import Any, Callable, Dict, List, Optional, Tuple
-from unittest.mock import patch
+from unittest.mock import AsyncMock, patch
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -28,7 +28,6 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import FakeChannel
-from tests.test_utils import make_awaitable
 from tests.test_utils.event_injection import inject_event
 from tests.unittest import override_config
 
@@ -264,7 +263,8 @@ class RelationsTestCase(BaseRelationsTestCase):
         # Disable the validation to pretend this came over federation.
         with patch(
             "synapse.handlers.message.EventCreationHandler._validate_event_relation",
-            new=lambda self, event: make_awaitable(None),
+            new_callable=AsyncMock,
+            return_value=None,
         ):
             # Generate a various relations from a different room.
             self.get_success(
@@ -1300,7 +1300,8 @@ class BundledAggregationsTestCase(BaseRelationsTestCase):
         # not an event the Client-Server API will allow..
         with patch(
             "synapse.handlers.message.EventCreationHandler._validate_event_relation",
-            new=lambda self, event: make_awaitable(None),
+            new_callable=AsyncMock,
+            return_value=None,
         ):
             # Create a sub-thread off the thread, which is not allowed.
             self._send_relation(
diff --git a/tests/rest/client/test_rooms.py b/tests/rest/client/test_rooms.py
index 88e579dc39..53182459e4 100644
--- a/tests/rest/client/test_rooms.py
+++ b/tests/rest/client/test_rooms.py
@@ -20,7 +20,7 @@
 import json
 from http import HTTPStatus
 from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
-from unittest.mock import Mock, call, patch
+from unittest.mock import AsyncMock, Mock, call, patch
 from urllib import parse as urlparse
 
 from parameterized import param, parameterized
@@ -52,7 +52,6 @@ from synapse.util.stringutils import random_string
 from tests import unittest
 from tests.http.server._base import make_request_with_cancellation_test
 from tests.storage.test_stream import PaginationTestCase
-from tests.test_utils import make_awaitable
 from tests.test_utils.event_injection import create_event
 from tests.unittest import override_config
 
@@ -70,8 +69,8 @@ class RoomBase(unittest.HomeserverTestCase):
         )
 
         self.hs.get_federation_handler = Mock()  # type: ignore[assignment]
-        self.hs.get_federation_handler.return_value.maybe_backfill = Mock(
-            return_value=make_awaitable(None)
+        self.hs.get_federation_handler.return_value.maybe_backfill = AsyncMock(
+            return_value=None
         )
 
         async def _insert_client_ip(*args: Any, **kwargs: Any) -> None:
@@ -2375,7 +2374,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
     ]
 
     def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
-        return self.setup_test_homeserver(federation_client=Mock())
+        return self.setup_test_homeserver(federation_client=AsyncMock())
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.register_user("user", "pass")
@@ -2385,7 +2384,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
 
     def test_simple(self) -> None:
         "Simple test for searching rooms over federation"
-        self.federation_client.get_public_rooms.return_value = make_awaitable({})  # type: ignore[attr-defined]
+        self.federation_client.get_public_rooms.return_value = {}  # type: ignore[attr-defined]
 
         search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
 
@@ -2413,7 +2412,7 @@ class PublicRoomsTestRemoteSearchFallbackTestCase(unittest.HomeserverTestCase):
         # with a 404, when using search filters.
         self.federation_client.get_public_rooms.side_effect = (  # type: ignore[attr-defined]
             HttpResponseException(HTTPStatus.NOT_FOUND, "Not Found", b""),
-            make_awaitable({}),
+            {},
         )
 
         search_filter = {PublicRoomsFilterFields.GENERIC_SEARCH_TERM: "foobar"}
@@ -3413,17 +3412,17 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # Mock a few functions to prevent the test from failing due to failing to talk to
         # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
-        make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
+        make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
         self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[assignment]
-        self.hs.get_identity_handler().lookup_3pid = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self.hs.get_identity_handler().lookup_3pid = AsyncMock(  # type: ignore[assignment]
+            return_value=None,
         )
 
         # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
         # allow everything for now.
         # `spec` argument is needed for this function mock to have `__qualname__`, which
         # is needed for `Measure` metrics buried in SpamChecker.
-        mock = Mock(return_value=make_awaitable(True), spec=lambda *x: None)
+        mock = AsyncMock(return_value=True, spec=lambda *x: None)
         self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
             mock
         )
@@ -3451,7 +3450,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
 
         # Now change the return value of the callback to deny any invite and test that
         # we can't send the invite.
-        mock.return_value = make_awaitable(False)
+        mock.return_value = False
         channel = self.make_request(
             method="POST",
             path="/rooms/" + self.room_id + "/invite",
@@ -3477,18 +3476,18 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # Mock a few functions to prevent the test from failing due to failing to talk to
         # a remote IS. We keep the mock for make_and_store_3pid_invite around so we
         # can check its call_count later on during the test.
-        make_invite_mock = Mock(return_value=make_awaitable((Mock(event_id="abc"), 0)))
+        make_invite_mock = AsyncMock(return_value=(Mock(event_id="abc"), 0))
         self.hs.get_room_member_handler()._make_and_store_3pid_invite = make_invite_mock  # type: ignore[assignment]
-        self.hs.get_identity_handler().lookup_3pid = Mock(  # type: ignore[assignment]
-            return_value=make_awaitable(None),
+        self.hs.get_identity_handler().lookup_3pid = AsyncMock(  # type: ignore[assignment]
+            return_value=None,
         )
 
         # Add a mock to the spamchecker callbacks for user_may_send_3pid_invite. Make it
         # allow everything for now.
         # `spec` argument is needed for this function mock to have `__qualname__`, which
         # is needed for `Measure` metrics buried in SpamChecker.
-        mock = Mock(
-            return_value=make_awaitable(synapse.module_api.NOT_SPAM),
+        mock = AsyncMock(
+            return_value=synapse.module_api.NOT_SPAM,
             spec=lambda *x: None,
         )
         self.hs.get_module_api_callbacks().spam_checker._user_may_send_3pid_invite_callbacks.append(
@@ -3519,7 +3518,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         # Now change the return value of the callback to deny any invite and test that
         # we can't send the invite. We pick an arbitrary error code to be able to check
         # that the same code has been returned
-        mock.return_value = make_awaitable(Codes.CONSENT_NOT_GIVEN)
+        mock.return_value = Codes.CONSENT_NOT_GIVEN
         channel = self.make_request(
             method="POST",
             path="/rooms/" + self.room_id + "/invite",
@@ -3538,7 +3537,7 @@ class ThreepidInviteTestCase(unittest.HomeserverTestCase):
         make_invite_mock.assert_called_once()
 
         # Run variant with `Tuple[Codes, dict]`.
-        mock.return_value = make_awaitable((Codes.EXPIRED_ACCOUNT, {"field": "value"}))
+        mock.return_value = (Codes.EXPIRED_ACCOUNT, {"field": "value"})
         channel = self.make_request(
             method="POST",
             path="/rooms/" + self.room_id + "/invite",
diff --git a/tests/rest/client/test_third_party_rules.py b/tests/rest/client/test_third_party_rules.py
index e5ba5a9706..da37fcb045 100644
--- a/tests/rest/client/test_third_party_rules.py
+++ b/tests/rest/client/test_third_party_rules.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 import threading
 from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -33,7 +33,6 @@ from synapse.util import Clock
 from synapse.util.frozenutils import unfreeze
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 if TYPE_CHECKING:
     from synapse.module_api import ModuleApi
@@ -477,7 +476,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
 
     def test_on_new_event(self) -> None:
         """Test that the on_new_event callback is called on new events"""
-        on_new_event = Mock(make_awaitable(None))
+        on_new_event = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_new_event_callbacks.append(
             on_new_event
         )
@@ -580,7 +579,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
 
         # Register a mock callback.
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
             m
         )
@@ -641,7 +640,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         avatar_url = "mxc://matrix.org/oWQDvfewxmlRaRCkVbfetyEo"
 
         # Register a mock callback.
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
             m
         )
@@ -682,7 +681,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing a user's deactivation.
         """
         # Register a mocked callback.
-        deactivation_mock = Mock(return_value=make_awaitable(None))
+        deactivation_mock = AsyncMock(return_value=None)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_user_deactivation_status_changed_callbacks.append(
             deactivation_mock,
@@ -690,7 +689,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         # Also register a mocked callback for profile updates, to check that the
         # deactivation code calls it in a way that let modules know the user is being
         # deactivated.
-        profile_mock = Mock(return_value=make_awaitable(None))
+        profile_mock = AsyncMock(return_value=None)
         self.hs.get_module_api_callbacks().third_party_event_rules._on_profile_update_callbacks.append(
             profile_mock,
         )
@@ -740,7 +739,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         well as a reactivation.
         """
         # Register a mock callback.
-        m = Mock(return_value=make_awaitable(None))
+        m = AsyncMock(return_value=None)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_user_deactivation_status_changed_callbacks.append(m)
 
@@ -794,7 +793,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing a user's deactivation.
         """
         # Register a mocked callback.
-        deactivation_mock = Mock(return_value=make_awaitable(False))
+        deactivation_mock = AsyncMock(return_value=False)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_deactivate_user_callbacks.append(
             deactivation_mock,
@@ -840,7 +839,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing a user's deactivation triggered by a server admin.
         """
         # Register a mocked callback.
-        deactivation_mock = Mock(return_value=make_awaitable(False))
+        deactivation_mock = AsyncMock(return_value=False)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_deactivate_user_callbacks.append(
             deactivation_mock,
@@ -879,7 +878,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         correctly when processing an admin's shutdown room request.
         """
         # Register a mocked callback.
-        shutdown_mock = Mock(return_value=make_awaitable(False))
+        shutdown_mock = AsyncMock(return_value=False)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._check_can_shutdown_room_callbacks.append(
             shutdown_mock,
@@ -915,7 +914,7 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         associating a 3PID to an account.
         """
         # Register a mocked callback.
-        threepid_bind_mock = Mock(return_value=make_awaitable(None))
+        threepid_bind_mock = AsyncMock(return_value=None)
         third_party_rules = self.hs.get_module_api_callbacks().third_party_event_rules
         third_party_rules._on_threepid_bind_callbacks.append(threepid_bind_mock)
 
@@ -957,11 +956,9 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         just before associating and removing a 3PID to/from an account.
         """
         # Pretend to be a Synapse module and register both callbacks as mocks.
-        on_add_user_third_party_identifier_callback_mock = Mock(
-            return_value=make_awaitable(None)
-        )
-        on_remove_user_third_party_identifier_callback_mock = Mock(
-            return_value=make_awaitable(None)
+        on_add_user_third_party_identifier_callback_mock = AsyncMock(return_value=None)
+        on_remove_user_third_party_identifier_callback_mock = AsyncMock(
+            return_value=None
         )
         self.hs.get_module_api().register_third_party_rules_callbacks(
             on_add_user_third_party_identifier=on_add_user_third_party_identifier_callback_mock,
@@ -1021,8 +1018,8 @@ class ThirdPartyRulesTestCase(unittest.FederatingHomeserverTestCase):
         when a user is deactivated and their third-party ID associations are deleted.
         """
         # Pretend to be a Synapse module and register both callbacks as mocks.
-        on_remove_user_third_party_identifier_callback_mock = Mock(
-            return_value=make_awaitable(None)
+        on_remove_user_third_party_identifier_callback_mock = AsyncMock(
+            return_value=None
         )
         self.hs.get_module_api().register_third_party_rules_callbacks(
             on_remove_user_third_party_identifier=on_remove_user_third_party_identifier_callback_mock,
diff --git a/tests/rest/client/test_transactions.py b/tests/rest/client/test_transactions.py
index d8dc56261a..951a3cbc43 100644
--- a/tests/rest/client/test_transactions.py
+++ b/tests/rest/client/test_transactions.py
@@ -14,7 +14,7 @@
 
 from http import HTTPStatus
 from typing import Any, Generator, Tuple, cast
-from unittest.mock import Mock, call
+from unittest.mock import AsyncMock, Mock, call
 
 from twisted.internet import defer, reactor as _reactor
 
@@ -24,7 +24,6 @@ from synapse.types import ISynapseReactor, JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.utils import MockClock
 
 reactor = cast(ISynapseReactor, _reactor)
@@ -53,7 +52,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
     def test_executes_given_function(
         self,
     ) -> Generator["defer.Deferred[Any]", object, None]:
-        cb = Mock(return_value=make_awaitable(self.mock_http_response))
+        cb = AsyncMock(return_value=self.mock_http_response)
         res = yield self.cache.fetch_or_execute_request(
             self.mock_request, self.mock_requester, cb, "some_arg", keyword="arg"
         )
@@ -64,7 +63,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
     def test_deduplicates_based_on_key(
         self,
     ) -> Generator["defer.Deferred[Any]", object, None]:
-        cb = Mock(return_value=make_awaitable(self.mock_http_response))
+        cb = AsyncMock(return_value=self.mock_http_response)
         for i in range(3):  # invoke multiple times
             res = yield self.cache.fetch_or_execute_request(
                 self.mock_request,
@@ -168,7 +167,7 @@ class HttpTransactionCacheTestCase(unittest.TestCase):
 
     @defer.inlineCallbacks
     def test_cleans_up(self) -> Generator["defer.Deferred[Any]", object, None]:
-        cb = Mock(return_value=make_awaitable(self.mock_http_response))
+        cb = AsyncMock(return_value=self.mock_http_response)
         yield self.cache.fetch_or_execute_request(
             self.mock_request, self.mock_requester, cb, "an arg"
         )