From 38e1fac8861f12b707609da06008695a05aaf21c Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 9 Jul 2020 09:52:58 -0400 Subject: Fix some spelling mistakes / typos. (#7811) --- tests/crypto/test_keyring.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tests/crypto/test_keyring.py') diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index 70c8e72303..f9ce609923 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -192,7 +192,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): d = _verify_json_for_server(kr, "server9", {}, 0, "test unsigned") self.failureResultOf(d, SynapseError) - # should suceed on a signed object + # should succeed on a signed object d = _verify_json_for_server(kr, "server9", json1, 500, "test signed") # self.assertFalse(d.called) self.get_success(d) -- cgit 1.5.1 From c978f6c4515a631f289aedb1844d8579b9334aaa Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Jul 2020 08:01:33 -0400 Subject: Convert federation client to async/await. (#7975) --- changelog.d/7975.misc | 1 + contrib/cmdclient/console.py | 16 ++-- synapse/crypto/keyring.py | 60 +++++++------- synapse/federation/federation_client.py | 8 +- synapse/federation/sender/__init__.py | 19 ++--- synapse/federation/transport/client.py | 96 ++++++++++------------- synapse/handlers/groups_local.py | 35 ++++----- synapse/http/matrixfederationclient.py | 72 ++++++++--------- tests/crypto/test_keyring.py | 11 +-- tests/federation/test_complexity.py | 21 ++--- tests/federation/test_federation_sender.py | 10 +-- tests/handlers/test_directory.py | 5 +- tests/handlers/test_profile.py | 3 +- tests/http/test_fedclient.py | 50 ++++++++---- tests/replication/test_federation_sender_shard.py | 13 ++- tests/rest/admin/test_admin.py | 4 +- tests/rest/key/v2/test_remote_key_resource.py | 4 +- tests/test_federation.py | 2 +- 18 files changed, 209 insertions(+), 221 deletions(-) create mode 100644 changelog.d/7975.misc (limited to 'tests/crypto/test_keyring.py') diff --git a/changelog.d/7975.misc b/changelog.d/7975.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/7975.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/contrib/cmdclient/console.py b/contrib/cmdclient/console.py index 77422f5e5d..dfc1d294dc 100755 --- a/contrib/cmdclient/console.py +++ b/contrib/cmdclient/console.py @@ -609,13 +609,15 @@ class SynapseCmd(cmd.Cmd): @defer.inlineCallbacks def _do_event_stream(self, timeout): - res = yield self.http_client.get_json( - self._url() + "/events", - { - "access_token": self._tok(), - "timeout": str(timeout), - "from": self.event_stream_token, - }, + res = yield defer.ensureDeferred( + self.http_client.get_json( + self._url() + "/events", + { + "access_token": self._tok(), + "timeout": str(timeout), + "from": self.event_stream_token, + }, + ) ) print(json.dumps(res, indent=4)) diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index dbfc3e8972..443cde0b6d 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -632,18 +632,20 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) try: - query_response = yield self.client.post_json( - destination=perspective_name, - path="/_matrix/key/v2/query", - data={ - "server_keys": { - server_name: { - key_id: {"minimum_valid_until_ts": min_valid_ts} - for key_id, min_valid_ts in server_keys.items() + query_response = yield defer.ensureDeferred( + self.client.post_json( + destination=perspective_name, + path="/_matrix/key/v2/query", + data={ + "server_keys": { + server_name: { + key_id: {"minimum_valid_until_ts": min_valid_ts} + for key_id, min_valid_ts in server_keys.items() + } + for server_name, server_keys in keys_to_fetch.items() } - for server_name, server_keys in keys_to_fetch.items() - } - }, + }, + ) ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve upon @@ -792,23 +794,25 @@ class ServerKeyFetcher(BaseV2KeyFetcher): time_now_ms = self.clock.time_msec() try: - response = yield self.client.get_json( - destination=server_name, - path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id), - ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an - # easy request to handle, so if it doesn't reply within 10s, it's - # probably not going to. - # - # Furthermore, when we are acting as a notary server, we cannot - # wait all day for all of the origin servers, as the requesting - # server will otherwise time out before we can respond. - # - # (Note that get_json may make 4 attempts, so this can still take - # almost 45 seconds to fetch the headers, plus up to another 60s to - # read the response). - timeout=10000, + response = yield defer.ensureDeferred( + self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server/" + + urllib.parse.quote(requested_key_id), + ignore_backoff=True, + # we only give the remote server 10s to respond. It should be an + # easy request to handle, so if it doesn't reply within 10s, it's + # probably not going to. + # + # Furthermore, when we are acting as a notary server, we cannot + # wait all day for all of the origin servers, as the requesting + # server will otherwise time out before we can respond. + # + # (Note that get_json may make 4 attempts, so this can still take + # almost 45 seconds to fetch the headers, plus up to another 60s to + # read the response). + timeout=10000, + ) ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 994e6c8d5a..38ac7ec699 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -135,7 +135,7 @@ class FederationClient(FederationBase): and try the request anyway. Returns: - a Deferred which will eventually yield a JSON object from the + a Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels(query_type).inc() @@ -157,7 +157,7 @@ class FederationClient(FederationBase): content (dict): The query content. Returns: - a Deferred which will eventually yield a JSON object from the + an Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels("client_device_keys").inc() @@ -180,7 +180,7 @@ class FederationClient(FederationBase): content (dict): The query content. Returns: - a Deferred which will eventually yield a JSON object from the + an Awaitable which will eventually yield a JSON object from the response """ sent_queries_counter.labels("client_one_time_keys").inc() @@ -900,7 +900,7 @@ class FederationClient(FederationBase): party instance Returns: - Deferred[Dict[str, Any]]: The response from the remote server, or None if + Awaitable[Dict[str, Any]]: The response from the remote server, or None if `remote_server` is the same as the local server_name Raises: diff --git a/synapse/federation/sender/__init__.py b/synapse/federation/sender/__init__.py index ba4ddd2370..8f549ae6ee 100644 --- a/synapse/federation/sender/__init__.py +++ b/synapse/federation/sender/__init__.py @@ -288,8 +288,7 @@ class FederationSender(object): for destination in destinations: self._get_per_destination_queue(destination).send_pdu(pdu, order) - @defer.inlineCallbacks - def send_read_receipt(self, receipt: ReadReceipt): + async def send_read_receipt(self, receipt: ReadReceipt) -> None: """Send a RR to any other servers in the room Args: @@ -330,9 +329,7 @@ class FederationSender(object): room_id = receipt.room_id # Work out which remote servers should be poked and poke them. - domains = yield defer.ensureDeferred( - self.state.get_current_hosts_in_room(room_id) - ) + domains = await self.state.get_current_hosts_in_room(room_id) domains = [ d for d in domains @@ -387,8 +384,7 @@ class FederationSender(object): queue.flush_read_receipts_for_room(room_id) @preserve_fn # the caller should not yield on this - @defer.inlineCallbacks - def send_presence(self, states: List[UserPresenceState]): + async def send_presence(self, states: List[UserPresenceState]): """Send the new presence states to the appropriate destinations. This actually queues up the presence states ready for sending and @@ -423,7 +419,7 @@ class FederationSender(object): if not states_map: break - yield self._process_presence_inner(list(states_map.values())) + await self._process_presence_inner(list(states_map.values())) except Exception: logger.exception("Error sending presence states to servers") finally: @@ -450,14 +446,11 @@ class FederationSender(object): self._get_per_destination_queue(destination).send_presence(states) @measure_func("txnqueue._process_presence") - @defer.inlineCallbacks - def _process_presence_inner(self, states: List[UserPresenceState]): + async def _process_presence_inner(self, states: List[UserPresenceState]): """Given a list of states populate self.pending_presence_by_dest and poke to send a new transaction to each destination """ - hosts_and_states = yield defer.ensureDeferred( - get_interested_remotes(self.store, states, self.state) - ) + hosts_and_states = await get_interested_remotes(self.store, states, self.state) for destinations, states in hosts_and_states: for destination in destinations: diff --git a/synapse/federation/transport/client.py b/synapse/federation/transport/client.py index cfdf23d366..9ea821dbb2 100644 --- a/synapse/federation/transport/client.py +++ b/synapse/federation/transport/client.py @@ -18,8 +18,6 @@ import logging import urllib from typing import Any, Dict, Optional -from twisted.internet import defer - from synapse.api.constants import Membership from synapse.api.errors import Codes, HttpResponseException, SynapseError from synapse.api.urls import ( @@ -51,7 +49,7 @@ class TransportLayerClient(object): event_id (str): The event we want the context at. Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug("get_room_state_ids dest=%s, room=%s", destination, room_id) @@ -75,7 +73,7 @@ class TransportLayerClient(object): giving up. None indicates no timeout. Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug("get_pdu dest=%s, event_id=%s", destination, event_id) @@ -96,7 +94,7 @@ class TransportLayerClient(object): limit (int) Returns: - Deferred: Results in a dict received from the remote homeserver. + Awaitable: Results in a dict received from the remote homeserver. """ logger.debug( "backfill dest=%s, room_id=%s, event_tuples=%r, limit=%s", @@ -118,16 +116,15 @@ class TransportLayerClient(object): destination, path=path, args=args, try_trailing_slash_on_400=True ) - @defer.inlineCallbacks @log_function - def send_transaction(self, transaction, json_data_callback=None): + async def send_transaction(self, transaction, json_data_callback=None): """ Sends the given Transaction to its destination Args: transaction (Transaction) Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Fails with ``HTTPRequestException`` if we get an HTTP response @@ -154,7 +151,7 @@ class TransportLayerClient(object): path = _create_v1_path("/send/%s", transaction.transaction_id) - response = yield self.client.put_json( + response = await self.client.put_json( transaction.destination, path=path, data=json_data, @@ -166,14 +163,13 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def make_query( + async def make_query( self, destination, query_type, args, retry_on_dns_fail, ignore_backoff=False ): path = _create_v1_path("/query/%s", query_type) - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, args=args, @@ -184,9 +180,10 @@ class TransportLayerClient(object): return content - @defer.inlineCallbacks @log_function - def make_membership_event(self, destination, room_id, user_id, membership, params): + async def make_membership_event( + self, destination, room_id, user_id, membership, params + ): """Asks a remote server to build and sign us a membership event Note that this does not append any events to any graphs. @@ -200,7 +197,7 @@ class TransportLayerClient(object): request. Returns: - Deferred: Succeeds when we get a 2xx HTTP response. The result + Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body (ie, the new event). Fails with ``HTTPRequestException`` if we get an HTTP response @@ -231,7 +228,7 @@ class TransportLayerClient(object): ignore_backoff = True retry_on_dns_fail = True - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, args=params, @@ -242,34 +239,31 @@ class TransportLayerClient(object): return content - @defer.inlineCallbacks @log_function - def send_join_v1(self, destination, room_id, event_id, content): + async def send_join_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_join/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content ) return response - @defer.inlineCallbacks @log_function - def send_join_v2(self, destination, room_id, event_id, content): + async def send_join_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/send_join/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content ) return response - @defer.inlineCallbacks @log_function - def send_leave_v1(self, destination, room_id, event_id, content): + async def send_leave_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/send_leave/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, @@ -282,12 +276,11 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def send_leave_v2(self, destination, room_id, event_id, content): + async def send_leave_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/send_leave/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, @@ -300,31 +293,28 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def send_invite_v1(self, destination, room_id, event_id, content): + async def send_invite_v1(self, destination, room_id, event_id, content): path = _create_v1_path("/invite/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) return response - @defer.inlineCallbacks @log_function - def send_invite_v2(self, destination, room_id, event_id, content): + async def send_invite_v2(self, destination, room_id, event_id, content): path = _create_v2_path("/invite/%s/%s", room_id, event_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=content, ignore_backoff=True ) return response - @defer.inlineCallbacks @log_function - def get_public_rooms( + async def get_public_rooms( self, remote_server: str, limit: Optional[int] = None, @@ -355,7 +345,7 @@ class TransportLayerClient(object): data["filter"] = search_filter try: - response = yield self.client.post_json( + response = await self.client.post_json( destination=remote_server, path=path, data=data, ignore_backoff=True ) except HttpResponseException as e: @@ -381,7 +371,7 @@ class TransportLayerClient(object): args["since"] = [since_token] try: - response = yield self.client.get_json( + response = await self.client.get_json( destination=remote_server, path=path, args=args, ignore_backoff=True ) except HttpResponseException as e: @@ -396,29 +386,26 @@ class TransportLayerClient(object): return response - @defer.inlineCallbacks @log_function - def exchange_third_party_invite(self, destination, room_id, event_dict): + async def exchange_third_party_invite(self, destination, room_id, event_dict): path = _create_v1_path("/exchange_third_party_invite/%s", room_id) - response = yield self.client.put_json( + response = await self.client.put_json( destination=destination, path=path, data=event_dict ) return response - @defer.inlineCallbacks @log_function - def get_event_auth(self, destination, room_id, event_id): + async def get_event_auth(self, destination, room_id, event_id): path = _create_v1_path("/event_auth/%s/%s", room_id, event_id) - content = yield self.client.get_json(destination=destination, path=path) + content = await self.client.get_json(destination=destination, path=path) return content - @defer.inlineCallbacks @log_function - def query_client_keys(self, destination, query_content, timeout): + async def query_client_keys(self, destination, query_content, timeout): """Query the device keys for a list of user ids hosted on a remote server. @@ -453,14 +440,13 @@ class TransportLayerClient(object): """ path = _create_v1_path("/user/keys/query") - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def query_user_devices(self, destination, user_id, timeout): + async def query_user_devices(self, destination, user_id, timeout): """Query the devices for a user id hosted on a remote server. Response: @@ -493,14 +479,13 @@ class TransportLayerClient(object): """ path = _create_v1_path("/user/devices/%s", user_id) - content = yield self.client.get_json( + content = await self.client.get_json( destination=destination, path=path, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def claim_client_keys(self, destination, query_content, timeout): + async def claim_client_keys(self, destination, query_content, timeout): """Claim one-time keys for a list of devices hosted on a remote server. Request: @@ -532,14 +517,13 @@ class TransportLayerClient(object): path = _create_v1_path("/user/keys/claim") - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data=query_content, timeout=timeout ) return content - @defer.inlineCallbacks @log_function - def get_missing_events( + async def get_missing_events( self, destination, room_id, @@ -551,7 +535,7 @@ class TransportLayerClient(object): ): path = _create_v1_path("/get_missing_events/%s", room_id) - content = yield self.client.post_json( + content = await self.client.post_json( destination=destination, path=path, data={ diff --git a/synapse/handlers/groups_local.py b/synapse/handlers/groups_local.py index ecdb12a7bf..0e2656ccb3 100644 --- a/synapse/handlers/groups_local.py +++ b/synapse/handlers/groups_local.py @@ -23,39 +23,32 @@ logger = logging.getLogger(__name__) def _create_rerouter(func_name): - """Returns a function that looks at the group id and calls the function + """Returns an async function that looks at the group id and calls the function on federation or the local group server if the group is local """ - def f(self, group_id, *args, **kwargs): + async def f(self, group_id, *args, **kwargs): if self.is_mine_id(group_id): - return getattr(self.groups_server_handler, func_name)( + return await getattr(self.groups_server_handler, func_name)( group_id, *args, **kwargs ) else: destination = get_domain_from_id(group_id) - d = getattr(self.transport_client, func_name)( - destination, group_id, *args, **kwargs - ) - # Capture errors returned by the remote homeserver and - # re-throw specific errors as SynapseErrors. This is so - # when the remote end responds with things like 403 Not - # In Group, we can communicate that to the client instead - # of a 500. - def http_response_errback(failure): - failure.trap(HttpResponseException) - e = failure.value + try: + return await getattr(self.transport_client, func_name)( + destination, group_id, *args, **kwargs + ) + except HttpResponseException as e: + # Capture errors returned by the remote homeserver and + # re-throw specific errors as SynapseErrors. This is so + # when the remote end responds with things like 403 Not + # In Group, we can communicate that to the client instead + # of a 500. raise e.to_synapse_error() - - def request_failed_errback(failure): - failure.trap(RequestSendFailed) + except RequestSendFailed: raise SynapseError(502, "Failed to contact group server") - d.addErrback(http_response_errback) - d.addErrback(request_failed_errback) - return d - return f diff --git a/synapse/http/matrixfederationclient.py b/synapse/http/matrixfederationclient.py index ea026ed9f4..2a6373937a 100644 --- a/synapse/http/matrixfederationclient.py +++ b/synapse/http/matrixfederationclient.py @@ -121,8 +121,7 @@ class MatrixFederationRequest(object): return self.json -@defer.inlineCallbacks -def _handle_json_response(reactor, timeout_sec, request, response): +async def _handle_json_response(reactor, timeout_sec, request, response): """ Reads the JSON body of a response, with a timeout @@ -141,7 +140,7 @@ def _handle_json_response(reactor, timeout_sec, request, response): d = treq.json_content(response) d = timeout_deferred(d, timeout=timeout_sec, reactor=reactor) - body = yield make_deferred_yieldable(d) + body = await make_deferred_yieldable(d) except TimeoutError as e: logger.warning( "{%s} [%s] Timed out reading response", request.txn_id, request.destination, @@ -224,8 +223,7 @@ class MatrixFederationHttpClient(object): self._cooperator = Cooperator(scheduler=schedule) - @defer.inlineCallbacks - def _send_request_with_optional_trailing_slash( + async def _send_request_with_optional_trailing_slash( self, request, try_trailing_slash_on_400=False, **send_request_args ): """Wrapper for _send_request which can optionally retry the request @@ -246,10 +244,10 @@ class MatrixFederationHttpClient(object): (except 429). Returns: - Deferred[Dict]: Parsed JSON response body. + Dict: Parsed JSON response body. """ try: - response = yield self._send_request(request, **send_request_args) + response = await self._send_request(request, **send_request_args) except HttpResponseException as e: # Received an HTTP error > 300. Check if it meets the requirements # to retry with a trailing slash @@ -265,12 +263,11 @@ class MatrixFederationHttpClient(object): logger.info("Retrying request with trailing slash") request.path += "/" - response = yield self._send_request(request, **send_request_args) + response = await self._send_request(request, **send_request_args) return response - @defer.inlineCallbacks - def _send_request( + async def _send_request( self, request, retry_on_dns_fail=True, @@ -311,7 +308,7 @@ class MatrixFederationHttpClient(object): backoff_on_404 (bool): Back off if we get a 404 Returns: - Deferred[twisted.web.client.Response]: resolves with the HTTP + twisted.web.client.Response: resolves with the HTTP response object on success. Raises: @@ -335,7 +332,7 @@ class MatrixFederationHttpClient(object): ): raise FederationDeniedError(request.destination) - limiter = yield synapse.util.retryutils.get_retry_limiter( + limiter = await synapse.util.retryutils.get_retry_limiter( request.destination, self.clock, self._store, @@ -433,7 +430,7 @@ class MatrixFederationHttpClient(object): reactor=self.reactor, ) - response = yield request_deferred + response = await request_deferred except TimeoutError as e: raise RequestSendFailed(e, can_retry=True) from e except DNSLookupError as e: @@ -474,7 +471,7 @@ class MatrixFederationHttpClient(object): ) try: - body = yield make_deferred_yieldable(d) + body = await make_deferred_yieldable(d) except Exception as e: # Eh, we're already going to raise an exception so lets # ignore if this fails. @@ -528,7 +525,7 @@ class MatrixFederationHttpClient(object): delay, ) - yield self.clock.sleep(delay) + await self.clock.sleep(delay) retries_left -= 1 else: raise @@ -591,8 +588,7 @@ class MatrixFederationHttpClient(object): ) return auth_headers - @defer.inlineCallbacks - def put_json( + async def put_json( self, destination, path, @@ -636,7 +632,7 @@ class MatrixFederationHttpClient(object): enabled. Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -658,7 +654,7 @@ class MatrixFederationHttpClient(object): json=data, ) - response = yield self._send_request_with_optional_trailing_slash( + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, backoff_on_404=backoff_on_404, @@ -667,14 +663,13 @@ class MatrixFederationHttpClient(object): timeout=timeout, ) - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, self.default_timeout, request, response ) return body - @defer.inlineCallbacks - def post_json( + async def post_json( self, destination, path, @@ -707,7 +702,7 @@ class MatrixFederationHttpClient(object): args (dict): query params Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -725,7 +720,7 @@ class MatrixFederationHttpClient(object): method="POST", destination=destination, path=path, query=args, json=data ) - response = yield self._send_request( + response = await self._send_request( request, long_retries=long_retries, timeout=timeout, @@ -737,13 +732,12 @@ class MatrixFederationHttpClient(object): else: _sec_timeout = self.default_timeout - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, _sec_timeout, request, response ) return body - @defer.inlineCallbacks - def get_json( + async def get_json( self, destination, path, @@ -775,7 +769,7 @@ class MatrixFederationHttpClient(object): response we should try appending a trailing slash to the end of the request. Workaround for #3622 in Synapse <= v0.99.3. Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -792,7 +786,7 @@ class MatrixFederationHttpClient(object): method="GET", destination=destination, path=path, query=args ) - response = yield self._send_request_with_optional_trailing_slash( + response = await self._send_request_with_optional_trailing_slash( request, try_trailing_slash_on_400, backoff_on_404=False, @@ -801,14 +795,13 @@ class MatrixFederationHttpClient(object): timeout=timeout, ) - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, self.default_timeout, request, response ) return body - @defer.inlineCallbacks - def delete_json( + async def delete_json( self, destination, path, @@ -836,7 +829,7 @@ class MatrixFederationHttpClient(object): args (dict): query params Returns: - Deferred[dict|list]: Succeeds when we get a 2xx HTTP response. The + dict|list: Succeeds when we get a 2xx HTTP response. The result will be the decoded JSON body. Raises: @@ -853,20 +846,19 @@ class MatrixFederationHttpClient(object): method="DELETE", destination=destination, path=path, query=args ) - response = yield self._send_request( + response = await self._send_request( request, long_retries=long_retries, timeout=timeout, ignore_backoff=ignore_backoff, ) - body = yield _handle_json_response( + body = await _handle_json_response( self.reactor, self.default_timeout, request, response ) return body - @defer.inlineCallbacks - def get_file( + async def get_file( self, destination, path, @@ -886,7 +878,7 @@ class MatrixFederationHttpClient(object): and try the request anyway. Returns: - Deferred[tuple[int, dict]]: Resolves with an (int,dict) tuple of + tuple[int, dict]: Resolves with an (int,dict) tuple of the file length and a dict of the response headers. Raises: @@ -903,7 +895,7 @@ class MatrixFederationHttpClient(object): method="GET", destination=destination, path=path, query=args ) - response = yield self._send_request( + response = await self._send_request( request, retry_on_dns_fail=retry_on_dns_fail, ignore_backoff=ignore_backoff ) @@ -912,7 +904,7 @@ class MatrixFederationHttpClient(object): try: d = _readBodyToFile(response, output_stream, max_size) d.addTimeout(self.default_timeout, self.reactor) - length = yield make_deferred_yieldable(d) + length = await make_deferred_yieldable(d) except Exception as e: logger.warning( "{%s} [%s] Error reading response: %s", diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index f9ce609923..e0ad8e8a77 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -102,11 +102,10 @@ class KeyringTestCase(unittest.HomeserverTestCase): } persp_deferred = defer.Deferred() - @defer.inlineCallbacks - def get_perspectives(**kwargs): + async def get_perspectives(**kwargs): self.assertEquals(current_context().request, "11") with PreserveLoggingContext(): - yield persp_deferred + await persp_deferred return persp_resp self.http_client.post_json.side_effect = get_perspectives @@ -355,7 +354,7 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase): } signedjson.sign.sign_json(response, SERVER_NAME, testkey) - def get_json(destination, path, **kwargs): + async def get_json(destination, path, **kwargs): self.assertEqual(destination, SERVER_NAME) self.assertEqual(path, "/_matrix/key/v2/server/key1") return response @@ -444,7 +443,7 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect a perspectives-server key query """ - def post_json(destination, path, data, **kwargs): + async def post_json(destination, path, data, **kwargs): self.assertEqual(destination, self.mock_perspective_server.server_name) self.assertEqual(path, "/_matrix/key/v2/query") @@ -580,14 +579,12 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase): # remove the perspectives server's signature response = build_response() del response["signatures"][self.mock_perspective_server.server_name] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing persp server sig") # remove the origin server's signature response = build_response() del response["signatures"][SERVER_NAME] - self.http_client.post_json.return_value = {"server_keys": [response]} keys = get_key_from_perspectives(response) self.assertEqual(keys, {}, "Expected empty dict with missing origin server sig") diff --git a/tests/federation/test_complexity.py b/tests/federation/test_complexity.py index 5cd0510f0d..b8ca118716 100644 --- a/tests/federation/test_complexity.py +++ b/tests/federation/test_complexity.py @@ -23,6 +23,7 @@ from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable class RoomComplexityTests(unittest.FederatingHomeserverTestCase): @@ -78,9 +79,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -109,9 +110,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -147,9 +148,9 @@ class RoomComplexityTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed(None)) + fed_transport.client.get_json = Mock(return_value=make_awaitable(None)) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) # Artificially raise the complexity @@ -203,9 +204,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( @@ -233,9 +234,9 @@ class RoomComplexityAdminTests(unittest.FederatingHomeserverTestCase): fed_transport = self.hs.get_federation_transport_client() # Mock out some things, because we don't want to test the whole join - fed_transport.client.get_json = Mock(return_value=defer.succeed({"v1": 9999})) + fed_transport.client.get_json = Mock(return_value=make_awaitable({"v1": 9999})) handler.federation_handler.do_invite_join = Mock( - return_value=defer.succeed(("", 1)) + return_value=make_awaitable(("", 1)) ) d = handler._remote_join( diff --git a/tests/federation/test_federation_sender.py b/tests/federation/test_federation_sender.py index d1bd18da39..5f512ff8bf 100644 --- a/tests/federation/test_federation_sender.py +++ b/tests/federation/test_federation_sender.py @@ -47,13 +47,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -87,13 +87,13 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): mock_send_transaction = ( self.hs.get_federation_transport_client().send_transaction ) - mock_send_transaction.return_value = defer.succeed({}) + mock_send_transaction.return_value = make_awaitable({}) sender = self.hs.get_federation_sender() receipt = ReadReceipt( "room_id", "m.read", "user_id", ["event_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() @@ -125,7 +125,7 @@ class FederationSenderReceiptsTestCases(HomeserverTestCase): receipt = ReadReceipt( "room_id", "m.read", "user_id", ["other_id"], {"ts": 1234} ) - self.successResultOf(sender.send_read_receipt(receipt)) + self.successResultOf(defer.ensureDeferred(sender.send_read_receipt(receipt))) self.pump() mock_send_transaction.assert_not_called() diff --git a/tests/handlers/test_directory.py b/tests/handlers/test_directory.py index 00bb776271..bc0c5aefdc 100644 --- a/tests/handlers/test_directory.py +++ b/tests/handlers/test_directory.py @@ -16,8 +16,6 @@ from mock import Mock -from twisted.internet import defer - import synapse import synapse.api.errors from synapse.api.constants import EventTypes @@ -26,6 +24,7 @@ from synapse.rest.client.v1 import directory, login, room from synapse.types import RoomAlias, create_requester from tests import unittest +from tests.test_utils import make_awaitable class DirectoryTestCase(unittest.HomeserverTestCase): @@ -71,7 +70,7 @@ class DirectoryTestCase(unittest.HomeserverTestCase): self.assertEquals({"room_id": "!8765qwer:test", "servers": ["test"]}, result) def test_get_remote_association(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"room_id": "!8765qwer:test", "servers": ["test", "remote"]} ) diff --git a/tests/handlers/test_profile.py b/tests/handlers/test_profile.py index 4f1347cd25..d70e1fc608 100644 --- a/tests/handlers/test_profile.py +++ b/tests/handlers/test_profile.py @@ -24,6 +24,7 @@ from synapse.handlers.profile import MasterProfileHandler from synapse.types import UserID from tests import unittest +from tests.test_utils import make_awaitable from tests.utils import setup_test_homeserver @@ -138,7 +139,7 @@ class ProfileTestCase(unittest.TestCase): @defer.inlineCallbacks def test_get_other_name(self): - self.mock_federation.make_query.return_value = defer.succeed( + self.mock_federation.make_query.return_value = make_awaitable( {"displayname": "Alice"} ) diff --git a/tests/http/test_fedclient.py b/tests/http/test_fedclient.py index fff4f0cbf4..ac598249e4 100644 --- a/tests/http/test_fedclient.py +++ b/tests/http/test_fedclient.py @@ -58,7 +58,9 @@ class FederationClientTests(HomeserverTestCase): @defer.inlineCallbacks def do_request(): with LoggingContext("one") as context: - fetch_d = self.cl.get_json("testserv:8008", "foo/bar") + fetch_d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar") + ) # Nothing happened yet self.assertNoResult(fetch_d) @@ -120,7 +122,9 @@ class FederationClientTests(HomeserverTestCase): """ If the DNS lookup returns an error, it will bubble up. """ - d = self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv2:8008", "foo/bar", timeout=10000) + ) self.pump() f = self.failureResultOf(d) @@ -128,7 +132,9 @@ class FederationClientTests(HomeserverTestCase): self.assertIsInstance(f.value.inner_exception, DNSLookupError) def test_client_connection_refused(self): - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -154,7 +160,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is not connected and is timed out, it'll give a ConnectingCancelledError or TimeoutError. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -184,7 +192,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -226,7 +236,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a blacklisted IPv4 address # ------------------------------------------------------ # Make the request - d = cl.get_json("internal:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.get_json("internal:8008", "foo/bar", timeout=10000)) # Nothing happened yet self.assertNoResult(d) @@ -244,7 +254,9 @@ class FederationClientTests(HomeserverTestCase): # Try making a POST request to a blacklisted IPv6 address # ------------------------------------------------------- # Make the request - d = cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + cl.post_json("internalv6:8008", "foo/bar", timeout=10000) + ) # Nothing has happened yet self.assertNoResult(d) @@ -263,7 +275,7 @@ class FederationClientTests(HomeserverTestCase): # Try making a GET request to a non-blacklisted IPv4 address # ---------------------------------------------------------- # Make the request - d = cl.post_json("fine:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred(cl.post_json("fine:8008", "foo/bar", timeout=10000)) # Nothing has happened yet self.assertNoResult(d) @@ -286,7 +298,7 @@ class FederationClientTests(HomeserverTestCase): request = MatrixFederationRequest( method="GET", destination="testserv:8008", path="foo/bar" ) - d = self.cl._send_request(request, timeout=10000) + d = defer.ensureDeferred(self.cl._send_request(request, timeout=10000)) self.pump() @@ -310,7 +322,9 @@ class FederationClientTests(HomeserverTestCase): If the HTTP request is connected, but gets no response before being timed out, it'll give a ResponseNeverReceived. """ - d = self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + d = defer.ensureDeferred( + self.cl.post_json("testserv:8008", "foo/bar", timeout=10000) + ) self.pump() @@ -342,7 +356,9 @@ class FederationClientTests(HomeserverTestCase): requiring a trailing slash. We need to retry the request with a trailing slash. Workaround for Synapse <= v0.99.3, explained in #3622. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -395,7 +411,9 @@ class FederationClientTests(HomeserverTestCase): See test_client_requires_trailing_slashes() for context. """ - d = self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + d = defer.ensureDeferred( + self.cl.get_json("testserv:8008", "foo/bar", try_trailing_slash_on_400=True) + ) # Send the request self.pump() @@ -432,7 +450,11 @@ class FederationClientTests(HomeserverTestCase): self.failureResultOf(d) def test_client_sends_body(self): - self.cl.post_json("testserv:8008", "foo/bar", timeout=10000, data={"a": "b"}) + defer.ensureDeferred( + self.cl.post_json( + "testserv:8008", "foo/bar", timeout=10000, data={"a": "b"} + ) + ) self.pump() @@ -453,7 +475,7 @@ class FederationClientTests(HomeserverTestCase): def test_closes_connection(self): """Check that the client closes unused HTTP connections""" - d = self.cl.get_json("testserv:8008", "foo/bar") + d = defer.ensureDeferred(self.cl.get_json("testserv:8008", "foo/bar")) self.pump() diff --git a/tests/replication/test_federation_sender_shard.py b/tests/replication/test_federation_sender_shard.py index 8d4dbf232e..83f9aa291c 100644 --- a/tests/replication/test_federation_sender_shard.py +++ b/tests/replication/test_federation_sender_shard.py @@ -16,8 +16,6 @@ import logging from mock import Mock -from twisted.internet import defer - from synapse.api.constants import EventTypes, Membership from synapse.events.builder import EventBuilderFactory from synapse.rest.admin import register_servlets_for_client_rest_resource @@ -25,6 +23,7 @@ from synapse.rest.client.v1 import login, room from synapse.types import UserID from tests.replication._base import BaseMultiWorkerStreamTestCase +from tests.test_utils import make_awaitable logger = logging.getLogger(__name__) @@ -46,7 +45,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new event. """ mock_client = Mock(spec=["put_json"]) - mock_client.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", @@ -74,7 +73,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new events. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -86,7 +85,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -137,7 +136,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): new typing EDUs. """ mock_client1 = Mock(spec=["put_json"]) - mock_client1.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client1.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { @@ -149,7 +148,7 @@ class FederationSenderTestCase(BaseMultiWorkerStreamTestCase): ) mock_client2 = Mock(spec=["put_json"]) - mock_client2.put_json.side_effect = lambda *_, **__: defer.succeed({}) + mock_client2.put_json.side_effect = lambda *_, **__: make_awaitable({}) self.make_worker_hs( "synapse.app.federation_sender", { diff --git a/tests/rest/admin/test_admin.py b/tests/rest/admin/test_admin.py index b1a4decced..0f1144fe1e 100644 --- a/tests/rest/admin/test_admin.py +++ b/tests/rest/admin/test_admin.py @@ -178,7 +178,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): self.fetches = [] - def get_file(destination, path, output_stream, args=None, max_size=None): + async def get_file(destination, path, output_stream, args=None, max_size=None): """ Returns tuple[int,dict,str,int] of file length, response headers, absolute URI, and response code. @@ -192,7 +192,7 @@ class QuarantineMediaTestCase(unittest.HomeserverTestCase): d = Deferred() d.addCallback(write_to) self.fetches.append((d, destination, path, args)) - return make_deferred_yieldable(d) + return await make_deferred_yieldable(d) client = Mock() client.get_file = get_file diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py index 99eb477149..6850c666be 100644 --- a/tests/rest/key/v2/test_remote_key_resource.py +++ b/tests/rest/key/v2/test_remote_key_resource.py @@ -53,7 +53,7 @@ class BaseRemoteKeyResourceTestCase(unittest.HomeserverTestCase): Tell the mock http client to expect an outgoing GET request for the given key """ - def get_json(destination, path, ignore_backoff=False, **kwargs): + async def get_json(destination, path, ignore_backoff=False, **kwargs): self.assertTrue(ignore_backoff) self.assertEqual(destination, server_name) key_id = "%s:%s" % (signing_key.alg, signing_key.version) @@ -177,7 +177,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase): # wire up outbound POST /key/v2/query requests from hs2 so that they # will be forwarded to hs1 - def post_json(destination, path, data): + async def post_json(destination, path, data): self.assertEqual(destination, self.hs.hostname) self.assertEqual( path, "/_matrix/key/v2/query", diff --git a/tests/test_federation.py b/tests/test_federation.py index 87a16d7d7a..c2f12c2741 100644 --- a/tests/test_federation.py +++ b/tests/test_federation.py @@ -95,7 +95,7 @@ class MessageAcceptTests(unittest.HomeserverTestCase): prev_events that said event references. """ - def post_json(destination, path, data, headers=None, timeout=0): + async def post_json(destination, path, data, headers=None, timeout=0): # If it asks us for new missing events, give them NOTHING if path.startswith("/_matrix/federation/v1/get_missing_events/"): return {"events": []} -- cgit 1.5.1 From 2a89ce8cd4d563ef22995882e9548f1aff3e42f1 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Mon, 3 Aug 2020 08:29:01 -0400 Subject: Convert the crypto module to async/await. (#8003) --- changelog.d/8003.misc | 1 + synapse/crypto/keyring.py | 201 ++++++++++++++++++++----------------------- tests/crypto/test_keyring.py | 39 ++++----- 3 files changed, 109 insertions(+), 132 deletions(-) create mode 100644 changelog.d/8003.misc (limited to 'tests/crypto/test_keyring.py') diff --git a/changelog.d/8003.misc b/changelog.d/8003.misc new file mode 100644 index 0000000000..dfe4c03171 --- /dev/null +++ b/changelog.d/8003.misc @@ -0,0 +1 @@ +Convert various parts of the codebase to async/await. diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py index 443cde0b6d..28ef7cfdb9 100644 --- a/synapse/crypto/keyring.py +++ b/synapse/crypto/keyring.py @@ -223,8 +223,7 @@ class Keyring(object): return results - @defer.inlineCallbacks - def _start_key_lookups(self, verify_requests): + async def _start_key_lookups(self, verify_requests): """Sets off the key fetches for each verify request Once each fetch completes, verify_request.key_ready will be resolved. @@ -245,7 +244,7 @@ class Keyring(object): server_to_request_ids.setdefault(server_name, set()).add(request_id) # Wait for any previous lookups to complete before proceeding. - yield self.wait_for_previous_lookups(server_to_request_ids.keys()) + await self.wait_for_previous_lookups(server_to_request_ids.keys()) # take out a lock on each of the servers by sticking a Deferred in # key_downloads @@ -283,15 +282,14 @@ class Keyring(object): except Exception: logger.exception("Error starting key lookups") - @defer.inlineCallbacks - def wait_for_previous_lookups(self, server_names): + async def wait_for_previous_lookups(self, server_names) -> None: """Waits for any previous key lookups for the given servers to finish. Args: server_names (Iterable[str]): list of servers which we want to look up Returns: - Deferred[None]: resolves once all key lookups for the given servers have + Resolves once all key lookups for the given servers have completed. Follows the synapse rules of logcontext preservation. """ loop_count = 1 @@ -309,7 +307,7 @@ class Keyring(object): loop_count, ) with PreserveLoggingContext(): - yield defer.DeferredList((w[1] for w in wait_on)) + await defer.DeferredList((w[1] for w in wait_on)) loop_count += 1 @@ -326,44 +324,44 @@ class Keyring(object): remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called} - @defer.inlineCallbacks - def do_iterations(): - with Measure(self.clock, "get_server_verify_keys"): - for f in self._key_fetchers: - if not remaining_requests: - return - yield self._attempt_key_fetches_with_fetcher(f, remaining_requests) + async def do_iterations(): + try: + with Measure(self.clock, "get_server_verify_keys"): + for f in self._key_fetchers: + if not remaining_requests: + return + await self._attempt_key_fetches_with_fetcher( + f, remaining_requests + ) - # look for any requests which weren't satisfied + # look for any requests which weren't satisfied + with PreserveLoggingContext(): + for verify_request in remaining_requests: + verify_request.key_ready.errback( + SynapseError( + 401, + "No key for %s with ids in %s (min_validity %i)" + % ( + verify_request.server_name, + verify_request.key_ids, + verify_request.minimum_valid_until_ts, + ), + Codes.UNAUTHORIZED, + ) + ) + except Exception as err: + # we don't really expect to get here, because any errors should already + # have been caught and logged. But if we do, let's log the error and make + # sure that all of the deferreds are resolved. + logger.error("Unexpected error in _get_server_verify_keys: %s", err) with PreserveLoggingContext(): for verify_request in remaining_requests: - verify_request.key_ready.errback( - SynapseError( - 401, - "No key for %s with ids in %s (min_validity %i)" - % ( - verify_request.server_name, - verify_request.key_ids, - verify_request.minimum_valid_until_ts, - ), - Codes.UNAUTHORIZED, - ) - ) - - def on_err(err): - # we don't really expect to get here, because any errors should already - # have been caught and logged. But if we do, let's log the error and make - # sure that all of the deferreds are resolved. - logger.error("Unexpected error in _get_server_verify_keys: %s", err) - with PreserveLoggingContext(): - for verify_request in remaining_requests: - if not verify_request.key_ready.called: - verify_request.key_ready.errback(err) + if not verify_request.key_ready.called: + verify_request.key_ready.errback(err) - run_in_background(do_iterations).addErrback(on_err) + run_in_background(do_iterations) - @defer.inlineCallbacks - def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): + async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests): """Use a key fetcher to attempt to satisfy some key requests Args: @@ -390,7 +388,7 @@ class Keyring(object): verify_request.minimum_valid_until_ts, ) - results = yield fetcher.get_keys(missing_keys) + results = await fetcher.get_keys(missing_keys) completed = [] for verify_request in remaining_requests: @@ -423,7 +421,7 @@ class Keyring(object): class KeyFetcher(object): - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """ Args: keys_to_fetch (dict[str, dict[str, int]]): @@ -442,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher): def __init__(self, hs): self.store = hs.get_datastore() - @defer.inlineCallbacks - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """see KeyFetcher.get_keys""" keys_to_fetch = ( @@ -452,7 +449,7 @@ class StoreKeyFetcher(KeyFetcher): for key_id in keys_for_server.keys() ) - res = yield self.store.get_server_verify_keys(keys_to_fetch) + res = await self.store.get_server_verify_keys(keys_to_fetch) keys = {} for (server_name, key_id), key in res.items(): keys.setdefault(server_name, {})[key_id] = key @@ -464,8 +461,7 @@ class BaseV2KeyFetcher(object): self.store = hs.get_datastore() self.config = hs.get_config() - @defer.inlineCallbacks - def process_v2_response(self, from_server, response_json, time_added_ms): + async def process_v2_response(self, from_server, response_json, time_added_ms): """Parse a 'Server Keys' structure from the result of a /key request This is used to parse either the entirety of the response from @@ -537,7 +533,7 @@ class BaseV2KeyFetcher(object): key_json_bytes = encode_canonical_json(response_json) - yield make_deferred_yieldable( + await make_deferred_yieldable( defer.gatherResults( [ run_in_background( @@ -567,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): self.client = hs.get_http_client() self.key_servers = self.config.key_servers - @defer.inlineCallbacks - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """see KeyFetcher.get_keys""" - @defer.inlineCallbacks - def get_key(key_server): + async def get_key(key_server): try: - result = yield self.get_server_verify_key_v2_indirect( + result = await self.get_server_verify_key_v2_indirect( keys_to_fetch, key_server ) return result @@ -592,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return {} - results = yield make_deferred_yieldable( + results = await make_deferred_yieldable( defer.gatherResults( [run_in_background(get_key, server) for server in self.key_servers], consumeErrors=True, @@ -606,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): return union_of_keys - @defer.inlineCallbacks - def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): + async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server): """ Args: keys_to_fetch (dict[str, dict[str, int]]): @@ -617,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): the keys Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map + dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map from server_name -> key_id -> FetchKeyResult Raises: @@ -632,20 +625,18 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) try: - query_response = yield defer.ensureDeferred( - self.client.post_json( - destination=perspective_name, - path="/_matrix/key/v2/query", - data={ - "server_keys": { - server_name: { - key_id: {"minimum_valid_until_ts": min_valid_ts} - for key_id, min_valid_ts in server_keys.items() - } - for server_name, server_keys in keys_to_fetch.items() + query_response = await self.client.post_json( + destination=perspective_name, + path="/_matrix/key/v2/query", + data={ + "server_keys": { + server_name: { + key_id: {"minimum_valid_until_ts": min_valid_ts} + for key_id, min_valid_ts in server_keys.items() } - }, - ) + for server_name, server_keys in keys_to_fetch.items() + } + }, ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve upon @@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): try: self._validate_perspectives_response(key_server, response) - processed_response = yield self.process_v2_response( + processed_response = await self.process_v2_response( perspective_name, response, time_added_ms=time_now_ms ) except KeyLookupError as e: @@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher): ) keys.setdefault(server_name, {}).update(processed_response) - yield self.store.store_server_verify_keys( + await self.store.store_server_verify_keys( perspective_name, time_now_ms, added_keys ) @@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher): self.clock = hs.get_clock() self.client = hs.get_http_client() - def get_keys(self, keys_to_fetch): + async def get_keys(self, keys_to_fetch): """ Args: keys_to_fetch (dict[str, iterable[str]]): the keys to be fetched. server_name -> key_ids Returns: - Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]: + dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]: map from server_name -> key_id -> FetchKeyResult """ results = {} - @defer.inlineCallbacks - def get_key(key_to_fetch_item): + async def get_key(key_to_fetch_item): server_name, key_ids = key_to_fetch_item try: - keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids) + keys = await self.get_server_verify_key_v2_direct(server_name, key_ids) results[server_name] = keys except KeyLookupError as e: logger.warning( @@ -767,12 +757,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher): except Exception: logger.exception("Error getting keys %s from %s", key_ids, server_name) - return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback( - lambda _: results - ) + return await yieldable_gather_results( + get_key, keys_to_fetch.items() + ).addCallback(lambda _: results) - @defer.inlineCallbacks - def get_server_verify_key_v2_direct(self, server_name, key_ids): + async def get_server_verify_key_v2_direct(self, server_name, key_ids): """ Args: @@ -794,25 +783,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher): time_now_ms = self.clock.time_msec() try: - response = yield defer.ensureDeferred( - self.client.get_json( - destination=server_name, - path="/_matrix/key/v2/server/" - + urllib.parse.quote(requested_key_id), - ignore_backoff=True, - # we only give the remote server 10s to respond. It should be an - # easy request to handle, so if it doesn't reply within 10s, it's - # probably not going to. - # - # Furthermore, when we are acting as a notary server, we cannot - # wait all day for all of the origin servers, as the requesting - # server will otherwise time out before we can respond. - # - # (Note that get_json may make 4 attempts, so this can still take - # almost 45 seconds to fetch the headers, plus up to another 60s to - # read the response). - timeout=10000, - ) + response = await self.client.get_json( + destination=server_name, + path="/_matrix/key/v2/server/" + + urllib.parse.quote(requested_key_id), + ignore_backoff=True, + # we only give the remote server 10s to respond. It should be an + # easy request to handle, so if it doesn't reply within 10s, it's + # probably not going to. + # + # Furthermore, when we are acting as a notary server, we cannot + # wait all day for all of the origin servers, as the requesting + # server will otherwise time out before we can respond. + # + # (Note that get_json may make 4 attempts, so this can still take + # almost 45 seconds to fetch the headers, plus up to another 60s to + # read the response). + timeout=10000, ) except (NotRetryingDestination, RequestSendFailed) as e: # these both have str() representations which we can't really improve @@ -827,12 +814,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher): % (server_name, response["server_name"]) ) - response_keys = yield self.process_v2_response( + response_keys = await self.process_v2_response( from_server=server_name, response_json=response, time_added_ms=time_now_ms, ) - yield self.store.store_server_verify_keys( + await self.store.store_server_verify_keys( server_name, time_now_ms, ((server_name, key_id, key) for key_id, key in response_keys.items()), @@ -842,22 +829,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher): return keys -@defer.inlineCallbacks -def _handle_key_deferred(verify_request): +async def _handle_key_deferred(verify_request) -> None: """Waits for the key to become available, and then performs a verification Args: verify_request (VerifyJsonRequest): - Returns: - Deferred[None] - Raises: SynapseError if there was a problem performing the verification """ server_name = verify_request.server_name with PreserveLoggingContext(): - _, key_id, verify_key = yield verify_request.key_ready + _, key_id, verify_key = await verify_request.key_ready json_object = verify_request.json_object diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py index e0ad8e8a77..0d4b05304b 100644 --- a/tests/crypto/test_keyring.py +++ b/tests/crypto/test_keyring.py @@ -40,6 +40,7 @@ from synapse.logging.context import ( from synapse.storage.keys import FetchKeyResult from tests import unittest +from tests.test_utils import make_awaitable class MockPerspectiveServer(object): @@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase): with a null `ts_valid_until_ms` """ mock_fetcher = keyring.KeyFetcher() - mock_fetcher.get_keys = Mock(return_value=defer.succeed({})) + mock_fetcher.get_keys = Mock(return_value=make_awaitable({})) kr = keyring.Keyring( self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher) @@ -244,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase): """Two requests for the same key should be deduped.""" key1 = signedjson.key.generate_signing_key(1) - def get_keys(keys_to_fetch): + async def get_keys(keys_to_fetch): # there should only be one request object (with the max validity) self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } + return { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) } - ) + } mock_fetcher = keyring.KeyFetcher() mock_fetcher.get_keys = Mock(side_effect=get_keys) @@ -281,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase): """If the first fetcher cannot provide a recent enough key, we fall back""" key1 = signedjson.key.generate_signing_key(1) - def get_keys1(keys_to_fetch): + async def get_keys1(keys_to_fetch): self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800) - } - } - ) + return { + "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)} + } - def get_keys2(keys_to_fetch): + async def get_keys2(keys_to_fetch): self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}}) - return defer.succeed( - { - "server1": { - get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) - } + return { + "server1": { + get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200) } - ) + } mock_fetcher1 = keyring.KeyFetcher() mock_fetcher1.get_keys = Mock(side_effect=get_keys1) -- cgit 1.5.1