diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 9014e60577..46d022092e 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Dict, Iterable, List, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from parameterized import parameterized
@@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import event_injection, make_awaitable, simple_async_mock
+from tests.test_utils import event_injection
from tests.unittest import override_config
from tests.utils import MockClock
@@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.mock_store = Mock()
- self.mock_as_api = Mock()
+ self.mock_as_api = AsyncMock()
self.mock_scheduler = Mock()
hs = Mock()
hs.get_datastores.return_value = Mock(main=self.mock_store)
- self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None)
- self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
- self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
- None
- )
+ self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
+ self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None)
+ self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
@@ -69,21 +67,25 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested_in_event=False),
]
- self.mock_as_api.query_user.return_value = make_awaitable(True)
+ self.mock_as_api.query_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = make_awaitable([])
+ self.mock_store.get_user_by_id = AsyncMock(return_value=[])
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
- make_awaitable((0, {})),
- make_awaitable((1, {event.event_id: 0})),
- ]
- self.mock_store.get_events_as_list.side_effect = [
- make_awaitable([]),
- make_awaitable([event]),
- ]
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+ side_effect=[
+ (0, {}),
+ (1, {event.event_id: 0}),
+ ]
+ )
+ self.mock_store.get_events_as_list = AsyncMock(
+ side_effect=[
+ [],
+ [event],
+ ]
+ )
self.handler.notify_interested_services(RoomStreamToken(None, 1))
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -95,14 +97,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = make_awaitable(None)
+ self.mock_store.get_user_by_id = AsyncMock(return_value=None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
- make_awaitable((0, {event.event_id: 0})),
- ]
- self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
+ self.mock_as_api.query_user.return_value = True
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+ side_effect=[
+ (0, {event.event_id: 0}),
+ ]
+ )
+ self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -112,13 +116,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
+ self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
- make_awaitable((0, [event], {event.event_id: 0})),
- ]
+ self.mock_as_api.query_user.return_value = True
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+ side_effect=[
+ (0, [event], {event.event_id: 0}),
+ ]
+ )
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice_alias(is_room_alias_in_namespace=False),
]
- self.mock_as_api.query_alias.return_value = make_awaitable(True)
+ self.mock_as_api.query_alias = AsyncMock(return_value=True)
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
- Mock(room_id=room_id, servers=servers)
+ self.mock_store.get_association_from_room_alias = AsyncMock(
+ return_value=Mock(room_id=room_id, servers=servers)
)
result = self.successResultOf(
@@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_protocol_no_response(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
+ self.mock_as_api.get_3pe_protocol.return_value = None
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_select_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
- {"x-protocol-data": 42, "instances": []}
- )
+ self.mock_as_api.get_3pe_protocol.return_value = {
+ "x-protocol-data": 42,
+ "instances": [],
+ }
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
)
@@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
- {"x-protocol-data": 42, "instances": []}
- )
+ self.mock_as_api.get_3pe_protocol.return_value = {
+ "x-protocol-data": 42,
+ "instances": [],
+ }
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["other-protocol"])
self.mock_store.get_app_services.return_value = [service_one, service_two]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
- {"x-protocol-data": 42, "instances": []}
- )
+ self.mock_as_api.get_3pe_protocol.return_value = {
+ "x-protocol-data": 42,
+ "instances": [],
+ }
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@@ -287,13 +296,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
- 579
- )
+ self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579)
event = Mock(event_id="event_1")
- self.event_source.sources.receipt.get_new_events_as.return_value = (
- make_awaitable(([event], None))
+ self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+ return_value=([event], None)
)
self.handler.notify_interested_services_ephemeral(
@@ -317,13 +324,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [interested_service]
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
- 580
- )
+ self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580)
event = Mock(event_id="event_1")
- self.event_source.sources.receipt.get_new_events_as.return_value = (
- make_awaitable(([event], None))
+ self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+ return_value=([event], None)
)
self.handler.notify_interested_services_ephemeral(
@@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
A mock representing the ApplicationService.
"""
service = Mock()
- service.is_interested_in_event.return_value = make_awaitable(
- is_interested_in_event
- )
+ service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event)
service.token = "mock_service_token"
service.url = "mock_service_url"
service.protocols = protocols
@@ -396,12 +399,12 @@ class ApplicationServicesHandlerSendEventsTestCase(unittest.HomeserverTestCase):
self.hs = hs
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track any outgoing ephemeral events
- self.send_mock = simple_async_mock()
- hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment]
+ self.send_mock = AsyncMock()
+ hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign]
return_value=self._services
)
@@ -894,12 +897,12 @@ class ApplicationServicesHandlerDeviceListsTestCase(unittest.HomeserverTestCase)
# Mock ApplicationServiceApi's put_json, so we can verify the raw JSON that
# will be sent over the wire
- self.put_json = simple_async_mock()
- hs.get_application_service_api().put_json = self.put_json # type: ignore[assignment]
+ self.put_json = AsyncMock()
+ hs.get_application_service_api().put_json = self.put_json # type: ignore[method-assign]
# Mock out application services, and allow defining our own in tests
self._services: List[ApplicationService] = []
- self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[assignment]
+ self.hs.get_datastores().main.get_app_services = Mock( # type: ignore[method-assign]
return_value=self._services
)
@@ -1000,8 +1003,8 @@ class ApplicationServicesHandlerOtkCountsTestCase(unittest.HomeserverTestCase):
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
# Mock the ApplicationServiceScheduler's _TransactionController's send method so that
# we can track what's going out
- self.send_mock = simple_async_mock()
- hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[assignment] # We assign to a method.
+ self.send_mock = AsyncMock()
+ hs.get_application_service_handler().scheduler.txn_ctrl.send = self.send_mock # type: ignore[method-assign] # We assign to a method.
# Define an application service for the tests
self._service_token = "VERYSECRET"
|