summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6274.misc1
-rw-r--r--changelog.d/6275.misc1
-rw-r--r--synapse/federation/federation_server.py12
-rw-r--r--synapse/replication/http/_base.py6
-rw-r--r--synapse/replication/http/federation.py24
-rw-r--r--synapse/replication/http/login.py7
-rw-r--r--synapse/replication/http/membership.py14
-rw-r--r--synapse/replication/http/register.py12
-rw-r--r--synapse/replication/http/send_event.py7
-rw-r--r--synapse/rest/client/v1/room.py166
10 files changed, 106 insertions, 144 deletions
diff --git a/changelog.d/6274.misc b/changelog.d/6274.misc
new file mode 100644
index 0000000000..eb4966124f
--- /dev/null
+++ b/changelog.d/6274.misc
@@ -0,0 +1 @@
+Port replication http server endpoints to async/await.
diff --git a/changelog.d/6275.misc b/changelog.d/6275.misc
new file mode 100644
index 0000000000..f57e2c4adb
--- /dev/null
+++ b/changelog.d/6275.misc
@@ -0,0 +1 @@
+Port room rest handlers to async/await.
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 15c1fa0a51..7c331753ad 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -809,7 +809,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
 
         super(ReplicationFederationHandlerRegistry, self).__init__()
 
-    def on_edu(self, edu_type, origin, content):
+    async def on_edu(self, edu_type, origin, content):
         """Overrides FederationHandlerRegistry
         """
         if not self.config.use_presence and edu_type == "m.presence":
@@ -817,17 +817,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry):
 
         handler = self.edu_handlers.get(edu_type)
         if handler:
-            return super(ReplicationFederationHandlerRegistry, self).on_edu(
+            return await super(ReplicationFederationHandlerRegistry, self).on_edu(
                 edu_type, origin, content
             )
 
-        return self._send_edu(edu_type=edu_type, origin=origin, content=content)
+        return await self._send_edu(edu_type=edu_type, origin=origin, content=content)
 
-    def on_query(self, query_type, args):
+    async def on_query(self, query_type, args):
         """Overrides FederationHandlerRegistry
         """
         handler = self.query_handlers.get(query_type)
         if handler:
-            return handler(args)
+            return await handler(args)
 
-        return self._get_query_client(query_type=query_type, args=args)
+        return await self._get_query_client(query_type=query_type, args=args)
diff --git a/synapse/replication/http/_base.py b/synapse/replication/http/_base.py
index 03560c1f0e..9be37cd998 100644
--- a/synapse/replication/http/_base.py
+++ b/synapse/replication/http/_base.py
@@ -110,14 +110,14 @@ class ReplicationEndpoint(object):
         return {}
 
     @abc.abstractmethod
-    def _handle_request(self, request, **kwargs):
+    async def _handle_request(self, request, **kwargs):
         """Handle incoming request.
 
         This is called with the request object and PATH_ARGS.
 
         Returns:
-            Deferred[dict]: A JSON serialisable dict to be used as response
-            body of request.
+            tuple[int, dict]: HTTP status code and a JSON serialisable dict
+            to be used as response body of request.
         """
         pass
 
diff --git a/synapse/replication/http/federation.py b/synapse/replication/http/federation.py
index 2f16955954..9af4e7e173 100644
--- a/synapse/replication/http/federation.py
+++ b/synapse/replication/http/federation.py
@@ -82,8 +82,7 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
 
         return payload
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request):
+    async def _handle_request(self, request):
         with Measure(self.clock, "repl_fed_send_events_parse"):
             content = parse_json_object_from_request(request)
 
@@ -101,15 +100,13 @@ class ReplicationFederationSendEventsRestServlet(ReplicationEndpoint):
                 EventType = event_type_from_format_version(format_ver)
                 event = EventType(event_dict, internal_metadata, rejected_reason)
 
-                context = yield EventContext.deserialize(
-                    self.store, event_payload["context"]
-                )
+                context = EventContext.deserialize(self.store, event_payload["context"])
 
                 event_and_contexts.append((event, context))
 
         logger.info("Got %d events from federation", len(event_and_contexts))
 
-        yield self.federation_handler.persist_events_and_notify(
+        await self.federation_handler.persist_events_and_notify(
             event_and_contexts, backfilled
         )
 
@@ -144,8 +141,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
     def _serialize_payload(edu_type, origin, content):
         return {"origin": origin, "content": content}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, edu_type):
+    async def _handle_request(self, request, edu_type):
         with Measure(self.clock, "repl_fed_send_edu_parse"):
             content = parse_json_object_from_request(request)
 
@@ -154,7 +150,7 @@ class ReplicationFederationSendEduRestServlet(ReplicationEndpoint):
 
         logger.info("Got %r edu from %s", edu_type, origin)
 
-        result = yield self.registry.on_edu(edu_type, origin, edu_content)
+        result = await self.registry.on_edu(edu_type, origin, edu_content)
 
         return 200, result
 
@@ -193,8 +189,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
         """
         return {"args": args}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, query_type):
+    async def _handle_request(self, request, query_type):
         with Measure(self.clock, "repl_fed_query_parse"):
             content = parse_json_object_from_request(request)
 
@@ -202,7 +197,7 @@ class ReplicationGetQueryRestServlet(ReplicationEndpoint):
 
         logger.info("Got %r query", query_type)
 
-        result = yield self.registry.on_query(query_type, args)
+        result = await self.registry.on_query(query_type, args)
 
         return 200, result
 
@@ -234,9 +229,8 @@ class ReplicationCleanRoomRestServlet(ReplicationEndpoint):
         """
         return {}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, room_id):
-        yield self.store.clean_room_for_join(room_id)
+    async def _handle_request(self, request, room_id):
+        await self.store.clean_room_for_join(room_id)
 
         return 200, {}
 
diff --git a/synapse/replication/http/login.py b/synapse/replication/http/login.py
index 786f5232b2..798b9d3af5 100644
--- a/synapse/replication/http/login.py
+++ b/synapse/replication/http/login.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 
@@ -52,15 +50,14 @@ class RegisterDeviceReplicationServlet(ReplicationEndpoint):
             "is_guest": is_guest,
         }
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, user_id):
+    async def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
         device_id = content["device_id"]
         initial_display_name = content["initial_display_name"]
         is_guest = content["is_guest"]
 
-        device_id, access_token = yield self.registration_handler.register_device(
+        device_id, access_token = await self.registration_handler.register_device(
             user_id, device_id, initial_display_name, is_guest
         )
 
diff --git a/synapse/replication/http/membership.py b/synapse/replication/http/membership.py
index b9ce3477ad..b5f5f13a62 100644
--- a/synapse/replication/http/membership.py
+++ b/synapse/replication/http/membership.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 from synapse.types import Requester, UserID
@@ -65,8 +63,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
             "content": content,
         }
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, room_id, user_id):
+    async def _handle_request(self, request, room_id, user_id):
         content = parse_json_object_from_request(request)
 
         remote_room_hosts = content["remote_room_hosts"]
@@ -79,7 +76,7 @@ class ReplicationRemoteJoinRestServlet(ReplicationEndpoint):
 
         logger.info("remote_join: %s into room: %s", user_id, room_id)
 
-        yield self.federation_handler.do_invite_join(
+        await self.federation_handler.do_invite_join(
             remote_room_hosts, room_id, user_id, event_content
         )
 
@@ -123,8 +120,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
             "remote_room_hosts": remote_room_hosts,
         }
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, room_id, user_id):
+    async def _handle_request(self, request, room_id, user_id):
         content = parse_json_object_from_request(request)
 
         remote_room_hosts = content["remote_room_hosts"]
@@ -137,7 +133,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
         logger.info("remote_reject_invite: %s out of room: %s", user_id, room_id)
 
         try:
-            event = yield self.federation_handler.do_remotely_reject_invite(
+            event = await self.federation_handler.do_remotely_reject_invite(
                 remote_room_hosts, room_id, user_id
             )
             ret = event.get_pdu_json()
@@ -150,7 +146,7 @@ class ReplicationRemoteRejectInviteRestServlet(ReplicationEndpoint):
             #
             logger.warn("Failed to reject invite: %s", e)
 
-            yield self.store.locally_reject_invite(user_id, room_id)
+            await self.store.locally_reject_invite(user_id, room_id)
             ret = {}
 
         return 200, ret
diff --git a/synapse/replication/http/register.py b/synapse/replication/http/register.py
index 38260256cf..915cfb9430 100644
--- a/synapse/replication/http/register.py
+++ b/synapse/replication/http/register.py
@@ -15,8 +15,6 @@
 
 import logging
 
-from twisted.internet import defer
-
 from synapse.http.servlet import parse_json_object_from_request
 from synapse.replication.http._base import ReplicationEndpoint
 
@@ -74,11 +72,10 @@ class ReplicationRegisterServlet(ReplicationEndpoint):
             "address": address,
         }
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, user_id):
+    async def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
-        yield self.registration_handler.register_with_store(
+        await self.registration_handler.register_with_store(
             user_id=user_id,
             password_hash=content["password_hash"],
             was_guest=content["was_guest"],
@@ -117,14 +114,13 @@ class ReplicationPostRegisterActionsServlet(ReplicationEndpoint):
         """
         return {"auth_result": auth_result, "access_token": access_token}
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, user_id):
+    async def _handle_request(self, request, user_id):
         content = parse_json_object_from_request(request)
 
         auth_result = content["auth_result"]
         access_token = content["access_token"]
 
-        yield self.registration_handler.post_registration_actions(
+        await self.registration_handler.post_registration_actions(
             user_id=user_id, auth_result=auth_result, access_token=access_token
         )
 
diff --git a/synapse/replication/http/send_event.py b/synapse/replication/http/send_event.py
index adb9b2f7f4..9bafd60b14 100644
--- a/synapse/replication/http/send_event.py
+++ b/synapse/replication/http/send_event.py
@@ -87,8 +87,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
 
         return payload
 
-    @defer.inlineCallbacks
-    def _handle_request(self, request, event_id):
+    async def _handle_request(self, request, event_id):
         with Measure(self.clock, "repl_send_event_parse"):
             content = parse_json_object_from_request(request)
 
@@ -101,7 +100,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             event = EventType(event_dict, internal_metadata, rejected_reason)
 
             requester = Requester.deserialize(self.store, content["requester"])
-            context = yield EventContext.deserialize(self.store, content["context"])
+            context = EventContext.deserialize(self.store, content["context"])
 
             ratelimit = content["ratelimit"]
             extra_users = [UserID.from_string(u) for u in content["extra_users"]]
@@ -113,7 +112,7 @@ class ReplicationSendEventRestServlet(ReplicationEndpoint):
             "Got event to send with ID: %s into room: %s", event.event_id, event.room_id
         )
 
-        yield self.event_creation_handler.persist_and_notify_client_event(
+        await self.event_creation_handler.persist_and_notify_client_event(
             requester, event, context, ratelimit=ratelimit, extra_users=extra_users
         )
 
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index 9c1d41421c..86bbcc0eea 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
@@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet):
         set_tag("txn_id", txn_id)
         return self.txns.fetch_or_execute_request(request, self.on_POST, request)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_POST(self, request):
+        requester = await self.auth.get_user_by_req(request)
 
-        info = yield self._room_creation_handler.create_room(
+        info = await self._room_creation_handler.create_room(
             requester, self.get_room_config(request)
         )
 
@@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet):
     def on_PUT_no_state_key(self, request, room_id, event_type):
         return self.on_PUT(request, room_id, event_type, "")
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id, event_type, state_key):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id, event_type, state_key):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         format = parse_string(
             request, "format", default="content", allowed_values=["content", "event"]
         )
 
         msg_handler = self.message_handler
-        data = yield msg_handler.get_room_data(
+        data = await msg_handler.get_room_data(
             user_id=requester.user.to_string(),
             room_id=room_id,
             event_type=event_type,
@@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
         elif format == "content":
             return 200, data.get_dict()["content"]
 
-    @defer.inlineCallbacks
-    def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+        requester = await self.auth.get_user_by_req(request)
 
         if txn_id:
             set_tag("txn_id", txn_id)
@@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
 
         if event_type == EventTypes.Member:
             membership = content.get("membership", None)
-            event = yield self.room_member_handler.update_membership(
+            event = await self.room_member_handler.update_membership(
                 requester,
                 target=UserID.from_string(state_key),
                 room_id=room_id,
@@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
                 content=content,
             )
         else:
-            event = yield self.event_creation_handler.create_and_send_nonmember_event(
+            event = await self.event_creation_handler.create_and_send_nonmember_event(
                 requester, event_dict, txn_id=txn_id
             )
 
@@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet):
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
         register_txn_path(self, PATTERNS, http_server, with_get=True)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id, event_type, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_POST(self, request, room_id, event_type, txn_id=None):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         content = parse_json_object_from_request(request)
 
         event_dict = {
@@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
         if b"ts" in request.args and requester.app_service:
             event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
 
-        event = yield self.event_creation_handler.create_and_send_nonmember_event(
+        event = await self.event_creation_handler.create_and_send_nonmember_event(
             requester, event_dict, txn_id=txn_id
         )
 
@@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
         PATTERNS = "/join/(?P<room_identifier>[^/]*)"
         register_txn_path(self, PATTERNS, http_server)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_identifier, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_POST(self, request, room_identifier, txn_id=None):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         try:
             content = parse_json_object_from_request(request)
@@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet):
         elif RoomAlias.is_valid(room_identifier):
             handler = self.room_member_handler
             room_alias = RoomAlias.from_string(room_identifier)
-            room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
+            room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
             room_id = room_id.to_string()
         else:
             raise SynapseError(
                 400, "%s was not legal room ID or room alias" % (room_identifier,)
             )
 
-        yield self.room_member_handler.update_membership(
+        await self.room_member_handler.update_membership(
             requester=requester,
             target=requester.user,
             room_id=room_id,
@@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request):
+    async def on_GET(self, request):
         server = parse_string(request, "server", default=None)
 
         try:
-            yield self.auth.get_user_by_req(request, allow_guest=True)
+            await self.auth.get_user_by_req(request, allow_guest=True)
         except InvalidClientCredentialsError as e:
             # Option to allow servers to require auth when accessing
             # /publicRooms via CS API. This is especially helpful in private
@@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server:
-            data = yield handler.get_remote_public_room_list(
+            data = await handler.get_remote_public_room_list(
                 server, limit=limit, since_token=since_token
             )
         else:
-            data = yield handler.get_local_public_room_list(
+            data = await handler.get_local_public_room_list(
                 limit=limit, since_token=since_token
             )
 
         return 200, data
 
-    @defer.inlineCallbacks
-    def on_POST(self, request):
-        yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_POST(self, request):
+        await self.auth.get_user_by_req(request, allow_guest=True)
 
         server = parse_string(request, "server", default=None)
         content = parse_json_object_from_request(request)
@@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server:
-            data = yield handler.get_remote_public_room_list(
+            data = await handler.get_remote_public_room_list(
                 server,
                 limit=limit,
                 since_token=since_token,
@@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
                 third_party_instance_id=third_party_instance_id,
             )
         else:
-            data = yield handler.get_local_public_room_list(
+            data = await handler.get_local_public_room_list(
                 limit=limit,
                 since_token=since_token,
                 search_filter=search_filter,
@@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet):
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
+    async def on_GET(self, request, room_id):
         # TODO support Pagination stream API (limit/tokens)
-        requester = yield self.auth.get_user_by_req(request)
+        requester = await self.auth.get_user_by_req(request)
         handler = self.message_handler
 
         # request the state as of a given event, as identified by a stream token,
@@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet):
         membership = parse_string(request, "membership")
         not_membership = parse_string(request, "not_membership")
 
-        events = yield handler.get_state_events(
+        events = await handler.get_state_events(
             room_id=room_id,
             user_id=requester.user.to_string(),
             at_token=at_token,
@@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_GET(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request)
 
-        users_with_profile = yield self.message_handler.get_joined_members(
+        users_with_profile = await self.message_handler.get_joined_members(
             requester, room_id
         )
 
@@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet):
         self.pagination_handler = hs.get_pagination_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         pagination_config = PaginationConfig.from_request(request, default_limit=10)
         as_client_event = b"raw" not in request.args
         filter_bytes = parse_string(request, b"filter", encoding=None)
@@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet):
                 as_client_event = False
         else:
             event_filter = None
-        msgs = yield self.pagination_handler.get_messages(
+        msgs = await self.pagination_handler.get_messages(
             room_id=room_id,
             requester=requester,
             pagin_config=pagination_config,
@@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet):
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         # Get all the current state for this room
-        events = yield self.message_handler.get_state_events(
+        events = await self.message_handler.get_state_events(
             room_id=room_id,
             user_id=requester.user.to_string(),
             is_guest=requester.is_guest,
@@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet):
         self.initial_sync_handler = hs.get_initial_sync_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         pagination_config = PaginationConfig.from_request(request)
-        content = yield self.initial_sync_handler.room_initial_sync(
+        content = await self.initial_sync_handler.room_initial_sync(
             room_id=room_id, requester=requester, pagin_config=pagination_config
         )
         return 200, content
@@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet):
         self._event_serializer = hs.get_event_client_serializer()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id, event_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id, event_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         try:
-            event = yield self.event_handler.get_event(
+            event = await self.event_handler.get_event(
                 requester.user, room_id, event_id
             )
         except AuthError:
@@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet):
 
         time_now = self.clock.time_msec()
         if event:
-            event = yield self._event_serializer.serialize_event(event, time_now)
+            event = await self._event_serializer.serialize_event(event, time_now)
             return 200, event
 
         return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
@@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet):
         self._event_serializer = hs.get_event_client_serializer()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id, event_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id, event_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         limit = parse_integer(request, "limit", default=10)
 
@@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet):
         else:
             event_filter = None
 
-        results = yield self.room_context_handler.get_event_context(
+        results = await self.room_context_handler.get_event_context(
             requester.user, room_id, event_id, limit, event_filter
         )
 
@@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet):
             raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
 
         time_now = self.clock.time_msec()
-        results["events_before"] = yield self._event_serializer.serialize_events(
+        results["events_before"] = await self._event_serializer.serialize_events(
             results["events_before"], time_now
         )
-        results["event"] = yield self._event_serializer.serialize_event(
+        results["event"] = await self._event_serializer.serialize_event(
             results["event"], time_now
         )
-        results["events_after"] = yield self._event_serializer.serialize_events(
+        results["events_after"] = await self._event_serializer.serialize_events(
             results["events_after"], time_now
         )
-        results["state"] = yield self._event_serializer.serialize_events(
+        results["state"] = await self._event_serializer.serialize_events(
             results["state"], time_now
         )
 
@@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet):
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
         register_txn_path(self, PATTERNS, http_server)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+    async def on_POST(self, request, room_id, txn_id=None):
+        requester = await self.auth.get_user_by_req(request, allow_guest=False)
 
-        yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
+        await self.room_member_handler.forget(user=requester.user, room_id=room_id)
 
         return 200, {}
 
@@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
         )
         register_txn_path(self, PATTERNS, http_server)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id, membership_action, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_POST(self, request, room_id, membership_action, txn_id=None):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         if requester.is_guest and membership_action not in {
             Membership.JOIN,
@@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
             content = {}
 
         if membership_action == "invite" and self._has_3pid_invite_keys(content):
-            yield self.room_member_handler.do_3pid_invite(
+            await self.room_member_handler.do_3pid_invite(
                 room_id,
                 requester.user,
                 content["medium"],
@@ -735,7 +717,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
         if "reason" in content and membership_action in ["kick", "ban"]:
             event_content = {"reason": content["reason"]}
 
-        yield self.room_member_handler.update_membership(
+        await self.room_member_handler.update_membership(
             requester=requester,
             target=target,
             room_id=room_id,
@@ -777,12 +759,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
         register_txn_path(self, PATTERNS, http_server)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id, event_id, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_POST(self, request, room_id, event_id, txn_id=None):
+        requester = await self.auth.get_user_by_req(request)
         content = parse_json_object_from_request(request)
 
-        event = yield self.event_creation_handler.create_and_send_nonmember_event(
+        event = await self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.Redaction,
@@ -816,29 +797,28 @@ class RoomTypingRestServlet(RestServlet):
         self.typing_handler = hs.get_typing_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_PUT(self, request, room_id, user_id):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_PUT(self, request, room_id, user_id):
+        requester = await self.auth.get_user_by_req(request)
 
         room_id = urlparse.unquote(room_id)
         target_user = UserID.from_string(urlparse.unquote(user_id))
 
         content = parse_json_object_from_request(request)
 
-        yield self.presence_handler.bump_presence_active_time(requester.user)
+        await self.presence_handler.bump_presence_active_time(requester.user)
 
         # Limit timeout to stop people from setting silly typing timeouts.
         timeout = min(content.get("timeout", 30000), 120000)
 
         if content["typing"]:
-            yield self.typing_handler.started_typing(
+            await self.typing_handler.started_typing(
                 target_user=target_user,
                 auth_user=requester.user,
                 room_id=room_id,
                 timeout=timeout,
             )
         else:
-            yield self.typing_handler.stopped_typing(
+            await self.typing_handler.stopped_typing(
                 target_user=target_user, auth_user=requester.user, room_id=room_id
             )
 
@@ -853,14 +833,13 @@ class SearchRestServlet(RestServlet):
         self.handlers = hs.get_handlers()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_POST(self, request):
+        requester = await self.auth.get_user_by_req(request)
 
         content = parse_json_object_from_request(request)
 
         batch = parse_string(request, "next_batch")
-        results = yield self.handlers.search_handler.search(
+        results = await self.handlers.search_handler.search(
             requester.user, content, batch
         )
 
@@ -875,11 +854,10 @@ class JoinedRoomsRestServlet(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
+        room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
         return 200, {"joined_rooms": list(room_ids)}