summary refs log tree commit diff
path: root/tests/crypto/test_keyring.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/crypto/test_keyring.py')
-rw-r--r--tests/crypto/test_keyring.py61
1 files changed, 23 insertions, 38 deletions
diff --git a/tests/crypto/test_keyring.py b/tests/crypto/test_keyring.py
index fdfd4f911d..2be341ac7b 100644
--- a/tests/crypto/test_keyring.py
+++ b/tests/crypto/test_keyring.py
@@ -456,24 +456,19 @@ class ServerKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
-        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json_for_remote(
-                [lookup_triplet]
+                SERVER_NAME, [testverifykey_id]
             )
         )
-        res_keys = key_json[lookup_triplet]
-        self.assertEqual(len(res_keys), 1)
-        res = res_keys[0]
-        self.assertEqual(res["key_id"], testverifykey_id)
-        self.assertEqual(res["from_server"], SERVER_NAME)
-        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
-        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
+        res = key_json[testverifykey_id]
+        self.assertIsNotNone(res)
+        assert res is not None
+        self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+        self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
 
         # we expect it to be encoded as canonical json *before* it hits the db
-        self.assertEqual(
-            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
-        )
+        self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
 
         # change the server name: the result should be ignored
         response["server_name"] = "OTHER_SERVER"
@@ -576,23 +571,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
-        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json_for_remote(
-                [lookup_triplet]
+                SERVER_NAME, [testverifykey_id]
             )
         )
-        res_keys = key_json[lookup_triplet]
-        self.assertEqual(len(res_keys), 1)
-        res = res_keys[0]
-        self.assertEqual(res["key_id"], testverifykey_id)
-        self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
-        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
-        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
-        self.assertEqual(
-            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
-        )
+        res = key_json[testverifykey_id]
+        self.assertIsNotNone(res)
+        assert res is not None
+        self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+        self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+        self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
 
     def test_get_multiple_keys_from_perspectives(self) -> None:
         """Check that we can correctly request multiple keys for the same server"""
@@ -699,23 +689,18 @@ class PerspectivesKeyFetcherTestCase(unittest.HomeserverTestCase):
         self.assertEqual(k.verify_key.version, "ver1")
 
         # check that the perspectives store is correctly updated
-        lookup_triplet = (SERVER_NAME, testverifykey_id, None)
         key_json = self.get_success(
             self.hs.get_datastores().main.get_server_keys_json_for_remote(
-                [lookup_triplet]
+                SERVER_NAME, [testverifykey_id]
             )
         )
-        res_keys = key_json[lookup_triplet]
-        self.assertEqual(len(res_keys), 1)
-        res = res_keys[0]
-        self.assertEqual(res["key_id"], testverifykey_id)
-        self.assertEqual(res["from_server"], self.mock_perspective_server.server_name)
-        self.assertEqual(res["ts_added_ms"], self.reactor.seconds() * 1000)
-        self.assertEqual(res["ts_valid_until_ms"], VALID_UNTIL_TS)
-
-        self.assertEqual(
-            bytes(res["key_json"]), canonicaljson.encode_canonical_json(response)
-        )
+        res = key_json[testverifykey_id]
+        self.assertIsNotNone(res)
+        assert res is not None
+        self.assertEqual(res.added_ts, self.reactor.seconds() * 1000)
+        self.assertEqual(res.valid_until_ts, VALID_UNTIL_TS)
+
+        self.assertEqual(res.key_json, canonicaljson.encode_canonical_json(response))
 
     def test_invalid_perspectives_responses(self) -> None:
         """Check that invalid responses from the perspectives server are rejected"""