diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 9014e60577..5e2ae82cd4 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -13,7 +13,7 @@
# limitations under the License.
from typing import Dict, Iterable, List, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
from parameterized import parameterized
@@ -36,7 +36,7 @@ from synapse.util import Clock
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.test_utils import event_injection, make_awaitable, simple_async_mock
+from tests.test_utils import event_injection, simple_async_mock
from tests.unittest import override_config
from tests.utils import MockClock
@@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def setUp(self) -> None:
self.mock_store = Mock()
- self.mock_as_api = Mock()
+ self.mock_as_api = AsyncMock()
self.mock_scheduler = Mock()
hs = Mock()
hs.get_datastores.return_value = Mock(main=self.mock_store)
- self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None)
- self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
- self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
- None
- )
+ self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
+ self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None)
+ self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
hs.get_application_service_api.return_value = self.mock_as_api
hs.get_application_service_scheduler.return_value = self.mock_scheduler
hs.get_clock.return_value = MockClock()
@@ -69,21 +67,25 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice(is_interested_in_event=False),
]
- self.mock_as_api.query_user.return_value = make_awaitable(True)
+ self.mock_as_api.query_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = make_awaitable([])
+ self.mock_store.get_user_by_id = AsyncMock(return_value=[])
event = Mock(
sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
)
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
- make_awaitable((0, {})),
- make_awaitable((1, {event.event_id: 0})),
- ]
- self.mock_store.get_events_as_list.side_effect = [
- make_awaitable([]),
- make_awaitable([event]),
- ]
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+ side_effect=[
+ (0, {}),
+ (1, {event.event_id: 0}),
+ ]
+ )
+ self.mock_store.get_events_as_list = AsyncMock(
+ side_effect=[
+ [],
+ [event],
+ ]
+ )
self.handler.notify_interested_services(RoomStreamToken(None, 1))
self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -95,14 +97,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = make_awaitable(None)
+ self.mock_store.get_user_by_id = AsyncMock(return_value=None)
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
- make_awaitable((0, {event.event_id: 0})),
- ]
- self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
+ self.mock_as_api.query_user.return_value = True
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+ side_effect=[
+ (0, {event.event_id: 0}),
+ ]
+ )
+ self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
self.handler.notify_interested_services(RoomStreamToken(None, 0))
self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -112,13 +116,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [self._mkservice(is_interested_in_event=True)]
services[0].is_interested_in_user.return_value = True
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
+ self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id})
event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
- self.mock_as_api.query_user.return_value = make_awaitable(True)
- self.mock_store.get_all_new_event_ids_stream.side_effect = [
- make_awaitable((0, [event], {event.event_id: 0})),
- ]
+ self.mock_as_api.query_user.return_value = True
+ self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+ side_effect=[
+ (0, [event], {event.event_id: 0}),
+ ]
+ )
self.handler.notify_interested_services(RoomStreamToken(None, 0))
@@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
self._mkservice_alias(is_room_alias_in_namespace=False),
]
- self.mock_as_api.query_alias.return_value = make_awaitable(True)
+ self.mock_as_api.query_alias = AsyncMock(return_value=True)
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
- Mock(room_id=room_id, servers=servers)
+ self.mock_store.get_association_from_room_alias = AsyncMock(
+ return_value=Mock(room_id=room_id, servers=servers)
)
result = self.successResultOf(
@@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_protocol_no_response(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
+ self.mock_as_api.get_3pe_protocol.return_value = None
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_select_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
- {"x-protocol-data": 42, "instances": []}
- )
+ self.mock_as_api.get_3pe_protocol.return_value = {
+ "x-protocol-data": 42,
+ "instances": [],
+ }
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
)
@@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
def test_get_3pe_protocols_one_protocol(self) -> None:
service = self._mkservice(False, ["my-protocol"])
self.mock_store.get_app_services.return_value = [service]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
- {"x-protocol-data": 42, "instances": []}
- )
+ self.mock_as_api.get_3pe_protocol.return_value = {
+ "x-protocol-data": 42,
+ "instances": [],
+ }
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
service_one = self._mkservice(False, ["my-protocol"])
service_two = self._mkservice(False, ["other-protocol"])
self.mock_store.get_app_services.return_value = [service_one, service_two]
- self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
- {"x-protocol-data": 42, "instances": []}
- )
+ self.mock_as_api.get_3pe_protocol.return_value = {
+ "x-protocol-data": 42,
+ "instances": [],
+ }
response = self.successResultOf(
defer.ensureDeferred(self.handler.get_3pe_protocols())
)
@@ -287,13 +296,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
interested_service = self._mkservice(is_interested_in_event=True)
services = [interested_service]
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
- 579
- )
+ self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579)
event = Mock(event_id="event_1")
- self.event_source.sources.receipt.get_new_events_as.return_value = (
- make_awaitable(([event], None))
+ self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+ return_value=([event], None)
)
self.handler.notify_interested_services_ephemeral(
@@ -317,13 +324,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
services = [interested_service]
self.mock_store.get_app_services.return_value = services
- self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
- 580
- )
+ self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580)
event = Mock(event_id="event_1")
- self.event_source.sources.receipt.get_new_events_as.return_value = (
- make_awaitable(([event], None))
+ self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+ return_value=([event], None)
)
self.handler.notify_interested_services_ephemeral(
@@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
A mock representing the ApplicationService.
"""
service = Mock()
- service.is_interested_in_event.return_value = make_awaitable(
- is_interested_in_event
- )
+ service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event)
service.token = "mock_service_token"
service.url = "mock_service_url"
service.protocols = protocols
|