summary refs log tree commit diff
path: root/synapse/crypto/keyring.py
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--synapse/crypto/keyring.py83
1 files changed, 68 insertions, 15 deletions
diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index 873c9b40fa..aa74d4d0cb 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -27,6 +27,8 @@ from synapse.api.errors import SynapseError, Codes
 from synapse.util.retryutils import get_retry_limiter
 from synapse.util import unwrapFirstError
 
+from synapse.util.async import ObservableDeferred
+
 from OpenSSL import crypto
 
 from collections import namedtuple
@@ -88,6 +90,8 @@ class Keyring(object):
                     "Not signed with a supported algorithm",
                     Codes.UNAUTHORIZED,
                 ))
+            else:
+                deferreds[group_id] = defer.Deferred()
 
             group = KeyGroup(server_name, group_id, key_ids)
 
@@ -133,10 +137,41 @@ class Keyring(object):
                     Codes.UNAUTHORIZED,
                 )
 
-        deferreds.update(self.get_server_verify_keys(
-            group_id_to_group
-        ))
+        server_to_deferred = {
+            server_name: defer.Deferred()
+            for server_name, _ in server_and_json
+        }
+
+        # We want to wait for any previous lookups to complete before
+        # proceeding.
+        wait_on_deferred = self.wait_for_previous_lookups(
+            [server_name for server_name, _ in server_and_json],
+            server_to_deferred,
+        )
+
+        # Actually start fetching keys.
+        wait_on_deferred.addBoth(
+            lambda _: self.get_server_verify_keys(group_id_to_group, deferreds)
+        )
+
+        # When we've finished fetching all the keys for a given server_name,
+        # resolve the deferred passed to `wait_for_previous_lookups` so that
+        # any lookups waiting will proceed.
+        server_to_gids = {}
+
+        def remove_deferreds(res, server_name, group_id):
+            server_to_gids[server_name].discard(group_id)
+            if not server_to_gids[server_name]:
+                server_to_deferred.pop(server_name).callback(None)
+            return res
 
+        for g_id, deferred in deferreds.items():
+            server_name = group_id_to_group[g_id].server_name
+            server_to_gids.setdefault(server_name, set()).add(g_id)
+            deferred.addBoth(remove_deferreds, server_name, g_id)
+
+        # Pass those keys to handle_key_deferred so that the json object
+        # signatures can be verified
         return [
             handle_key_deferred(
                 group_id_to_group[g_id],
@@ -145,7 +180,30 @@ class Keyring(object):
             for g_id in group_ids
         ]
 
-    def get_server_verify_keys(self, group_id_to_group):
+    @defer.inlineCallbacks
+    def wait_for_previous_lookups(self, server_names, server_to_deferred):
+        """Waits for any previous key lookups for the given servers to finish.
+
+        Args:
+            server_names (list): list of server_names we want to lookup
+            server_to_deferred (dict): server_name to deferred which gets
+                resolved once we've finished looking up keys for that server
+        """
+        while True:
+            wait_on = [
+                self.key_downloads[server_name]
+                for server_name in server_names
+                if server_name in self.key_downloads
+            ]
+            if wait_on:
+                yield defer.DeferredList(wait_on)
+            else:
+                break
+
+        for server_name, deferred in server_to_deferred:
+            self.key_downloads[server_name] = ObservableDeferred(deferred)
+
+    def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred):
         """Takes a dict of KeyGroups and tries to find at least one key for
         each group.
         """
@@ -157,11 +215,6 @@ class Keyring(object):
             self.get_keys_from_server,  # Then try directly
         )
 
-        group_deferreds = {
-            group_id: defer.Deferred()
-            for group_id in group_id_to_group
-        }
-
         @defer.inlineCallbacks
         def do_iterations():
             merged_results = {}
@@ -182,7 +235,7 @@ class Keyring(object):
                 for group in group_id_to_group.values():
                     for key_id in group.key_ids:
                         if key_id in merged_results[group.server_name]:
-                            group_deferreds.pop(group.group_id).callback((
+                            group_id_to_deferred[group.group_id].callback((
                                 group.group_id,
                                 group.server_name,
                                 key_id,
@@ -205,7 +258,7 @@ class Keyring(object):
                 }
 
             for group in missing_groups.values():
-                group_deferreds.pop(group.group_id).errback(SynapseError(
+                group_id_to_deferred[group.group_id].errback(SynapseError(
                     401,
                     "No key for %s with id %s" % (
                         group.server_name, group.key_ids,
@@ -214,13 +267,13 @@ class Keyring(object):
                 ))
 
         def on_err(err):
-            for deferred in group_deferreds.values():
-                deferred.errback(err)
-            group_deferreds.clear()
+            for deferred in group_id_to_deferred.values():
+                if not deferred.called:
+                    deferred.errback(err)
 
         do_iterations().addErrback(on_err)
 
-        return group_deferreds
+        return group_id_to_deferred
 
     @defer.inlineCallbacks
     def get_keys_from_store(self, server_name_and_key_ids):