diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index f23eacc0d7..902128a23c 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -14,9 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
import logging
import urllib
from collections import defaultdict
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
import attr
from signedjson.key import (
@@ -40,6 +42,7 @@ from synapse.api.errors import (
RequestSendFailed,
SynapseError,
)
+from synapse.config.key import TrustedKeyServer
from synapse.logging.context import (
PreserveLoggingContext,
make_deferred_yieldable,
@@ -47,11 +50,15 @@ from synapse.logging.context import (
run_in_background,
)
from synapse.storage.keys import FetchKeyResult
+from synapse.types import JsonDict
from synapse.util import unwrapFirstError
from synapse.util.async_helpers import yieldable_gather_results
from synapse.util.metrics import Measure
from synapse.util.retryutils import NotRetryingDestination
+if TYPE_CHECKING:
+ from synapse.app.homeserver import HomeServer
+
logger = logging.getLogger(__name__)
@@ -61,16 +68,17 @@ class VerifyJsonRequest:
A request to verify a JSON object.
Attributes:
- server_name(str): The name of the server to verify against.
-
- key_ids(set[str]): The set of key_ids to that could be used to verify the
- JSON object
+ server_name: The name of the server to verify against.
- json_object(dict): The JSON object to verify.
+ json_object: The JSON object to verify.
- minimum_valid_until_ts (int): time at which we require the signing key to
+ minimum_valid_until_ts: time at which we require the signing key to
be valid. (0 implies we don't care)
+ request_name: The name of the request.
+
+ key_ids: The set of key_ids to that could be used to verify the JSON object
+
key_ready (Deferred[str, str, nacl.signing.VerifyKey]):
A deferred (server_name, key_id, verify_key) tuple that resolves when
a verify key has been fetched. The deferreds' callbacks are run with no
@@ -80,12 +88,12 @@ class VerifyJsonRequest:
errbacks with an M_UNAUTHORIZED SynapseError.
"""
- server_name = attr.ib()
- json_object = attr.ib()
- minimum_valid_until_ts = attr.ib()
- request_name = attr.ib()
- key_ids = attr.ib(init=False)
- key_ready = attr.ib(default=attr.Factory(defer.Deferred))
+ server_name = attr.ib(type=str)
+ json_object = attr.ib(type=JsonDict)
+ minimum_valid_until_ts = attr.ib(type=int)
+ request_name = attr.ib(type=str)
+ key_ids = attr.ib(init=False, type=List[str])
+ key_ready = attr.ib(default=attr.Factory(defer.Deferred), type=defer.Deferred)
def __attrs_post_init__(self):
self.key_ids = signature_ids(self.json_object, self.server_name)
@@ -96,7 +104,9 @@ class KeyLookupError(ValueError):
class Keyring:
- def __init__(self, hs, key_fetchers=None):
+ def __init__(
+ self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
+ ):
self.clock = hs.get_clock()
if key_fetchers is None:
@@ -112,22 +122,26 @@ class Keyring:
# completes.
#
# These are regular, logcontext-agnostic Deferreds.
- self.key_downloads = {}
+ self.key_downloads = {} # type: Dict[str, defer.Deferred]
def verify_json_for_server(
- self, server_name, json_object, validity_time, request_name
- ):
+ self,
+ server_name: str,
+ json_object: JsonDict,
+ validity_time: int,
+ request_name: str,
+ ) -> defer.Deferred:
"""Verify that a JSON object has been signed by a given server
Args:
- server_name (str): name of the server which must have signed this object
+ server_name: name of the server which must have signed this object
- json_object (dict): object to be checked
+ json_object: object to be checked
- validity_time (int): timestamp at which we require the signing key to
+ validity_time: timestamp at which we require the signing key to
be valid. (0 implies we don't care)
- request_name (str): an identifier for this json object (eg, an event id)
+ request_name: an identifier for this json object (eg, an event id)
for logging.
Returns:
@@ -138,12 +152,14 @@ class Keyring:
requests = (req,)
return make_deferred_yieldable(self._verify_objects(requests)[0])
- def verify_json_objects_for_server(self, server_and_json):
+ def verify_json_objects_for_server(
+ self, server_and_json: Iterable[Tuple[str, dict, int, str]]
+ ) -> List[defer.Deferred]:
"""Bulk verifies signatures of json objects, bulk fetching keys as
necessary.
Args:
- server_and_json (iterable[Tuple[str, dict, int, str]):
+ server_and_json:
Iterable of (server_name, json_object, validity_time, request_name)
tuples.
@@ -164,13 +180,14 @@ class Keyring:
for server_name, json_object, validity_time, request_name in server_and_json
)
- def _verify_objects(self, verify_requests):
+ def _verify_objects(
+ self, verify_requests: Iterable[VerifyJsonRequest]
+ ) -> List[defer.Deferred]:
"""Does the work of verify_json_[objects_]for_server
Args:
- verify_requests (iterable[VerifyJsonRequest]):
- Iterable of verification requests.
+ verify_requests: Iterable of verification requests.
Returns:
List<Deferred[None]>: for each input item, a deferred indicating success
@@ -182,7 +199,7 @@ class Keyring:
key_lookups = []
handle = preserve_fn(_handle_key_deferred)
- def process(verify_request):
+ def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
"""Process an entry in the request list
Adds a key request to key_lookups, and returns a deferred which
@@ -222,18 +239,20 @@ class Keyring:
return results
- async def _start_key_lookups(self, verify_requests):
+ async def _start_key_lookups(
+ self, verify_requests: List[VerifyJsonRequest]
+ ) -> None:
"""Sets off the key fetches for each verify request
Once each fetch completes, verify_request.key_ready will be resolved.
Args:
- verify_requests (List[VerifyJsonRequest]):
+ verify_requests:
"""
try:
# map from server name to a set of outstanding request ids
- server_to_request_ids = {}
+ server_to_request_ids = {} # type: Dict[str, Set[int]]
for verify_request in verify_requests:
server_name = verify_request.server_name
@@ -275,11 +294,11 @@ class Keyring:
except Exception:
logger.exception("Error starting key lookups")
- async def wait_for_previous_lookups(self, server_names) -> None:
+ async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
"""Waits for any previous key lookups for the given servers to finish.
Args:
- server_names (Iterable[str]): list of servers which we want to look up
+ server_names: list of servers which we want to look up
Returns:
Resolves once all key lookups for the given servers have
@@ -304,7 +323,7 @@ class Keyring:
loop_count += 1
- def _get_server_verify_keys(self, verify_requests):
+ def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
"""Tries to find at least one key for each verify request
For each verify_request, verify_request.key_ready is called back with
@@ -312,7 +331,7 @@ class Keyring:
with a SynapseError if none of the keys are found.
Args:
- verify_requests (list[VerifyJsonRequest]): list of verify requests
+ verify_requests: list of verify requests
"""
remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
@@ -366,17 +385,19 @@ class Keyring:
run_in_background(do_iterations)
- async def _attempt_key_fetches_with_fetcher(self, fetcher, remaining_requests):
+ async def _attempt_key_fetches_with_fetcher(
+ self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
+ ):
"""Use a key fetcher to attempt to satisfy some key requests
Args:
- fetcher (KeyFetcher): fetcher to use to fetch the keys
- remaining_requests (set[VerifyJsonRequest]): outstanding key requests.
+ fetcher: fetcher to use to fetch the keys
+ remaining_requests: outstanding key requests.
Any successfully-completed requests will be removed from the list.
"""
- # dict[str, dict[str, int]]: keys to fetch.
+ # The keys to fetch.
# server_name -> key_id -> min_valid_ts
- missing_keys = defaultdict(dict)
+ missing_keys = defaultdict(dict) # type: Dict[str, Dict[str, int]]
for verify_request in remaining_requests:
# any completed requests should already have been removed
@@ -438,16 +459,18 @@ class Keyring:
remaining_requests.difference_update(completed)
-class KeyFetcher:
- async def get_keys(self, keys_to_fetch):
+class KeyFetcher(metaclass=abc.ABCMeta):
+ @abc.abstractmethod
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, dict[str, int]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
Returns:
- Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
- map from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
"""
raise NotImplementedError
@@ -455,31 +478,35 @@ class KeyFetcher:
class StoreKeyFetcher(KeyFetcher):
"""KeyFetcher impl which fetches keys from our data store"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
- keys_to_fetch = (
+ key_ids_to_fetch = (
(server_name, key_id)
for server_name, keys_for_server in keys_to_fetch.items()
for key_id in keys_for_server.keys()
)
- res = await self.store.get_server_verify_keys(keys_to_fetch)
- keys = {}
+ res = await self.store.get_server_verify_keys(key_ids_to_fetch)
+ keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for (server_name, key_id), key in res.items():
keys.setdefault(server_name, {})[key_id] = key
return keys
-class BaseV2KeyFetcher:
- def __init__(self, hs):
+class BaseV2KeyFetcher(KeyFetcher):
+ def __init__(self, hs: "HomeServer"):
self.store = hs.get_datastore()
self.config = hs.get_config()
- async def process_v2_response(self, from_server, response_json, time_added_ms):
+ async def process_v2_response(
+ self, from_server: str, response_json: JsonDict, time_added_ms: int
+ ) -> Dict[str, FetchKeyResult]:
"""Parse a 'Server Keys' structure from the result of a /key request
This is used to parse either the entirety of the response from
@@ -493,16 +520,16 @@ class BaseV2KeyFetcher:
to /_matrix/key/v2/query.
Args:
- from_server (str): the name of the server producing this result: either
+ from_server: the name of the server producing this result: either
the origin server for a /_matrix/key/v2/server request, or the notary
for a /_matrix/key/v2/query.
- response_json (dict): the json-decoded Server Keys response object
+ response_json: the json-decoded Server Keys response object
- time_added_ms (int): the timestamp to record in server_keys_json
+ time_added_ms: the timestamp to record in server_keys_json
Returns:
- Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
+ Map from key_id to result object
"""
ts_valid_until_ms = response_json["valid_until_ts"]
@@ -575,21 +602,22 @@ class BaseV2KeyFetcher:
class PerspectivesKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the "perspectives" servers"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
self.key_servers = self.config.key_servers
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""see KeyFetcher.get_keys"""
- async def get_key(key_server):
+ async def get_key(key_server: TrustedKeyServer) -> Dict:
try:
- result = await self.get_server_verify_key_v2_indirect(
+ return await self.get_server_verify_key_v2_indirect(
keys_to_fetch, key_server
)
- return result
except KeyLookupError as e:
logger.warning(
"Key lookup failed from %r: %s", key_server.server_name, e
@@ -611,25 +639,25 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
).addErrback(unwrapFirstError)
)
- union_of_keys = {}
+ union_of_keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
for result in results:
for server_name, keys in result.items():
union_of_keys.setdefault(server_name, {}).update(keys)
return union_of_keys
- async def get_server_verify_key_v2_indirect(self, keys_to_fetch, key_server):
+ async def get_server_verify_key_v2_indirect(
+ self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, dict[str, int]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_id -> min_valid_ts
- key_server (synapse.config.key.TrustedKeyServer): notary server to query for
- the keys
+ key_server: notary server to query for the keys
Returns:
- dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]: map
- from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
Raises:
KeyLookupError if there was an error processing the entire response from
@@ -662,11 +690,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
- keys = {}
- added_keys = []
+ keys = {} # type: Dict[str, Dict[str, FetchKeyResult]]
+ added_keys = [] # type: List[Tuple[str, str, FetchKeyResult]]
time_now_ms = self.clock.time_msec()
+ assert isinstance(query_response, dict)
for response in query_response["server_keys"]:
# do this first, so that we can give useful errors thereafter
server_name = response.get("server_name")
@@ -704,14 +733,15 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
return keys
- def _validate_perspectives_response(self, key_server, response):
+ def _validate_perspectives_response(
+ self, key_server: TrustedKeyServer, response: JsonDict
+ ) -> None:
"""Optionally check the signature on the result of a /key/query request
Args:
- key_server (synapse.config.key.TrustedKeyServer): the notary server that
- produced this result
+ key_server: the notary server that produced this result
- response (dict): the json-decoded Server Keys response object
+ response: the json-decoded Server Keys response object
"""
perspective_name = key_server.server_name
perspective_keys = key_server.verify_keys
@@ -745,25 +775,26 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
class ServerKeyFetcher(BaseV2KeyFetcher):
"""KeyFetcher impl which fetches keys from the origin servers"""
- def __init__(self, hs):
+ def __init__(self, hs: "HomeServer"):
super().__init__(hs)
self.clock = hs.get_clock()
self.client = hs.get_federation_http_client()
- async def get_keys(self, keys_to_fetch):
+ async def get_keys(
+ self, keys_to_fetch: Dict[str, Dict[str, int]]
+ ) -> Dict[str, Dict[str, FetchKeyResult]]:
"""
Args:
- keys_to_fetch (dict[str, iterable[str]]):
+ keys_to_fetch:
the keys to be fetched. server_name -> key_ids
Returns:
- dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]:
- map from server_name -> key_id -> FetchKeyResult
+ Map from server_name -> key_id -> FetchKeyResult
"""
results = {}
- async def get_key(key_to_fetch_item):
+ async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
server_name, key_ids = key_to_fetch_item
try:
keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
@@ -778,20 +809,22 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
await yieldable_gather_results(get_key, keys_to_fetch.items())
return results
- async def get_server_verify_key_v2_direct(self, server_name, key_ids):
+ async def get_server_verify_key_v2_direct(
+ self, server_name: str, key_ids: Iterable[str]
+ ) -> Dict[str, FetchKeyResult]:
"""
Args:
- server_name (str):
- key_ids (iterable[str]):
+ server_name:
+ key_ids:
Returns:
- dict[str, FetchKeyResult]: map from key ID to lookup result
+ Map from key ID to lookup result
Raises:
KeyLookupError if there was a problem making the lookup
"""
- keys = {} # type: dict[str, FetchKeyResult]
+ keys = {} # type: Dict[str, FetchKeyResult]
for requested_key_id in key_ids:
# we may have found this key as a side-effect of asking for another.
@@ -825,6 +858,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
except HttpResponseException as e:
raise KeyLookupError("Remote server returned an error: %s" % (e,))
+ assert isinstance(response, dict)
if response["server_name"] != server_name:
raise KeyLookupError(
"Expected a response for server %r not %r"
@@ -846,11 +880,11 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
return keys
-async def _handle_key_deferred(verify_request) -> None:
+async def _handle_key_deferred(verify_request: VerifyJsonRequest) -> None:
"""Waits for the key to become available, and then performs a verification
Args:
- verify_request (VerifyJsonRequest):
+ verify_request:
Raises:
SynapseError if there was a problem performing the verification
|