summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2021-05-04 17:36:02 +0100
committerErik Johnston <erik@matrix.org>2021-05-04 17:36:02 +0100
commit3bfd3c55f931c6d8093f9b11e0bda4ffc4d2ef22 (patch)
treee05618a1e964c43c865269c6d2aa01a7994d64b9
parentTime external cache response time (#9904) (diff)
downloadsynapse-3bfd3c55f931c6d8093f9b11e0bda4ffc4d2ef22.tar.xz
Refactor keyring
-rw-r--r--synapse/crypto/keyring.py529
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py9
2 files changed, 182 insertions, 356 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 5f18ef7748..313d577f11 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -16,8 +16,7 @@
 import abc
 import logging
 import urllib
-from collections import defaultdict
-from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple
 
 import attr
 from signedjson.key import (
@@ -45,14 +44,13 @@ from synapse.config.key import TrustedKeyServer
 from synapse.logging.context import (
     PreserveLoggingContext,
     make_deferred_yieldable,
-    preserve_fn,
     run_in_background,
 )
+from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.keys import FetchKeyResult
 from synapse.types import JsonDict
 from synapse.util import unwrapFirstError
-from synapse.util.async_helpers import yieldable_gather_results
-from synapse.util.metrics import Measure
+from synapse.util.async_helpers import Linearizer, yieldable_gather_results
 from synapse.util.retryutils import NotRetryingDestination
 
 if TYPE_CHECKING:
@@ -102,6 +100,62 @@ class KeyLookupError(ValueError):
     pass
 
 
+@attr.s(slots=True)
+class _QueueValue:
+    server_name = attr.ib(type=str)
+    minimum_valid_until_ts = attr.ib(type=int)
+    key_ids = attr.ib(type=List[str])
+
+
+class _Queue:
+    def __init__(self, name, clock, process_items):
+        self._name = name
+        self._clock = clock
+        self._is_processing = False
+        self._next_values = []
+
+        self.process_items = process_items
+
+    async def add_to_queue(self, value: _QueueValue) -> Dict[str, FetchKeyResult]:
+        d = defer.Deferred()
+        self._next_values.append((value, d))
+
+        if self._is_processing:
+            return await d
+
+        run_as_background_process(self._name, self._unsafe_process)
+
+        return await d
+
+    async def _unsafe_process(self):
+        # We purposefully defer to the next loop.
+        await self._clock.sleep(0)
+
+        try:
+            if self._is_processing:
+                return
+
+            self._is_processing = True
+
+            while self._next_values:
+                next_values = self._next_values
+                self._next_values = []
+
+                try:
+                    values = [value for value, _ in next_values]
+                    results = await self.process_items(values)
+
+                    for value, deferred in next_values:
+                        deferred.callback(results.get(value.server_name, {}))
+
+                except Exception as e:
+                    for _, deferred in next_values:
+                        deferred.errback(e)
+
+        finally:
+            self._is_processing = False
+
+
 class Keyring:
     def __init__(
         self, hs: "HomeServer", key_fetchers: "Optional[Iterable[KeyFetcher]]" = None
@@ -116,12 +170,7 @@ class Keyring:
             )
         self._key_fetchers = key_fetchers
 
-        # map from server name to Deferred. Has an entry for each server with
-        # an ongoing key download; the Deferred completes once the download
-        # completes.
-        #
-        # These are regular, logcontext-agnostic Deferreds.
-        self.key_downloads = {}  # type: Dict[str, defer.Deferred]
+        self._server_queue = Linearizer("keyring_server")
 
     def verify_json_for_server(
         self,
@@ -130,365 +179,129 @@ class Keyring:
         validity_time: int,
         request_name: str,
     ) -> defer.Deferred:
-        """Verify that a JSON object has been signed by a given server
-
-        Args:
-            server_name: name of the server which must have signed this object
-
-            json_object: object to be checked
-
-            validity_time: timestamp at which we require the signing key to
-                be valid. (0 implies we don't care)
-
-            request_name: an identifier for this json object (eg, an event id)
-                for logging.
-
-        Returns:
-            Deferred[None]: completes if the the object was correctly signed, otherwise
-                errbacks with an error
-        """
-        req = VerifyJsonRequest(server_name, json_object, validity_time, request_name)
-        requests = (req,)
-        return make_deferred_yieldable(self._verify_objects(requests)[0])
+        request = VerifyJsonRequest(
+            server_name, json_object, validity_time, request_name
+        )
+        return defer.ensureDeferred(self._verify_object(request))
 
     def verify_json_objects_for_server(
         self, server_and_json: Iterable[Tuple[str, dict, int, str]]
     ) -> List[defer.Deferred]:
-        """Bulk verifies signatures of json objects, bulk fetching keys as
-        necessary.
-
-        Args:
-            server_and_json:
-                Iterable of (server_name, json_object, validity_time, request_name)
-                tuples.
-
-                validity_time is a timestamp at which the signing key must be
-                valid.
-
-                request_name is an identifier for this json object (eg, an event id)
-                for logging.
-
-        Returns:
-            List<Deferred[None]>: for each input triplet, a deferred indicating success
-                or failure to verify each json object's signature for the given
-                server_name. The deferreds run their callbacks in the sentinel
-                logcontext.
-        """
-        return self._verify_objects(
-            VerifyJsonRequest(server_name, json_object, validity_time, request_name)
-            for server_name, json_object, validity_time, request_name in server_and_json
-        )
-
-    def _verify_objects(
-        self, verify_requests: Iterable[VerifyJsonRequest]
-    ) -> List[defer.Deferred]:
-        """Does the work of verify_json_[objects_]for_server
-
-
-        Args:
-            verify_requests: Iterable of verification requests.
-
-        Returns:
-            List<Deferred[None]>: for each input item, a deferred indicating success
-                or failure to verify each json object's signature for the given
-                server_name. The deferreds run their callbacks in the sentinel
-                logcontext.
-        """
-        # a list of VerifyJsonRequests which are awaiting a key lookup
-        key_lookups = []
-        handle = preserve_fn(_handle_key_deferred)
-
-        def process(verify_request: VerifyJsonRequest) -> defer.Deferred:
-            """Process an entry in the request list
-
-            Adds a key request to key_lookups, and returns a deferred which
-            will complete or fail (in the sentinel context) when verification completes.
-            """
-            if not verify_request.key_ids:
-                return defer.fail(
-                    SynapseError(
-                        400,
-                        "Not signed by %s" % (verify_request.server_name,),
-                        Codes.UNAUTHORIZED,
+        return [
+            defer.ensureDeferred(
+                self._verify_object(
+                    VerifyJsonRequest(
+                        server_name, json_object, validity_time, request_name
                     )
                 )
-
-            logger.debug(
-                "Verifying %s for %s with key_ids %s, min_validity %i",
-                verify_request.request_name,
-                verify_request.server_name,
-                verify_request.key_ids,
-                verify_request.minimum_valid_until_ts,
             )
+            for server_name, json_object, validity_time, request_name in server_and_json
+        ]
+
+    async def _verify_object(self, verify_request: VerifyJsonRequest):
+        # TODO: Use a batching thing.
+        with (await self._server_queue.queue(verify_request.server_name)):
+            found_keys: Dict[str, FetchKeyResult] = {}
+            missing_key_ids = set(verify_request.key_ids)
+            for fetcher in self._key_fetchers:
+                if not missing_key_ids:
+                    break
+
+                keys = await fetcher.get_keys(
+                    verify_request.server_name,
+                    list(missing_key_ids),
+                    verify_request.minimum_valid_until_ts,
+                )
 
-            # add the key request to the queue, but don't start it off yet.
-            key_lookups.append(verify_request)
-
-            # now run _handle_key_deferred, which will wait for the key request
-            # to complete and then do the verification.
-            #
-            # We want _handle_key_request to log to the right context, so we
-            # wrap it with preserve_fn (aka run_in_background)
-            return handle(verify_request)
-
-        results = [process(r) for r in verify_requests]
-
-        if key_lookups:
-            run_in_background(self._start_key_lookups, key_lookups)
-
-        return results
-
-    async def _start_key_lookups(
-        self, verify_requests: List[VerifyJsonRequest]
-    ) -> None:
-        """Sets off the key fetches for each verify request
-
-        Once each fetch completes, verify_request.key_ready will be resolved.
-
-        Args:
-            verify_requests:
-        """
-
-        try:
-            # map from server name to a set of outstanding request ids
-            server_to_request_ids = {}  # type: Dict[str, Set[int]]
-
-            for verify_request in verify_requests:
-                server_name = verify_request.server_name
-                request_id = id(verify_request)
-                server_to_request_ids.setdefault(server_name, set()).add(request_id)
-
-            # Wait for any previous lookups to complete before proceeding.
-            await self.wait_for_previous_lookups(server_to_request_ids.keys())
-
-            # take out a lock on each of the servers by sticking a Deferred in
-            # key_downloads
-            for server_name in server_to_request_ids.keys():
-                self.key_downloads[server_name] = defer.Deferred()
-                logger.debug("Got key lookup lock on %s", server_name)
-
-            # When we've finished fetching all the keys for a given server_name,
-            # drop the lock by resolving the deferred in key_downloads.
-            def drop_server_lock(server_name):
-                d = self.key_downloads.pop(server_name)
-                d.callback(None)
-
-            def lookup_done(res, verify_request):
-                server_name = verify_request.server_name
-                server_requests = server_to_request_ids[server_name]
-                server_requests.remove(id(verify_request))
-
-                # if there are no more requests for this server, we can drop the lock.
-                if not server_requests:
-                    logger.debug("Releasing key lookup lock on %s", server_name)
-                    drop_server_lock(server_name)
-
-                return res
-
-            for verify_request in verify_requests:
-                verify_request.key_ready.addBoth(lookup_done, verify_request)
-
-            # Actually start fetching keys.
-            self._get_server_verify_keys(verify_requests)
-        except Exception:
-            logger.exception("Error starting key lookups")
-
-    async def wait_for_previous_lookups(self, server_names: Iterable[str]) -> None:
-        """Waits for any previous key lookups for the given servers to finish.
-
-        Args:
-            server_names: list of servers which we want to look up
-
-        Returns:
-            Resolves once all key lookups for the given servers have
-                completed. Follows the synapse rules of logcontext preservation.
-        """
-        loop_count = 1
-        while True:
-            wait_on = [
-                (server_name, self.key_downloads[server_name])
-                for server_name in server_names
-                if server_name in self.key_downloads
-            ]
-            if not wait_on:
-                break
-            logger.info(
-                "Waiting for existing lookups for %s to complete [loop %i]",
-                [w[0] for w in wait_on],
-                loop_count,
-            )
-            with PreserveLoggingContext():
-                await defer.DeferredList((w[1] for w in wait_on))
-
-            loop_count += 1
-
-    def _get_server_verify_keys(self, verify_requests: List[VerifyJsonRequest]) -> None:
-        """Tries to find at least one key for each verify request
-
-        For each verify_request, verify_request.key_ready is called back with
-        params (server_name, key_id, VerifyKey) if a key is found, or errbacked
-        with a SynapseError if none of the keys are found.
-
-        Args:
-            verify_requests: list of verify requests
-        """
+                for key_id, key in keys.items():
+                    if not key:
+                        continue
 
-        remaining_requests = {rq for rq in verify_requests if not rq.key_ready.called}
+                    if key.valid_until_ts < verify_request.minimum_valid_until_ts:
+                        continue
 
-        async def do_iterations():
-            try:
-                with Measure(self.clock, "get_server_verify_keys"):
-                    for f in self._key_fetchers:
-                        if not remaining_requests:
-                            return
-                        await self._attempt_key_fetches_with_fetcher(
-                            f, remaining_requests
-                        )
-
-                    # look for any requests which weren't satisfied
-                    while remaining_requests:
-                        verify_request = remaining_requests.pop()
-                        rq_str = (
-                            "VerifyJsonRequest(server=%s, key_ids=%s, min_valid=%i)"
-                            % (
-                                verify_request.server_name,
-                                verify_request.key_ids,
-                                verify_request.minimum_valid_until_ts,
-                            )
-                        )
-
-                        # If we run the errback immediately, it may cancel our
-                        # loggingcontext while we are still in it, so instead we
-                        # schedule it for the next time round the reactor.
-                        #
-                        # (this also ensures that we don't get a stack overflow if we
-                        # has a massive queue of lookups waiting for this server).
-                        self.clock.call_later(
-                            0,
-                            verify_request.key_ready.errback,
-                            SynapseError(
-                                401,
-                                "Failed to find any key to satisfy %s" % (rq_str,),
-                                Codes.UNAUTHORIZED,
-                            ),
-                        )
-            except Exception as err:
-                # we don't really expect to get here, because any errors should already
-                # have been caught and logged. But if we do, let's log the error and make
-                # sure that all of the deferreds are resolved.
-                logger.error("Unexpected error in _get_server_verify_keys: %s", err)
-                with PreserveLoggingContext():
-                    for verify_request in remaining_requests:
-                        if not verify_request.key_ready.called:
-                            verify_request.key_ready.errback(err)
-
-        run_in_background(do_iterations)
-
-    async def _attempt_key_fetches_with_fetcher(
-        self, fetcher: "KeyFetcher", remaining_requests: Set[VerifyJsonRequest]
-    ):
-        """Use a key fetcher to attempt to satisfy some key requests
+                    existing_key = found_keys.get(key_id)
+                    if existing_key:
+                        if key.valid_until_ts <= existing_key.valid_until_ts:
+                            continue
 
-        Args:
-            fetcher: fetcher to use to fetch the keys
-            remaining_requests: outstanding key requests.
-                Any successfully-completed requests will be removed from the list.
-        """
-        # The keys to fetch.
-        # server_name -> key_id -> min_valid_ts
-        missing_keys = defaultdict(dict)  # type: Dict[str, Dict[str, int]]
+                    found_keys[key_id] = key
 
-        for verify_request in remaining_requests:
-            # any completed requests should already have been removed
-            assert not verify_request.key_ready.called
-            keys_for_server = missing_keys[verify_request.server_name]
+                missing_key_ids.difference_update(found_keys)
 
-            for key_id in verify_request.key_ids:
-                # If we have several requests for the same key, then we only need to
-                # request that key once, but we should do so with the greatest
-                # min_valid_until_ts of the requests, so that we can satisfy all of
-                # the requests.
-                keys_for_server[key_id] = max(
-                    keys_for_server.get(key_id, -1),
-                    verify_request.minimum_valid_until_ts,
+            if missing_key_ids:
+                raise SynapseError(
+                    400,
+                    "Missing keys for %s: %s"
+                    % (verify_request.server_name, missing_key_ids),
+                    Codes.UNAUTHORIZED,
                 )
 
-        results = await fetcher.get_keys(missing_keys)
-
-        completed = []
-        for verify_request in remaining_requests:
-            server_name = verify_request.server_name
-
-            # see if any of the keys we got this time are sufficient to
-            # complete this VerifyJsonRequest.
-            result_keys = results.get(server_name, {})
             for key_id in verify_request.key_ids:
-                fetch_key_result = result_keys.get(key_id)
-                if not fetch_key_result:
-                    # we didn't get a result for this key
-                    continue
-
-                if (
-                    fetch_key_result.valid_until_ts
-                    < verify_request.minimum_valid_until_ts
-                ):
-                    # key was not valid at this point
-                    continue
-
-                # we have a valid key for this request. If we run the callback
-                # immediately, it may cancel our loggingcontext while we are still in
-                # it, so instead we schedule it for the next time round the reactor.
-                #
-                # (this also ensures that we don't get a stack overflow if we had
-                # a massive queue of lookups waiting for this server).
-                logger.debug(
-                    "Found key %s:%s for %s",
-                    server_name,
-                    key_id,
-                    verify_request.request_name,
-                )
-                self.clock.call_later(
-                    0,
-                    verify_request.key_ready.callback,
-                    (server_name, key_id, fetch_key_result.verify_key),
-                )
-                completed.append(verify_request)
-                break
-
-        remaining_requests.difference_update(completed)
+                verify_key = found_keys[key_id].verify_key
+                try:
+                    verify_signed_json(
+                        verify_request.json_object,
+                        verify_request.server_name,
+                        verify_key,
+                    )
+                except SignatureVerifyException as e:
+                    logger.debug(
+                        "Error verifying signature for %s:%s:%s with key %s: %s",
+                        verify_request.server_name,
+                        verify_key.alg,
+                        verify_key.version,
+                        encode_verify_key_base64(verify_key),
+                        str(e),
+                    )
+                    raise SynapseError(
+                        401,
+                        "Invalid signature for server %s with key %s:%s: %s"
+                        % (
+                            verify_request.server_name,
+                            verify_key.alg,
+                            verify_key.version,
+                            str(e),
+                        ),
+                        Codes.UNAUTHORIZED,
+                    )
 
 
 class KeyFetcher(metaclass=abc.ABCMeta):
-    @abc.abstractmethod
+    def __init__(self, hs: "HomeServer"):
+        self._queue = _Queue(self.__class__.__name__, hs.get_clock(), self._fetch_keys)
+
     async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
-    ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """
-        Args:
-            keys_to_fetch:
-                the keys to be fetched. server_name -> key_id -> min_valid_ts
+        self, server_name: str, key_ids: List[str], minimum_valid_until_ts: int
+    ) -> Dict[str, FetchKeyResult]:
+        return await self._queue.add_to_queue(
+            _QueueValue(
+                server_name=server_name,
+                key_ids=key_ids,
+                minimum_valid_until_ts=minimum_valid_until_ts,
+            )
+        )
 
-        Returns:
-            Map from server_name -> key_id -> FetchKeyResult
-        """
-        raise NotImplementedError
+    @abc.abstractmethod
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_QueueValue]
+    ) -> Dict[str, Dict[str, FetchKeyResult]]:
+        pass
 
 
 class StoreKeyFetcher(KeyFetcher):
     """KeyFetcher impl which fetches keys from our data store"""
 
     def __init__(self, hs: "HomeServer"):
-        self.store = hs.get_datastore()
+        super().__init__(hs)
 
-    async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
-    ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """see KeyFetcher.get_keys"""
+        self.store = hs.get_datastore()
 
+    async def _fetch_keys(self, keys_to_fetch: List[_QueueValue]):
         key_ids_to_fetch = (
-            (server_name, key_id)
-            for server_name, keys_for_server in keys_to_fetch.items()
-            for key_id in keys_for_server.keys()
+            (queue_value.server_name, key_id)
+            for queue_value in keys_to_fetch
+            for key_id in queue_value.key_ids
         )
 
         res = await self.store.get_server_verify_keys(key_ids_to_fetch)
@@ -500,6 +313,8 @@ class StoreKeyFetcher(KeyFetcher):
 
 class BaseV2KeyFetcher(KeyFetcher):
     def __init__(self, hs: "HomeServer"):
+        super().__init__(hs)
+
         self.store = hs.get_datastore()
         self.config = hs.config
 
@@ -607,10 +422,10 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         self.client = hs.get_federation_http_client()
         self.key_servers = self.config.key_servers
 
-    async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_QueueValue]
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
-        """see KeyFetcher.get_keys"""
+        """see KeyFetcher._fetch_keys"""
 
         async def get_key(key_server: TrustedKeyServer) -> Dict:
             try:
@@ -646,12 +461,12 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         return union_of_keys
 
     async def get_server_verify_key_v2_indirect(
-        self, keys_to_fetch: Dict[str, Dict[str, int]], key_server: TrustedKeyServer
+        self, keys_to_fetch: List[_QueueValue], key_server: TrustedKeyServer
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
             keys_to_fetch:
-                the keys to be fetched. server_name -> key_id -> min_valid_ts
+                the keys to be fetched.
 
             key_server: notary server to query for the keys
 
@@ -665,7 +480,7 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
         perspective_name = key_server.server_name
         logger.info(
             "Requesting keys %s from notary server %s",
-            keys_to_fetch.items(),
+            keys_to_fetch,
             perspective_name,
         )
 
@@ -675,11 +490,13 @@ class PerspectivesKeyFetcher(BaseV2KeyFetcher):
                 path="/_matrix/key/v2/query",
                 data={
                     "server_keys": {
-                        server_name: {
-                            key_id: {"minimum_valid_until_ts": min_valid_ts}
-                            for key_id, min_valid_ts in server_keys.items()
+                        queue_value.server_name: {
+                            key_id: {
+                                "minimum_valid_until_ts": queue_value.minimum_valid_until_ts,
+                            }
+                            for key_id in queue_value.key_ids
                         }
-                        for server_name, server_keys in keys_to_fetch.items()
+                        for queue_value in keys_to_fetch
                     }
                 },
             )
@@ -779,8 +596,8 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
         self.clock = hs.get_clock()
         self.client = hs.get_federation_http_client()
 
-    async def get_keys(
-        self, keys_to_fetch: Dict[str, Dict[str, int]]
+    async def _fetch_keys(
+        self, keys_to_fetch: List[_QueueValue]
     ) -> Dict[str, Dict[str, FetchKeyResult]]:
         """
         Args:
@@ -793,8 +610,10 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
 
         results = {}
 
-        async def get_key(key_to_fetch_item: Tuple[str, Dict[str, int]]) -> None:
-            server_name, key_ids = key_to_fetch_item
+        async def get_key(key_to_fetch_item: _QueueValue) -> None:
+            server_name = key_to_fetch_item.server_name
+            key_ids = key_to_fetch_item.key_ids
+
             try:
                 keys = await self.get_server_verify_key_v2_direct(server_name, key_ids)
                 results[server_name] = keys
@@ -805,7 +624,7 @@ class ServerKeyFetcher(BaseV2KeyFetcher):
             except Exception:
                 logger.exception("Error getting keys %s from %s", key_ids, server_name)
 
-        await yieldable_gather_results(get_key, keys_to_fetch.items())
+        await yieldable_gather_results(get_key, keys_to_fetch)
         return results
 
     async def get_server_verify_key_v2_direct(
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index f648678b09..24af3c28ea 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -22,6 +22,7 @@ from synapse.crypto.keyring import ServerKeyFetcher
 from synapse.http.server import DirectServeJsonResource, respond_with_json
 from synapse.http.servlet import parse_integer, parse_json_object_from_request
 from synapse.util import json_decoder
+from synapse.util.async_helpers import yieldable_gather_results
 
 logger = logging.getLogger(__name__)
 
@@ -213,7 +214,13 @@ class RemoteKey(DirectServeJsonResource):
         # If there is a cache miss, request the missing keys, then recurse (and
         # ensure the result is sent).
         if cache_misses and query_remote_on_cache_miss:
-            await self.fetcher.get_keys(cache_misses)
+            await yieldable_gather_results(
+                self.fetcher.get_keys,
+                (
+                    (server_name, list(keys), 0)
+                    for server_name, keys in cache_misses.items()
+                ),
+            )
             await self.query_keys(request, query, query_remote_on_cache_miss=False)
         else:
             signed_keys = []