summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/16123.misc1
-rw-r--r--synapse/rest/key/v2/remote_key_resource.py44
-rw-r--r--synapse/storage/databases/main/keys.py132
-rw-r--r--synapse/storage/keys.py7
-rw-r--r--tests/crypto/test_keyring.py61
5 files changed, 144 insertions, 101 deletions
diff --git a/changelog.d/16123.misc b/changelog.d/16123.misc
new file mode 100644
index 0000000000..b7c6b7c2f2
--- /dev/null
+++ b/changelog.d/16123.misc
@@ -0,0 +1 @@
+Add cache to `get_server_keys_json_for_remote`.
diff --git a/synapse/rest/key/v2/remote_key_resource.py b/synapse/rest/key/v2/remote_key_resource.py
index 8f3865d412..981fd1f58a 100644
--- a/synapse/rest/key/v2/remote_key_resource.py
+++ b/synapse/rest/key/v2/remote_key_resource.py
@@ -14,7 +14,7 @@
 
 import logging
 import re
-from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
+from typing import TYPE_CHECKING, Dict, Mapping, Optional, Set, Tuple
 
 from signedjson.sign import sign_json
 
@@ -27,6 +27,7 @@ from synapse.http.servlet import (
     parse_integer,
     parse_json_object_from_request,
 )
+from synapse.storage.keys import FetchKeyResultForRemote
 from synapse.types import JsonDict
 from synapse.util import json_decoder
 from synapse.util.async_helpers import yieldable_gather_results
@@ -157,14 +158,22 @@ class RemoteKey(RestServlet):
     ) -> JsonDict:
         logger.info("Handling query for keys %r", query)
 
-        store_queries = []
+        server_keys: Dict[Tuple[str, str], Optional[FetchKeyResultForRemote]] = {}
         for server_name, key_ids in query.items():
-            if not key_ids:
-                key_ids = (None,)
-            for key_id in key_ids:
-                store_queries.append((server_name, key_id, None))
+            if key_ids:
+                results: Mapping[
+                    str, Optional[FetchKeyResultForRemote]
+                ] = await self.store.get_server_keys_json_for_remote(
+                    server_name, key_ids
+                )
+            else:
+                results = await self.store.get_all_server_keys_json_for_remote(
+                    server_name
+                )
 
-        cached = await self.store.get_server_keys_json_for_remote(store_queries)
+            server_keys.update(
+                ((server_name, key_id), res) for key_id, res in results.items()
+            )
 
         json_results: Set[bytes] = set()
 
@@ -173,23 +182,20 @@ class RemoteKey(RestServlet):
         # Map server_name->key_id->int. Note that the value of the int is unused.
         # XXX: why don't we just use a set?
         cache_misses: Dict[str, Dict[str, int]] = {}
-        for (server_name, key_id, _), key_results in cached.items():
-            results = [(result["ts_added_ms"], result) for result in key_results]
-
-            if key_id is None:
+        for (server_name, key_id), key_result in server_keys.items():
+            if not query[server_name]:
                 # all keys were requested. Just return what we have without worrying
                 # about validity
-                for _, result in results:
-                    # Cast to bytes since postgresql returns a memoryview.
-                    json_results.add(bytes(result["key_json"]))
+                if key_result:
+                    json_results.add(key_result.key_json)
                 continue
 
             miss = False
-            if not results:
+            if key_result is None:
                 miss = True
             else:
-                ts_added_ms, most_recent_result = max(results)
-                ts_valid_until_ms = most_recent_result["ts_valid_until_ms"]
+                ts_added_ms = key_result.added_ts
+                ts_valid_until_ms = key_result.valid_until_ts
                 req_key = query.get(server_name, {}).get(key_id, {})
                 req_valid_until = req_key.get("minimum_valid_until_ts")
                 if req_valid_until is not None:
@@ -235,8 +241,8 @@ class RemoteKey(RestServlet):
                         ts_valid_until_ms,
                         time_now_ms,
                     )
-                # Cast to bytes since postgresql returns a memoryview.
-                json_results.add(bytes(most_recent_result["key_json"]))
+
+                json_results.add(key_result.key_json)
 
             if miss and query_remote_on_cache_miss:
                 # only bother attempting to fetch keys from servers on our whitelist
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index cea32a034a..a3b4744855 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,14 +16,13 @@
 import itertools
 import json
 import logging
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
+from typing import Dict, Iterable, Mapping, Optional, Tuple
 
 from signedjson.key import decode_verify_key_bytes
 from unpaddedbase64 import decode_base64
 
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
-from synapse.storage.keys import FetchKeyResult
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
 from synapse.storage.types import Cursor
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
 db_binary_type = memoryview
 
 
-class KeyStore(SQLBaseStore):
+class KeyStore(CacheInvalidationWorkerStore):
     """Persistence for signature verification keys"""
 
     @cached()
@@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
         # invalidate takes a tuple corresponding to the params of
         # _get_server_keys_json. _get_server_keys_json only takes one
         # param, which is itself the 2-tuple (server_name, key_id).
-        self._get_server_keys_json.invalidate(((server_name, key_id),))
+        await self.invalidate_cache_and_stream(
+            "_get_server_keys_json", ((server_name, key_id),)
+        )
+        await self.invalidate_cache_and_stream(
+            "get_server_key_json_for_remote", (server_name, key_id)
+        )
 
     @cached()
     def _get_server_keys_json(
@@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
 
         return await self.db_pool.runInteraction("get_server_keys_json", _txn)
 
-    async def get_server_keys_json_for_remote(
-        self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
-    ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
-        """Retrieve the key json for a list of server_keys and key ids.
-        If no keys are found for a given server, key_id and source then
-        that server, key_id, and source triplet entry will be an empty list.
-        The JSON is returned as a byte array so that it can be efficiently
-        used in an HTTP response.
+    @cached()
+    def get_server_key_json_for_remote(
+        self,
+        server_name: str,
+        key_id: str,
+    ) -> Optional[FetchKeyResultForRemote]:
+        raise NotImplementedError()
 
-        Args:
-            server_keys: List of (server_name, key_id, source) triplets.
+    @cachedList(
+        cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
+    )
+    async def get_server_keys_json_for_remote(
+        self, server_name: str, key_ids: Iterable[str]
+    ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
+        """Fetch the cached keys for the given server/key IDs.
 
-        Returns:
-            A mapping from (server_name, key_id, source) triplets to a list of dicts
+        If we have multiple entries for a given key ID, returns the most recent.
         """
+        rows = await self.db_pool.simple_select_many_batch(
+            table="server_keys_json",
+            column="key_id",
+            iterable=key_ids,
+            keyvalues={"server_name": server_name},
+            retcols=(
+                "key_id",
+                "from_server",
+                "ts_added_ms",
+                "ts_valid_until_ms",
+                "key_json",
+            ),
+            desc="get_server_keys_json_for_remote",
+        )
 
-        def _get_server_keys_json_txn(
-            txn: LoggingTransaction,
-        ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
-            results = {}
-            for server_name, key_id, from_server in server_keys:
-                keyvalues = {"server_name": server_name}
-                if key_id is not None:
-                    keyvalues["key_id"] = key_id
-                if from_server is not None:
-                    keyvalues["from_server"] = from_server
-                rows = self.db_pool.simple_select_list_txn(
-                    txn,
-                    "server_keys_json",
-                    keyvalues=keyvalues,
-                    retcols=(
-                        "key_id",
-                        "from_server",
-                        "ts_added_ms",
-                        "ts_valid_until_ms",
-                        "key_json",
-                    ),
-                )
-                results[(server_name, key_id, from_server)] = rows
-            return results
+        if not rows:
+            return {}
+
+        # We sort the rows so that the most recently added entry is picked up.
+        rows.sort(key=lambda r: r["ts_added_ms"])
+
+        return {
+            row["key_id"]: FetchKeyResultForRemote(
+                # Cast to bytes since postgresql returns a memoryview.
+                key_json=bytes(row["key_json"]),
+                valid_until_ts=row["ts_valid_until_ms"],
+                added_ts=row["ts_added_ms"],
+            )
+            for row in rows
+        }
 
-        return await self.db_pool.runInteraction(
-            "get_server_keys_json", _get_server_keys_json_txn
+    async def get_all_server_keys_json_for_remote(
+        self,
+        server_name: str,
+    ) -> Dict[str, FetchKeyResultForRemote]:
+        """Fetch the cached keys for the given server.
+
+        If we have multiple entries for a given key ID, returns the most recent.
+        """
+        rows = await self.db_pool.simple_select_list(
+            table="server_keys_json",
+            keyvalues={"server_name": server_name},
+            retcols=(
+                "key_id",
+                "from_server",
+                "ts_added_ms",
+                "ts_valid_until_ms",
+                "key_json",
+            ),
+            desc="get_server_keys_json_for_remote",
         )
+
+        if not rows:
+            return {}
+
+        rows.sort(key=lambda r: r["ts_added_ms"])
+
+        return {
+            row["key_id"]: FetchKeyResultForRemote(
+                # Cast to bytes since postgresql returns a memoryview.
+                key_json=bytes(row["key_json"]),
+                valid_until_ts=row["ts_valid_until_ms"],
+                added_ts=row["ts_added_ms"],
+            )
+            for row in rows
+        }
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 71584f3f74..e74b2269d2 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -25,3 +25,10 @@ logger = logging.getLogger(__name__)
 class FetchKeyResult:
     verify_key: VerifyKey  # the key itself
     valid_until_ts: int  # how long we can use this key for
+
+
+@attr.s(slots=True, frozen=True, auto_attribs=True)
+class FetchKeyResultForRemote:
+    key_json: bytes  # the full key JSON
+    valid_until_ts: int  # how long we can use this key for, in milliseconds.
+    added_ts: int  # When we added this key, in milliseconds.
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index fdfd4f911d..2be341ac7b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
-        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json_for_remote(
-                [lookup_triplet]
+                SERVER_NAME, [testverifykey_id]
             )
         )
-        res_keys = key_json[lookup_triplet]
-        self.assertEqual(len(res_keys), 1)
-        res = res_keys[0]
-        self.assertEqual(res["key_id"], testverifykey_id)
-        self.assertEqual(res["from_server"], SERVER_NAME)
-        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
-        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+        res = key_json[testverifykey_id]
+        self.assertIsNotNone(res)
+        assert res is not None
+        self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+        self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
 
         # we expect it to be encoded as canonical json *before* it hits the db
-        self.assertEqual(
-            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
-        )
+        self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
 
         # change the server name: the result should be ignored
         response["server_name"] = "OTHER_SERVER"
@@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
-        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json_for_remote(
-                [lookup_triplet]
+                SERVER_NAME, [testverifykey_id]
             )
         )
-        res_keys = key_json[lookup_triplet]
-        self.assertEqual(len(res_keys), 1)
-        res = res_keys[0]
-        self.assertEqual(res["key_id"], testverifykey_id)
-        self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
-        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
-        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
-        self.assertEqual(
-            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
-        )
+        res = key_json[testverifykey_id]
+        self.assertIsNotNone(res)
+        assert res is not None
+        self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+        self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+        self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
 
     def test_get_multiple_keys_from_perspectives(self) -> None:
         """Check that we can correctly request multiple keys for the same server"""
@@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
-        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json_for_remote(
-                [lookup_triplet]
+                SERVER_NAME, [testverifykey_id]
             )
         )
-        res_keys = key_json[lookup_triplet]
-        self.assertEqual(len(res_keys), 1)
-        res = res_keys[0]
-        self.assertEqual(res["key_id"], testverifykey_id)
-        self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
-        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
-        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
-        self.assertEqual(
-            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
-        )
+        res = key_json[testverifykey_id]
+        self.assertIsNotNone(res)
+        assert res is not None
+        self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+        self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+        self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
 
     def test_invalid_perspectives_responses(self) -> None:
         """Check that invalid responses from the perspectives server are rejected"""