diff options
-rw-r--r-- | changelog.d/6275.misc | 1 | ||||
-rw-r--r-- | changelog.d/6279.misc | 1 | ||||
-rw-r--r-- | synapse/federation/federation_server.py | 217 | ||||
-rw-r--r-- | synapse/rest/client/v1/room.py | 166 | ||||
-rw-r--r-- | synapse/util/async_helpers.py | 7 | ||||
-rw-r--r-- | tests/handlers/test_typing.py | 3 |
6 files changed, 173 insertions, 222 deletions
diff --git a/changelog.d/6275.misc b/changelog.d/6275.misc new file mode 100644 index 0000000000..f57e2c4adb --- /dev/null +++ b/changelog.d/6275.misc @@ -0,0 +1 @@ +Port room rest handlers to async/await. diff --git a/changelog.d/6279.misc b/changelog.d/6279.misc new file mode 100644 index 0000000000..5f5144a9ee --- /dev/null +++ b/changelog.d/6279.misc @@ -0,0 +1 @@ +Port `federation_server.py` to async/await. diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 5fc7c1d67b..7c331753ad 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -21,7 +21,6 @@ from six import iteritems from canonicaljson import json from prometheus_client import Counter -from twisted.internet import defer from twisted.internet.abstract import isIPAddress from twisted.python import failure @@ -86,14 +85,12 @@ class FederationServer(FederationBase): # come in waves. self._state_resp_cache = ResponseCache(hs, "state_resp", timeout_ms=30000) - @defer.inlineCallbacks - @log_function - def on_backfill_request(self, origin, room_id, versions, limit): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_backfill_request(self, origin, room_id, versions, limit): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - pdus = yield self.handler.on_backfill_request( + pdus = await self.handler.on_backfill_request( origin, room_id, versions, limit ) @@ -101,9 +98,7 @@ class FederationServer(FederationBase): return 200, res - @defer.inlineCallbacks - @log_function - def on_incoming_transaction(self, origin, transaction_data): + async def on_incoming_transaction(self, origin, transaction_data): # keep this as early as possible to make the calculated origin ts as # accurate as possible. request_time = self._clock.time_msec() @@ -118,18 +113,17 @@ class FederationServer(FederationBase): # use a linearizer to ensure that we don't process the same transaction # multiple times in parallel. with ( - yield self._transaction_linearizer.queue( + await self._transaction_linearizer.queue( (origin, transaction.transaction_id) ) ): - result = yield self._handle_incoming_transaction( + result = await self._handle_incoming_transaction( origin, transaction, request_time ) return result - @defer.inlineCallbacks - def _handle_incoming_transaction(self, origin, transaction, request_time): + async def _handle_incoming_transaction(self, origin, transaction, request_time): """ Process an incoming transaction and return the HTTP response Args: @@ -140,7 +134,7 @@ class FederationServer(FederationBase): Returns: Deferred[(int, object)]: http response code and body """ - response = yield self.transaction_actions.have_responded(origin, transaction) + response = await self.transaction_actions.have_responded(origin, transaction) if response: logger.debug( @@ -159,7 +153,7 @@ class FederationServer(FederationBase): logger.info("Transaction PDU or EDU count too large. Returning 400") response = {} - yield self.transaction_actions.set_response( + await self.transaction_actions.set_response( origin, transaction, 400, response ) return 400, response @@ -195,7 +189,7 @@ class FederationServer(FederationBase): continue try: - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) except NotFoundError: logger.info("Ignoring PDU for unknown room_id: %s", room_id) continue @@ -221,11 +215,10 @@ class FederationServer(FederationBase): # require callouts to other servers to fetch missing events), but # impose a limit to avoid going too crazy with ram/cpu. - @defer.inlineCallbacks - def process_pdus_for_room(room_id): + async def process_pdus_for_room(room_id): logger.debug("Processing PDUs for %s", room_id) try: - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) except AuthError as e: logger.warn("Ignoring PDUs for room %s from banned server", room_id) for pdu in pdus_by_room[room_id]: @@ -237,7 +230,7 @@ class FederationServer(FederationBase): event_id = pdu.event_id with nested_logging_context(event_id): try: - yield self._handle_received_pdu(origin, pdu) + await self._handle_received_pdu(origin, pdu) pdu_results[event_id] = {} except FederationError as e: logger.warn("Error handling PDU %s: %s", event_id, e) @@ -251,36 +244,33 @@ class FederationServer(FederationBase): exc_info=(f.type, f.value, f.getTracebackObject()), ) - yield concurrently_execute( + await concurrently_execute( process_pdus_for_room, pdus_by_room.keys(), TRANSACTION_CONCURRENCY_LIMIT ) if hasattr(transaction, "edus"): for edu in (Edu(**x) for x in transaction.edus): - yield self.received_edu(origin, edu.edu_type, edu.content) + await self.received_edu(origin, edu.edu_type, edu.content) response = {"pdus": pdu_results} logger.debug("Returning: %s", str(response)) - yield self.transaction_actions.set_response(origin, transaction, 200, response) + await self.transaction_actions.set_response(origin, transaction, 200, response) return 200, response - @defer.inlineCallbacks - def received_edu(self, origin, edu_type, content): + async def received_edu(self, origin, edu_type, content): received_edus_counter.inc() - yield self.registry.on_edu(edu_type, origin, content) + await self.registry.on_edu(edu_type, origin, content) - @defer.inlineCallbacks - @log_function - def on_context_state_request(self, origin, room_id, event_id): + async def on_context_state_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") @@ -289,8 +279,8 @@ class FederationServer(FederationBase): # in the cache so we could return it without waiting for the linearizer # - but that's non-trivial to get right, and anyway somewhat defeats # the point of the linearizer. - with (yield self._server_linearizer.queue((origin, room_id))): - resp = yield self._state_resp_cache.wrap( + with (await self._server_linearizer.queue((origin, room_id))): + resp = await self._state_resp_cache.wrap( (room_id, event_id), self._on_context_state_request_compute, room_id, @@ -299,65 +289,58 @@ class FederationServer(FederationBase): return 200, resp - @defer.inlineCallbacks - def on_state_ids_request(self, origin, room_id, event_id): + async def on_state_ids_request(self, origin, room_id, event_id): if not event_id: raise NotImplementedError("Specify an event") origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - in_room = yield self.auth.check_host_in_room(room_id, origin) + in_room = await self.auth.check_host_in_room(room_id, origin) if not in_room: raise AuthError(403, "Host not in room.") - state_ids = yield self.handler.get_state_ids_for_pdu(room_id, event_id) - auth_chain_ids = yield self.store.get_auth_chain_ids(state_ids) + state_ids = await self.handler.get_state_ids_for_pdu(room_id, event_id) + auth_chain_ids = await self.store.get_auth_chain_ids(state_ids) return 200, {"pdu_ids": state_ids, "auth_chain_ids": auth_chain_ids} - @defer.inlineCallbacks - def _on_context_state_request_compute(self, room_id, event_id): - pdus = yield self.handler.get_state_for_pdu(room_id, event_id) - auth_chain = yield self.store.get_auth_chain([pdu.event_id for pdu in pdus]) + async def _on_context_state_request_compute(self, room_id, event_id): + pdus = await self.handler.get_state_for_pdu(room_id, event_id) + auth_chain = await self.store.get_auth_chain([pdu.event_id for pdu in pdus]) return { "pdus": [pdu.get_pdu_json() for pdu in pdus], "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], } - @defer.inlineCallbacks - @log_function - def on_pdu_request(self, origin, event_id): - pdu = yield self.handler.get_persisted_pdu(origin, event_id) + async def on_pdu_request(self, origin, event_id): + pdu = await self.handler.get_persisted_pdu(origin, event_id) if pdu: return 200, self._transaction_from_pdus([pdu]).get_dict() else: return 404, "" - @defer.inlineCallbacks - def on_query_request(self, query_type, args): + async def on_query_request(self, query_type, args): received_queries_counter.labels(query_type).inc() - resp = yield self.registry.on_query(query_type, args) + resp = await self.registry.on_query(query_type, args) return 200, resp - @defer.inlineCallbacks - def on_make_join_request(self, origin, room_id, user_id, supported_versions): + async def on_make_join_request(self, origin, room_id, user_id, supported_versions): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) if room_version not in supported_versions: logger.warn("Room version %s not in %s", room_version, supported_versions) raise IncompatibleRoomVersionError(room_version=room_version) - pdu = yield self.handler.on_make_join_request(origin, room_id, user_id) + pdu = await self.handler.on_make_join_request(origin, room_id, user_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_invite_request(self, origin, content, room_version): + async def on_invite_request(self, origin, content, room_version): if room_version not in KNOWN_ROOM_VERSIONS: raise SynapseError( 400, @@ -369,28 +352,27 @@ class FederationServer(FederationBase): pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) - pdu = yield self._check_sigs_and_hash(room_version, pdu) - ret_pdu = yield self.handler.on_invite_request(origin, pdu) + await self.check_server_matches_acl(origin_host, pdu.room_id) + pdu = await self._check_sigs_and_hash(room_version, pdu) + ret_pdu = await self.handler.on_invite_request(origin, pdu) time_now = self._clock.time_msec() return {"event": ret_pdu.get_pdu_json(time_now)} - @defer.inlineCallbacks - def on_send_join_request(self, origin, content, room_id): + async def on_send_join_request(self, origin, content, room_id): logger.debug("on_send_join_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_join_request: pdu sigs: %s", pdu.signatures) - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) - res_pdus = yield self.handler.on_send_join_request(origin, pdu) + res_pdus = await self.handler.on_send_join_request(origin, pdu) time_now = self._clock.time_msec() return ( 200, @@ -402,48 +384,44 @@ class FederationServer(FederationBase): }, ) - @defer.inlineCallbacks - def on_make_leave_request(self, origin, room_id, user_id): + async def on_make_leave_request(self, origin, room_id, user_id): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) - pdu = yield self.handler.on_make_leave_request(origin, room_id, user_id) + await self.check_server_matches_acl(origin_host, room_id) + pdu = await self.handler.on_make_leave_request(origin, room_id, user_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) time_now = self._clock.time_msec() return {"event": pdu.get_pdu_json(time_now), "room_version": room_version} - @defer.inlineCallbacks - def on_send_leave_request(self, origin, content, room_id): + async def on_send_leave_request(self, origin, content, room_id): logger.debug("on_send_leave_request: content: %s", content) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) pdu = event_from_pdu_json(content, format_ver) origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, pdu.room_id) + await self.check_server_matches_acl(origin_host, pdu.room_id) logger.debug("on_send_leave_request: pdu sigs: %s", pdu.signatures) - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) - yield self.handler.on_send_leave_request(origin, pdu) + await self.handler.on_send_leave_request(origin, pdu) return 200, {} - @defer.inlineCallbacks - def on_event_auth(self, origin, room_id, event_id): - with (yield self._server_linearizer.queue((origin, room_id))): + async def on_event_auth(self, origin, room_id, event_id): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) time_now = self._clock.time_msec() - auth_pdus = yield self.handler.on_event_auth(event_id) + auth_pdus = await self.handler.on_event_auth(event_id) res = {"auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus]} return 200, res - @defer.inlineCallbacks - def on_query_auth_request(self, origin, content, room_id, event_id): + async def on_query_auth_request(self, origin, content, room_id, event_id): """ Content is a dict with keys:: auth_chain (list): A list of events that give the auth chain. @@ -462,22 +440,22 @@ class FederationServer(FederationBase): Returns: Deferred: Results in `dict` with the same format as `content` """ - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) - room_version = yield self.store.get_room_version(room_id) + room_version = await self.store.get_room_version(room_id) format_ver = room_version_to_event_format(room_version) auth_chain = [ event_from_pdu_json(e, format_ver) for e in content["auth_chain"] ] - signed_auth = yield self._check_sigs_and_hash_and_fetch( + signed_auth = await self._check_sigs_and_hash_and_fetch( origin, auth_chain, outlier=True, room_version=room_version ) - ret = yield self.handler.on_query_auth( + ret = await self.handler.on_query_auth( origin, event_id, room_id, @@ -503,16 +481,14 @@ class FederationServer(FederationBase): return self.on_query_request("user_devices", user_id) @trace - @defer.inlineCallbacks - @log_function - def on_claim_client_keys(self, origin, content): + async def on_claim_client_keys(self, origin, content): query = [] for user_id, device_keys in content.get("one_time_keys", {}).items(): for device_id, algorithm in device_keys.items(): query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = yield self.store.claim_e2e_one_time_keys(query) + results = await self.store.claim_e2e_one_time_keys(query) json_result = {} for user_id, device_keys in results.items(): @@ -536,14 +512,12 @@ class FederationServer(FederationBase): return {"one_time_keys": json_result} - @defer.inlineCallbacks - @log_function - def on_get_missing_events( + async def on_get_missing_events( self, origin, room_id, earliest_events, latest_events, limit ): - with (yield self._server_linearizer.queue((origin, room_id))): + with (await self._server_linearizer.queue((origin, room_id))): origin_host, _ = parse_server_name(origin) - yield self.check_server_matches_acl(origin_host, room_id) + await self.check_server_matches_acl(origin_host, room_id) logger.info( "on_get_missing_events: earliest_events: %r, latest_events: %r," @@ -553,7 +527,7 @@ class FederationServer(FederationBase): limit, ) - missing_events = yield self.handler.on_get_missing_events( + missing_events = await self.handler.on_get_missing_events( origin, room_id, earliest_events, latest_events, limit ) @@ -586,8 +560,7 @@ class FederationServer(FederationBase): destination=None, ) - @defer.inlineCallbacks - def _handle_received_pdu(self, origin, pdu): + async def _handle_received_pdu(self, origin, pdu): """ Process a PDU received in a federation /send/ transaction. If the event is invalid, then this method throws a FederationError. @@ -640,37 +613,34 @@ class FederationServer(FederationBase): logger.info("Accepting join PDU %s from %s", pdu.event_id, origin) # We've already checked that we know the room version by this point - room_version = yield self.store.get_room_version(pdu.room_id) + room_version = await self.store.get_room_version(pdu.room_id) # Check signature. try: - pdu = yield self._check_sigs_and_hash(room_version, pdu) + pdu = await self._check_sigs_and_hash(room_version, pdu) except SynapseError as e: raise FederationError("ERROR", e.code, e.msg, affected=pdu.event_id) - yield self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) + await self.handler.on_receive_pdu(origin, pdu, sent_to_us_directly=True) def __str__(self): return "<ReplicationLayer(%s)>" % self.server_name - @defer.inlineCallbacks - def exchange_third_party_invite( + async def exchange_third_party_invite( self, sender_user_id, target_user_id, room_id, signed ): - ret = yield self.handler.exchange_third_party_invite( + ret = await self.handler.exchange_third_party_invite( sender_user_id, target_user_id, room_id, signed ) return ret - @defer.inlineCallbacks - def on_exchange_third_party_invite_request(self, room_id, event_dict): - ret = yield self.handler.on_exchange_third_party_invite_request( + async def on_exchange_third_party_invite_request(self, room_id, event_dict): + ret = await self.handler.on_exchange_third_party_invite_request( room_id, event_dict ) return ret - @defer.inlineCallbacks - def check_server_matches_acl(self, server_name, room_id): + async def check_server_matches_acl(self, server_name, room_id): """Check if the given server is allowed by the server ACLs in the room Args: @@ -680,13 +650,13 @@ class FederationServer(FederationBase): Raises: AuthError if the server does not match the ACL """ - state_ids = yield self.store.get_current_state_ids(room_id) + state_ids = await self.store.get_current_state_ids(room_id) acl_event_id = state_ids.get((EventTypes.ServerACL, "")) if not acl_event_id: return - acl_event = yield self.store.get_event(acl_event_id) + acl_event = await self.store.get_event(acl_event_id) if server_matches_acl_event(server_name, acl_event): return @@ -799,15 +769,14 @@ class FederationHandlerRegistry(object): self.query_handlers[query_type] = handler - @defer.inlineCallbacks - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): handler = self.edu_handlers.get(edu_type) if not handler: logger.warn("No handler registered for EDU type %s", edu_type) with start_active_span_from_edu(content, "handle_edu"): try: - yield handler(origin, content) + await handler(origin, content) except SynapseError as e: logger.info("Failed to handle edu %r: %r", edu_type, e) except Exception: @@ -840,7 +809,7 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): super(ReplicationFederationHandlerRegistry, self).__init__() - def on_edu(self, edu_type, origin, content): + async def on_edu(self, edu_type, origin, content): """Overrides FederationHandlerRegistry """ if not self.config.use_presence and edu_type == "m.presence": @@ -848,17 +817,17 @@ class ReplicationFederationHandlerRegistry(FederationHandlerRegistry): handler = self.edu_handlers.get(edu_type) if handler: - return super(ReplicationFederationHandlerRegistry, self).on_edu( + return await super(ReplicationFederationHandlerRegistry, self).on_edu( edu_type, origin, content ) - return self._send_edu(edu_type=edu_type, origin=origin, content=content) + return await self._send_edu(edu_type=edu_type, origin=origin, content=content) - def on_query(self, query_type, args): + async def on_query(self, query_type, args): """Overrides FederationHandlerRegistry """ handler = self.query_handlers.get(query_type) if handler: - return handler(args) + return await handler(args) - return self._get_query_client(query_type=query_type, args=args) + return await self._get_query_client(query_type=query_type, args=args) diff --git a/synapse/rest/client/v1/room.py b/synapse/rest/client/v1/room.py index 9c1d41421c..86bbcc0eea 100644 --- a/synapse/rest/client/v1/room.py +++ b/synapse/rest/client/v1/room.py @@ -21,8 +21,6 @@ from six.moves.urllib import parse as urlparse from canonicaljson import json -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.api.errors import ( AuthError, @@ -85,11 +83,10 @@ class RoomCreateRestServlet(TransactionRestServlet): set_tag("txn_id", txn_id) return self.txns.fetch_or_execute_request(request, self.on_POST, request) - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) - info = yield self._room_creation_handler.create_room( + info = await self._room_creation_handler.create_room( requester, self.get_room_config(request) ) @@ -154,15 +151,14 @@ class RoomStateEventRestServlet(TransactionRestServlet): def on_PUT_no_state_key(self, request, room_id, event_type): return self.on_PUT(request, room_id, event_type, "") - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_type, state_key): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_type, state_key): + requester = await self.auth.get_user_by_req(request, allow_guest=True) format = parse_string( request, "format", default="content", allowed_values=["content", "event"] ) msg_handler = self.message_handler - data = yield msg_handler.get_room_data( + data = await msg_handler.get_room_data( user_id=requester.user.to_string(), room_id=room_id, event_type=event_type, @@ -179,9 +175,8 @@ class RoomStateEventRestServlet(TransactionRestServlet): elif format == "content": return 200, data.get_dict()["content"] - @defer.inlineCallbacks - def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, event_type, state_key, txn_id=None): + requester = await self.auth.get_user_by_req(request) if txn_id: set_tag("txn_id", txn_id) @@ -200,7 +195,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): if event_type == EventTypes.Member: membership = content.get("membership", None) - event = yield self.room_member_handler.update_membership( + event = await self.room_member_handler.update_membership( requester, target=UserID.from_string(state_key), room_id=room_id, @@ -208,7 +203,7 @@ class RoomStateEventRestServlet(TransactionRestServlet): content=content, ) else: - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -231,9 +226,8 @@ class RoomSendEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/send/(?P<event_type>[^/]*)" register_txn_path(self, PATTERNS, http_server, with_get=True) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_type, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, event_type, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) content = parse_json_object_from_request(request) event_dict = { @@ -246,7 +240,7 @@ class RoomSendEventRestServlet(TransactionRestServlet): if b"ts" in request.args and requester.app_service: event_dict["origin_server_ts"] = parse_integer(request, "ts", 0) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, event_dict, txn_id=txn_id ) @@ -276,9 +270,8 @@ class JoinRoomAliasServlet(TransactionRestServlet): PATTERNS = "/join/(?P<room_identifier>[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_identifier, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_identifier, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: content = parse_json_object_from_request(request) @@ -298,14 +291,14 @@ class JoinRoomAliasServlet(TransactionRestServlet): elif RoomAlias.is_valid(room_identifier): handler = self.room_member_handler room_alias = RoomAlias.from_string(room_identifier) - room_id, remote_room_hosts = yield handler.lookup_room_alias(room_alias) + room_id, remote_room_hosts = await handler.lookup_room_alias(room_alias) room_id = room_id.to_string() else: raise SynapseError( 400, "%s was not legal room ID or room alias" % (room_identifier,) ) - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=requester.user, room_id=room_id, @@ -335,12 +328,11 @@ class PublicRoomListRestServlet(TransactionRestServlet): self.hs = hs self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): + async def on_GET(self, request): server = parse_string(request, "server", default=None) try: - yield self.auth.get_user_by_req(request, allow_guest=True) + await self.auth.get_user_by_req(request, allow_guest=True) except InvalidClientCredentialsError as e: # Option to allow servers to require auth when accessing # /publicRooms via CS API. This is especially helpful in private @@ -367,19 +359,18 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: - data = yield handler.get_remote_public_room_list( + data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token ) else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token ) return 200, data - @defer.inlineCallbacks - def on_POST(self, request): - yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request): + await self.auth.get_user_by_req(request, allow_guest=True) server = parse_string(request, "server", default=None) content = parse_json_object_from_request(request) @@ -408,7 +399,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): handler = self.hs.get_room_list_handler() if server: - data = yield handler.get_remote_public_room_list( + data = await handler.get_remote_public_room_list( server, limit=limit, since_token=since_token, @@ -417,7 +408,7 @@ class PublicRoomListRestServlet(TransactionRestServlet): third_party_instance_id=third_party_instance_id, ) else: - data = yield handler.get_local_public_room_list( + data = await handler.get_local_public_room_list( limit=limit, since_token=since_token, search_filter=search_filter, @@ -436,10 +427,9 @@ class RoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): + async def on_GET(self, request, room_id): # TODO support Pagination stream API (limit/tokens) - requester = yield self.auth.get_user_by_req(request) + requester = await self.auth.get_user_by_req(request) handler = self.message_handler # request the state as of a given event, as identified by a stream token, @@ -459,7 +449,7 @@ class RoomMemberListRestServlet(RestServlet): membership = parse_string(request, "membership") not_membership = parse_string(request, "not_membership") - events = yield handler.get_state_events( + events = await handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), at_token=at_token, @@ -488,11 +478,10 @@ class JoinedRoomMemberListRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request) - users_with_profile = yield self.message_handler.get_joined_members( + users_with_profile = await self.message_handler.get_joined_members( requester, room_id ) @@ -508,9 +497,8 @@ class RoomMessageListRestServlet(RestServlet): self.pagination_handler = hs.get_pagination_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request, default_limit=10) as_client_event = b"raw" not in request.args filter_bytes = parse_string(request, b"filter", encoding=None) @@ -521,7 +509,7 @@ class RoomMessageListRestServlet(RestServlet): as_client_event = False else: event_filter = None - msgs = yield self.pagination_handler.get_messages( + msgs = await self.pagination_handler.get_messages( room_id=room_id, requester=requester, pagin_config=pagination_config, @@ -541,11 +529,10 @@ class RoomStateRestServlet(RestServlet): self.message_handler = hs.get_message_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) # Get all the current state for this room - events = yield self.message_handler.get_state_events( + events = await self.message_handler.get_state_events( room_id=room_id, user_id=requester.user.to_string(), is_guest=requester.is_guest, @@ -562,11 +549,10 @@ class RoomInitialSyncRestServlet(RestServlet): self.initial_sync_handler = hs.get_initial_sync_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) pagination_config = PaginationConfig.from_request(request) - content = yield self.initial_sync_handler.room_initial_sync( + content = await self.initial_sync_handler.room_initial_sync( room_id=room_id, requester=requester, pagin_config=pagination_config ) return 200, content @@ -584,11 +570,10 @@ class RoomEventServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) try: - event = yield self.event_handler.get_event( + event = await self.event_handler.get_event( requester.user, room_id, event_id ) except AuthError: @@ -599,7 +584,7 @@ class RoomEventServlet(RestServlet): time_now = self.clock.time_msec() if event: - event = yield self._event_serializer.serialize_event(event, time_now) + event = await self._event_serializer.serialize_event(event, time_now) return 200, event return SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) @@ -617,9 +602,8 @@ class RoomEventContextServlet(RestServlet): self._event_serializer = hs.get_event_client_serializer() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request, room_id, event_id): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request, room_id, event_id): + requester = await self.auth.get_user_by_req(request, allow_guest=True) limit = parse_integer(request, "limit", default=10) @@ -631,7 +615,7 @@ class RoomEventContextServlet(RestServlet): else: event_filter = None - results = yield self.room_context_handler.get_event_context( + results = await self.room_context_handler.get_event_context( requester.user, room_id, event_id, limit, event_filter ) @@ -639,16 +623,16 @@ class RoomEventContextServlet(RestServlet): raise SynapseError(404, "Event not found.", errcode=Codes.NOT_FOUND) time_now = self.clock.time_msec() - results["events_before"] = yield self._event_serializer.serialize_events( + results["events_before"] = await self._event_serializer.serialize_events( results["events_before"], time_now ) - results["event"] = yield self._event_serializer.serialize_event( + results["event"] = await self._event_serializer.serialize_event( results["event"], time_now ) - results["events_after"] = yield self._event_serializer.serialize_events( + results["events_after"] = await self._event_serializer.serialize_events( results["events_after"], time_now ) - results["state"] = yield self._event_serializer.serialize_events( + results["state"] = await self._event_serializer.serialize_events( results["state"], time_now ) @@ -665,11 +649,10 @@ class RoomForgetRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/forget" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=False) + async def on_POST(self, request, room_id, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=False) - yield self.room_member_handler.forget(user=requester.user, room_id=room_id) + await self.room_member_handler.forget(user=requester.user, room_id=room_id) return 200, {} @@ -696,9 +679,8 @@ class RoomMembershipRestServlet(TransactionRestServlet): ) register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, membership_action, txn_id=None): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_POST(self, request, room_id, membership_action, txn_id=None): + requester = await self.auth.get_user_by_req(request, allow_guest=True) if requester.is_guest and membership_action not in { Membership.JOIN, @@ -714,7 +696,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): content = {} if membership_action == "invite" and self._has_3pid_invite_keys(content): - yield self.room_member_handler.do_3pid_invite( + await self.room_member_handler.do_3pid_invite( room_id, requester.user, content["medium"], @@ -735,7 +717,7 @@ class RoomMembershipRestServlet(TransactionRestServlet): if "reason" in content and membership_action in ["kick", "ban"]: event_content = {"reason": content["reason"]} - yield self.room_member_handler.update_membership( + await self.room_member_handler.update_membership( requester=requester, target=target, room_id=room_id, @@ -777,12 +759,11 @@ class RoomRedactEventRestServlet(TransactionRestServlet): PATTERNS = "/rooms/(?P<room_id>[^/]*)/redact/(?P<event_id>[^/]*)" register_txn_path(self, PATTERNS, http_server) - @defer.inlineCallbacks - def on_POST(self, request, room_id, event_id, txn_id=None): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request, room_id, event_id, txn_id=None): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) - event = yield self.event_creation_handler.create_and_send_nonmember_event( + event = await self.event_creation_handler.create_and_send_nonmember_event( requester, { "type": EventTypes.Redaction, @@ -816,29 +797,28 @@ class RoomTypingRestServlet(RestServlet): self.typing_handler = hs.get_typing_handler() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_PUT(self, request, room_id, user_id): - requester = yield self.auth.get_user_by_req(request) + async def on_PUT(self, request, room_id, user_id): + requester = await self.auth.get_user_by_req(request) room_id = urlparse.unquote(room_id) target_user = UserID.from_string(urlparse.unquote(user_id)) content = parse_json_object_from_request(request) - yield self.presence_handler.bump_presence_active_time(requester.user) + await self.presence_handler.bump_presence_active_time(requester.user) # Limit timeout to stop people from setting silly typing timeouts. timeout = min(content.get("timeout", 30000), 120000) if content["typing"]: - yield self.typing_handler.started_typing( + await self.typing_handler.started_typing( target_user=target_user, auth_user=requester.user, room_id=room_id, timeout=timeout, ) else: - yield self.typing_handler.stopped_typing( + await self.typing_handler.stopped_typing( target_user=target_user, auth_user=requester.user, room_id=room_id ) @@ -853,14 +833,13 @@ class SearchRestServlet(RestServlet): self.handlers = hs.get_handlers() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_POST(self, request): - requester = yield self.auth.get_user_by_req(request) + async def on_POST(self, request): + requester = await self.auth.get_user_by_req(request) content = parse_json_object_from_request(request) batch = parse_string(request, "next_batch") - results = yield self.handlers.search_handler.search( + results = await self.handlers.search_handler.search( requester.user, content, batch ) @@ -875,11 +854,10 @@ class JoinedRoomsRestServlet(RestServlet): self.store = hs.get_datastore() self.auth = hs.get_auth() - @defer.inlineCallbacks - def on_GET(self, request): - requester = yield self.auth.get_user_by_req(request, allow_guest=True) + async def on_GET(self, request): + requester = await self.auth.get_user_by_req(request, allow_guest=True) - room_ids = yield self.store.get_rooms_for_user(requester.user.to_string()) + room_ids = await self.store.get_rooms_for_user(requester.user.to_string()) return 200, {"joined_rooms": list(room_ids)} diff --git a/synapse/util/async_helpers.py b/synapse/util/async_helpers.py index 804dbca443..7659eaeb42 100644 --- a/synapse/util/async_helpers.py +++ b/synapse/util/async_helpers.py @@ -138,7 +138,7 @@ def concurrently_execute(func, args, limit): the number of concurrent executions. Args: - func (func): Function to execute, should return a deferred. + func (func): Function to execute, should return a deferred or coroutine. args (list): List of arguments to pass to func, each invocation of func gets a signle argument. limit (int): Maximum number of conccurent executions. @@ -148,11 +148,10 @@ def concurrently_execute(func, args, limit): """ it = iter(args) - @defer.inlineCallbacks - def _concurrently_execute_inner(): + async def _concurrently_execute_inner(): try: while True: - yield func(next(it)) + await maybe_awaitable(func(next(it))) except StopIteration: pass diff --git a/tests/handlers/test_typing.py b/tests/handlers/test_typing.py index 67f1013051..f360c8e965 100644 --- a/tests/handlers/test_typing.py +++ b/tests/handlers/test_typing.py @@ -144,6 +144,9 @@ class TypingNotificationsTestCase(unittest.HomeserverTestCase): self.datastore.get_to_device_stream_token = lambda: 0 self.datastore.get_new_device_msgs_for_remote = lambda *args, **kargs: ([], 0) self.datastore.delete_device_msgs_for_remote = lambda *args, **kargs: None + self.datastore.set_received_txn_response = lambda *args, **kwargs: defer.succeed( + None + ) def test_started_typing_local(self): self.room_members = [U_APPLE, U_BANANA] |