diff --git a/synapse/crypto/keyring.py b/synapse/crypto/keyring.py
index aff69c5f83..873c9b40fa 100644
--- a/synapse/crypto/keyring.py
+++ b/synapse/crypto/keyring.py
@@ -25,11 +25,11 @@ from syutil.base64util import decode_base64, encode_base64
from synapse.api.errors import SynapseError, Codes
from synapse.util.retryutils import get_retry_limiter
-
-from synapse.util.async import ObservableDeferred
+from synapse.util import unwrapFirstError
from OpenSSL import crypto
+from collections import namedtuple
import urllib
import hashlib
import logging
@@ -38,6 +38,9 @@ import logging
logger = logging.getLogger(__name__)
+KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids"))
+
+
class Keyring(object):
def __init__(self, hs):
self.store = hs.get_datastore()
@@ -49,141 +52,274 @@ class Keyring(object):
self.key_downloads = {}
- @defer.inlineCallbacks
def verify_json_for_server(self, server_name, json_object):
- logger.debug("Verifying for %s", server_name)
- key_ids = signature_ids(json_object, server_name)
- if not key_ids:
- raise SynapseError(
- 400,
- "Not signed with a supported algorithm",
- Codes.UNAUTHORIZED,
- )
- try:
- verify_key = yield self.get_server_verify_key(server_name, key_ids)
- except IOError as e:
- logger.warn(
- "Got IOError when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 502,
- "Error downloading keys for %s" % (server_name,),
- Codes.UNAUTHORIZED,
- )
- except Exception as e:
- logger.warn(
- "Got Exception when downloading keys for %s: %s %s",
- server_name, type(e).__name__, str(e.message),
- )
- raise SynapseError(
- 401,
- "No key for %s with id %s" % (server_name, key_ids),
- Codes.UNAUTHORIZED,
- )
+ return self.verify_json_objects_for_server(
+ [(server_name, json_object)]
+ )[0]
- try:
- verify_signed_json(json_object, server_name, verify_key)
- except:
- raise SynapseError(
- 401,
- "Invalid signature for server %s with key %s:%s" % (
- server_name, verify_key.alg, verify_key.version
- ),
- Codes.UNAUTHORIZED,
- )
+ def verify_json_objects_for_server(self, server_and_json):
+ """Bulk verfies signatures of json objects, bulk fetching keys as
+ necessary.
- @defer.inlineCallbacks
- def get_server_verify_key(self, server_name, key_ids):
- """Finds a verification key for the server with one of the key ids.
- Trys to fetch the key from a trusted perspective server first.
Args:
- server_name(str): The name of the server to fetch a key for.
- keys_ids (list of str): The key_ids to check for.
+ server_and_json (list): List of pairs of (server_name, json_object)
+
+ Returns:
+ list of deferreds indicating success or failure to verify each
+ json object's signature for the given server_name.
"""
- cached = yield self.store.get_server_verify_keys(server_name, key_ids)
+ group_id_to_json = {}
+ group_id_to_group = {}
+ group_ids = []
+
+ next_group_id = 0
+ deferreds = {}
+
+ for server_name, json_object in server_and_json:
+ logger.debug("Verifying for %s", server_name)
+ group_id = next_group_id
+ next_group_id += 1
+ group_ids.append(group_id)
+
+ key_ids = signature_ids(json_object, server_name)
+ if not key_ids:
+ deferreds[group_id] = defer.fail(SynapseError(
+ 400,
+ "Not signed with a supported algorithm",
+ Codes.UNAUTHORIZED,
+ ))
+
+ group = KeyGroup(server_name, group_id, key_ids)
- if cached:
- defer.returnValue(cached[0])
- return
+ group_id_to_group[group_id] = group
+ group_id_to_json[group_id] = json_object
- download = self.key_downloads.get(server_name)
+ @defer.inlineCallbacks
+ def handle_key_deferred(group, deferred):
+ server_name = group.server_name
+ try:
+ _, _, key_id, verify_key = yield deferred
+ except IOError as e:
+ logger.warn(
+ "Got IOError when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 502,
+ "Error downloading keys for %s" % (server_name,),
+ Codes.UNAUTHORIZED,
+ )
+ except Exception as e:
+ logger.exception(
+ "Got Exception when downloading keys for %s: %s %s",
+ server_name, type(e).__name__, str(e.message),
+ )
+ raise SynapseError(
+ 401,
+ "No key for %s with id %s" % (server_name, key_ids),
+ Codes.UNAUTHORIZED,
+ )
+
+ json_object = group_id_to_json[group.group_id]
- if download is None:
- download = self._get_server_verify_key_impl(server_name, key_ids)
- download = ObservableDeferred(
- download,
- consumeErrors=True
+ try:
+ verify_signed_json(json_object, server_name, verify_key)
+ except:
+ raise SynapseError(
+ 401,
+ "Invalid signature for server %s with key %s:%s" % (
+ server_name, verify_key.alg, verify_key.version
+ ),
+ Codes.UNAUTHORIZED,
+ )
+
+ deferreds.update(self.get_server_verify_keys(
+ group_id_to_group
+ ))
+
+ return [
+ handle_key_deferred(
+ group_id_to_group[g_id],
+ deferreds[g_id],
)
- self.key_downloads[server_name] = download
+ for g_id in group_ids
+ ]
- @download.addBoth
- def callback(ret):
- del self.key_downloads[server_name]
- return ret
+ def get_server_verify_keys(self, group_id_to_group):
+ """Takes a dict of KeyGroups and tries to find at least one key for
+ each group.
+ """
+
+ # These are functions that produce keys given a list of key ids
+ key_fetch_fns = (
+ self.get_keys_from_store, # First try the local store
+ self.get_keys_from_perspectives, # Then try via perspectives
+ self.get_keys_from_server, # Then try directly
+ )
+
+ group_deferreds = {
+ group_id: defer.Deferred()
+ for group_id in group_id_to_group
+ }
+
+ @defer.inlineCallbacks
+ def do_iterations():
+ merged_results = {}
+
+ missing_keys = {
+ group.server_name: key_id
+ for group in group_id_to_group.values()
+ for key_id in group.key_ids
+ }
+
+ for fn in key_fetch_fns:
+ results = yield fn(missing_keys.items())
+ merged_results.update(results)
+
+ # We now need to figure out which groups we have keys for
+ # and which we don't
+ missing_groups = {}
+ for group in group_id_to_group.values():
+ for key_id in group.key_ids:
+ if key_id in merged_results[group.server_name]:
+ group_deferreds.pop(group.group_id).callback((
+ group.group_id,
+ group.server_name,
+ key_id,
+ merged_results[group.server_name][key_id],
+ ))
+ break
+ else:
+ missing_groups.setdefault(
+ group.server_name, []
+ ).append(group)
+
+ if not missing_groups:
+ break
+
+ missing_keys = {
+ server_name: set(
+ key_id for group in groups for key_id in group.key_ids
+ )
+ for server_name, groups in missing_groups.items()
+ }
- r = yield download.observe()
- defer.returnValue(r)
+ for group in missing_groups.values():
+ group_deferreds.pop(group.group_id).errback(SynapseError(
+ 401,
+ "No key for %s with id %s" % (
+ group.server_name, group.key_ids,
+ ),
+ Codes.UNAUTHORIZED,
+ ))
+
+ def on_err(err):
+ for deferred in group_deferreds.values():
+ deferred.errback(err)
+ group_deferreds.clear()
+
+ do_iterations().addErrback(on_err)
+
+ return group_deferreds
@defer.inlineCallbacks
- def _get_server_verify_key_impl(self, server_name, key_ids):
- keys = None
+ def get_keys_from_store(self, server_name_and_key_ids):
+ res = yield defer.gatherResults(
+ [
+ self.store.get_server_verify_keys(server_name, key_ids)
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
+
+ defer.returnValue(dict(zip(
+ [server_name for server_name, _ in server_name_and_key_ids],
+ res
+ )))
+ @defer.inlineCallbacks
+ def get_keys_from_perspectives(self, server_name_and_key_ids):
@defer.inlineCallbacks
def get_key(perspective_name, perspective_keys):
try:
result = yield self.get_server_verify_key_v2_indirect(
- server_name, key_ids, perspective_name, perspective_keys
+ server_name_and_key_ids, perspective_name, perspective_keys
)
defer.returnValue(result)
except Exception as e:
- logging.info(
- "Unable to getting key %r for %r from %r: %s %s",
- key_ids, server_name, perspective_name,
+ logger.exception(
+ "Unable to get key from %r: %s %s",
+ perspective_name,
type(e).__name__, str(e.message),
)
+ defer.returnValue({})
- perspective_results = yield defer.gatherResults([
- get_key(p_name, p_keys)
- for p_name, p_keys in self.perspective_servers.items()
- ])
+ results = yield defer.gatherResults(
+ [
+ get_key(p_name, p_keys)
+ for p_name, p_keys in self.perspective_servers.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
- for results in perspective_results:
- if results is not None:
- keys = results
+ union_of_keys = {}
+ for result in results:
+ for server_name, keys in result.items():
+ union_of_keys.setdefault(server_name, {}).update(keys)
- limiter = yield get_retry_limiter(
- server_name,
- self.clock,
- self.store,
- )
+ defer.returnValue(union_of_keys)
- with limiter:
- if not keys:
+ @defer.inlineCallbacks
+ def get_keys_from_server(self, server_name_and_key_ids):
+ @defer.inlineCallbacks
+ def get_key(server_name, key_ids):
+ limiter = yield get_retry_limiter(
+ server_name,
+ self.clock,
+ self.store,
+ )
+ with limiter:
+ keys = None
try:
keys = yield self.get_server_verify_key_v2_direct(
server_name, key_ids
)
except Exception as e:
- logging.info(
+ logger.info(
"Unable to getting key %r for %r directly: %s %s",
key_ids, server_name,
type(e).__name__, str(e.message),
)
- if not keys:
- keys = yield self.get_server_verify_key_v1_direct(
- server_name, key_ids
- )
+ if not keys:
+ keys = yield self.get_server_verify_key_v1_direct(
+ server_name, key_ids
+ )
+
+ keys = {server_name: keys}
+
+ defer.returnValue(keys)
+
+ results = yield defer.gatherResults(
+ [
+ get_key(server_name, key_ids)
+ for server_name, key_ids in server_name_and_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
- for key_id in key_ids:
- if key_id in keys:
- defer.returnValue(keys[key_id])
- return
- raise ValueError("No verification key found for given key ids")
+ merged = {}
+ for result in results:
+ merged.update(result)
+
+ defer.returnValue({
+ server_name: keys
+ for server_name, keys in merged.items()
+ if keys
+ })
@defer.inlineCallbacks
- def get_server_verify_key_v2_indirect(self, server_name, key_ids,
+ def get_server_verify_key_v2_indirect(self, server_names_and_key_ids,
perspective_name,
perspective_keys):
limiter = yield get_retry_limiter(
@@ -204,6 +340,7 @@ class Keyring(object):
u"minimum_valid_until_ts": 0
} for key_id in key_ids
}
+ for server_name, key_ids in server_names_and_key_ids
}
},
)
@@ -243,23 +380,29 @@ class Keyring(object):
" server %r" % (perspective_name,)
)
- response_keys = yield self.process_v2_response(
- server_name, perspective_name, response
+ processed_response = yield self.process_v2_response(
+ perspective_name, response
)
- keys.update(response_keys)
+ for server_name, response_keys in processed_response.items():
+ keys.setdefault(server_name, {}).update(response_keys)
- yield self.store_keys(
- server_name=server_name,
- from_server=perspective_name,
- verify_keys=keys,
- )
+ yield defer.gatherResults(
+ [
+ self.store_keys(
+ server_name=server_name,
+ from_server=perspective_name,
+ verify_keys=response_keys,
+ )
+ for server_name, response_keys in keys.items()
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
def get_server_verify_key_v2_direct(self, server_name, key_ids):
-
keys = {}
for requested_key_id in key_ids:
@@ -295,25 +438,30 @@ class Keyring(object):
raise ValueError("TLS certificate not allowed by fingerprints")
response_keys = yield self.process_v2_response(
- server_name=server_name,
from_server=server_name,
- requested_id=requested_key_id,
+ requested_ids=[requested_key_id],
response_json=response,
)
keys.update(response_keys)
- yield self.store_keys(
- server_name=server_name,
- from_server=server_name,
- verify_keys=keys,
- )
+ yield defer.gatherResults(
+ [
+ self.store_keys(
+ server_name=key_server_name,
+ from_server=server_name,
+ verify_keys=verify_keys,
+ )
+ for key_server_name, verify_keys in keys.items()
+ ],
+ consumeErrors=True
+ ).addErrback(unwrapFirstError)
defer.returnValue(keys)
@defer.inlineCallbacks
- def process_v2_response(self, server_name, from_server, response_json,
- requested_id=None):
+ def process_v2_response(self, from_server, response_json,
+ requested_ids=[]):
time_now_ms = self.clock.time_msec()
response_keys = {}
verify_keys = {}
@@ -335,6 +483,8 @@ class Keyring(object):
verify_key.time_added = time_now_ms
old_verify_keys[key_id] = verify_key
+ results = {}
+ server_name = response_json["server_name"]
for key_id in response_json["signatures"].get(server_name, {}):
if key_id not in response_json["verify_keys"]:
raise ValueError(
@@ -357,28 +507,31 @@ class Keyring(object):
signed_key_json_bytes = encode_canonical_json(signed_key_json)
ts_valid_until_ms = signed_key_json[u"valid_until_ts"]
- updated_key_ids = set()
- if requested_id is not None:
- updated_key_ids.add(requested_id)
+ updated_key_ids = set(requested_ids)
updated_key_ids.update(verify_keys)
updated_key_ids.update(old_verify_keys)
response_keys.update(verify_keys)
response_keys.update(old_verify_keys)
- for key_id in updated_key_ids:
- yield self.store.store_server_keys_json(
- server_name=server_name,
- key_id=key_id,
- from_server=server_name,
- ts_now_ms=time_now_ms,
- ts_expires_ms=ts_valid_until_ms,
- key_json_bytes=signed_key_json_bytes,
- )
+ yield defer.gatherResults(
+ [
+ self.store.store_server_keys_json(
+ server_name=server_name,
+ key_id=key_id,
+ from_server=server_name,
+ ts_now_ms=time_now_ms,
+ ts_expires_ms=ts_valid_until_ms,
+ key_json_bytes=signed_key_json_bytes,
+ )
+ for key_id in updated_key_ids
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
- defer.returnValue(response_keys)
+ results[server_name] = response_keys
- raise ValueError("No verification key found for given key ids")
+ defer.returnValue(results)
@defer.inlineCallbacks
def get_server_verify_key_v1_direct(self, server_name, key_ids):
@@ -462,8 +615,13 @@ class Keyring(object):
Returns:
A deferred that completes when the keys are stored.
"""
- for key_id, key in verify_keys.items():
- # TODO(markjh): Store whether the keys have expired.
- yield self.store.store_server_verify_key(
- server_name, server_name, key.time_added, key
- )
+ # TODO(markjh): Store whether the keys have expired.
+ yield defer.gatherResults(
+ [
+ self.store.store_server_verify_key(
+ server_name, server_name, key.time_added, key
+ )
+ for key_id, key in verify_keys.items()
+ ],
+ consumeErrors=True,
+ ).addErrback(unwrapFirstError)
diff --git a/synapse/federation/federation_base.py b/synapse/federation/federation_base.py
index 299493af91..bdfa247604 100644
--- a/synapse/federation/federation_base.py
+++ b/synapse/federation/federation_base.py
@@ -32,7 +32,8 @@ logger = logging.getLogger(__name__)
class FederationBase(object):
@defer.inlineCallbacks
- def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False):
+ def _check_sigs_and_hash_and_fetch(self, origin, pdus, outlier=False,
+ include_none=False):
"""Takes a list of PDUs and checks the signatures and hashs of each
one. If a PDU fails its signature check then we check if we have it in
the database and if not then request if from the originating server of
@@ -50,84 +51,108 @@ class FederationBase(object):
Returns:
Deferred : A list of PDUs that have valid signatures and hashes.
"""
+ deferreds = self._check_sigs_and_hashes(pdus)
- signed_pdus = []
+ def callback(pdu):
+ return pdu
- @defer.inlineCallbacks
- def do(pdu):
- try:
- new_pdu = yield self._check_sigs_and_hash(pdu)
- signed_pdus.append(new_pdu)
- except SynapseError:
- # FIXME: We should handle signature failures more gracefully.
+ def errback(failure, pdu):
+ failure.trap(SynapseError)
+ return None
+ def try_local_db(res, pdu):
+ if not res:
# Check local db.
- new_pdu = yield self.store.get_event(
+ return self.store.get_event(
pdu.event_id,
allow_rejected=True,
allow_none=True,
)
- if new_pdu:
- signed_pdus.append(new_pdu)
- return
-
- # Check pdu.origin
- if pdu.origin != origin:
- try:
- new_pdu = yield self.get_pdu(
- destinations=[pdu.origin],
- event_id=pdu.event_id,
- outlier=outlier,
- timeout=10000,
- )
-
- if new_pdu:
- signed_pdus.append(new_pdu)
- return
- except:
- pass
-
+ return res
+
+ def try_remote(res, pdu):
+ if not res and pdu.origin != origin:
+ return self.get_pdu(
+ destinations=[pdu.origin],
+ event_id=pdu.event_id,
+ outlier=outlier,
+ timeout=10000,
+ ).addErrback(lambda e: None)
+ return res
+
+ def warn(res, pdu):
+ if not res:
logger.warn(
"Failed to find copy of %s with valid signature",
pdu.event_id,
)
+ return res
+
+ for pdu, deferred in zip(pdus, deferreds):
+ deferred.addCallbacks(
+ callback, errback, errbackArgs=[pdu]
+ ).addCallback(
+ try_local_db, pdu
+ ).addCallback(
+ try_remote, pdu
+ ).addCallback(
+ warn, pdu
+ )
- yield defer.gatherResults(
- [do(pdu) for pdu in pdus],
+ valid_pdus = yield defer.gatherResults(
+ deferreds,
consumeErrors=True
).addErrback(unwrapFirstError)
- defer.returnValue(signed_pdus)
+ if include_none:
+ defer.returnValue(valid_pdus)
+ else:
+ defer.returnValue([p for p in valid_pdus if p])
- @defer.inlineCallbacks
def _check_sigs_and_hash(self, pdu):
- """Throws a SynapseError if the PDU does not have the correct
+ return self._check_sigs_and_hashes([pdu])[0]
+
+ def _check_sigs_and_hashes(self, pdus):
+ """Throws a SynapseError if a PDU does not have the correct
signatures.
Returns:
FrozenEvent: Either the given event or it redacted if it failed the
content hash check.
"""
- # Check signatures are correct.
- redacted_event = prune_event(pdu)
- redacted_pdu_json = redacted_event.get_pdu_json()
- try:
- yield self.keyring.verify_json_for_server(
- pdu.origin, redacted_pdu_json
- )
- except SynapseError:
+ redacted_pdus = [
+ prune_event(pdu)
+ for pdu in pdus
+ ]
+
+ deferreds = self.keyring.verify_json_objects_for_server([
+ (p.origin, p.get_pdu_json())
+ for p in redacted_pdus
+ ])
+
+ def callback(_, pdu, redacted):
+ if not check_event_content_hash(pdu):
+ logger.warn(
+ "Event content has been tampered, redacting %s: %s",
+ pdu.event_id, pdu.get_pdu_json()
+ )
+ return redacted
+ return pdu
+
+ def errback(failure, pdu):
+ failure.trap(SynapseError)
logger.warn(
"Signature check failed for %s",
pdu.event_id,
)
- raise
+ return failure
- if not check_event_content_hash(pdu):
- logger.warn(
- "Event content has been tampered, redacting.",
- pdu.event_id,
+ for deferred, pdu, redacted in zip(deferreds, pdus, redacted_pdus):
+ deferred.addCallbacks(
+ callback, errback,
+ callbackArgs=[pdu, redacted],
+ errbackArgs=[pdu],
)
- defer.returnValue(redacted_event)
- defer.returnValue(pdu)
+ return deferreds
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index d3b46b24c1..7736d14fb5 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -30,6 +30,7 @@ import synapse.metrics
from synapse.util.retryutils import get_retry_limiter, NotRetryingDestination
+import copy
import itertools
import logging
import random
@@ -167,7 +168,7 @@ class FederationClient(FederationBase):
# FIXME: We should handle signature failures more gracefully.
pdus[:] = yield defer.gatherResults(
- [self._check_sigs_and_hash(pdu) for pdu in pdus],
+ self._check_sigs_and_hashes(pdus),
consumeErrors=True,
).addErrback(unwrapFirstError)
@@ -230,7 +231,7 @@ class FederationClient(FederationBase):
pdu = pdu_list[0]
# Check signatures are correct.
- pdu = yield self._check_sigs_and_hash(pdu)
+ pdu = yield self._check_sigs_and_hashes([pdu])[0]
break
@@ -327,6 +328,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def make_join(self, destinations, room_id, user_id):
for destination in destinations:
+ if destination == self.server_name:
+ continue
+
try:
ret = yield self.transport_layer.make_join(
destination, room_id, user_id
@@ -353,6 +357,9 @@ class FederationClient(FederationBase):
@defer.inlineCallbacks
def send_join(self, destinations, pdu):
for destination in destinations:
+ if destination == self.server_name:
+ continue
+
try:
time_now = self._clock.time_msec()
_, content = yield self.transport_layer.send_join(
@@ -374,17 +381,39 @@ class FederationClient(FederationBase):
for p in content.get("auth_chain", [])
]
- signed_state, signed_auth = yield defer.gatherResults(
- [
- self._check_sigs_and_hash_and_fetch(
- destination, state, outlier=True
- ),
- self._check_sigs_and_hash_and_fetch(
- destination, auth_chain, outlier=True
- )
- ],
- consumeErrors=True
- ).addErrback(unwrapFirstError)
+ pdus = {
+ p.event_id: p
+ for p in itertools.chain(state, auth_chain)
+ }
+
+ valid_pdus = yield self._check_sigs_and_hash_and_fetch(
+ destination, pdus.values(),
+ outlier=True,
+ )
+
+ valid_pdus_map = {
+ p.event_id: p
+ for p in valid_pdus
+ }
+
+ # NB: We *need* to copy to ensure that we don't have multiple
+ # references being passed on, as that causes... issues.
+ signed_state = [
+ copy.copy(valid_pdus_map[p.event_id])
+ for p in state
+ if p.event_id in valid_pdus_map
+ ]
+
+ signed_auth = [
+ valid_pdus_map[p.event_id]
+ for p in auth_chain
+ if p.event_id in valid_pdus_map
+ ]
+
+ # NB: We *need* to copy to ensure that we don't have multiple
+ # references being passed on, as that causes... issues.
+ for s in signed_state:
+ s.internal_metadata = copy.deepcopy(s.internal_metadata)
auth_chain.sort(key=lambda e: e.depth)
@@ -396,7 +425,7 @@ class FederationClient(FederationBase):
except CodeMessageException:
raise
except Exception as e:
- logger.warn(
+ logger.exception(
"Failed to send_join via %s: %s",
destination, e.message
)
diff --git a/synapse/storage/keys.py b/synapse/storage/keys.py
index 5bdf497b93..940a5f7e08 100644
--- a/synapse/storage/keys.py
+++ b/synapse/storage/keys.py
@@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from _base import SQLBaseStore
+from _base import SQLBaseStore, cached
from twisted.internet import defer
@@ -71,6 +71,25 @@ class KeyStore(SQLBaseStore):
desc="store_server_certificate",
)
+ @cached()
+ @defer.inlineCallbacks
+ def get_all_server_verify_keys(self, server_name):
+ rows = yield self._simple_select_list(
+ table="server_signature_keys",
+ keyvalues={
+ "server_name": server_name,
+ },
+ retcols=["key_id", "verify_key"],
+ desc="get_all_server_verify_keys",
+ )
+
+ defer.returnValue({
+ row["key_id"]: decode_verify_key_bytes(
+ row["key_id"], str(row["verify_key"])
+ )
+ for row in rows
+ })
+
@defer.inlineCallbacks
def get_server_verify_keys(self, server_name, key_ids):
"""Retrieve the NACL verification key for a given server for the given
@@ -81,24 +100,14 @@ class KeyStore(SQLBaseStore):
Returns:
(list of VerifyKey): The verification keys.
"""
- sql = (
- "SELECT key_id, verify_key FROM server_signature_keys"
- " WHERE server_name = ?"
- " AND key_id in (" + ",".join("?" for key_id in key_ids) + ")"
- )
-
- rows = yield self._execute_and_decode(
- "get_server_verify_keys", sql, server_name, *key_ids
- )
-
- keys = []
- for row in rows:
- key_id = row["key_id"]
- key_bytes = row["verify_key"]
- key = decode_verify_key_bytes(key_id, str(key_bytes))
- keys.append(key)
- defer.returnValue(keys)
+ keys = yield self.get_all_server_verify_keys(server_name)
+ defer.returnValue({
+ k: keys[k]
+ for k in key_ids
+ if k in keys and keys[k]
+ })
+ @defer.inlineCallbacks
def store_server_verify_key(self, server_name, from_server, time_now_ms,
verify_key):
"""Stores a NACL verification key for the given server.
@@ -109,7 +118,7 @@ class KeyStore(SQLBaseStore):
ts_now_ms (int): The time now in milliseconds
verification_key (VerifyKey): The NACL verify key.
"""
- return self._simple_upsert(
+ yield self._simple_upsert(
table="server_signature_keys",
keyvalues={
"server_name": server_name,
@@ -123,6 +132,8 @@ class KeyStore(SQLBaseStore):
desc="store_server_verify_key",
)
+ self.get_all_server_verify_keys.invalidate(server_name)
+
def store_server_keys_json(self, server_name, key_id, from_server,
ts_now_ms, ts_expires_ms, key_json_bytes):
"""Stores the JSON bytes for a set of keys from a server
@@ -152,6 +163,7 @@ class KeyStore(SQLBaseStore):
"ts_valid_until_ms": ts_expires_ms,
"key_json": buffer(key_json_bytes),
},
+ desc="store_server_keys_json",
)
def get_server_keys_json(self, server_keys):
|