diff options
Diffstat (limited to 'tests/storage/test_keys.py')
-rw-r--r-- | tests/storage/test_keys.py | 70 |
1 files changed, 51 insertions, 19 deletions
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 6bfaa00fe9..e07ff01201 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -17,6 +17,8 @@ import signedjson.key from twisted.internet.defer import Deferred +from synapse.storage.keys import FetchKeyResult + import tests.unittest KEY_1 = signedjson.key.decode_verify_key_base64( @@ -31,23 +33,34 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): def test_get_server_verify_keys(self): store = self.hs.get_datastore() - d = store.store_server_verify_key("server1", "from_server", 0, KEY_1) - self.get_success(d) - d = store.store_server_verify_key("server1", "from_server", 0, KEY_2) + key_id_1 = "ed25519:key1" + key_id_2 = "ed25519:KEY_ID_2" + d = store.store_server_verify_keys( + "from_server", + 10, + [ + ("server1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("server1", key_id_2, FetchKeyResult(KEY_2, 200)), + ], + ) self.get_success(d) d = store.get_server_verify_keys( - [ - ("server1", "ed25519:key1"), - ("server1", "ed25519:key2"), - ("server1", "ed25519:key3"), - ] + [("server1", key_id_1), ("server1", key_id_2), ("server1", "ed25519:key3")] ) res = self.get_success(d) self.assertEqual(len(res.keys()), 3) - self.assertEqual(res[("server1", "ed25519:key1")].version, "key1") - self.assertEqual(res[("server1", "ed25519:key2")].version, "key2") + res1 = res[("server1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.verify_key.version, "key1") + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("server1", key_id_2)] + self.assertEqual(res2.verify_key, KEY_2) + # version comes from the ID it was stored with + self.assertEqual(res2.verify_key.version, "KEY_ID_2") + self.assertEqual(res2.valid_until_ts, 200) # non-existent result gives None self.assertIsNone(res[("server1", "ed25519:key3")]) @@ -60,32 +73,51 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): key_id_1 = "ed25519:key1" key_id_2 = "ed25519:key2" - d = store.store_server_verify_key("srv1", "from_server", 0, KEY_1) - self.get_success(d) - d = store.store_server_verify_key("srv1", "from_server", 0, KEY_2) + d = store.store_server_verify_keys( + "from_server", + 0, + [ + ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), + ], + ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], KEY_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, KEY_2) + self.assertEqual(res2.valid_until_ts, 200) # we should be able to look up the same thing again without a db hit res = store.get_server_verify_keys([("srv1", key_id_1)]) if isinstance(res, Deferred): res = self.successResultOf(res) self.assertEqual(len(res.keys()), 1) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) + self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) new_key_2 = signedjson.key.get_verify_key( signedjson.key.generate_signing_key("key2") ) - d = store.store_server_verify_key("srv1", "from_server", 10, new_key_2) + d = store.store_server_verify_keys( + "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))] + ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], new_key_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, new_key_2) + self.assertEqual(res2.valid_until_ts, 300) |