summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
authorRichard van der Hoff <richard@matrix.org>2019-04-08 14:51:07 +0100
committerRichard van der Hoff <richard@matrix.org>2019-04-09 00:00:10 +0100
commit18b69be00f9fa79cf2b237992ef1f0094d1dc453 (patch)
treea71af9dc0dcee3c322a704754fa8d610014b74d6 /synapse/crypto/keyring.py
parentRewrite test_keys as a HomeserverTestCase (diff)
downloadsynapse-18b69be00f9fa79cf2b237992ef1f0094d1dc453.tar.xz
Rewrite Datastore.get_server_verify_keys
Rewrite this so that it doesn't hammer the database.
Diffstat (limited to 'synapse/crypto/keyring.py')
-rw-r--r--synapse/crypto/keyring.py38
1 files changed, 17 insertions, 21 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index ede120b2a6..834b107705 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -301,13 +301,12 @@ class Keyring(object):
                         # complete this VerifyKeyRequest.
                         result_keys = results.get(server_name, {})
                         for key_id in verify_request.key_ids:
-                            if key_id in result_keys:
+                            key = result_keys.get(key_id)
+                            if key:
                                 with PreserveLoggingContext():
-                                    verify_request.deferred.callback((
-                                        server_name,
-                                        key_id,
-                                        result_keys[key_id],
-                                    ))
+                                    verify_request.deferred.callback(
+                                        (server_name, key_id, key)
+                                    )
                                 break
                         else:
                             # The else block is only reached if the loop above
@@ -341,27 +340,24 @@ class Keyring(object):
     @defer.inlineCallbacks
     def get_keys_from_store(self, server_name_and_key_ids):
         """
-
         Args:
-            server_name_and_key_ids (list[(str, iterable[str])]):
+            server_name_and_key_ids (iterable(Tuple[str, iterable[str]]):
                 list of (server_name, iterable[key_id]) tuples to fetch keys for
 
         Returns:
-            Deferred: resolves to dict[str, dict[str, VerifyKey]]: map from
+            Deferred: resolves to dict[str, dict[str, VerifyKey|None]]: map from
                 server_name -> key_id -> VerifyKey
         """
-        res = yield logcontext.make_deferred_yieldable(defer.gatherResults(
-            [
-                run_in_background(
-                    self.store.get_server_verify_keys,
-                    server_name, key_ids,
-                ).addCallback(lambda ks, server: (server, ks), server_name)
-                for server_name, key_ids in server_name_and_key_ids
-            ],
-            consumeErrors=True,
-        ).addErrback(unwrapFirstError))
-
-        defer.returnValue(dict(res))
+        keys_to_fetch = (
+            (server_name, key_id)
+            for server_name, key_ids in server_name_and_key_ids
+            for key_id in key_ids
+        )
+        res = yield self.store.get_server_verify_keys(keys_to_fetch)
+        keys = {}
+        for (server_name, key_id), key in res.items():
+            keys.setdefault(server_name, {})[key_id] = key
+        defer.returnValue(keys)
 
     @defer.inlineCallbacks
     def get_keys_from_perspectives(self, server_name_and_key_ids):