summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6275.misc1
-rw-r--r--synapse/rest/client/v1/room.py166
2 files changed, 73 insertions, 94 deletions
diff --git a/changelog.d/6275.misc b/changelog.d/6275.misc
new file mode 100644
index 0000000000..f57e2c4adb
--- /dev/null
+++ b/changelog.d/6275.misc
@@ -0,0 +1 @@
+Port room rest handlers to async/await.
diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py
index cebaceb885..91846fe1d7 100644
--- a/synapse/rest/client/v1/room.py
+++ b/synapse/rest/client/v1/room.py
@@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse
 
 from canonicaljson import json
 
-from twisted.internet import defer
-
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import (
     AuthError,
@@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet):
         set_tag("txn_id", txn_id)
         return self.txns.fetch_or_execute_request(request, self.on_POST, request)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_POST(self, request):
+        requester = await self.auth.get_user_by_req(request)
 
-        info = yield self._room_creation_handler.create_room(
+        info = await self._room_creation_handler.create_room(
             requester, self.get_room_config(request)
         )
 
@@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet):
     def on_PUT_no_state_key(self, request, room_id, event_type):
         return self.on_PUT(request, room_id, event_type, "")
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id, event_type, state_key):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id, event_type, state_key):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         format = parse_string(
             request, "format", default="content", allowed_values=["content", "event"]
         )
 
         msg_handler = self.message_handler
-        data = yield msg_handler.get_room_data(
+        data = await msg_handler.get_room_data(
             user_id=requester.user.to_string(),
             room_id=room_id,
             event_type=event_type,
@@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet):
         elif format == "content":
             return 200, data.get_dict()["content"]
 
-    @defer.inlineCallbacks
-    def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None):
+        requester = await self.auth.get_user_by_req(request)
 
         if txn_id:
             set_tag("txn_id", txn_id)
@@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
 
         if event_type == EventTypes.Member:
             membership = content.get("membership", None)
-            event = yield self.room_member_handler.update_membership(
+            event = await self.room_member_handler.update_membership(
                 requester,
                 target=UserID.from_string(state_key),
                 room_id=room_id,
@@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet):
                 content=content,
             )
         else:
-            event = yield self.event_creation_handler.create_and_send_nonmember_event(
+            event = await self.event_creation_handler.create_and_send_nonmember_event(
                 requester, event_dict, txn_id=txn_id
             )
 
@@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet):
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)"
         register_txn_path(self, PATTERNS, http_server, with_get=True)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id, event_type, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_POST(self, request, room_id, event_type, txn_id=None):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         content = parse_json_object_from_request(request)
 
         event_dict = {
@@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet):
         if b"ts" in request.args and requester.app_service:
             event_dict["origin_server_ts"] = parse_integer(request, "ts", 0)
 
-        event = yield self.event_creation_handler.create_and_send_nonmember_event(
+        event = await self.event_creation_handler.create_and_send_nonmember_event(
             requester, event_dict, txn_id=txn_id
         )
 
@@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet):
         PATTERNS = "/join/(?P<room_identifier>[^/]*)"
         register_txn_path(self, PATTERNS, http_server)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_identifier, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_POST(self, request, room_identifier, txn_id=None):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         try:
             content = parse_json_object_from_request(request)
@@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet):
         elif RoomAlias.is_valid(room_identifier):
             handler = self.room_member_handler
             room_alias = RoomAlias.from_string(room_identifier)
-            room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias)
+            room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias)
             room_id = room_id.to_string()
         else:
             raise SynapseError(
                 400, "%s was not legal room ID or room alias" % (room_identifier,)
             )
 
-        yield self.room_member_handler.update_membership(
+        await self.room_member_handler.update_membership(
             requester=requester,
             target=requester.user,
             room_id=room_id,
@@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet):
         self.hs = hs
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request):
+    async def on_GET(self, request):
         server = parse_string(request, "server", default=None)
 
         try:
-            yield self.auth.get_user_by_req(request, allow_guest=True)
+            await self.auth.get_user_by_req(request, allow_guest=True)
         except InvalidClientCredentialsError as e:
             # Option to allow servers to require auth when accessing
             # /publicRooms via CS API. This is especially helpful in private
@@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server:
-            data = yield handler.get_remote_public_room_list(
+            data = await handler.get_remote_public_room_list(
                 server, limit=limit, since_token=since_token
             )
         else:
-            data = yield handler.get_local_public_room_list(
+            data = await handler.get_local_public_room_list(
                 limit=limit, since_token=since_token
             )
 
         return 200, data
 
-    @defer.inlineCallbacks
-    def on_POST(self, request):
-        yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_POST(self, request):
+        await self.auth.get_user_by_req(request, allow_guest=True)
 
         server = parse_string(request, "server", default=None)
         content = parse_json_object_from_request(request)
@@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
 
         handler = self.hs.get_room_list_handler()
         if server:
-            data = yield handler.get_remote_public_room_list(
+            data = await handler.get_remote_public_room_list(
                 server,
                 limit=limit,
                 since_token=since_token,
@@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet):
                 third_party_instance_id=third_party_instance_id,
             )
         else:
-            data = yield handler.get_local_public_room_list(
+            data = await handler.get_local_public_room_list(
                 limit=limit,
                 since_token=since_token,
                 search_filter=search_filter,
@@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet):
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
+    async def on_GET(self, request, room_id):
         # TODO support Pagination stream API (limit/tokens)
-        requester = yield self.auth.get_user_by_req(request)
+        requester = await self.auth.get_user_by_req(request)
         handler = self.message_handler
 
         # request the state as of a given event, as identified by a stream token,
@@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet):
         membership = parse_string(request, "membership")
         not_membership = parse_string(request, "not_membership")
 
-        events = yield handler.get_state_events(
+        events = await handler.get_state_events(
             room_id=room_id,
             user_id=requester.user.to_string(),
             at_token=at_token,
@@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet):
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_GET(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request)
 
-        users_with_profile = yield self.message_handler.get_joined_members(
+        users_with_profile = await self.message_handler.get_joined_members(
             requester, room_id
         )
 
@@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet):
         self.pagination_handler = hs.get_pagination_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         pagination_config = PaginationConfig.from_request(request, default_limit=10)
         as_client_event = b"raw" not in request.args
         filter_bytes = parse_string(request, b"filter", encoding=None)
@@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet):
                 as_client_event = False
         else:
             event_filter = None
-        msgs = yield self.pagination_handler.get_messages(
+        msgs = await self.pagination_handler.get_messages(
             room_id=room_id,
             requester=requester,
             pagin_config=pagination_config,
@@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet):
         self.message_handler = hs.get_message_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         # Get all the current state for this room
-        events = yield self.message_handler.get_state_events(
+        events = await self.message_handler.get_state_events(
             room_id=room_id,
             user_id=requester.user.to_string(),
             is_guest=requester.is_guest,
@@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet):
         self.initial_sync_handler = hs.get_initial_sync_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         pagination_config = PaginationConfig.from_request(request)
-        content = yield self.initial_sync_handler.room_initial_sync(
+        content = await self.initial_sync_handler.room_initial_sync(
             room_id=room_id, requester=requester, pagin_config=pagination_config
         )
         return 200, content
@@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet):
         self._event_serializer = hs.get_event_client_serializer()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id, event_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id, event_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
         try:
-            event = yield self.event_handler.get_event(
+            event = await self.event_handler.get_event(
                 requester.user, room_id, event_id
             )
         except AuthError:
@@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet):
 
         time_now = self.clock.time_msec()
         if event:
-            event = yield self._event_serializer.serialize_event(event, time_now)
+            event = await self._event_serializer.serialize_event(event, time_now)
             return 200, event
 
         return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
@@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet):
         self._event_serializer = hs.get_event_client_serializer()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request, room_id, event_id):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request, room_id, event_id):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         limit = parse_integer(request, "limit", default=10)
 
@@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet):
         else:
             event_filter = None
 
-        results = yield self.room_context_handler.get_event_context(
+        results = await self.room_context_handler.get_event_context(
             requester.user, room_id, event_id, limit, event_filter
         )
 
@@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet):
             raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND)
 
         time_now = self.clock.time_msec()
-        results["events_before"] = yield self._event_serializer.serialize_events(
+        results["events_before"] = await self._event_serializer.serialize_events(
             results["events_before"], time_now
         )
-        results["event"] = yield self._event_serializer.serialize_event(
+        results["event"] = await self._event_serializer.serialize_event(
             results["event"], time_now
         )
-        results["events_after"] = yield self._event_serializer.serialize_events(
+        results["events_after"] = await self._event_serializer.serialize_events(
             results["events_after"], time_now
         )
-        results["state"] = yield self._event_serializer.serialize_events(
+        results["state"] = await self._event_serializer.serialize_events(
             results["state"], time_now
         )
 
@@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet):
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget"
         register_txn_path(self, PATTERNS, http_server)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=False)
+    async def on_POST(self, request, room_id, txn_id=None):
+        requester = await self.auth.get_user_by_req(request, allow_guest=False)
 
-        yield self.room_member_handler.forget(user=requester.user, room_id=room_id)
+        await self.room_member_handler.forget(user=requester.user, room_id=room_id)
 
         return 200, {}
 
@@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet):
         )
         register_txn_path(self, PATTERNS, http_server)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id, membership_action, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_POST(self, request, room_id, membership_action, txn_id=None):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
         if requester.is_guest and membership_action not in {
             Membership.JOIN,
@@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
             content = {}
 
         if membership_action == "invite" and self._has_3pid_invite_keys(content):
-            yield self.room_member_handler.do_3pid_invite(
+            await self.room_member_handler.do_3pid_invite(
                 room_id,
                 requester.user,
                 content["medium"],
@@ -736,7 +718,7 @@ class RoomMembershipRestServlet(TransactionRestServlet):
         if "reason" in content and membership_action in ["kick", "ban"]:
             event_content = {"reason": content["reason"]}
 
-        yield self.room_member_handler.update_membership(
+        await self.room_member_handler.update_membership(
             requester=requester,
             target=target,
             room_id=room_id,
@@ -778,12 +760,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet):
         PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)"
         register_txn_path(self, PATTERNS, http_server)
 
-    @defer.inlineCallbacks
-    def on_POST(self, request, room_id, event_id, txn_id=None):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_POST(self, request, room_id, event_id, txn_id=None):
+        requester = await self.auth.get_user_by_req(request)
         content = parse_json_object_from_request(request)
 
-        event = yield self.event_creation_handler.create_and_send_nonmember_event(
+        event = await self.event_creation_handler.create_and_send_nonmember_event(
             requester,
             {
                 "type": EventTypes.Redaction,
@@ -817,29 +798,28 @@ class RoomTypingRestServlet(RestServlet):
         self.typing_handler = hs.get_typing_handler()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_PUT(self, request, room_id, user_id):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_PUT(self, request, room_id, user_id):
+        requester = await self.auth.get_user_by_req(request)
 
         room_id = urlparse.unquote(room_id)
         target_user = UserID.from_string(urlparse.unquote(user_id))
 
         content = parse_json_object_from_request(request)
 
-        yield self.presence_handler.bump_presence_active_time(requester.user)
+        await self.presence_handler.bump_presence_active_time(requester.user)
 
         # Limit timeout to stop people from setting silly typing timeouts.
         timeout = min(content.get("timeout", 30000), 120000)
 
         if content["typing"]:
-            yield self.typing_handler.started_typing(
+            await self.typing_handler.started_typing(
                 target_user=target_user,
                 auth_user=requester.user,
                 room_id=room_id,
                 timeout=timeout,
             )
         else:
-            yield self.typing_handler.stopped_typing(
+            await self.typing_handler.stopped_typing(
                 target_user=target_user, auth_user=requester.user, room_id=room_id
             )
 
@@ -854,14 +834,13 @@ class SearchRestServlet(RestServlet):
         self.handlers = hs.get_handlers()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_POST(self, request):
-        requester = yield self.auth.get_user_by_req(request)
+    async def on_POST(self, request):
+        requester = await self.auth.get_user_by_req(request)
 
         content = parse_json_object_from_request(request)
 
         batch = parse_string(request, "next_batch")
-        results = yield self.handlers.search_handler.search(
+        results = await self.handlers.search_handler.search(
             requester.user, content, batch
         )
 
@@ -876,11 +855,10 @@ class JoinedRoomsRestServlet(RestServlet):
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
-    @defer.inlineCallbacks
-    def on_GET(self, request):
-        requester = yield self.auth.get_user_by_req(request, allow_guest=True)
+    async def on_GET(self, request):
+        requester = await self.auth.get_user_by_req(request, allow_guest=True)
 
-        room_ids = yield self.store.get_rooms_for_user(requester.user.to_string())
+        room_ids = await self.store.get_rooms_for_user(requester.user.to_string())
         return 200, {"joined_rooms": list(room_ids)}