summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/keys.py132
1 files changed, 88 insertions, 44 deletions
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index cea32a034a..a3b4744855 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,14 +16,13 @@
 import itertools
 import json
 import logging
-from typing import Any, Dict, Iterable, List, Mapping, Optional, Tuple
+from typing import Dict, Iterable, Mapping, Optional, Tuple
 
 from signedjson.key import decode_verify_key_bytes
 from unpaddedbase64 import decode_base64
 
-from synapse.storage._base import SQLBaseStore
-from synapse.storage.database import LoggingTransaction
-from synapse.storage.keys import FetchKeyResult
+from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
+from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
 from synapse.storage.types import Cursor
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
@@ -34,7 +33,7 @@ logger = logging.getLogger(__name__)
 db_binary_type = memoryview
 
 
-class KeyStore(SQLBaseStore):
+class KeyStore(CacheInvalidationWorkerStore):
     """Persistence for signature verification keys"""
 
     @cached()
@@ -188,7 +187,12 @@ class KeyStore(SQLBaseStore):
         # invalidate takes a tuple corresponding to the params of
         # _get_server_keys_json. _get_server_keys_json only takes one
         # param, which is itself the 2-tuple (server_name, key_id).
-        self._get_server_keys_json.invalidate(((server_name, key_id),))
+        await self.invalidate_cache_and_stream(
+            "_get_server_keys_json", ((server_name, key_id),)
+        )
+        await self.invalidate_cache_and_stream(
+            "get_server_key_json_for_remote", (server_name, key_id)
+        )
 
     @cached()
     def _get_server_keys_json(
@@ -253,47 +257,87 @@ class KeyStore(SQLBaseStore):
 
         return await self.db_pool.runInteraction("get_server_keys_json", _txn)
 
-    async def get_server_keys_json_for_remote(
-        self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
-    ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
-        """Retrieve the key json for a list of server_keys and key ids.
-        If no keys are found for a given server, key_id and source then
-        that server, key_id, and source triplet entry will be an empty list.
-        The JSON is returned as a byte array so that it can be efficiently
-        used in an HTTP response.
+    @cached()
+    def get_server_key_json_for_remote(
+        self,
+        server_name: str,
+        key_id: str,
+    ) -> Optional[FetchKeyResultForRemote]:
+        raise NotImplementedError()
 
-        Args:
-            server_keys: List of (server_name, key_id, source) triplets.
+    @cachedList(
+        cached_method_name="get_server_key_json_for_remote", list_name="key_ids"
+    )
+    async def get_server_keys_json_for_remote(
+        self, server_name: str, key_ids: Iterable[str]
+    ) -> Dict[str, Optional[FetchKeyResultForRemote]]:
+        """Fetch the cached keys for the given server/key IDs.
 
-        Returns:
-            A mapping from (server_name, key_id, source) triplets to a list of dicts
+        If we have multiple entries for a given key ID, returns the most recent.
         """
+        rows = await self.db_pool.simple_select_many_batch(
+            table="server_keys_json",
+            column="key_id",
+            iterable=key_ids,
+            keyvalues={"server_name": server_name},
+            retcols=(
+                "key_id",
+                "from_server",
+                "ts_added_ms",
+                "ts_valid_until_ms",
+                "key_json",
+            ),
+            desc="get_server_keys_json_for_remote",
+        )
 
-        def _get_server_keys_json_txn(
-            txn: LoggingTransaction,
-        ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
-            results = {}
-            for server_name, key_id, from_server in server_keys:
-                keyvalues = {"server_name": server_name}
-                if key_id is not None:
-                    keyvalues["key_id"] = key_id
-                if from_server is not None:
-                    keyvalues["from_server"] = from_server
-                rows = self.db_pool.simple_select_list_txn(
-                    txn,
-                    "server_keys_json",
-                    keyvalues=keyvalues,
-                    retcols=(
-                        "key_id",
-                        "from_server",
-                        "ts_added_ms",
-                        "ts_valid_until_ms",
-                        "key_json",
-                    ),
-                )
-                results[(server_name, key_id, from_server)] = rows
-            return results
+        if not rows:
+            return {}
+
+        # We sort the rows so that the most recently added entry is picked up.
+        rows.sort(key=lambda r: r["ts_added_ms"])
+
+        return {
+            row["key_id"]: FetchKeyResultForRemote(
+                # Cast to bytes since postgresql returns a memoryview.
+                key_json=bytes(row["key_json"]),
+                valid_until_ts=row["ts_valid_until_ms"],
+                added_ts=row["ts_added_ms"],
+            )
+            for row in rows
+        }
 
-        return await self.db_pool.runInteraction(
-            "get_server_keys_json", _get_server_keys_json_txn
+    async def get_all_server_keys_json_for_remote(
+        self,
+        server_name: str,
+    ) -> Dict[str, FetchKeyResultForRemote]:
+        """Fetch the cached keys for the given server.
+
+        If we have multiple entries for a given key ID, returns the most recent.
+        """
+        rows = await self.db_pool.simple_select_list(
+            table="server_keys_json",
+            keyvalues={"server_name": server_name},
+            retcols=(
+                "key_id",
+                "from_server",
+                "ts_added_ms",
+                "ts_valid_until_ms",
+                "key_json",
+            ),
+            desc="get_server_keys_json_for_remote",
         )
+
+        if not rows:
+            return {}
+
+        rows.sort(key=lambda r: r["ts_added_ms"])
+
+        return {
+            row["key_id"]: FetchKeyResultForRemote(
+                # Cast to bytes since postgresql returns a memoryview.
+                key_json=bytes(row["key_json"]),
+                valid_until_ts=row["ts_valid_until_ms"],
+                added_ts=row["ts_added_ms"],
+            )
+            for row in rows
+        }