summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/5237.misc1
-rw-r--r--synapse/crypto/keyring.py47
-rw-r--r--synapse/storage/keys.py31
-rw-r--r--synapse/storage/schema/delta/54/add_validity_to_server_keys.sql23
-rw-r--r--tests/crypto/test_keyring.py22
-rw-r--r--tests/storage/test_keys.py44
6 files changed, 122 insertions, 46 deletions
diff --git a/changelog.d/5237.misc b/changelog.d/5237.misc
new file mode 100644
index 0000000000..f4fe3b821b
--- /dev/null
+++ b/changelog.d/5237.misc
@@ -0,0 +1 @@
+Store key validity time in the storage layer.
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 9d629b2238..14a27288fd 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -20,7 +20,6 @@ from collections import namedtuple
 from six import raise_from
 from six.moves import urllib
 
-import nacl.signing
 from signedjson.key import (
     decode_verify_key_bytes,
     encode_verify_key_base64,
@@ -43,6 +42,7 @@ from synapse.api.errors import (
     RequestSendFailed,
     SynapseError,
 )
+from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext, unwrapFirstError
 from synapse.util.logcontext import (
     LoggingContext,
@@ -307,11 +307,15 @@ class Keyring(object):
                         # complete this VerifyKeyRequest.
                         result_keys = results.get(server_name, {})
                         for key_id in verify_request.key_ids:
-                            key = result_keys.get(key_id)
-                            if key:
+                            fetch_key_result = result_keys.get(key_id)
+                            if fetch_key_result:
                                 with PreserveLoggingContext():
                                     verify_request.deferred.callback(
-                                        (server_name, key_id, key)
+                                        (
+                                            server_name,
+                                            key_id,
+                                            fetch_key_result.verify_key,
+                                        )
                                     )
                                 break
                         else:
@@ -348,12 +352,12 @@ class Keyring(object):
     def get_keys_from_store(self, server_name_and_key_ids):
         """
         Args:
-            server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
+            server_name_and_key_ids (iterable[Tuple[str, iterable[str]]]):
                 list of (server_name, iterable[key_id]) tuples to fetch keys for
 
         Returns:
-            Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
-                server_name -> key_id -> VerifyKey
+            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult|None]]]:
+                map from server_name -> key_id -> FetchKeyResult
         """
         keys_to_fetch = (
             (server_name, key_id)
@@ -430,6 +434,18 @@ class Keyring(object):
     def get_server_verify_key_v2_indirect(
         self, server_names_and_key_ids, perspective_name, perspective_keys
     ):
+        """
+        Args:
+            server_names_and_key_ids (iterable[Tuple[str, iterable[str]]]):
+                list of (server_name, iterable[key_id]) tuples to fetch keys for
+            perspective_name (str): name of the notary server to query for the keys
+            perspective_keys (dict[str, VerifyKey]): map of key_id->key for the
+                notary server
+
+        Returns:
+            Deferred[dict[str, dict[str, synapse.storage.keys.FetchKeyResult]]]: map
+                from server_name -> key_id -> FetchKeyResult
+        """
         # TODO(mark): Set the minimum_valid_until_ts to that needed by
         # the events being validated or the current time if validating
         # an incoming request.
@@ -506,7 +522,7 @@ class Keyring(object):
 
     @defer.inlineCallbacks
     def get_server_verify_key_v2_direct(self, server_name, key_ids):
-        keys = {}  # type: dict[str, nacl.signing.VerifyKey]
+        keys = {}  # type: dict[str, FetchKeyResult]
 
         for requested_key_id in key_ids:
             if requested_key_id in keys:
@@ -583,9 +599,9 @@ class Keyring(object):
                 actually in the response
 
         Returns:
-            Deferred[dict[str, nacl.signing.VerifyKey]]:
-                map from key_id to key object
+            Deferred[dict[str, FetchKeyResult]]: map from key_id to result object
         """
+        ts_valid_until_ms = response_json[u"valid_until_ts"]
 
         # start by extracting the keys from the response, since they may be required
         # to validate the signature on the response.
@@ -595,7 +611,9 @@ class Keyring(object):
                 key_base64 = key_data["key"]
                 key_bytes = decode_base64(key_base64)
                 verify_key = decode_verify_key_bytes(key_id, key_bytes)
-                verify_keys[key_id] = verify_key
+                verify_keys[key_id] = FetchKeyResult(
+                    verify_key=verify_key, valid_until_ts=ts_valid_until_ms
+                )
 
         # TODO: improve this signature checking
         server_name = response_json["server_name"]
@@ -606,7 +624,7 @@ class Keyring(object):
                 )
 
             verify_signed_json(
-                response_json, server_name, verify_keys[key_id]
+                response_json, server_name, verify_keys[key_id].verify_key
             )
 
         for key_id, key_data in response_json["old_verify_keys"].items():
@@ -614,7 +632,9 @@ class Keyring(object):
                 key_base64 = key_data["key"]
                 key_bytes = decode_base64(key_base64)
                 verify_key = decode_verify_key_bytes(key_id, key_bytes)
-                verify_keys[key_id] = verify_key
+                verify_keys[key_id] = FetchKeyResult(
+                    verify_key=verify_key, valid_until_ts=key_data["expired_ts"]
+                )
 
         # re-sign the json with our own key, so that it is ready if we are asked to
         # give it out as a notary server
@@ -623,7 +643,6 @@ class Keyring(object):
         )
 
         signed_key_json_bytes = encode_canonical_json(signed_key_json)
-        ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
 
         # for reasons I don't quite understand, we store this json for the key ids we
         # requested, as well as those we got.
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 3c5f52009b..5300720dbb 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -19,6 +19,7 @@ import logging
 
 import six
 
+import attr
 from signedjson.key import decode_verify_key_bytes
 
 from synapse.util import batch_iter
@@ -36,6 +37,12 @@ else:
     db_binary_type = memoryview
 
 
+@attr.s(slots=True, frozen=True)
+class FetchKeyResult(object):
+    verify_key = attr.ib()  # VerifyKey: the key itself
+    valid_until_ts = attr.ib()  # int: how long we can use this key for
+
+
 class KeyStore(SQLBaseStore):
     """Persistence for signature verification keys
     """
@@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore):
                 iterable of (server_name, key-id) tuples to fetch keys for
 
         Returns:
-            Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
-                map from (server_name, key_id) -> VerifyKey, or None if the key is
+            Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
+                map from (server_name, key_id) -> FetchKeyResult, or None if the key is
                 unknown
         """
         keys = {}
@@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore):
 
             # batch_iter always returns tuples so it's safe to do len(batch)
             sql = (
-                "SELECT server_name, key_id, verify_key FROM server_signature_keys "
-                "WHERE 1=0"
+                "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
+                "FROM server_signature_keys WHERE 1=0"
             ) + " OR (server_name=? AND key_id=?)" * len(batch)
 
             txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
 
             for row in txn:
-                server_name, key_id, key_bytes = row
-                keys[(server_name, key_id)] = decode_verify_key_bytes(
-                    key_id, bytes(key_bytes)
+                server_name, key_id, key_bytes, ts_valid_until_ms = row
+                res = FetchKeyResult(
+                    verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
+                    valid_until_ts=ts_valid_until_ms,
                 )
+                keys[(server_name, key_id)] = res
 
         def _txn(txn):
             for batch in batch_iter(server_name_and_key_ids, 50):
@@ -89,20 +98,21 @@ class KeyStore(SQLBaseStore):
         Args:
             from_server (str): Where the verification keys were looked up
             ts_added_ms (int): The time to record that the key was added
-            verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]):
+            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
                 keys to be stored. Each entry is a triplet of
                 (server_name, key_id, key).
         """
         key_values = []
         value_values = []
         invalidations = []
-        for server_name, key_id, verify_key in verify_keys:
+        for server_name, key_id, fetch_result in verify_keys:
             key_values.append((server_name, key_id))
             value_values.append(
                 (
                     from_server,
                     ts_added_ms,
-                    db_binary_type(verify_key.encode()),
+                    fetch_result.valid_until_ts,
+                    db_binary_type(fetch_result.verify_key.encode()),
                 )
             )
             # invalidate takes a tuple corresponding to the params of
@@ -125,6 +135,7 @@ class KeyStore(SQLBaseStore):
             value_names=(
                 "from_server",
                 "ts_added_ms",
+                "ts_valid_until_ms",
                 "verify_key",
             ),
             value_values=value_values,
diff --git a/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
new file mode 100644
index 0000000000..c01aa9d2d9
--- /dev/null
+++ b/synapse/storage/schema/delta/54/add_validity_to_server_keys.sql
@@ -0,0 +1,23 @@
+/* Copyright 2019 New Vector Ltd
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/* When we can use this key until, before we have to refresh it. */
+ALTER TABLE server_signature_keys ADD COLUMN ts_valid_until_ms BIGINT;
+
+UPDATE server_signature_keys SET ts_valid_until_ms = (
+    SELECT MAX(ts_valid_until_ms) FROM server_keys_json skj WHERE
+        skj.server_name = server_signature_keys.server_name AND
+        skj.key_id = server_signature_keys.key_id
+);
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index bcffe53a91..83de32b05d 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -25,6 +25,7 @@ from twisted.internet import defer
 from synapse.api.errors import SynapseError
 from synapse.crypto import keyring
 from synapse.crypto.keyring import KeyLookupError
+from synapse.storage.keys import FetchKeyResult
 from synapse.util import logcontext
 from synapse.util.logcontext import LoggingContext
 
@@ -201,7 +202,7 @@ class KeyringTestCase(unittest.HomeserverTestCase):
                 (
                     "server9",
                     key1_id,
-                    signedjson.key.get_verify_key(key1),
+                    FetchKeyResult(signedjson.key.get_verify_key(key1), 1000),
                 ),
             ],
         )
@@ -251,9 +252,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         server_name_and_key_ids = [(SERVER_NAME, ("key1",))]
         keys = self.get_success(kr.get_keys_from_server(server_name_and_key_ids))
         k = keys[SERVER_NAME][testverifykey_id]
-        self.assertEqual(k, testverifykey)
-        self.assertEqual(k.alg, "ed25519")
-        self.assertEqual(k.version, "ver1")
+        self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+        self.assertEqual(k.verify_key, testverifykey)
+        self.assertEqual(k.verify_key.alg, "ed25519")
+        self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
         lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -321,9 +323,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
         keys = self.get_success(kr.get_keys_from_perspectives(server_name_and_key_ids))
         self.assertIn(SERVER_NAME, keys)
         k = keys[SERVER_NAME][testverifykey_id]
-        self.assertEqual(k, testverifykey)
-        self.assertEqual(k.alg, "ed25519")
-        self.assertEqual(k.version, "ver1")
+        self.assertEqual(k.valid_until_ts, VALID_UNTIL_TS)
+        self.assertEqual(k.verify_key, testverifykey)
+        self.assertEqual(k.verify_key.alg, "ed25519")
+        self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
         lookup_triplet = (SERVER_NAME, testverifykey_id, None)
@@ -346,7 +349,10 @@ class KeyringTestCase(unittest.HomeserverTestCase):
 
 @defer.inlineCallbacks
 def run_in_context(f, *args, **kwargs):
-    with LoggingContext("testctx"):
+    with LoggingContext("testctx") as ctx:
+        # we set the "request" prop to make it easier to follow what's going on in the
+        # logs.
+        ctx.request = "testctx"
         rv = yield f(*args, **kwargs)
     defer.returnValue(rv)
 
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 71ad7aee32..e07ff01201 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -17,6 +17,8 @@ import signedjson.key
 
 from twisted.internet.defer import Deferred
 
+from synapse.storage.keys import FetchKeyResult
+
 import tests.unittest
 
 KEY_1 = signedjson.key.decode_verify_key_base64(
@@ -37,8 +39,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
             "from_server",
             10,
             [
-                ("server1", key_id_1, KEY_1),
-                ("server1", key_id_2, KEY_2),
+                ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
             ],
         )
         self.get_success(d)
@@ -50,13 +52,15 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
 
         self.assertEqual(len(res.keys()), 3)
         res1 = res[("server1", key_id_1)]
-        self.assertEqual(res1, KEY_1)
-        self.assertEqual(res1.version, "key1")
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.verify_key.version, "key1")
+        self.assertEqual(res1.valid_until_ts, 100)
 
         res2 = res[("server1", key_id_2)]
-        self.assertEqual(res2, KEY_2)
+        self.assertEqual(res2.verify_key, KEY_2)
         # version comes from the ID it was stored with
-        self.assertEqual(res2.version, "KEY_ID_2")
+        self.assertEqual(res2.verify_key.version, "KEY_ID_2")
+        self.assertEqual(res2.valid_until_ts, 200)
 
         # non-existent result gives None
         self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -73,8 +77,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
             "from_server",
             0,
             [
-                ("srv1", key_id_1, KEY_1),
-                ("srv1", key_id_2, KEY_2),
+                ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
             ],
         )
         self.get_success(d)
@@ -82,26 +86,38 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
         res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
-        self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+        res1 = res[("srv1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("srv1", key_id_2)]
+        self.assertEqual(res2.verify_key, KEY_2)
+        self.assertEqual(res2.valid_until_ts, 200)
 
         # we should be able to look up the same thing again without a db hit
         res = store.get_server_verify_keys([("srv1", key_id_1)])
         if isinstance(res, Deferred):
             res = self.successResultOf(res)
         self.assertEqual(len(res.keys()), 1)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+        self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
 
         new_key_2 = signedjson.key.get_verify_key(
             signedjson.key.generate_signing_key("key2")
         )
         d = store.store_server_verify_keys(
-            "from_server", 10, [("srv1", key_id_2, new_key_2)]
+            "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
         )
         self.get_success(d)
 
         d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
         res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
-        self.assertEqual(res[("srv1", key_id_2)], new_key_2)
+
+        res1 = res[("srv1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("srv1", key_id_2)]
+        self.assertEqual(res2.verify_key, new_key_2)
+        self.assertEqual(res2.valid_until_ts, 300)