diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 71302facd1..48f39df9fe 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -15,7 +15,7 @@ import json
import os
import tempfile
from typing import List, cast
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
import yaml
@@ -35,7 +35,6 @@ from synapse.types import DeviceListUpdates
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import make_awaitable
class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
@@ -339,7 +338,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
# we aren't testing store._base stuff here, so mock this out
# (ignore needed because Mypy won't allow us to assign to a method otherwise)
- self.store.get_events_as_list = Mock(return_value=make_awaitable(events)) # type: ignore[assignment]
+ self.store.get_events_as_list = AsyncMock(return_value=events) # type: ignore[assignment]
self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events))
self.get_success(self._insert_txn(service.id, 10, events))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index a4a823a252..2af7280ba3 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
import yaml
@@ -32,7 +32,7 @@ from synapse.types import JsonDict
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import make_awaitable, simple_async_mock
+from tests.test_utils import simple_async_mock
from tests.unittest import override_config
@@ -363,9 +363,9 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
# Register the callbacks with more mocks
self.hs.get_module_api().register_background_update_controller_callbacks(
on_update=self._on_update,
- min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
- default_batch_size=Mock(
- return_value=make_awaitable(self._default_batch_size),
+ min_batch_size=AsyncMock(return_value=self._default_batch_size),
+ default_batch_size=AsyncMock(
+ return_value=self._default_batch_size,
),
)
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 209d68b40b..12e24d4dbd 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -14,7 +14,7 @@
# limitations under the License.
from typing import Any, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
from parameterized import parameterized
@@ -30,7 +30,6 @@ from synapse.util import Clock
from tests import unittest
from tests.server import make_request
-from tests.test_utils import make_awaitable
from tests.unittest import override_config
@@ -443,9 +442,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
lots_of_users = 100
user_id = "@user:server"
- self.store.get_monthly_active_count = Mock(
- return_value=make_awaitable(lots_of_users)
- )
+ self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
self.get_success(
self.store.insert_client_ip(
user_id, "access_token", "ip", "user_agent", "device_id"
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 2827738379..0bf706ba08 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
from twisted.test.proto_helpers import MemoryReactor
@@ -21,7 +21,6 @@ from synapse.server import HomeServer
from synapse.util import Clock
from tests import unittest
-from tests.test_utils import make_awaitable
from tests.unittest import default_config, override_config
FORTY_DAYS = 40 * 24 * 60 * 60
@@ -253,7 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
)
self.get_success(d)
- self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
+ self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
d = self.store.populate_monthly_active_users(user_id)
self.get_success(d)
@@ -261,24 +260,22 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
self.store.upsert_monthly_active_user.assert_not_called()
def test_populate_monthly_users_should_update(self) -> None:
- self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
+ self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
- self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
+ self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[assignment]
- self.store.user_last_seen_monthly_active = Mock(
- return_value=make_awaitable(None)
- )
+ self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
d = self.store.populate_monthly_active_users("user_id")
self.get_success(d)
self.store.upsert_monthly_active_user.assert_called_once()
def test_populate_monthly_users_should_not_update(self) -> None:
- self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
+ self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
- self.store.is_trial_user = Mock(return_value=make_awaitable(False)) # type: ignore[assignment]
- self.store.user_last_seen_monthly_active = Mock(
- return_value=make_awaitable(self.hs.get_clock().time_msec())
+ self.store.is_trial_user = AsyncMock(return_value=False) # type: ignore[assignment]
+ self.store.user_last_seen_monthly_active = AsyncMock(
+ return_value=self.hs.get_clock().time_msec()
)
d = self.store.populate_monthly_active_users("user_id")
@@ -359,7 +356,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
@override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
def test_no_users_when_not_tracking(self) -> None:
- self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None)) # type: ignore[assignment]
+ self.store.upsert_monthly_active_user = AsyncMock(return_value=None) # type: ignore[assignment]
self.get_success(self.store.populate_monthly_active_users("@user:sever"))
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 0e3fc2a77f..29be8cdbd0 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -22,7 +22,6 @@ from synapse.storage.util.partial_state_events_tracker import (
PartialStateEventsTracker,
)
-from tests.test_utils import make_awaitable
from tests.unittest import TestCase
@@ -124,16 +123,17 @@ class PartialStateEventsTrackerTestCase(TestCase):
class PartialCurrentStateTrackerTestCase(TestCase):
def setUp(self) -> None:
self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
+ self.mock_store.is_partial_state_room = mock.AsyncMock()
self.tracker = PartialCurrentStateTracker(self.mock_store)
def test_does_not_block_for_full_state_rooms(self) -> None:
- self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
+ self.mock_store.is_partial_state_room.return_value = False
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
def test_blocks_for_partial_room_state(self) -> None:
- self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+ self.mock_store.is_partial_state_room.return_value = True
d = ensureDeferred(self.tracker.await_full_state("room_id"))
@@ -156,7 +156,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
def test_cancellation(self) -> None:
- self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+ self.mock_store.is_partial_state_room.return_value = True
d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
self.assertNoResult(d1)
|