summary refs log tree commit diff
path: root/synapse/storage/keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/keys.py')
-rw-r--r--synapse/storage/keys.py31
1 files changed, 21 insertions, 10 deletions
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 3c5f52009b..5300720dbb 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -19,6 +19,7 @@ import logging
 
 import six
 
+import attr
 from signedjson.key import decode_verify_key_bytes
 
 from synapse.util import batch_iter
@@ -36,6 +37,12 @@ else:
     db_binary_type = memoryview
 
 
+@attr.s(slots=True, frozen=True)
+class FetchKeyResult(object):
+    verify_key = attr.ib()  # VerifyKey: the key itself
+    valid_until_ts = attr.ib()  # int: how long we can use this key for
+
+
 class KeyStore(SQLBaseStore):
     """Persistence for signature verification keys
     """
@@ -54,8 +61,8 @@ class KeyStore(SQLBaseStore):
                 iterable of (server_name, key-id) tuples to fetch keys for
 
         Returns:
-            Deferred: resolves to dict[Tuple[str, str], VerifyKey|None]:
-                map from (server_name, key_id) -> VerifyKey, or None if the key is
+            Deferred: resolves to dict[Tuple[str, str], FetchKeyResult|None]:
+                map from (server_name, key_id) -> FetchKeyResult, or None if the key is
                 unknown
         """
         keys = {}
@@ -65,17 +72,19 @@ class KeyStore(SQLBaseStore):
 
             # batch_iter always returns tuples so it's safe to do len(batch)
             sql = (
-                "SELECT server_name, key_id, verify_key FROM server_signature_keys "
-                "WHERE 1=0"
+                "SELECT server_name, key_id, verify_key, ts_valid_until_ms "
+                "FROM server_signature_keys WHERE 1=0"
             ) + " OR (server_name=? AND key_id=?)" * len(batch)
 
             txn.execute(sql, tuple(itertools.chain.from_iterable(batch)))
 
             for row in txn:
-                server_name, key_id, key_bytes = row
-                keys[(server_name, key_id)] = decode_verify_key_bytes(
-                    key_id, bytes(key_bytes)
+                server_name, key_id, key_bytes, ts_valid_until_ms = row
+                res = FetchKeyResult(
+                    verify_key=decode_verify_key_bytes(key_id, bytes(key_bytes)),
+                    valid_until_ts=ts_valid_until_ms,
                 )
+                keys[(server_name, key_id)] = res
 
         def _txn(txn):
             for batch in batch_iter(server_name_and_key_ids, 50):
@@ -89,20 +98,21 @@ class KeyStore(SQLBaseStore):
         Args:
             from_server (str): Where the verification keys were looked up
             ts_added_ms (int): The time to record that the key was added
-            verify_keys (iterable[tuple[str, str, nacl.signing.VerifyKey]]):
+            verify_keys (iterable[tuple[str, str, FetchKeyResult]]):
                 keys to be stored. Each entry is a triplet of
                 (server_name, key_id, key).
         """
         key_values = []
         value_values = []
         invalidations = []
-        for server_name, key_id, verify_key in verify_keys:
+        for server_name, key_id, fetch_result in verify_keys:
             key_values.append((server_name, key_id))
             value_values.append(
                 (
                     from_server,
                     ts_added_ms,
-                    db_binary_type(verify_key.encode()),
+                    fetch_result.valid_until_ts,
+                    db_binary_type(fetch_result.verify_key.encode()),
                 )
             )
             # invalidate takes a tuple corresponding to the params of
@@ -125,6 +135,7 @@ class KeyStore(SQLBaseStore):
             value_names=(
                 "from_server",
                 "ts_added_ms",
+                "ts_valid_until_ms",
                 "verify_key",
             ),
             value_values=value_values,