summary refs log tree commit diff
path: root/tests/handlers/test_appservice.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_appservice.py')
-rw-r--r--tests/handlers/test_appservice.py113
1 files changed, 58 insertions, 55 deletions
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index 9014e60577..5e2ae82cd4 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 
 from typing import Dict, Iterable, List, Optional
-from unittest.mock import Mock
+from unittest.mock import AsyncMock, Mock
 
 from parameterized import parameterized
 
@@ -36,7 +36,7 @@ from synapse.util import Clock
 from synapse.util.stringutils import random_string
 
 from tests import unittest
-from tests.test_utils import event_injection, make_awaitable, simple_async_mock
+from tests.test_utils import event_injection, simple_async_mock
 from tests.unittest import override_config
 from tests.utils import MockClock
 
@@ -46,15 +46,13 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
     def setUp(self) -> None:
         self.mock_store = Mock()
-        self.mock_as_api = Mock()
+        self.mock_as_api = AsyncMock()
         self.mock_scheduler = Mock()
         hs = Mock()
         hs.get_datastores.return_value = Mock(main=self.mock_store)
-        self.mock_store.get_appservice_last_pos.return_value = make_awaitable(None)
-        self.mock_store.set_appservice_last_pos.return_value = make_awaitable(None)
-        self.mock_store.set_appservice_stream_type_pos.return_value = make_awaitable(
-            None
-        )
+        self.mock_store.get_appservice_last_pos = AsyncMock(return_value=None)
+        self.mock_store.set_appservice_last_pos = AsyncMock(return_value=None)
+        self.mock_store.set_appservice_stream_type_pos = AsyncMock(return_value=None)
         hs.get_application_service_api.return_value = self.mock_as_api
         hs.get_application_service_scheduler.return_value = self.mock_scheduler
         hs.get_clock.return_value = MockClock()
@@ -69,21 +67,25 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice(is_interested_in_event=False),
         ]
 
-        self.mock_as_api.query_user.return_value = make_awaitable(True)
+        self.mock_as_api.query_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = make_awaitable([])
+        self.mock_store.get_user_by_id = AsyncMock(return_value=[])
 
         event = Mock(
             sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
         )
-        self.mock_store.get_all_new_event_ids_stream.side_effect = [
-            make_awaitable((0, {})),
-            make_awaitable((1, {event.event_id: 0})),
-        ]
-        self.mock_store.get_events_as_list.side_effect = [
-            make_awaitable([]),
-            make_awaitable([event]),
-        ]
+        self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+            side_effect=[
+                (0, {}),
+                (1, {event.event_id: 0}),
+            ]
+        )
+        self.mock_store.get_events_as_list = AsyncMock(
+            side_effect=[
+                [],
+                [event],
+            ]
+        )
         self.handler.notify_interested_services(RoomStreamToken(None, 1))
 
         self.mock_scheduler.enqueue_for_appservice.assert_called_once_with(
@@ -95,14 +97,16 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [self._mkservice(is_interested_in_event=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = make_awaitable(None)
+        self.mock_store.get_user_by_id = AsyncMock(return_value=None)
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.query_user.return_value = make_awaitable(True)
-        self.mock_store.get_all_new_event_ids_stream.side_effect = [
-            make_awaitable((0, {event.event_id: 0})),
-        ]
-        self.mock_store.get_events_as_list.side_effect = [make_awaitable([event])]
+        self.mock_as_api.query_user.return_value = True
+        self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+            side_effect=[
+                (0, {event.event_id: 0}),
+            ]
+        )
+        self.mock_store.get_events_as_list = AsyncMock(side_effect=[[event]])
         self.handler.notify_interested_services(RoomStreamToken(None, 0))
 
         self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
@@ -112,13 +116,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [self._mkservice(is_interested_in_event=True)]
         services[0].is_interested_in_user.return_value = True
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_user_by_id.return_value = make_awaitable({"name": user_id})
+        self.mock_store.get_user_by_id = AsyncMock(return_value={"name": user_id})
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.query_user.return_value = make_awaitable(True)
-        self.mock_store.get_all_new_event_ids_stream.side_effect = [
-            make_awaitable((0, [event], {event.event_id: 0})),
-        ]
+        self.mock_as_api.query_user.return_value = True
+        self.mock_store.get_all_new_event_ids_stream = AsyncMock(
+            side_effect=[
+                (0, [event], {event.event_id: 0}),
+            ]
+        )
 
         self.handler.notify_interested_services(RoomStreamToken(None, 0))
 
@@ -141,10 +147,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice_alias(is_room_alias_in_namespace=False),
         ]
 
-        self.mock_as_api.query_alias.return_value = make_awaitable(True)
+        self.mock_as_api.query_alias = AsyncMock(return_value=True)
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_association_from_room_alias.return_value = make_awaitable(
-            Mock(room_id=room_id, servers=servers)
+        self.mock_store.get_association_from_room_alias = AsyncMock(
+            return_value=Mock(room_id=room_id, servers=servers)
         )
 
         result = self.successResultOf(
@@ -177,7 +183,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_get_3pe_protocols_protocol_no_response(self) -> None:
         service = self._mkservice(False, ["my-protocol"])
         self.mock_store.get_app_services.return_value = [service]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(None)
+        self.mock_as_api.get_3pe_protocol.return_value = None
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
@@ -189,9 +195,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_get_3pe_protocols_select_one_protocol(self) -> None:
         service = self._mkservice(False, ["my-protocol"])
         self.mock_store.get_app_services.return_value = [service]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
-            {"x-protocol-data": 42, "instances": []}
-        )
+        self.mock_as_api.get_3pe_protocol.return_value = {
+            "x-protocol-data": 42,
+            "instances": [],
+        }
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols("my-protocol"))
         )
@@ -205,9 +212,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_get_3pe_protocols_one_protocol(self) -> None:
         service = self._mkservice(False, ["my-protocol"])
         self.mock_store.get_app_services.return_value = [service]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
-            {"x-protocol-data": 42, "instances": []}
-        )
+        self.mock_as_api.get_3pe_protocol.return_value = {
+            "x-protocol-data": 42,
+            "instances": [],
+        }
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
@@ -222,9 +230,10 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         service_one = self._mkservice(False, ["my-protocol"])
         service_two = self._mkservice(False, ["other-protocol"])
         self.mock_store.get_app_services.return_value = [service_one, service_two]
-        self.mock_as_api.get_3pe_protocol.return_value = make_awaitable(
-            {"x-protocol-data": 42, "instances": []}
-        )
+        self.mock_as_api.get_3pe_protocol.return_value = {
+            "x-protocol-data": 42,
+            "instances": [],
+        }
         response = self.successResultOf(
             defer.ensureDeferred(self.handler.get_3pe_protocols())
         )
@@ -287,13 +296,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         interested_service = self._mkservice(is_interested_in_event=True)
         services = [interested_service]
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
-            579
-        )
+        self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=579)
 
         event = Mock(event_id="event_1")
-        self.event_source.sources.receipt.get_new_events_as.return_value = (
-            make_awaitable(([event], None))
+        self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+            return_value=([event], None)
         )
 
         self.handler.notify_interested_services_ephemeral(
@@ -317,13 +324,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         services = [interested_service]
 
         self.mock_store.get_app_services.return_value = services
-        self.mock_store.get_type_stream_id_for_appservice.return_value = make_awaitable(
-            580
-        )
+        self.mock_store.get_type_stream_id_for_appservice = AsyncMock(return_value=580)
 
         event = Mock(event_id="event_1")
-        self.event_source.sources.receipt.get_new_events_as.return_value = (
-            make_awaitable(([event], None))
+        self.event_source.sources.receipt.get_new_events_as = AsyncMock(
+            return_value=([event], None)
         )
 
         self.handler.notify_interested_services_ephemeral(
@@ -350,9 +355,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             A mock representing the ApplicationService.
         """
         service = Mock()
-        service.is_interested_in_event.return_value = make_awaitable(
-            is_interested_in_event
-        )
+        service.is_interested_in_event = AsyncMock(return_value=is_interested_in_event)
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         service.protocols = protocols