summary refs log tree commit diff
path: root/tests/storage/test_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_keys.py')
-rw-r--r--tests/storage/test_keys.py44
1 files changed, 30 insertions, 14 deletions
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py
index 71ad7aee32..e07ff01201 100644
--- a/tests/storage/test_keys.py
+++ b/tests/storage/test_keys.py
@@ -17,6 +17,8 @@ import signedjson.key
 
 from twisted.internet.defer import Deferred
 
+from synapse.storage.keys import FetchKeyResult
+
 import tests.unittest
 
 KEY_1 = signedjson.key.decode_verify_key_base64(
@@ -37,8 +39,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
             "from_server",
             10,
             [
-                ("server1", key_id_1, KEY_1),
-                ("server1", key_id_2, KEY_2),
+                ("server1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                ("server1", key_id_2, FetchKeyResult(KEY_2, 200)),
             ],
         )
         self.get_success(d)
@@ -50,13 +52,15 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
 
         self.assertEqual(len(res.keys()), 3)
         res1 = res[("server1", key_id_1)]
-        self.assertEqual(res1, KEY_1)
-        self.assertEqual(res1.version, "key1")
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.verify_key.version, "key1")
+        self.assertEqual(res1.valid_until_ts, 100)
 
         res2 = res[("server1", key_id_2)]
-        self.assertEqual(res2, KEY_2)
+        self.assertEqual(res2.verify_key, KEY_2)
         # version comes from the ID it was stored with
-        self.assertEqual(res2.version, "KEY_ID_2")
+        self.assertEqual(res2.verify_key.version, "KEY_ID_2")
+        self.assertEqual(res2.valid_until_ts, 200)
 
         # non-existent result gives None
         self.assertIsNone(res[("server1", "ed25519:key3")])
@@ -73,8 +77,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
             "from_server",
             0,
             [
-                ("srv1", key_id_1, KEY_1),
-                ("srv1", key_id_2, KEY_2),
+                ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)),
+                ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)),
             ],
         )
         self.get_success(d)
@@ -82,26 +86,38 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase):
         d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
         res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
-        self.assertEqual(res[("srv1", key_id_2)], KEY_2)
+
+        res1 = res[("srv1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("srv1", key_id_2)]
+        self.assertEqual(res2.verify_key, KEY_2)
+        self.assertEqual(res2.valid_until_ts, 200)
 
         # we should be able to look up the same thing again without a db hit
         res = store.get_server_verify_keys([("srv1", key_id_1)])
         if isinstance(res, Deferred):
             res = self.successResultOf(res)
         self.assertEqual(len(res.keys()), 1)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
+        self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1)
 
         new_key_2 = signedjson.key.get_verify_key(
             signedjson.key.generate_signing_key("key2")
         )
         d = store.store_server_verify_keys(
-            "from_server", 10, [("srv1", key_id_2, new_key_2)]
+            "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))]
         )
         self.get_success(d)
 
         d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)])
         res = self.get_success(d)
         self.assertEqual(len(res.keys()), 2)
-        self.assertEqual(res[("srv1", key_id_1)], KEY_1)
-        self.assertEqual(res[("srv1", key_id_2)], new_key_2)
+
+        res1 = res[("srv1", key_id_1)]
+        self.assertEqual(res1.verify_key, KEY_1)
+        self.assertEqual(res1.valid_until_ts, 100)
+
+        res2 = res[("srv1", key_id_2)]
+        self.assertEqual(res2.verify_key, new_key_2)
+        self.assertEqual(res2.valid_until_ts, 300)