summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/7775.misc1
-rw-r--r--synapse/appservice/api.py1
-rw-r--r--synapse/handlers/appservice.py74
-rw-r--r--tests/handlers/test_appservice.py68
4 files changed, 68 insertions, 76 deletions
diff --git a/changelog.d/7775.misc b/changelog.d/7775.misc
new file mode 100644
index 0000000000..af6fdb782f
--- /dev/null
+++ b/changelog.d/7775.misc
@@ -0,0 +1 @@
+Convert the appserver handler to async/await.
diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index da9a5e86d4..f92bfb420b 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -98,7 +98,6 @@ class ApplicationServiceApi(SimpleHttpClient):
         if service.url is None:
             return False
         uri = service.url + ("/users/%s" % urllib.parse.quote(user_id))
-        response = None
         try:
             response = yield self.get_json(uri, {"access_token": service.hs_token})
             if response is not None:  # just an empty json object
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index 904c96eeec..92d4c6e16c 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -48,8 +48,7 @@ class ApplicationServicesHandler(object):
         self.current_max = 0
         self.is_processing = False
 
-    @defer.inlineCallbacks
-    def notify_interested_services(self, current_id):
+    async def notify_interested_services(self, current_id):
         """Notifies (pushes) all application services interested in this event.
 
         Pushing is done asynchronously, so this method won't block for any
@@ -74,7 +73,7 @@ class ApplicationServicesHandler(object):
                     (
                         upper_bound,
                         events,
-                    ) = yield self.store.get_new_events_for_appservice(
+                    ) = await self.store.get_new_events_for_appservice(
                         self.current_max, limit
                     )
 
@@ -85,10 +84,9 @@ class ApplicationServicesHandler(object):
                     for event in events:
                         events_by_room.setdefault(event.room_id, []).append(event)
 
-                    @defer.inlineCallbacks
-                    def handle_event(event):
+                    async def handle_event(event):
                         # Gather interested services
-                        services = yield self._get_services_for_event(event)
+                        services = await self._get_services_for_event(event)
                         if len(services) == 0:
                             return  # no services need notifying
 
@@ -96,9 +94,9 @@ class ApplicationServicesHandler(object):
                         # query API for all services which match that user regex.
                         # This needs to block as these user queries need to be
                         # made BEFORE pushing the event.
-                        yield self._check_user_exists(event.sender)
+                        await self._check_user_exists(event.sender)
                         if event.type == EventTypes.Member:
-                            yield self._check_user_exists(event.state_key)
+                            await self._check_user_exists(event.state_key)
 
                         if not self.started_scheduler:
 
@@ -115,17 +113,16 @@ class ApplicationServicesHandler(object):
                             self.scheduler.submit_event_for_as(service, event)
 
                         now = self.clock.time_msec()
-                        ts = yield self.store.get_received_ts(event.event_id)
+                        ts = await self.store.get_received_ts(event.event_id)
                         synapse.metrics.event_processing_lag_by_event.labels(
                             "appservice_sender"
                         ).observe((now - ts) / 1000)
 
-                    @defer.inlineCallbacks
-                    def handle_room_events(events):
+                    async def handle_room_events(events):
                         for event in events:
-                            yield handle_event(event)
+                            await handle_event(event)
 
-                    yield make_deferred_yieldable(
+                    await make_deferred_yieldable(
                         defer.gatherResults(
                             [
                                 run_in_background(handle_room_events, evs)
@@ -135,10 +132,10 @@ class ApplicationServicesHandler(object):
                         )
                     )
 
-                    yield self.store.set_appservice_last_pos(upper_bound)
+                    await self.store.set_appservice_last_pos(upper_bound)
 
                     now = self.clock.time_msec()
-                    ts = yield self.store.get_received_ts(events[-1].event_id)
+                    ts = await self.store.get_received_ts(events[-1].event_id)
 
                     synapse.metrics.event_processing_positions.labels(
                         "appservice_sender"
@@ -161,8 +158,7 @@ class ApplicationServicesHandler(object):
             finally:
                 self.is_processing = False
 
-    @defer.inlineCallbacks
-    def query_user_exists(self, user_id):
+    async def query_user_exists(self, user_id):
         """Check if any application service knows this user_id exists.
 
         Args:
@@ -170,15 +166,14 @@ class ApplicationServicesHandler(object):
         Returns:
             True if this user exists on at least one application service.
         """
-        user_query_services = yield self._get_services_for_user(user_id=user_id)
+        user_query_services = self._get_services_for_user(user_id=user_id)
         for user_service in user_query_services:
-            is_known_user = yield self.appservice_api.query_user(user_service, user_id)
+            is_known_user = await self.appservice_api.query_user(user_service, user_id)
             if is_known_user:
                 return True
         return False
 
-    @defer.inlineCallbacks
-    def query_room_alias_exists(self, room_alias):
+    async def query_room_alias_exists(self, room_alias):
         """Check if an application service knows this room alias exists.
 
         Args:
@@ -193,19 +188,18 @@ class ApplicationServicesHandler(object):
             s for s in services if (s.is_interested_in_alias(room_alias_str))
         ]
         for alias_service in alias_query_services:
-            is_known_alias = yield self.appservice_api.query_alias(
+            is_known_alias = await self.appservice_api.query_alias(
                 alias_service, room_alias_str
             )
             if is_known_alias:
                 # the alias exists now so don't query more ASes.
-                result = yield self.store.get_association_from_room_alias(room_alias)
+                result = await self.store.get_association_from_room_alias(room_alias)
                 return result
 
-    @defer.inlineCallbacks
-    def query_3pe(self, kind, protocol, fields):
-        services = yield self._get_services_for_3pn(protocol)
+    async def query_3pe(self, kind, protocol, fields):
+        services = self._get_services_for_3pn(protocol)
 
-        results = yield make_deferred_yieldable(
+        results = await make_deferred_yieldable(
             defer.DeferredList(
                 [
                     run_in_background(
@@ -224,8 +218,7 @@ class ApplicationServicesHandler(object):
 
         return ret
 
-    @defer.inlineCallbacks
-    def get_3pe_protocols(self, only_protocol=None):
+    async def get_3pe_protocols(self, only_protocol=None):
         services = self.store.get_app_services()
         protocols = {}
 
@@ -238,7 +231,7 @@ class ApplicationServicesHandler(object):
                 if p not in protocols:
                     protocols[p] = []
 
-                info = yield self.appservice_api.get_3pe_protocol(s, p)
+                info = await self.appservice_api.get_3pe_protocol(s, p)
 
                 if info is not None:
                     protocols[p].append(info)
@@ -263,8 +256,7 @@ class ApplicationServicesHandler(object):
 
         return protocols
 
-    @defer.inlineCallbacks
-    def _get_services_for_event(self, event):
+    async def _get_services_for_event(self, event):
         """Retrieve a list of application services interested in this event.
 
         Args:
@@ -280,7 +272,7 @@ class ApplicationServicesHandler(object):
         # inside of a list comprehension anymore.
         interested_list = []
         for s in services:
-            if (yield s.is_interested(event, self.store)):
+            if await s.is_interested(event, self.store):
                 interested_list.append(s)
 
         return interested_list
@@ -288,21 +280,20 @@ class ApplicationServicesHandler(object):
     def _get_services_for_user(self, user_id):
         services = self.store.get_app_services()
         interested_list = [s for s in services if (s.is_interested_in_user(user_id))]
-        return defer.succeed(interested_list)
+        return interested_list
 
     def _get_services_for_3pn(self, protocol):
         services = self.store.get_app_services()
         interested_list = [s for s in services if s.is_interested_in_protocol(protocol)]
-        return defer.succeed(interested_list)
+        return interested_list
 
-    @defer.inlineCallbacks
-    def _is_unknown_user(self, user_id):
+    async def _is_unknown_user(self, user_id):
         if not self.is_mine_id(user_id):
             # we don't know if they are unknown or not since it isn't one of our
             # users. We can't poke ASes.
             return False
 
-        user_info = yield self.store.get_user_by_id(user_id)
+        user_info = await self.store.get_user_by_id(user_id)
         if user_info:
             return False
 
@@ -311,10 +302,9 @@ class ApplicationServicesHandler(object):
         service_list = [s for s in services if s.sender == user_id]
         return len(service_list) == 0
 
-    @defer.inlineCallbacks
-    def _check_user_exists(self, user_id):
-        unknown_user = yield self._is_unknown_user(user_id)
+    async def _check_user_exists(self, user_id):
+        unknown_user = await self._is_unknown_user(user_id)
         if unknown_user:
-            exists = yield self.query_user_exists(user_id)
+            exists = await self.query_user_exists(user_id)
             return exists
         return True
diff --git a/tests/handlers/test_appservice.py b/tests/handlers/test_appservice.py
index ba7148ec01..ebabe9a7d6 100644
--- a/tests/handlers/test_appservice.py
+++ b/tests/handlers/test_appservice.py
@@ -32,10 +32,11 @@ class AppServiceHandlerTestCase(unittest.TestCase):
         self.mock_as_api = Mock()
         self.mock_scheduler = Mock()
         hs = Mock()
-        hs.get_datastore = Mock(return_value=self.mock_store)
-        self.mock_store.get_received_ts.return_value = 0
-        hs.get_application_service_api = Mock(return_value=self.mock_as_api)
-        hs.get_application_service_scheduler = Mock(return_value=self.mock_scheduler)
+        hs.get_datastore.return_value = self.mock_store
+        self.mock_store.get_received_ts.return_value = defer.succeed(0)
+        self.mock_store.set_appservice_last_pos.return_value = defer.succeed(None)
+        hs.get_application_service_api.return_value = self.mock_as_api
+        hs.get_application_service_scheduler.return_value = self.mock_scheduler
         hs.get_clock.return_value = MockClock()
         self.handler = ApplicationServicesHandler(hs)
 
@@ -48,18 +49,18 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice(is_interested=False),
         ]
 
-        self.mock_store.get_app_services = Mock(return_value=services)
-        self.mock_store.get_user_by_id = Mock(return_value=[])
+        self.mock_as_api.query_user.return_value = defer.succeed(True)
+        self.mock_store.get_app_services.return_value = services
+        self.mock_store.get_user_by_id.return_value = defer.succeed([])
 
         event = Mock(
             sender="@someone:anywhere", type="m.room.message", room_id="!foo:bar"
         )
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            (0, [event]),
-            (0, []),
+            defer.succeed((0, [event])),
+            defer.succeed((0, [])),
         ]
-        self.mock_as_api.push = Mock()
-        yield self.handler.notify_interested_services(0)
+        yield defer.ensureDeferred(self.handler.notify_interested_services(0))
         self.mock_scheduler.submit_event_for_as.assert_called_once_with(
             interested_service, event
         )
@@ -68,36 +69,34 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_query_user_exists_unknown_user(self):
         user_id = "@someone:anywhere"
         services = [self._mkservice(is_interested=True)]
-        services[0].is_interested_in_user = Mock(return_value=True)
-        self.mock_store.get_app_services = Mock(return_value=services)
-        self.mock_store.get_user_by_id = Mock(return_value=None)
+        services[0].is_interested_in_user.return_value = True
+        self.mock_store.get_app_services.return_value = services
+        self.mock_store.get_user_by_id.return_value = defer.succeed(None)
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.push = Mock()
-        self.mock_as_api.query_user = Mock()
+        self.mock_as_api.query_user.return_value = defer.succeed(True)
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            (0, [event]),
-            (0, []),
+            defer.succeed((0, [event])),
+            defer.succeed((0, [])),
         ]
-        yield self.handler.notify_interested_services(0)
+        yield defer.ensureDeferred(self.handler.notify_interested_services(0))
         self.mock_as_api.query_user.assert_called_once_with(services[0], user_id)
 
     @defer.inlineCallbacks
     def test_query_user_exists_known_user(self):
         user_id = "@someone:anywhere"
         services = [self._mkservice(is_interested=True)]
-        services[0].is_interested_in_user = Mock(return_value=True)
-        self.mock_store.get_app_services = Mock(return_value=services)
-        self.mock_store.get_user_by_id = Mock(return_value={"name": user_id})
+        services[0].is_interested_in_user.return_value = True
+        self.mock_store.get_app_services.return_value = services
+        self.mock_store.get_user_by_id.return_value = defer.succeed({"name": user_id})
 
         event = Mock(sender=user_id, type="m.room.message", room_id="!foo:bar")
-        self.mock_as_api.push = Mock()
-        self.mock_as_api.query_user = Mock()
+        self.mock_as_api.query_user.return_value = defer.succeed(True)
         self.mock_store.get_new_events_for_appservice.side_effect = [
-            (0, [event]),
-            (0, []),
+            defer.succeed((0, [event])),
+            defer.succeed((0, [])),
         ]
-        yield self.handler.notify_interested_services(0)
+        yield defer.ensureDeferred(self.handler.notify_interested_services(0))
         self.assertFalse(
             self.mock_as_api.query_user.called,
             "query_user called when it shouldn't have been.",
@@ -107,7 +106,7 @@ class AppServiceHandlerTestCase(unittest.TestCase):
     def test_query_room_alias_exists(self):
         room_alias_str = "#foo:bar"
         room_alias = Mock()
-        room_alias.to_string = Mock(return_value=room_alias_str)
+        room_alias.to_string.return_value = room_alias_str
 
         room_id = "!alpha:bet"
         servers = ["aperture"]
@@ -118,12 +117,15 @@ class AppServiceHandlerTestCase(unittest.TestCase):
             self._mkservice_alias(is_interested_in_alias=False),
         ]
 
-        self.mock_store.get_app_services = Mock(return_value=services)
-        self.mock_store.get_association_from_room_alias = Mock(
-            return_value=Mock(room_id=room_id, servers=servers)
+        self.mock_as_api.query_alias.return_value = defer.succeed(True)
+        self.mock_store.get_app_services.return_value = services
+        self.mock_store.get_association_from_room_alias.return_value = defer.succeed(
+            Mock(room_id=room_id, servers=servers)
         )
 
-        result = yield self.handler.query_room_alias_exists(room_alias)
+        result = yield defer.ensureDeferred(
+            self.handler.query_room_alias_exists(room_alias)
+        )
 
         self.mock_as_api.query_alias.assert_called_once_with(
             interested_service, room_alias_str
@@ -133,14 +135,14 @@ class AppServiceHandlerTestCase(unittest.TestCase):
 
     def _mkservice(self, is_interested):
         service = Mock()
-        service.is_interested = Mock(return_value=is_interested)
+        service.is_interested.return_value = defer.succeed(is_interested)
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         return service
 
     def _mkservice_alias(self, is_interested_in_alias):
         service = Mock()
-        service.is_interested_in_alias = Mock(return_value=is_interested_in_alias)
+        service.is_interested_in_alias.return_value = is_interested_in_alias
         service.token = "mock_service_token"
         service.url = "mock_service_url"
         return service