diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index fdfd4f911d..2be341ac7b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], SERVER_NAME)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
# we expect it to be encoded as canonical json *before* it hits the db
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
# change the server name: the result should be ignored
response["server_name"] = "OTHER_SERVER"
@@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_get_multiple_keys_from_perspectives(self) -> None:
"""Check that we can correctly request multiple keys for the same server"""
@@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
self.assertEqual(k.verify_key.version, "ver1")
# check that the perspectives store is correctly updated
- lookup_triplet = (SERVER_NAME, testverifykey_id, None)
key_json = self.get_success(
self.hs.get_datastores().main.get_server_keys_json_for_remote(
- [lookup_triplet]
+ SERVER_NAME, [testverifykey_id]
)
)
- res_keys = key_json[lookup_triplet]
- self.assertEqual(len(res_keys), 1)
- res = res_keys[0]
- self.assertEqual(res["key_id"], testverifykey_id)
- self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
- self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
- self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
- self.assertEqual(
- bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
- )
+ res = key_json[testverifykey_id]
+ self.assertIsNotNone(res)
+ assert res is not None
+ self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+ self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+ self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
def test_invalid_perspectives_responses(self) -> None:
"""Check that invalid responses from the perspectives server are rejected"""
|