diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index c2b38d72a9..668a90e495 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -145,7 +145,7 @@ class E2eKeysHandler(object):
"status": 503, "message": e.message
}
- yield preserve_context_over_deferred(defer.gatherResults([
+ yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries_not_in_cache
]))
@@ -257,11 +257,21 @@ class E2eKeysHandler(object):
"status": 503, "message": e.message
}
- yield preserve_context_over_deferred(defer.gatherResults([
+ yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
+ logger.info(
+ "Claimed one-time-keys: %s",
+ ",".join((
+ "%s for %s:%s" % (key_id, user_id, device_id)
+ for user_id, user_keys in json_result.iteritems()
+ for device_id, device_keys in user_keys.iteritems()
+ for key_id, _ in device_keys.iteritems()
+ )),
+ )
+
defer.returnValue({
"one_time_keys": json_result,
"failures": failures
@@ -288,19 +298,8 @@ class E2eKeysHandler(object):
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
- logger.info(
- "Adding %d one_time_keys for device %r for user %r at %d",
- len(one_time_keys), device_id, user_id, time_now
- )
- key_list = []
- for key_id, key_json in one_time_keys.items():
- algorithm, key_id = key_id.split(":")
- key_list.append((
- algorithm, key_id, encode_canonical_json(key_json)
- ))
-
- yield self.store.add_e2e_one_time_keys(
- user_id, device_id, time_now, key_list
+ yield self._upload_one_time_keys_for_user(
+ user_id, device_id, time_now, one_time_keys,
)
# the device should have been registered already, but it may have been
@@ -313,3 +312,58 @@ class E2eKeysHandler(object):
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue({"one_time_key_counts": result})
+
+ @defer.inlineCallbacks
+ def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
+ one_time_keys):
+ logger.info(
+ "Adding one_time_keys %r for device %r for user %r at %d",
+ one_time_keys.keys(), device_id, user_id, time_now,
+ )
+
+ # make a list of (alg, id, key) tuples
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append((
+ algorithm, key_id, key_obj
+ ))
+
+ # First we check if we have already persisted any of the keys.
+ existing_key_map = yield self.store.get_e2e_one_time_keys(
+ user_id, device_id, [k_id for _, k_id, _ in key_list]
+ )
+
+ new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
+ for algorithm, key_id, key in key_list:
+ ex_json = existing_key_map.get((algorithm, key_id), None)
+ if ex_json:
+ if not _one_time_keys_match(ex_json, key):
+ raise SynapseError(
+ 400,
+ ("One time key %s:%s already exists. "
+ "Old key: %s; new key: %r") %
+ (algorithm, key_id, ex_json, key)
+ )
+ else:
+ new_keys.append((algorithm, key_id, encode_canonical_json(key)))
+
+ yield self.store.add_e2e_one_time_keys(
+ user_id, device_id, time_now, new_keys
+ )
+
+
+def _one_time_keys_match(old_key_json, new_key):
+ old_key = json.loads(old_key_json)
+
+ # if either is a string rather than an object, they must match exactly
+ if not isinstance(old_key, dict) or not isinstance(new_key, dict):
+ return old_key == new_key
+
+ # otherwise, we strip off the 'signatures' if any, because it's legitimate
+ # for different upload attempts to have different signatures.
+ old_key.pop("signatures", None)
+ new_key_copy = dict(new_key)
+ new_key_copy.pop("signatures", None)
+
+ return old_key == new_key_copy
|