summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_appservice.py5
-rw-r--r--tests/storage/test_background_update.py10
-rw-r--r--tests/storage/test_client_ips.py7
-rw-r--r--tests/storage/test_monthly_active_users.py23
-rw-r--r--tests/storage/util/test_partial_state_events_tracker.py8
5 files changed, 23 insertions, 30 deletions
diff --git a/tests/storage/test_appservice.py b/tests/storage/test_appservice.py
index 71302facd1..48f39df9fe 100644
--- a/tests/storage/test_appservice.py
+++ b/tests/storage/test_appservice.py
@@ -15,7 +15,7 @@ import json
 import os
 import tempfile
 from typing import List, cast
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import yaml
 
@@ -35,7 +35,6 @@ from synapse.types import DeviceListUpdates
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 
 
 class ApplicationServiceStoreTestCase(unittest.HomeserverTestCase):
@@ -339,7 +338,7 @@ class ApplicationServiceTransactionStoreTestCase(unittest.HomeserverTestCase):
 
         # we aren't testing store._base stuff here, so mock this out
         # (ignore needed because Mypy won't allow us to assign to a method otherwise)
-        self.store.get_events_as_list = Mock(return_value=make_awaitable(events))  # type: ignore[assignment]
+        self.store.get_events_as_list = AsyncMock(return_value=events)  # type: ignore[assignment]
 
         self.get_success(self._insert_txn(self.as_list[1]["id"], 9, other_events))
         self.get_success(self._insert_txn(service.id, 10, events))
diff --git a/tests/storage/test_background_update.py b/tests/storage/test_background_update.py
index a4a823a252..2af7280ba3 100644
--- a/tests/storage/test_background_update.py
+++ b/tests/storage/test_background_update.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 import yaml
 
@@ -32,7 +32,7 @@ from synapse.types import JsonDict
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable, simple_async_mock
+from tests.test_utils import simple_async_mock
 from tests.unittest import override_config
 
 
@@ -363,9 +363,9 @@ class BackgroundUpdateControllerTestCase(unittest.HomeserverTestCase):
         # Register the callbacks with more mocks
         self.hs.get_module_api().register_background_update_controller_callbacks(
             on_update=self._on_update,
-            min_batch_size=Mock(return_value=make_awaitable(self._default_batch_size)),
-            default_batch_size=Mock(
-                return_value=make_awaitable(self._default_batch_size),
+            min_batch_size=AsyncMock(return_value=self._default_batch_size),
+            default_batch_size=AsyncMock(
+                return_value=self._default_batch_size,
             ),
         )
 
diff --git a/tests/storage/test_client_ips.py b/tests/storage/test_client_ips.py
index 209d68b40b..12e24d4dbd 100644
--- a/tests/storage/test_client_ips.py
+++ b/tests/storage/test_client_ips.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 
 from typing import Any, Dict
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from parameterized import parameterized
 
@@ -30,7 +30,6 @@ from synapse.util import Clock
 
 from tests import unittest
 from tests.server import make_request
-from tests.test_utils import make_awaitable
 from tests.unittest import override_config
 
 
@@ -443,9 +442,7 @@ class ClientIpStoreTestCase(unittest.HomeserverTestCase):
         lots_of_users = 100
         user_id = "@user:server"
 
-        self.store.get_monthly_active_count = Mock(
-            return_value=make_awaitable(lots_of_users)
-        )
+        self.store.get_monthly_active_count = AsyncMock(return_value=lots_of_users)
         self.get_success(
             self.store.insert_client_ip(
                 user_id, "access_token", "ip", "user_agent", "device_id"
diff --git a/tests/storage/test_monthly_active_users.py b/tests/storage/test_monthly_active_users.py
index 2827738379..0bf706ba08 100644
--- a/tests/storage/test_monthly_active_users.py
+++ b/tests/storage/test_monthly_active_users.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from typing import Any, Dict, List
-from unittest.mock import Mock
+from unittest.mock import AsyncMock
 
 from twisted.test.proto_helpers import MemoryReactor
 
@@ -21,7 +21,6 @@ from synapse.server import HomeServer
 from synapse.util import Clock
 
 from tests import unittest
-from tests.test_utils import make_awaitable
 from tests.unittest import default_config, override_config
 
 FORTY_DAYS = 40 * 24 * 60 * 60
@@ -253,7 +252,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         )
         self.get_success(d)
 
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[assignment]
 
         d = self.store.populate_monthly_active_users(user_id)
         self.get_success(d)
@@ -261,24 +260,22 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
         self.store.upsert_monthly_active_user.assert_not_called()
 
     def test_populate_monthly_users_should_update(self) -> None:
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[assignment]
 
-        self.store.is_trial_user = Mock(return_value=make_awaitable(False))  # type: ignore[assignment]
+        self.store.is_trial_user = AsyncMock(return_value=False)  # type: ignore[assignment]
 
-        self.store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(None)
-        )
+        self.store.user_last_seen_monthly_active = AsyncMock(return_value=None)
         d = self.store.populate_monthly_active_users("user_id")
         self.get_success(d)
 
         self.store.upsert_monthly_active_user.assert_called_once()
 
     def test_populate_monthly_users_should_not_update(self) -> None:
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[assignment]
 
-        self.store.is_trial_user = Mock(return_value=make_awaitable(False))  # type: ignore[assignment]
-        self.store.user_last_seen_monthly_active = Mock(
-            return_value=make_awaitable(self.hs.get_clock().time_msec())
+        self.store.is_trial_user = AsyncMock(return_value=False)  # type: ignore[assignment]
+        self.store.user_last_seen_monthly_active = AsyncMock(
+            return_value=self.hs.get_clock().time_msec()
         )
 
         d = self.store.populate_monthly_active_users("user_id")
@@ -359,7 +356,7 @@ class MonthlyActiveUsersTestCase(unittest.HomeserverTestCase):
 
     @override_config({"limit_usage_by_mau": False, "mau_stats_only": False})
     def test_no_users_when_not_tracking(self) -> None:
-        self.store.upsert_monthly_active_user = Mock(return_value=make_awaitable(None))  # type: ignore[assignment]
+        self.store.upsert_monthly_active_user = AsyncMock(return_value=None)  # type: ignore[assignment]
 
         self.get_success(self.store.populate_monthly_active_users("@user:sever"))
 
diff --git a/tests/storage/util/test_partial_state_events_tracker.py b/tests/storage/util/test_partial_state_events_tracker.py
index 0e3fc2a77f..29be8cdbd0 100644
--- a/tests/storage/util/test_partial_state_events_tracker.py
+++ b/tests/storage/util/test_partial_state_events_tracker.py
@@ -22,7 +22,6 @@ from synapse.storage.util.partial_state_events_tracker import (
     PartialStateEventsTracker,
 )
 
-from tests.test_utils import make_awaitable
 from tests.unittest import TestCase
 
 
@@ -124,16 +123,17 @@ class PartialStateEventsTrackerTestCase(TestCase):
 class PartialCurrentStateTrackerTestCase(TestCase):
     def setUp(self) -> None:
         self.mock_store = mock.Mock(spec_set=["is_partial_state_room"])
+        self.mock_store.is_partial_state_room = mock.AsyncMock()
 
         self.tracker = PartialCurrentStateTracker(self.mock_store)
 
     def test_does_not_block_for_full_state_rooms(self) -> None:
-        self.mock_store.is_partial_state_room.return_value = make_awaitable(False)
+        self.mock_store.is_partial_state_room.return_value = False
 
         self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
 
     def test_blocks_for_partial_room_state(self) -> None:
-        self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+        self.mock_store.is_partial_state_room.return_value = True
 
         d = ensureDeferred(self.tracker.await_full_state("room_id"))
 
@@ -156,7 +156,7 @@ class PartialCurrentStateTrackerTestCase(TestCase):
         self.successResultOf(ensureDeferred(self.tracker.await_full_state("room_id")))
 
     def test_cancellation(self) -> None:
-        self.mock_store.is_partial_state_room.return_value = make_awaitable(True)
+        self.mock_store.is_partial_state_room.return_value = True
 
         d1 = ensureDeferred(self.tracker.await_full_state("room_id"))
         self.assertNoResult(d1)