summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/7860.misc1
-rw-r--r--synapse/handlers/_base.py7
-rw-r--r--synapse/handlers/message.py8
-rw-r--r--synapse/handlers/profile.py63
-rw-r--r--synapse/handlers/receipts.py16
-rw-r--r--tests/handlers/test_profile.py17
6 files changed, 53 insertions, 59 deletions
diff --git a/changelog.d/7860.misc b/changelog.d/7860.misc
new file mode 100644
index 0000000000..fdd48b955c
--- /dev/null
+++ b/changelog.d/7860.misc
@@ -0,0 +1 @@
+Convert _base, profile, and _receipts handlers to async/await.
diff --git a/synapse/handlers/_base.py b/synapse/handlers/_base.py
index 6a4944467a..ba2bf99800 100644
--- a/synapse/handlers/_base.py
+++ b/synapse/handlers/_base.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 import synapse.state
 import synapse.storage
 import synapse.types
@@ -66,8 +64,7 @@ class BaseHandler(object):
 
         self.event_builder_factory = hs.get_event_builder_factory()
 
-    @defer.inlineCallbacks
-    def ratelimit(self, requester, update=True, is_admin_redaction=False):
+    async def ratelimit(self, requester, update=True, is_admin_redaction=False):
         """Ratelimits requests.
 
         Args:
@@ -99,7 +96,7 @@ class BaseHandler(object):
         burst_count = self._rc_message.burst_count
 
         # Check if there is a per user override in the DB.
-        override = yield self.store.get_ratelimit_for_user(user_id)
+        override = await self.store.get_ratelimit_for_user(user_id)
         if override:
             # If overridden with a null Hz then ratelimiting has been entirely
             # disabled for the user
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da206e1ec1..c47764a4ce 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -488,11 +488,15 @@ class EventCreationHandler(object):
 
                 try:
                     if "displayname" not in content:
-                        displayname = yield profile.get_displayname(target)
+                        displayname = yield defer.ensureDeferred(
+                            profile.get_displayname(target)
+                        )
                         if displayname is not None:
                             content["displayname"] = displayname
                     if "avatar_url" not in content:
-                        avatar_url = yield profile.get_avatar_url(target)
+                        avatar_url = yield defer.ensureDeferred(
+                            profile.get_avatar_url(target)
+                        )
                         if avatar_url is not None:
                             content["avatar_url"] = avatar_url
                 except Exception as e:
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index 4b1e3073a8..31a2e5ea18 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.api.errors import (
     AuthError,
     Codes,
@@ -54,16 +52,15 @@ class BaseProfileHandler(BaseHandler):
 
         self.user_directory_handler = hs.get_user_directory_handler()
 
-    @defer.inlineCallbacks
-    def get_profile(self, user_id):
+    async def get_profile(self, user_id):
         target_user = UserID.from_string(user_id)
 
         if self.hs.is_mine(target_user):
             try:
-                displayname = yield self.store.get_profile_displayname(
+                displayname = await self.store.get_profile_displayname(
                     target_user.localpart
                 )
-                avatar_url = yield self.store.get_profile_avatar_url(
+                avatar_url = await self.store.get_profile_avatar_url(
                     target_user.localpart
                 )
             except StoreError as e:
@@ -74,7 +71,7 @@ class BaseProfileHandler(BaseHandler):
             return {"displayname": displayname, "avatar_url": avatar_url}
         else:
             try:
-                result = yield self.federation.make_query(
+                result = await self.federation.make_query(
                     destination=target_user.domain,
                     query_type="profile",
                     args={"user_id": user_id},
@@ -86,8 +83,7 @@ class BaseProfileHandler(BaseHandler):
             except HttpResponseException as e:
                 raise e.to_synapse_error()
 
-    @defer.inlineCallbacks
-    def get_profile_from_cache(self, user_id):
+    async def get_profile_from_cache(self, user_id):
         """Get the profile information from our local cache. If the user is
         ours then the profile information will always be corect. Otherwise,
         it may be out of date/missing.
@@ -95,10 +91,10 @@ class BaseProfileHandler(BaseHandler):
         target_user = UserID.from_string(user_id)
         if self.hs.is_mine(target_user):
             try:
-                displayname = yield self.store.get_profile_displayname(
+                displayname = await self.store.get_profile_displayname(
                     target_user.localpart
                 )
-                avatar_url = yield self.store.get_profile_avatar_url(
+                avatar_url = await self.store.get_profile_avatar_url(
                     target_user.localpart
                 )
             except StoreError as e:
@@ -108,14 +104,13 @@ class BaseProfileHandler(BaseHandler):
 
             return {"displayname": displayname, "avatar_url": avatar_url}
         else:
-            profile = yield self.store.get_from_remote_profile_cache(user_id)
+            profile = await self.store.get_from_remote_profile_cache(user_id)
             return profile or {}
 
-    @defer.inlineCallbacks
-    def get_displayname(self, target_user):
+    async def get_displayname(self, target_user):
         if self.hs.is_mine(target_user):
             try:
-                displayname = yield self.store.get_profile_displayname(
+                displayname = await self.store.get_profile_displayname(
                     target_user.localpart
                 )
             except StoreError as e:
@@ -126,7 +121,7 @@ class BaseProfileHandler(BaseHandler):
             return displayname
         else:
             try:
-                result = yield self.federation.make_query(
+                result = await self.federation.make_query(
                     destination=target_user.domain,
                     query_type="profile",
                     args={"user_id": target_user.to_string(), "field": "displayname"},
@@ -189,11 +184,10 @@ class BaseProfileHandler(BaseHandler):
 
         await self._update_join_states(requester, target_user)
 
-    @defer.inlineCallbacks
-    def get_avatar_url(self, target_user):
+    async def get_avatar_url(self, target_user):
         if self.hs.is_mine(target_user):
             try:
-                avatar_url = yield self.store.get_profile_avatar_url(
+                avatar_url = await self.store.get_profile_avatar_url(
                     target_user.localpart
                 )
             except StoreError as e:
@@ -203,7 +197,7 @@ class BaseProfileHandler(BaseHandler):
             return avatar_url
         else:
             try:
-                result = yield self.federation.make_query(
+                result = await self.federation.make_query(
                     destination=target_user.domain,
                     query_type="profile",
                     args={"user_id": target_user.to_string(), "field": "avatar_url"},
@@ -253,8 +247,7 @@ class BaseProfileHandler(BaseHandler):
 
         await self._update_join_states(requester, target_user)
 
-    @defer.inlineCallbacks
-    def on_profile_query(self, args):
+    async def on_profile_query(self, args):
         user = UserID.from_string(args["user_id"])
         if not self.hs.is_mine(user):
             raise SynapseError(400, "User is not hosted on this homeserver")
@@ -264,12 +257,12 @@ class BaseProfileHandler(BaseHandler):
         response = {}
         try:
             if just_field is None or just_field == "displayname":
-                response["displayname"] = yield self.store.get_profile_displayname(
+                response["displayname"] = await self.store.get_profile_displayname(
                     user.localpart
                 )
 
             if just_field is None or just_field == "avatar_url":
-                response["avatar_url"] = yield self.store.get_profile_avatar_url(
+                response["avatar_url"] = await self.store.get_profile_avatar_url(
                     user.localpart
                 )
         except StoreError as e:
@@ -304,8 +297,7 @@ class BaseProfileHandler(BaseHandler):
                     "Failed to update join event for room %s - %s", room_id, str(e)
                 )
 
-    @defer.inlineCallbacks
-    def check_profile_query_allowed(self, target_user, requester=None):
+    async def check_profile_query_allowed(self, target_user, requester=None):
         """Checks whether a profile query is allowed. If the
         'require_auth_for_profile_requests' config flag is set to True and a
         'requester' is provided, the query is only allowed if the two users
@@ -337,8 +329,8 @@ class BaseProfileHandler(BaseHandler):
             return
 
         try:
-            requester_rooms = yield self.store.get_rooms_for_user(requester.to_string())
-            target_user_rooms = yield self.store.get_rooms_for_user(
+            requester_rooms = await self.store.get_rooms_for_user(requester.to_string())
+            target_user_rooms = await self.store.get_rooms_for_user(
                 target_user.to_string()
             )
 
@@ -371,25 +363,24 @@ class MasterProfileHandler(BaseProfileHandler):
             "Update remote profile", self._update_remote_profile_cache
         )
 
-    @defer.inlineCallbacks
-    def _update_remote_profile_cache(self):
+    async def _update_remote_profile_cache(self):
         """Called periodically to check profiles of remote users we haven't
         checked in a while.
         """
-        entries = yield self.store.get_remote_profile_cache_entries_that_expire(
+        entries = await self.store.get_remote_profile_cache_entries_that_expire(
             last_checked=self.clock.time_msec() - self.PROFILE_UPDATE_EVERY_MS
         )
 
         for user_id, displayname, avatar_url in entries:
-            is_subscribed = yield self.store.is_subscribed_remote_profile_for_user(
+            is_subscribed = await self.store.is_subscribed_remote_profile_for_user(
                 user_id
             )
             if not is_subscribed:
-                yield self.store.maybe_delete_remote_profile_cache(user_id)
+                await self.store.maybe_delete_remote_profile_cache(user_id)
                 continue
 
             try:
-                profile = yield self.federation.make_query(
+                profile = await self.federation.make_query(
                     destination=get_domain_from_id(user_id),
                     query_type="profile",
                     args={"user_id": user_id},
@@ -398,7 +389,7 @@ class MasterProfileHandler(BaseProfileHandler):
             except Exception:
                 logger.exception("Failed to get avatar_url")
 
-                yield self.store.update_remote_profile_cache(
+                await self.store.update_remote_profile_cache(
                     user_id, displayname, avatar_url
                 )
                 continue
@@ -407,4 +398,4 @@ class MasterProfileHandler(BaseProfileHandler):
             new_avatar = profile.get("avatar_url")
 
             # We always hit update to update the last_check timestamp
-            yield self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
+            await self.store.update_remote_profile_cache(user_id, new_name, new_avatar)
diff --git a/synapse/handlers/receipts.py b/synapse/handlers/receipts.py
index 8bc100db42..f922d8a545 100644
--- a/synapse/handlers/receipts.py
+++ b/synapse/handlers/receipts.py
@@ -14,8 +14,6 @@
 # limitations under the License.
 import logging
 
-from twisted.internet import defer
-
 from synapse.handlers._base import BaseHandler
 from synapse.types import ReadReceipt, get_domain_from_id
 from synapse.util.async_helpers import maybe_awaitable
@@ -129,15 +127,14 @@ class ReceiptEventSource(object):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
-    def get_new_events(self, from_key, room_ids, **kwargs):
+    async def get_new_events(self, from_key, room_ids, **kwargs):
         from_key = int(from_key)
-        to_key = yield self.get_current_key()
+        to_key = self.get_current_key()
 
         if from_key == to_key:
             return [], to_key
 
-        events = yield self.store.get_linearized_receipts_for_rooms(
+        events = await self.store.get_linearized_receipts_for_rooms(
             room_ids, from_key=from_key, to_key=to_key
         )
 
@@ -146,8 +143,7 @@ class ReceiptEventSource(object):
     def get_current_key(self, direction="f"):
         return self.store.get_max_receipt_stream_id()
 
-    @defer.inlineCallbacks
-    def get_pagination_rows(self, user, config, key):
+    async def get_pagination_rows(self, user, config, key):
         to_key = int(config.from_key)
 
         if config.to_key:
@@ -155,8 +151,8 @@ class ReceiptEventSource(object):
         else:
             from_key = None
 
-        room_ids = yield self.store.get_rooms_for_user(user.to_string())
-        events = yield self.store.get_linearized_receipts_for_rooms(
+        room_ids = await self.store.get_rooms_for_user(user.to_string())
+        events = await self.store.get_linearized_receipts_for_rooms(
             room_ids, from_key=from_key, to_key=to_key
         )
 
diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py
index 29dd7d9c6e..4f1347cd25 100644
--- a/tests/handlers/test_profile.py
+++ b/tests/handlers/test_profile.py
@@ -72,7 +72,9 @@ class ProfileTestCase(unittest.TestCase):
     def test_get_my_name(self):
         yield self.store.set_profile_displayname(self.frank.localpart, "Frank")
 
-        displayname = yield self.handler.get_displayname(self.frank)
+        displayname = yield defer.ensureDeferred(
+            self.handler.get_displayname(self.frank)
+        )
 
         self.assertEquals("Frank", displayname)
 
@@ -140,7 +142,9 @@ class ProfileTestCase(unittest.TestCase):
             {"displayname": "Alice"}
         )
 
-        displayname = yield self.handler.get_displayname(self.alice)
+        displayname = yield defer.ensureDeferred(
+            self.handler.get_displayname(self.alice)
+        )
 
         self.assertEquals(displayname, "Alice")
         self.mock_federation.make_query.assert_called_with(
@@ -155,8 +159,10 @@ class ProfileTestCase(unittest.TestCase):
         yield self.store.create_profile("caroline")
         yield self.store.set_profile_displayname("caroline", "Caroline")
 
-        response = yield self.query_handlers["profile"](
-            {"user_id": "@caroline:test", "field": "displayname"}
+        response = yield defer.ensureDeferred(
+            self.query_handlers["profile"](
+                {"user_id": "@caroline:test", "field": "displayname"}
+            )
         )
 
         self.assertEquals({"displayname": "Caroline"}, response)
@@ -166,8 +172,7 @@ class ProfileTestCase(unittest.TestCase):
         yield self.store.set_profile_avatar_url(
             self.frank.localpart, "http://my.server/me.png"
         )
-
-        avatar_url = yield self.handler.get_avatar_url(self.frank)
+        avatar_url = yield defer.ensureDeferred(self.handler.get_avatar_url(self.frank))
 
         self.assertEquals("http://my.server/me.png", avatar_url)