summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8003.misc1
-rw-r--r--synapse/crypto/keyring.py201
-rw-r--r--tests/crypto/test_keyring.py39
3 files changed, 109 insertions, 132 deletions
diff --git a/changelog.d/8003.misc b/changelog.d/8003.misc
new file mode 100644
index 0000000000..dfe4c03171
--- /dev/null
+++ b/changelog.d/8003.misc
@@ -0,0 +1 @@
+Convert various parts of the codebase to async/await.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 443cde0b6d..28ef7cfdb9 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -223,8 +223,7 @@ class Keyring(object):
 
         return results
 
-    @defer.inlineCallbacks
-    def _start_key_lookups(self, verify_requests):
+    async def _start_key_lookups(self, verify_requests):
         """Sets off the key fetches for each verify request
 
         Once each fetch completes, verify_request.key_ready will be resolved.
@@ -245,7 +244,7 @@ class Keyring(object):
                 server_to_request_ids.setdefault(server_name, set()).add(request_id)
 
             # Wait for any previous lookups to complete before proceeding.
-            yield self.wait_for_previous_lookups(server_to_request_ids.keys())
+            await self.wait_for_previous_lookups(server_to_request_ids.keys())
 
             # take out a lock on each of the servers by sticking a Deferred in
             # key_downloads
@@ -283,15 +282,14 @@ class Keyring(object):
         except Exception:
             logger.exception("Error starting key lookups")
 
-    @defer.inlineCallbacks
-    def wait_for_previous_lookups(self, server_names):
+    async def wait_for_previous_lookups(self, server_names) -> None:
         """Waits for any previous key lookups for the given servers to finish.
 
         Args:
             server_names (Iterable[str]): list of servers which we want to look up
 
         Returns:
-            Deferred[None]: resolves once all key lookups for the given servers have
+            Resolves once all key lookups for the given servers have
                 completed. Follows the synapse rules of logcontext preservation.
         """
         loop_count = 1
@@ -309,7 +307,7 @@ class Keyring(object):
                 loop_count,
             )
             with PreserveLoggingContext():
-                yield defer.DeferredList((w[1] for w in wait_on))
+                await defer.DeferredList((w[1] for w in wait_on))
 
             loop_count += 1
 
@@ -326,44 +324,44 @@ class Keyring(object):
 
         remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
 
-        @defer.inlineCallbacks
-        def do_iterations():
-            with Measure(self.clock, "get_server_verify_keys"):
-                for f in self._key_fetchers:
-                    if not remaining_requests:
-                        return
-                    yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
+        async def do_iterations():
+            try:
+                with Measure(self.clock, "get_server_verify_keys"):
+                    for f in self._key_fetchers:
+                        if not remaining_requests:
+                            return
+                        await self._attempt_key_fetches_with_fetcher(
+                            f, remaining_requests
+                        )
 
-                # look for any requests which weren't satisfied
+                    # look for any requests which weren't satisfied
+                    with PreserveLoggingContext():
+                        for verify_request in remaining_requests:
+                            verify_request.key_ready.errback(
+                                SynapseError(
+                                    401,
+                                    "No key for %s with ids in %s (min_validity %i)"
+                                    % (
+                                        verify_request.server_name,
+                                        verify_request.key_ids,
+                                        verify_request.minimum_valid_until_ts,
+                                    ),
+                                    Codes.UNAUTHORIZED,
+                                )
+                            )
+            except Exception as err:
+                # we don't really expect to get here, because any errors should already
+                # have been caught and logged. But if we do, let's log the error and make
+                # sure that all of the deferreds are resolved.
+                logger.error("Unexpected error in _get_server_verify_keys: %s", err)
                 with PreserveLoggingContext():
                     for verify_request in remaining_requests:
-                        verify_request.key_ready.errback(
-                            SynapseError(
-                                401,
-                                "No key for %s with ids in %s (min_validity %i)"
-                                % (
-                                    verify_request.server_name,
-                                    verify_request.key_ids,
-                                    verify_request.minimum_valid_until_ts,
-                                ),
-                                Codes.UNAUTHORIZED,
-                            )
-                        )
-
-        def on_err(err):
-            # we don't really expect to get here, because any errors should already
-            # have been caught and logged. But if we do, let's log the error and make
-            # sure that all of the deferreds are resolved.
-            logger.error("Unexpected error in _get_server_verify_keys: %s", err)
-            with PreserveLoggingContext():
-                for verify_request in remaining_requests:
-                    if not verify_request.key_ready.called:
-                        verify_request.key_ready.errback(err)
+                        if not verify_request.key_ready.called:
+                            verify_request.key_ready.errback(err)
 
-        run_in_background(do_iterations).addErrback(on_err)
+        run_in_background(do_iterations)
 
-    @defer.inlineCallbacks
-    def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+    async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
         """Use a key fetcher to attempt to satisfy some key requests
 
         Args:
@@ -390,7 +388,7 @@ class Keyring(object):
                     verify_request.minimum_valid_until_ts,
                 )
 
-        results = yield fetcher.get_keys(missing_keys)
+        results = await fetcher.get_keys(missing_keys)
 
         completed = []
         for verify_request in remaining_requests:
@@ -423,7 +421,7 @@ class Keyring(object):
 
 
 class KeyFetcher(object):
-    def get_keys(self, keys_to_fetch):
+    async def get_keys(self, keys_to_fetch):
         """
         Args:
             keys_to_fetch (dict[str, dict[str, int]]):
@@ -442,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher):
     def __init__(self, hs):
         self.store = hs.get_datastore()
 
-    @defer.inlineCallbacks
-    def get_keys(self, keys_to_fetch):
+    async def get_keys(self, keys_to_fetch):
         """see KeyFetcher.get_keys"""
 
         keys_to_fetch = (
@@ -452,7 +449,7 @@ class StoreKeyFetcher(KeyFetcher):
             for key_id in keys_for_server.keys()
         )
 
-        res = yield self.store.get_server_verify_keys(keys_to_fetch)
+        res = await self.store.get_server_verify_keys(keys_to_fetch)
         keys = {}
         for (server_name, key_id), key in res.items():
             keys.setdefault(server_name, {})[key_id] = key
@@ -464,8 +461,7 @@ class BaseV2KeyFetcher(object):
         self.store = hs.get_datastore()
         self.config = hs.get_config()
 
-    @defer.inlineCallbacks
-    def process_v2_response(self, from_server, response_json, time_added_ms):
+    async def process_v2_response(self, from_server, response_json, time_added_ms):
         """Parse a 'Server Keys' structure from the result of a /key request
 
         This is used to parse either the entirety of the response from
@@ -537,7 +533,7 @@ class BaseV2KeyFetcher(object):
 
         key_json_bytes = encode_canonical_json(response_json)
 
-        yield make_deferred_yieldable(
+        await make_deferred_yieldable(
             defer.gatherResults(
                 [
                     run_in_background(
@@ -567,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         self.client = hs.get_http_client()
         self.key_servers = self.config.key_servers
 
-    @defer.inlineCallbacks
-    def get_keys(self, keys_to_fetch):
+    async def get_keys(self, keys_to_fetch):
         """see KeyFetcher.get_keys"""
 
-        @defer.inlineCallbacks
-        def get_key(key_server):
+        async def get_key(key_server):
             try:
-                result = yield self.get_server_verify_key_v2_indirect(
+                result = await self.get_server_verify_key_v2_indirect(
                     keys_to_fetch, key_server
                 )
                 return result
@@ -592,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 
             return {}
 
-        results = yield make_deferred_yieldable(
+        results = await make_deferred_yieldable(
             defer.gatherResults(
                 [run_in_background(get_key, server) for server in self.key_servers],
                 consumeErrors=True,
@@ -606,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
 
         return union_of_keys
 
-    @defer.inlineCallbacks
-    def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
+    async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
         """
         Args:
             keys_to_fetch (dict[str, dict[str, int]]):
@@ -617,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 the keys
 
         Returns:
-            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
+            dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
                 from server_name -> key_id -> FetchKeyResult
 
         Raises:
@@ -632,20 +625,18 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         )
 
         try:
-            query_response = yield defer.ensureDeferred(
-                self.client.post_json(
-                    destination=perspective_name,
-                    path="/_matrix/key/v2/query",
-                    data={
-                        "server_keys": {
-                            server_name: {
-                                key_id: {"minimum_valid_until_ts": min_valid_ts}
-                                for key_id, min_valid_ts in server_keys.items()
-                            }
-                            for server_name, server_keys in keys_to_fetch.items()
+            query_response = await self.client.post_json(
+                destination=perspective_name,
+                path="/_matrix/key/v2/query",
+                data={
+                    "server_keys": {
+                        server_name: {
+                            key_id: {"minimum_valid_until_ts": min_valid_ts}
+                            for key_id, min_valid_ts in server_keys.items()
                         }
-                    },
-                )
+                        for server_name, server_keys in keys_to_fetch.items()
+                    }
+                },
             )
         except (NotRetryingDestination, RequestSendFailed) as e:
             # these both have str() representations which we can't really improve upon
@@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             try:
                 self._validate_perspectives_response(key_server, response)
 
-                processed_response = yield self.process_v2_response(
+                processed_response = await self.process_v2_response(
                     perspective_name, response, time_added_ms=time_now_ms
                 )
             except KeyLookupError as e:
@@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
             )
             keys.setdefault(server_name, {}).update(processed_response)
 
-        yield self.store.store_server_verify_keys(
+        await self.store.store_server_verify_keys(
             perspective_name, time_now_ms, added_keys
         )
 
@@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         self.clock = hs.get_clock()
         self.client = hs.get_http_client()
 
-    def get_keys(self, keys_to_fetch):
+    async def get_keys(self, keys_to_fetch):
         """
         Args:
             keys_to_fetch (dict[str, iterable[str]]):
                 the keys to be fetched. server_name -> key_ids
 
         Returns:
-            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+            dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
                 map from server_name -> key_id -> FetchKeyResult
         """
 
         results = {}
 
-        @defer.inlineCallbacks
-        def get_key(key_to_fetch_item):
+        async def get_key(key_to_fetch_item):
             server_name, key_ids = key_to_fetch_item
             try:
-                keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+                keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
                 results[server_name] = keys
             except KeyLookupError as e:
                 logger.warning(
@@ -767,12 +757,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             except Exception:
                 logger.exception("Error getting keys %s from %s", key_ids, server_name)
 
-        return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
-            lambda _: results
-        )
+        return await yieldable_gather_results(
+            get_key, keys_to_fetch.items()
+        ).addCallback(lambda _: results)
 
-    @defer.inlineCallbacks
-    def get_server_verify_key_v2_direct(self, server_name, key_ids):
+    async def get_server_verify_key_v2_direct(self, server_name, key_ids):
         """
 
         Args:
@@ -794,25 +783,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
 
             time_now_ms = self.clock.time_msec()
             try:
-                response = yield defer.ensureDeferred(
-                    self.client.get_json(
-                        destination=server_name,
-                        path="/_matrix/key/v2/server/"
-                        + urllib.parse.quote(requested_key_id),
-                        ignore_backoff=True,
-                        # we only give the remote server 10s to respond. It should be an
-                        # easy request to handle, so if it doesn't reply within 10s, it's
-                        # probably not going to.
-                        #
-                        # Furthermore, when we are acting as a notary server, we cannot
-                        # wait all day for all of the origin servers, as the requesting
-                        # server will otherwise time out before we can respond.
-                        #
-                        # (Note that get_json may make 4 attempts, so this can still take
-                        # almost 45 seconds to fetch the headers, plus up to another 60s to
-                        # read the response).
-                        timeout=10000,
-                    )
+                response = await self.client.get_json(
+                    destination=server_name,
+                    path="/_matrix/key/v2/server/"
+                    + urllib.parse.quote(requested_key_id),
+                    ignore_backoff=True,
+                    # we only give the remote server 10s to respond. It should be an
+                    # easy request to handle, so if it doesn't reply within 10s, it's
+                    # probably not going to.
+                    #
+                    # Furthermore, when we are acting as a notary server, we cannot
+                    # wait all day for all of the origin servers, as the requesting
+                    # server will otherwise time out before we can respond.
+                    #
+                    # (Note that get_json may make 4 attempts, so this can still take
+                    # almost 45 seconds to fetch the headers, plus up to another 60s to
+                    # read the response).
+                    timeout=10000,
                 )
             except (NotRetryingDestination, RequestSendFailed) as e:
                 # these both have str() representations which we can't really improve
@@ -827,12 +814,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
                     % (server_name, response["server_name"])
                 )
 
-            response_keys = yield self.process_v2_response(
+            response_keys = await self.process_v2_response(
                 from_server=server_name,
                 response_json=response,
                 time_added_ms=time_now_ms,
             )
-            yield self.store.store_server_verify_keys(
+            await self.store.store_server_verify_keys(
                 server_name,
                 time_now_ms,
                 ((server_name, key_id, key) for key_id, key in response_keys.items()),
@@ -842,22 +829,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         return keys
 
 
-@defer.inlineCallbacks
-def _handle_key_deferred(verify_request):
+async def _handle_key_deferred(verify_request) -> None:
     """Waits for the key to become available, and then performs a verification
 
     Args:
         verify_request (VerifyJsonRequest):
 
-    Returns:
-        Deferred[None]
-
     Raises:
         SynapseError if there was a problem performing the verification
     """
     server_name = verify_request.server_name
     with PreserveLoggingContext():
-        _, key_id, verify_key = yield verify_request.key_ready
+        _, key_id, verify_key = await verify_request.key_ready
 
     json_object = verify_request.json_object
 
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index e0ad8e8a77..0d4b05304b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -40,6 +40,7 @@ from synapse.logging.context import (
 from synapse.storage.keys import FetchKeyResult
 
 from tests import unittest
+from tests.test_utils import make_awaitable
 
 
 class MockPerspectiveServer(object):
@@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         with a null `ts_valid_until_ms`
         """
         mock_fetcher = keyring.KeyFetcher()
-        mock_fetcher.get_keys = Mock(return_value=defer.succeed({}))
+        mock_fetcher.get_keys = Mock(return_value=make_awaitable({}))
 
         kr = keyring.Keyring(
             self.hs, key_fetchers=(StoreKeyFetcher(self.hs), mock_fetcher)
@@ -244,17 +245,15 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """Two requests for the same key should be deduped."""
         key1 = signedjson.key.generate_signing_key(1)
 
-        def get_keys(keys_to_fetch):
+        async def get_keys(keys_to_fetch):
             # there should only be one request object (with the max validity)
             self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
 
-            return defer.succeed(
-                {
-                    "server1": {
-                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                    }
+            return {
+                "server1": {
+                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
                 }
-            )
+            }
 
         mock_fetcher = keyring.KeyFetcher()
         mock_fetcher.get_keys = Mock(side_effect=get_keys)
@@ -281,25 +280,19 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         """If the first fetcher cannot provide a recent enough key, we fall back"""
         key1 = signedjson.key.generate_signing_key(1)
 
-        def get_keys1(keys_to_fetch):
+        async def get_keys1(keys_to_fetch):
             self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return defer.succeed(
-                {
-                    "server1": {
-                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)
-                    }
-                }
-            )
+            return {
+                "server1": {get_key_id(key1): FetchKeyResult(get_verify_key(key1), 800)}
+            }
 
-        def get_keys2(keys_to_fetch):
+        async def get_keys2(keys_to_fetch):
             self.assertEqual(keys_to_fetch, {"server1": {get_key_id(key1): 1500}})
-            return defer.succeed(
-                {
-                    "server1": {
-                        get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
-                    }
+            return {
+                "server1": {
+                    get_key_id(key1): FetchKeyResult(get_verify_key(key1), 1200)
                 }
-            )
+            }
 
         mock_fetcher1 = keyring.KeyFetcher()
         mock_fetcher1.get_keys = Mock(side_effect=get_keys1)