summary refs log tree commit diff
path: root/synapse/storage/databases/main
diff options
context:
space:
mode:
authorErik Johnston <erikj@matrix.org>2023-09-12 11:08:04 +0100
committerGitHub <noreply@github.com>2023-09-12 11:08:04 +0100
commit2b35626b6b7aed52a626734a5a85fe77c847251d (patch)
treecde4e127a463a82a37039818facc63358a29f787 /synapse/storage/databases/main
parentAdd the List-Unsubscribe header for notification emails. (#16274) (diff)
downloadsynapse-2b35626b6b7aed52a626734a5a85fe77c847251d.tar.xz
Refactor storing of server keys (#16261)
Diffstat (limited to 'synapse/storage/databases/main')
-rw-r--r--synapse/storage/databases/main/keys.py219
1 files changed, 72 insertions, 147 deletions
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 57aa4921e1..41563371dc 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -16,14 +16,17 @@
 import itertools
 import json
 import logging
-from typing import Dict, Iterable, Mapping, Optional, Tuple
+from typing import Dict, Iterable, Optional, Tuple
 
+from canonicaljson import encode_canonical_json
 from signedjson.key import decode_verify_key_bytes
 from unpaddedbase64 import decode_base64
 
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
 from synapse.storage.keys import FetchKeyResult, FetchKeyResultForRemote
 from synapse.storage.types import Cursor
+from synapse.types import JsonDict
 from synapse.util.caches.descriptors import cached, cachedList
 from synapse.util.iterutils import batch_iter
 
@@ -36,162 +39,84 @@ db_binary_type = memoryview
 class KeyStore(CacheInvalidationWorkerStore):
     """Persistence for signature verification keys"""
 
-    @cached()
-    def _get_server_signature_key(
-        self, server_name_and_key_id: Tuple[str, str]
-    ) -> FetchKeyResult:
-        raise NotImplementedError()
-
-    @cachedList(
-        cached_method_name="_get_server_signature_key",
-        list_name="server_name_and_key_ids",
-    )
-    async def get_server_signature_keys(
-        self, server_name_and_key_ids: Iterable[Tuple[str, str]]
-    ) -> Dict[Tuple[str, str], FetchKeyResult]:
-        """
-        Args:
-            server_name_and_key_ids:
-                iterable of (server_name, key-id) tuples to fetch keys for
-
-        Returns:
-            A map from (server_name, key_id) -> FetchKeyResult, or None if the
-            key is unknown
-        """
-        keys = {}
-
-        def _get_keys(txn: Cursor, batch: Tuple[Tuple[str, str], ...]) -> None:
-            """Processes a batch of keys to fetch, and adds the result to `keys`."""
-
-            # batch_iter always returns tuples so it's safe to do len(batch)
-            sql = """
-            SELECT server_name, key_id, verify_key, ts_valid_until_ms
-            FROM server_signature_keys WHERE 1=0
-            """ + " OR (server_name=? AND key_id=?)" * len(
-                batch
-            )
-
-            txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
-
-            for row in txn:
-                server_name, key_id, key_bytes, ts_valid_until_ms = row
-
-                if ts_valid_until_ms is None:
-                    # Old keys may be stored with a ts_valid_until_ms of null,
-                    # in which case we treat this as if it was set to `0`, i.e.
-                    # it won't match key requests that define a minimum
-                    # `ts_valid_until_ms`.
-                    ts_valid_until_ms = 0
-
-                keys[(server_name, key_id)] = FetchKeyResult(
-                    verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
-                    valid_until_ts=ts_valid_until_ms,
-                )
-
-        def _txn(txn: Cursor) -> Dict[Tuple[str, str], FetchKeyResult]:
-            for batch in batch_iter(server_name_and_key_ids, 50):
-                _get_keys(txn, batch)
-            return keys
-
-        return await self.db_pool.runInteraction("get_server_signature_keys", _txn)
-
-    async def store_server_signature_keys(
+    async def store_server_keys_response(
         self,
+        server_name: str,
         from_server: str,
         ts_added_ms: int,
-        verify_keys: Mapping[Tuple[str, str], FetchKeyResult],
+        verify_keys: Dict[str, FetchKeyResult],
+        response_json: JsonDict,
     ) -> None:
-        """Stores NACL verification keys for remote servers.
+        """Stores the keys for the given server that we got from `from_server`.
+
         Args:
-            from_server: Where the verification keys were looked up
-            ts_added_ms: The time to record that the key was added
-            verify_keys:
-                keys to be stored. Each entry is a triplet of
-                (server_name, key_id, key).
+            server_name: The owner of the keys
+            from_server: Which server we got the keys from
+            ts_added_ms: When we're adding the keys
+            verify_keys: The decoded keys
+            response_json: The full *signed* response JSON that contains the keys.
         """
-        key_values = []
-        value_values = []
-        invalidations = []
-        for (server_name, key_id), fetch_result in verify_keys.items():
-            key_values.append((server_name, key_id))
-            value_values.append(
-                (
-                    from_server,
-                    ts_added_ms,
-                    fetch_result.valid_until_ts,
-                    db_binary_type(fetch_result.verify_key.encode()),
-                )
-            )
-            # invalidate takes a tuple corresponding to the params of
-            # _get_server_signature_key. _get_server_signature_key only takes one
-            # param, which is itself the 2-tuple (server_name, key_id).
-            invalidations.append((server_name, key_id))
 
-        await self.db_pool.simple_upsert_many(
-            table="server_signature_keys",
-            key_names=("server_name", "key_id"),
-            key_values=key_values,
-            value_names=(
-                "from_server",
-                "ts_added_ms",
-                "ts_valid_until_ms",
-                "verify_key",
-            ),
-            value_values=value_values,
-            desc="store_server_signature_keys",
-        )
+        key_json_bytes = encode_canonical_json(response_json)
+
+        def store_server_keys_response_txn(txn: LoggingTransaction) -> None:
+            self.db_pool.simple_upsert_many_txn(
+                txn,
+                table="server_signature_keys",
+                key_names=("server_name", "key_id"),
+                key_values=[(server_name, key_id) for key_id in verify_keys],
+                value_names=(
+                    "from_server",
+                    "ts_added_ms",
+                    "ts_valid_until_ms",
+                    "verify_key",
+                ),
+                value_values=[
+                    (
+                        from_server,
+                        ts_added_ms,
+                        fetch_result.valid_until_ts,
+                        db_binary_type(fetch_result.verify_key.encode()),
+                    )
+                    for fetch_result in verify_keys.values()
+                ],
+            )
 
-        invalidate = self._get_server_signature_key.invalidate
-        for i in invalidations:
-            invalidate((i,))
+            self.db_pool.simple_upsert_many_txn(
+                txn,
+                table="server_keys_json",
+                key_names=("server_name", "key_id", "from_server"),
+                key_values=[
+                    (server_name, key_id, from_server) for key_id in verify_keys
+                ],
+                value_names=(
+                    "ts_added_ms",
+                    "ts_valid_until_ms",
+                    "key_json",
+                ),
+                value_values=[
+                    (
+                        ts_added_ms,
+                        fetch_result.valid_until_ts,
+                        db_binary_type(key_json_bytes),
+                    )
+                    for fetch_result in verify_keys.values()
+                ],
+            )
 
-    async def store_server_keys_json(
-        self,
-        server_name: str,
-        key_id: str,
-        from_server: str,
-        ts_now_ms: int,
-        ts_expires_ms: int,
-        key_json_bytes: bytes,
-    ) -> None:
-        """Stores the JSON bytes for a set of keys from a server
-        The JSON should be signed by the originating server, the intermediate
-        server, and by this server. Updates the value for the
-        (server_name, key_id, from_server) triplet if one already existed.
-        Args:
-            server_name: The name of the server.
-            key_id: The identifier of the key this JSON is for.
-            from_server: The server this JSON was fetched from.
-            ts_now_ms: The time now in milliseconds.
-            ts_valid_until_ms: The time when this json stops being valid.
-            key_json_bytes: The encoded JSON.
-        """
-        await self.db_pool.simple_upsert(
-            table="server_keys_json",
-            keyvalues={
-                "server_name": server_name,
-                "key_id": key_id,
-                "from_server": from_server,
-            },
-            values={
-                "server_name": server_name,
-                "key_id": key_id,
-                "from_server": from_server,
-                "ts_added_ms": ts_now_ms,
-                "ts_valid_until_ms": ts_expires_ms,
-                "key_json": db_binary_type(key_json_bytes),
-            },
-            desc="store_server_keys_json",
-        )
+            # invalidate takes a tuple corresponding to the params of
+            # _get_server_keys_json. _get_server_keys_json only takes one
+            # param, which is itself the 2-tuple (server_name, key_id).
+            for key_id in verify_keys:
+                self._invalidate_cache_and_stream(
+                    txn, self._get_server_keys_json, ((server_name, key_id),)
+                )
+                self._invalidate_cache_and_stream(
+                    txn, self.get_server_key_json_for_remote, (server_name, key_id)
+                )
 
-        # invalidate takes a tuple corresponding to the params of
-        # _get_server_keys_json. _get_server_keys_json only takes one
-        # param, which is itself the 2-tuple (server_name, key_id).
-        await self.invalidate_cache_and_stream(
-            "_get_server_keys_json", ((server_name, key_id),)
-        )
-        await self.invalidate_cache_and_stream(
-            "get_server_key_json_for_remote", (server_name, key_id)
+        await self.db_pool.runInteraction(
+            "store_server_keys_response", store_server_keys_response_txn
         )
 
     @cached()