summary refs log tree commit diff
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-04-03 18:10:24 +0100
committerRichard van der Hoff <richard@matrix.org>2019-05-23 11:52:22 +0100
commitb75537beaf841089f9f07c9dbed04a7a420a8b1f (patch)
tree81ceb1b2de43ed5e3195976337b9445cae9ca4a0
parentSimplify process_v2_response (#5236) (diff)
downloadsynapse-b75537beaf841089f9f07c9dbed04a7a420a8b1f.tar.xz
Store key validity time in the storage layer
This is a first step to checking that the key is valid at the required moment.

The idea here is that, rather than passing VerifyKey objects in and out of the
storage layer, we instead pass FetchKeyResult objects, which simply wrap the
VerifyKey and add a valid_until_ts field.
-rw-r--r--changelog.d/5237.misc1
-rw-r--r--synapse/crypto/keyring.py47
-rw-r--r--synapse/storage/keys.py31
-rw-r--r--synapse/storage/schema/delta/54/add_validity_to_server_keys.sql23
-rw-r--r--tests/crypto/test_keyring.py22
-rw-r--r--tests/storage/test_keys.py44
6 files changed, 122 insertions, 46 deletions
diff --git a/changelog.d/5237.misc b/changelog.d/5237.misc
new file mode 100644
index 0000000000..f4fe3b821b
--- /dev/null
+++ b/changelog.d/5237.misc
@@ -0,0 +1 @@
+Store key validity time in the storage layer.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 9d629b2238..14a27288fd 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -20,7 +20,6 @@ from collections import namedtuple
 from six import raise_from
 from six.moves import urllib
 
-import nacl.signing
 from signedjson.key import (
     decode_verify_key_bytes,
     encode_verify_key_base64,
@@ -43,6 +42,7 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext, unwrapFirstError
 from synapse.util.logcontext import (
     LoggingContext,
@@ -307,11 +307,15 @@ class Keyring(object):
                         # complete this VerifyKeyRequest.
                         result_keys = results.get(server_name, {})
                         for key_id in verify_request.key_ids:
-                            key = result_keys.get(key_id)
-                            if key:
+                            fetch_key_result = result_keys.get(key_id)
+                            if fetch_key_result:
                                 with PreserveLoggingContext():
                                     verify_request.deferred.callback(
-                                        (server_name, key_id, key)
+                                        (
+                                            server_name,
+                                            key_id,
+                                            fetch_key_result.verify_key,
+                                        )
                                     )
                                 break
                         else:
@@ -348,12 +352,12 @@ class Keyring(object):
     def get_keys_from_store(self, server_name_and_key_ids):
         """
         Args:
-            server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
+            server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
                 list of (server_name, iterable[key_id]) tuples to fetch keys for
 
         Returns:
-            Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
-                server_name -> key_id -> VerifyKey
+            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+                map from server_name -> key_id -> FetchKeyResult
         """
         keys_to_fetch = (
             (server_name, key_id)
@@ -430,6 +434,18 @@ class Keyring(object):
     def get_server_verify_key_v2_indirect(
         self, server_names_and_key_ids, perspective_name, perspective_keys
     ):
+        """
+        Args:
+            server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
+                list of (server_name, iterable[key_id]) tuples to fetch keys for
+            perspective_name (str): name of the notary server to query for the keys
+            perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
+                notary server
+
+        Returns:
+            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
+                from server_name -> key_id -> FetchKeyResult
+        """
         # TODO(mark): Set the minimum_valid_until_ts to that needed by
         # the events being validated or the current time if validating
         # an incoming request.
@@ -506,7 +522,7 @@ class Keyring(object):
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_direct(self, server_name, key_ids):
-        keys = {}  # type: dict[str, nacl.signing.VerifyKey]
+        keys = {}  # type: dict[str, FetchKeyResult]
 
         for requested_key_id in key_ids:
             if requested_key_id in keys:
@@ -583,9 +599,9 @@ class Keyring(object):
                 actually in the response
 
         Returns:
-            Deferred[dict[str, nacl.signing.VerifyKey]]:
-                map from key_id to key object
+            Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
         """
+        ts_valid_until_ms = response_json[u"valid_until_ts"]
 
         # start by extracting the keys from the response, since they may be required
         # to validate the signature on the response.
@@ -595,7 +611,9 @@ class Keyring(object):
                 key_base64 = key_data["key"]
                 key_bytes = decode_base64(key_base64)
                 verify_key = decode_verify_key_bytes(key_id, key_bytes)
-                verify_keys[key_id] = verify_key
+                verify_keys[key_id] = FetchKeyResult(
+                    verify_key=verify_key, valid_until_ts=ts_valid_until_ms
+                )
 
         # TODO: improve this signature checking
         server_name = response_json["server_name"]
@@ -606,7 +624,7 @@ class Keyring(object):
                 )
 
             verify_signed_json(
-                response_json, server_name, verify_keys[key_id]
+                response_json, server_name, verify_keys[key_id].verify_key
             )
 
         for key_id, key_data in response_json["old_verify_keys"].items():
@@ -614,7 +632,9 @@ class Keyring(object):
                 key_base64 = key_data["key"]
                 key_bytes = decode_base64(key_base64)
                 verify_key = decode_verify_key_bytes(key_id, key_bytes)
-                verify_keys[key_id] = verify_key
+                verify_keys[key_id] = FetchKeyResult(
+                    verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
+                )
 
         # re-sign the json with our own key, so that it is ready if we are asked to
         # give it out as a notary server
@@ -623,7 +643,6 @@ class Keyring(object):
         )
 
         signed_key_json_bytes = encode_canonical_json(signed_key_json)
-        ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
 
         # for reasons I don't quite understand, we store this json for the key ids we
         # requested, as well as those we got.
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 3c5f52009b..5300720dbb 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -19,6 +19,7 @@ import logging
 
 import six
 
+import attr
 from signedjson.key import decode_verify_key_bytes
 
 from synapse.util import batch_iter
@@ -36,6 +37,12 @@ else:
     db_binary_type = memoryview
 
 
+@attr.s(slots=True, frozen=True)
+class FetchKeyResult(object):
+    verify_key = attr.ib()  # VerifyKey: the key itself
+    valid_until_ts = attr.ib()  # int: how long we can use this key for
+
+
 class KeyStore(SQLBaseStore):
     """Persistence for signature verification keys
     """
@@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore):
                 iterable of (server_name, key-id) tuples to fetch keys for
 
         Returns:
-            Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
-                map from (server_name, key_id) -> VerifyKey, or None if the key is
+            Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
+                map from (server_name, key_id) -> FetchKeyResult, or None if the key is
                 unknown
         """
         keys = {}
@@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore):
 
             # batch_iter always returns tuples so it's safe to do len(batch)
             sql = (
-                "SELECT server_name, key_id, verify_key FROM server_signature_keys "
-                "WHERE 1=0"
+                "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
+                "FROM server_signature_keys WHERE 1=0"
             ) + " OR (server_name=? AND key_id=?)" * len(batch)
 
             txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
 
             for row in txn:
-                server_name, key_id, key_bytes = row
-                keys[(server_name, key_id)] = decode_verify_key_bytes(
-                    key_id, bytes(key_bytes)
+                server_name, key_id, key_bytes, ts_valid_until_ms = row
+                res = FetchKeyResult(
+                    verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
+                    valid_until_ts=ts_valid_until_ms,
                 )
+                keys[(server_name, key_id)] = res
 
         def _txn(txn):
             for batch in batch_iter(server_name_and_key_ids, 50):
@@ -89,20 +98,21 @@ class KeyStore(SQLBaseStore):
         Args:
             from_server (str): Where the verification keys were looked up
             ts_added_ms (int): The time to record that the key was added
-            verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]):
+            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
                 keys to be stored. Each entry is a triplet of
                 (server_name, key_id, key).
         """
         key_values = []
         value_values = []
         invalidations = []
-        for server_name, key_id, verify_key in verify_keys:
+        for server_name, key_id, fetch_result in verify_keys:
             key_values.append((server_name, key_id))
             value_values.append(
                 (
                     from_server,
                     ts_added_ms,
-                    db_binary_type(verify_key.encode()),
+                    fetch_result.valid_until_ts,
+                    db_binary_type(fetch_result.verify_key.encode()),
                 )
             )
             # invalidate takes a tuple corresponding to the params of
@@ -125,6 +135,7 @@ class KeyStore(SQLBaseStore):
             value_names=(
                 "from_server",
                 "ts_added_ms",
+                "ts_valid_until_ms",
                 "verify_key",
             ),
             value_values=value_values,
diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
new file mode 100644
index 0000000000..c01aa9d2d9
--- /dev/null
+++ b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
@@ -0,0 +1,23 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* When we can use this key until, before we have to refresh it. */
+ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT;
+
+UPDATE server_signature_keys SET ts_valid_until_ms = (
+    SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE
+        skj.server_name = server_signature_keys.server_name AND
+        skj.key_id = server_signature_keys.key_id
+);
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index bcffe53a91..83de32b05d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 from synapse.api.errors import SynapseError
 from synapse.crypto import keyring
 from synapse.crypto.keyring import KeyLookupError
+from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext
 from synapse.util.logcontext import LoggingContext
 
@@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 (
                     "server9",
                     key1_id,
-                    signedjson.key.get_verify_key(key1),
+                    FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
                 ),
             ],
         )
@@ -251,9 +252,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
         keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
         k = keys[SERVER_NAME][testverifykey_id]
-        self.assertEqual(k, testverifykey)
-        self.assertEqual(k.alg, "ed25519")
-        self.assertEqual(k.version, "ver1")
+        self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+        self.assertEqual(k.verify_key, testverifykey)
+        self.assertEqual(k.verify_key.alg, "ed25519")
+        self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
         lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -321,9 +323,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
         self.assertIn(SERVER_NAME, keys)
         k = keys[SERVER_NAME][testverifykey_id]
-        self.assertEqual(k, testverifykey)
-        self.assertEqual(k.alg, "ed25519")
-        self.assertEqual(k.version, "ver1")
+        self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+        self.assertEqual(k.verify_key, testverifykey)
+        self.assertEqual(k.verify_key.alg, "ed25519")
+        self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
         lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -346,7 +349,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
 @defer.inlineCallbacks
 def run_in_context(f, *args, **kwargs):
-    with LoggingContext("testctx"):
+    with LoggingContext("testctx") as ctx:
+        # we set the "request" prop to make it easier to follow what's going on in the
+        # logs.
+        ctx.request = "testctx"
         rv = yield f(*args, **kwargs)
     defer.returnValue(rv)
 
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 71ad7aee32..e07ff01201 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -17,6 +17,8 @@ import signedjson.key
 
 from twisted.internet.defer import Deferred
 
+from synapse.storage.keys import FetchKeyResult
+
 import tests.unittest
 
 KEY_1 = signedjson.key.decode_verify_key_base64(
@@ -37,8 +39,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
             "from_server",
             10,
             [
-                ("server1", key_id_1, KEY_1),
-                ("server1", key_id_2, KEY_2),
+                ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
             ],
         )
         self.get_success(d)
@@ -50,13 +52,15 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
 
         self.assertEqual(len(res.keys()), 3)
         res1 = res[("server1", key_id_1)]
-        self.assertEqual(res1, KEY_1)
-        self.assertEqual(res1.version, "key1")
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.verify_key.version, "key1")
+        self.assertEqual(res1.valid_until_ts, 100)
 
         res2 = res[("server1", key_id_2)]
-        self.assertEqual(res2, KEY_2)
+        self.assertEqual(res2.verify_key, KEY_2)
         # version comes from the ID it was stored with
-        self.assertEqual(res2.version, "KEY_ID_2")
+        self.assertEqual(res2.verify_key.version, "KEY_ID_2")
+        self.assertEqual(res2.valid_until_ts, 200)
 
         # non-existent result gives None
         self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -73,8 +77,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
             "from_server",
             0,
             [
-                ("srv1", key_id_1, KEY_1),
-                ("srv1", key_id_2, KEY_2),
+                ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
             ],
         )
         self.get_success(d)
@@ -82,26 +86,38 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
         res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
-        self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+        res1 = res[("srv1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("srv1", key_id_2)]
+        self.assertEqual(res2.verify_key, KEY_2)
+        self.assertEqual(res2.valid_until_ts, 200)
 
         # we should be able to look up the same thing again without a db hit
         res = store.get_server_verify_keys([("srv1", key_id_1)])
         if isinstance(res, Deferred):
             res = self.successResultOf(res)
         self.assertEqual(len(res.keys()), 1)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+        self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
 
         new_key_2 = signedjson.key.get_verify_key(
             signedjson.key.generate_signing_key("key2")
         )
         d = store.store_server_verify_keys(
-            "from_server", 10, [("srv1", key_id_2, new_key_2)]
+            "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
         )
         self.get_success(d)
 
         d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
         res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
-        self.assertEqual(res[("srv1", key_id_2)], new_key_2)
+
+        res1 = res[("srv1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("srv1", key_id_2)]
+        self.assertEqual(res2.verify_key, new_key_2)
+        self.assertEqual(res2.valid_until_ts, 300)