diff options
Diffstat (limited to 'tests/storage/test_keys.py')
-rw-r--r-- | tests/storage/test_keys.py | 44 |
1 files changed, 30 insertions, 14 deletions
diff --git a/tests/storage/test_keys.py b/tests/storage/test_keys.py index 71ad7aee32..e07ff01201 100644 --- a/tests/storage/test_keys.py +++ b/tests/storage/test_keys.py @@ -17,6 +17,8 @@ import signedjson.key from twisted.internet.defer import Deferred +from synapse.storage.keys import FetchKeyResult + import tests.unittest KEY_1 = signedjson.key.decode_verify_key_base64( @@ -37,8 +39,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): "from_server", 10, [ - ("server1", key_id_1, KEY_1), - ("server1", key_id_2, KEY_2), + ("server1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("server1", key_id_2, FetchKeyResult(KEY_2, 200)), ], ) self.get_success(d) @@ -50,13 +52,15 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): self.assertEqual(len(res.keys()), 3) res1 = res[("server1", key_id_1)] - self.assertEqual(res1, KEY_1) - self.assertEqual(res1.version, "key1") + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.verify_key.version, "key1") + self.assertEqual(res1.valid_until_ts, 100) res2 = res[("server1", key_id_2)] - self.assertEqual(res2, KEY_2) + self.assertEqual(res2.verify_key, KEY_2) # version comes from the ID it was stored with - self.assertEqual(res2.version, "KEY_ID_2") + self.assertEqual(res2.verify_key.version, "KEY_ID_2") + self.assertEqual(res2.valid_until_ts, 200) # non-existent result gives None self.assertIsNone(res[("server1", "ed25519:key3")]) @@ -73,8 +77,8 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): "from_server", 0, [ - ("srv1", key_id_1, KEY_1), - ("srv1", key_id_2, KEY_2), + ("srv1", key_id_1, FetchKeyResult(KEY_1, 100)), + ("srv1", key_id_2, FetchKeyResult(KEY_2, 200)), ], ) self.get_success(d) @@ -82,26 +86,38 @@ class KeyStoreTestCase(tests.unittest.HomeserverTestCase): d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], KEY_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, KEY_2) + self.assertEqual(res2.valid_until_ts, 200) # we should be able to look up the same thing again without a db hit res = store.get_server_verify_keys([("srv1", key_id_1)]) if isinstance(res, Deferred): res = self.successResultOf(res) self.assertEqual(len(res.keys()), 1) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) + self.assertEqual(res[("srv1", key_id_1)].verify_key, KEY_1) new_key_2 = signedjson.key.get_verify_key( signedjson.key.generate_signing_key("key2") ) d = store.store_server_verify_keys( - "from_server", 10, [("srv1", key_id_2, new_key_2)] + "from_server", 10, [("srv1", key_id_2, FetchKeyResult(new_key_2, 300))] ) self.get_success(d) d = store.get_server_verify_keys([("srv1", key_id_1), ("srv1", key_id_2)]) res = self.get_success(d) self.assertEqual(len(res.keys()), 2) - self.assertEqual(res[("srv1", key_id_1)], KEY_1) - self.assertEqual(res[("srv1", key_id_2)], new_key_2) + + res1 = res[("srv1", key_id_1)] + self.assertEqual(res1.verify_key, KEY_1) + self.assertEqual(res1.valid_until_ts, 100) + + res2 = res[("srv1", key_id_2)] + self.assertEqual(res2.verify_key, new_key_2) + self.assertEqual(res2.valid_until_ts, 300) |