diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index a9f4025bfe..32c31b1cd1 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -15,11 +15,9 @@
# limitations under the License.
import logging
+import urllib
from collections import defaultdict
-import six
-from six.moves import urllib
-
import attr
from signedjson.key import (
decode_verify_key_bytes,
@@ -59,7 +57,7 @@ logger = logging.getLogger(__name__)
@attr.s(slots=True, cmp=False)
-class VerifyJsonRequest(object):
+class VerifyJsonRequest:
"""
A request to verify a JSON object.
@@ -98,7 +96,7 @@ class KeyLookupError(ValueError):
pass
-class Keyring(object):
+class Keyring:
def __init__(self, hs, key_fetchers=None):
self.clock = hs.get_clock()
@@ -225,8 +223,7 @@ class Keyring(object):
return results
- @defer.inlineCallbacks
- def _start_key_lookups(self, verify_requests):
+ async def _start_key_lookups(self, verify_requests):
"""Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved.
@@ -247,7 +244,7 @@ class Keyring(object):
server_to_request_ids.setdefault(server_name, set()).add(request_id)
# Wait for any previous lookups to complete before proceeding.
- yield self.wait_for_previous_lookups(server_to_request_ids.keys())
+ await self.wait_for_previous_lookups(server_to_request_ids.keys())
# take out a lock on each of the servers by sticking a Deferred in
# key_downloads
@@ -285,15 +282,14 @@ class Keyring(object):
except Exception:
logger.exception("Error starting key lookups")
- @defer.inlineCallbacks
- def wait_for_previous_lookups(self, server_names):
+ async def wait_for_previous_lookups(self, server_names) -> None:
"""Waits for any previous key lookups for the given servers to finish.
Args:
server_names (Iterable[str]): list of servers which we want to look up
Returns:
- Deferred[None]: resolves once all key lookups for the given servers have
+ Resolves once all key lookups for the given servers have
completed. Follows the synapse rules of logcontext preservation.
"""
loop_count = 1
@@ -311,7 +307,7 @@ class Keyring(object):
loop_count,
)
with PreserveLoggingContext():
- yield defer.DeferredList((w[1] for w in wait_on))
+ await defer.DeferredList((w[1] for w in wait_on))
loop_count += 1
@@ -328,44 +324,44 @@ class Keyring(object):
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
- @defer.inlineCallbacks
- def do_iterations():
- with Measure(self.clock, "get_server_verify_keys"):
- for f in self._key_fetchers:
- if not remaining_requests:
- return
- yield self._attempt_key_fetches_with_fetcher(f, remaining_requests)
+ async def do_iterations():
+ try:
+ with Measure(self.clock, "get_server_verify_keys"):
+ for f in self._key_fetchers:
+ if not remaining_requests:
+ return
+ await self._attempt_key_fetches_with_fetcher(
+ f, remaining_requests
+ )
- # look for any requests which weren't satisfied
+ # look for any requests which weren't satisfied
+ with PreserveLoggingContext():
+ for verify_request in remaining_requests:
+ verify_request.key_ready.errback(
+ SynapseError(
+ 401,
+ "No key for %s with ids in %s (min_validity %i)"
+ % (
+ verify_request.server_name,
+ verify_request.key_ids,
+ verify_request.minimum_valid_until_ts,
+ ),
+ Codes.UNAUTHORIZED,
+ )
+ )
+ except Exception as err:
+ # we don't really expect to get here, because any errors should already
+ # have been caught and logged. But if we do, let's log the error and make
+ # sure that all of the deferreds are resolved.
+ logger.error("Unexpected error in _get_server_verify_keys: %s", err)
with PreserveLoggingContext():
for verify_request in remaining_requests:
- verify_request.key_ready.errback(
- SynapseError(
- 401,
- "No key for %s with ids in %s (min_validity %i)"
- % (
- verify_request.server_name,
- verify_request.key_ids,
- verify_request.minimum_valid_until_ts,
- ),
- Codes.UNAUTHORIZED,
- )
- )
+ if not verify_request.key_ready.called:
+ verify_request.key_ready.errback(err)
- def on_err(err):
- # we don't really expect to get here, because any errors should already
- # have been caught and logged. But if we do, let's log the error and make
- # sure that all of the deferreds are resolved.
- logger.error("Unexpected error in _get_server_verify_keys: %s", err)
- with PreserveLoggingContext():
- for verify_request in remaining_requests:
- if not verify_request.key_ready.called:
- verify_request.key_ready.errback(err)
-
- run_in_background(do_iterations).addErrback(on_err)
+ run_in_background(do_iterations)
- @defer.inlineCallbacks
- def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+ async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
@@ -392,7 +388,7 @@ class Keyring(object):
verify_request.minimum_valid_until_ts,
)
- results = yield fetcher.get_keys(missing_keys)
+ results = await fetcher.get_keys(missing_keys)
completed = []
for verify_request in remaining_requests:
@@ -424,8 +420,8 @@ class Keyring(object):
remaining_requests.difference_update(completed)
-class KeyFetcher(object):
- def get_keys(self, keys_to_fetch):
+class KeyFetcher:
+ async def get_keys(self, keys_to_fetch):
"""
Args:
keys_to_fetch (dict[str, dict[str, int]]):
@@ -444,8 +440,7 @@ class StoreKeyFetcher(KeyFetcher):
def __init__(self, hs):
self.store = hs.get_datastore()
- @defer.inlineCallbacks
- def get_keys(self, keys_to_fetch):
+ async def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
keys_to_fetch = (
@@ -454,20 +449,19 @@ class StoreKeyFetcher(KeyFetcher):
for key_id in keys_for_server.keys()
)
- res = yield self.store.get_server_verify_keys(keys_to_fetch)
+ res = await self.store.get_server_verify_keys(keys_to_fetch)
keys = {}
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
return keys
-class BaseV2KeyFetcher(object):
+class BaseV2KeyFetcher:
def __init__(self, hs):
self.store = hs.get_datastore()
self.config = hs.get_config()
- @defer.inlineCallbacks
- def process_v2_response(self, from_server, response_json, time_added_ms):
+ async def process_v2_response(self, from_server, response_json, time_added_ms):
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
@@ -539,7 +533,7 @@ class BaseV2KeyFetcher(object):
key_json_bytes = encode_canonical_json(response_json)
- yield make_deferred_yieldable(
+ await make_deferred_yieldable(
defer.gatherResults(
[
run_in_background(
@@ -569,14 +563,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
self.client = hs.get_http_client()
self.key_servers = self.config.key_servers
- @defer.inlineCallbacks
- def get_keys(self, keys_to_fetch):
+ async def get_keys(self, keys_to_fetch):
"""see KeyFetcher.get_keys"""
- @defer.inlineCallbacks
- def get_key(key_server):
+ async def get_key(key_server):
try:
- result = yield self.get_server_verify_key_v2_indirect(
+ result = await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
return result
@@ -594,7 +586,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return {}
- results = yield make_deferred_yieldable(
+ results = await make_deferred_yieldable(
defer.gatherResults(
[run_in_background(get_key, server) for server in self.key_servers],
consumeErrors=True,
@@ -608,8 +600,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return union_of_keys
- @defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
+ async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
"""
Args:
keys_to_fetch (dict[str, dict[str, int]]):
@@ -619,7 +610,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
the keys
Returns:
- Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
+ dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
from server_name -> key_id -> FetchKeyResult
Raises:
@@ -634,7 +625,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
try:
- query_response = yield self.client.post_json(
+ query_response = await self.client.post_json(
destination=perspective_name,
path="/_matrix/key/v2/query",
data={
@@ -661,7 +652,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
for response in query_response["server_keys"]:
# do this first, so that we can give useful errors thereafter
server_name = response.get("server_name")
- if not isinstance(server_name, six.string_types):
+ if not isinstance(server_name, str):
raise KeyLookupError(
"Malformed response from key notary server %s: invalid server_name"
% (perspective_name,)
@@ -670,7 +661,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
try:
self._validate_perspectives_response(key_server, response)
- processed_response = yield self.process_v2_response(
+ processed_response = await self.process_v2_response(
perspective_name, response, time_added_ms=time_now_ms
)
except KeyLookupError as e:
@@ -689,7 +680,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
)
keys.setdefault(server_name, {}).update(processed_response)
- yield self.store.store_server_verify_keys(
+ await self.store.store_server_verify_keys(
perspective_name, time_now_ms, added_keys
)
@@ -741,24 +732,23 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
self.clock = hs.get_clock()
self.client = hs.get_http_client()
- def get_keys(self, keys_to_fetch):
+ async def get_keys(self, keys_to_fetch):
"""
Args:
keys_to_fetch (dict[str, iterable[str]]):
the keys to be fetched. server_name -> key_ids
Returns:
- Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+ dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
map from server_name -> key_id -> FetchKeyResult
"""
results = {}
- @defer.inlineCallbacks
- def get_key(key_to_fetch_item):
+ async def get_key(key_to_fetch_item):
server_name, key_ids = key_to_fetch_item
try:
- keys = yield self.get_server_verify_key_v2_direct(server_name, key_ids)
+ keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
results[server_name] = keys
except KeyLookupError as e:
logger.warning(
@@ -767,12 +757,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except Exception:
logger.exception("Error getting keys %s from %s", key_ids, server_name)
- return yieldable_gather_results(get_key, keys_to_fetch.items()).addCallback(
- lambda _: results
- )
+ await yieldable_gather_results(get_key, keys_to_fetch.items())
+ return results
- @defer.inlineCallbacks
- def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ async def get_server_verify_key_v2_direct(self, server_name, key_ids):
"""
Args:
@@ -780,7 +768,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
key_ids (iterable[str]):
Returns:
- Deferred[dict[str, FetchKeyResult]]: map from key ID to lookup result
+ dict[str, FetchKeyResult]: map from key ID to lookup result
Raises:
KeyLookupError if there was a problem making the lookup
@@ -794,7 +782,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
time_now_ms = self.clock.time_msec()
try:
- response = yield self.client.get_json(
+ response = await self.client.get_json(
destination=server_name,
path="/_matrix/key/v2/server/"
+ urllib.parse.quote(requested_key_id),
@@ -825,12 +813,12 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
% (server_name, response["server_name"])
)
- response_keys = yield self.process_v2_response(
+ response_keys = await self.process_v2_response(
from_server=server_name,
response_json=response,
time_added_ms=time_now_ms,
)
- yield self.store.store_server_verify_keys(
+ await self.store.store_server_verify_keys(
server_name,
time_now_ms,
((server_name, key_id, key) for key_id, key in response_keys.items()),
@@ -840,22 +828,18 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys
-@defer.inlineCallbacks
-def _handle_key_deferred(verify_request):
+async def _handle_key_deferred(verify_request) -> None:
"""Waits for the key to become available, and then performs a verification
Args:
verify_request (VerifyJsonRequest):
- Returns:
- Deferred[None]
-
Raises:
SynapseError if there was a problem performing the verification
"""
server_name = verify_request.server_name
with PreserveLoggingContext():
- _, key_id, verify_key = yield verify_request.key_ready
+ _, key_id, verify_key = await verify_request.key_ready
json_object = verify_request.json_object
|