diff options
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r-- | synapse/handlers/e2e_keys.py | 86 |
1 files changed, 70 insertions, 16 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index c2b38d72a9..668a90e495 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -21,7 +21,7 @@ from twisted.internet import defer from synapse.api.errors import SynapseError, CodeMessageException from synapse.types import get_domain_from_id -from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred +from synapse.util.logcontext import preserve_fn, make_deferred_yieldable from synapse.util.retryutils import NotRetryingDestination logger = logging.getLogger(__name__) @@ -145,7 +145,7 @@ class E2eKeysHandler(object): "status": 503, "message": e.message } - yield preserve_context_over_deferred(defer.gatherResults([ + yield make_deferred_yieldable(defer.gatherResults([ preserve_fn(do_remote_query)(destination) for destination in remote_queries_not_in_cache ])) @@ -257,11 +257,21 @@ class E2eKeysHandler(object): "status": 503, "message": e.message } - yield preserve_context_over_deferred(defer.gatherResults([ + yield make_deferred_yieldable(defer.gatherResults([ preserve_fn(claim_client_keys)(destination) for destination in remote_queries ])) + logger.info( + "Claimed one-time-keys: %s", + ",".join(( + "%s for %s:%s" % (key_id, user_id, device_id) + for user_id, user_keys in json_result.iteritems() + for device_id, device_keys in user_keys.iteritems() + for key_id, _ in device_keys.iteritems() + )), + ) + defer.returnValue({ "one_time_keys": json_result, "failures": failures @@ -288,19 +298,8 @@ class E2eKeysHandler(object): one_time_keys = keys.get("one_time_keys", None) if one_time_keys: - logger.info( - "Adding %d one_time_keys for device %r for user %r at %d", - len(one_time_keys), device_id, user_id, time_now - ) - key_list = [] - for key_id, key_json in one_time_keys.items(): - algorithm, key_id = key_id.split(":") - key_list.append(( - algorithm, key_id, encode_canonical_json(key_json) - )) - - yield self.store.add_e2e_one_time_keys( - user_id, device_id, time_now, key_list + yield self._upload_one_time_keys_for_user( + user_id, device_id, time_now, one_time_keys, ) # the device should have been registered already, but it may have been @@ -313,3 +312,58 @@ class E2eKeysHandler(object): result = yield self.store.count_e2e_one_time_keys(user_id, device_id) defer.returnValue({"one_time_key_counts": result}) + + @defer.inlineCallbacks + def _upload_one_time_keys_for_user(self, user_id, device_id, time_now, + one_time_keys): + logger.info( + "Adding one_time_keys %r for device %r for user %r at %d", + one_time_keys.keys(), device_id, user_id, time_now, + ) + + # make a list of (alg, id, key) tuples + key_list = [] + for key_id, key_obj in one_time_keys.items(): + algorithm, key_id = key_id.split(":") + key_list.append(( + algorithm, key_id, key_obj + )) + + # First we check if we have already persisted any of the keys. + existing_key_map = yield self.store.get_e2e_one_time_keys( + user_id, device_id, [k_id for _, k_id, _ in key_list] + ) + + new_keys = [] # Keys that we need to insert. (alg, id, json) tuples. + for algorithm, key_id, key in key_list: + ex_json = existing_key_map.get((algorithm, key_id), None) + if ex_json: + if not _one_time_keys_match(ex_json, key): + raise SynapseError( + 400, + ("One time key %s:%s already exists. " + "Old key: %s; new key: %r") % + (algorithm, key_id, ex_json, key) + ) + else: + new_keys.append((algorithm, key_id, encode_canonical_json(key))) + + yield self.store.add_e2e_one_time_keys( + user_id, device_id, time_now, new_keys + ) + + +def _one_time_keys_match(old_key_json, new_key): + old_key = json.loads(old_key_json) + + # if either is a string rather than an object, they must match exactly + if not isinstance(old_key, dict) or not isinstance(new_key, dict): + return old_key == new_key + + # otherwise, we strip off the 'signatures' if any, because it's legitimate + # for different upload attempts to have different signatures. + old_key.pop("signatures", None) + new_key_copy = dict(new_key) + new_key_copy.pop("signatures", None) + + return old_key == new_key_copy |