diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 71ad7aee32..e07ff01201 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -17,6 +17,8 @@ import signedjson.key
from twisted.internet.defer import Deferred
+from synapse.storage.keys import FetchKeyResult
+
import tests.unittest
KEY_1 = signedjson.key.decode_verify_key_base64(
@@ -37,8 +39,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
"from_server",
10,
[
- ("server1", key_id_1, KEY_1),
- ("server1", key_id_2, KEY_2),
+ ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
],
)
self.get_success(d)
@@ -50,13 +52,15 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
self.assertEqual(len(res.keys()), 3)
res1 = res[("server1", key_id_1)]
- self.assertEqual(res1, KEY_1)
- self.assertEqual(res1.version, "key1")
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.verify_key.version, "key1")
+ self.assertEqual(res1.valid_until_ts, 100)
res2 = res[("server1", key_id_2)]
- self.assertEqual(res2, KEY_2)
+ self.assertEqual(res2.verify_key, KEY_2)
# version comes from the ID it was stored with
- self.assertEqual(res2.version, "KEY_ID_2")
+ self.assertEqual(res2.verify_key.version, "KEY_ID_2")
+ self.assertEqual(res2.valid_until_ts, 200)
# non-existent result gives None
self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -73,8 +77,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
"from_server",
0,
[
- ("srv1", key_id_1, KEY_1),
- ("srv1", key_id_2, KEY_2),
+ ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+ ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
],
)
self.get_success(d)
@@ -82,26 +86,38 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
- self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+ res1 = res[("srv1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("srv1", key_id_2)]
+ self.assertEqual(res2.verify_key, KEY_2)
+ self.assertEqual(res2.valid_until_ts, 200)
# we should be able to look up the same thing again without a db hit
res = store.get_server_verify_keys([("srv1", key_id_1)])
if isinstance(res, Deferred):
res = self.successResultOf(res)
self.assertEqual(len(res.keys()), 1)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+ self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
new_key_2 = signedjson.key.get_verify_key(
signedjson.key.generate_signing_key("key2")
)
d = store.store_server_verify_keys(
- "from_server", 10, [("srv1", key_id_2, new_key_2)]
+ "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
)
self.get_success(d)
d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
res = self.get_success(d)
self.assertEqual(len(res.keys()), 2)
- self.assertEqual(res[("srv1", key_id_1)], KEY_1)
- self.assertEqual(res[("srv1", key_id_2)], new_key_2)
+
+ res1 = res[("srv1", key_id_1)]
+ self.assertEqual(res1.verify_key, KEY_1)
+ self.assertEqual(res1.valid_until_ts, 100)
+
+ res2 = res[("srv1", key_id_2)]
+ self.assertEqual(res2.verify_key, new_key_2)
+ self.assertEqual(res2.valid_until_ts, 300)
|