diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index bc20b9c201..51e3fdea06 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -440,6 +440,16 @@ class FederationServer(FederationBase):
key_id: json.loads(json_bytes)
}
+ logger.info(
+ "Claimed one-time-keys: %s",
+ ",".join((
+ "%s for %s:%s" % (key_id, user_id, device_id)
+ for user_id, user_keys in json_result.iteritems()
+ for device_id, device_keys in user_keys.iteritems()
+ for key_id, _ in device_keys.iteritems()
+ )),
+ )
+
defer.returnValue({"one_time_keys": json_result})
@defer.inlineCallbacks
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index c2b38d72a9..668a90e495 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
from synapse.api.errors import SynapseError, CodeMessageException
from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
from synapse.util.retryutils import NotRetryingDestination
logger = logging.getLogger(__name__)
@@ -145,7 +145,7 @@ class E2eKeysHandler(object):
"status": 503, "message": e.message
}
- yield preserve_context_over_deferred(defer.gatherResults([
+ yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
for destination in remote_queries_not_in_cache
]))
@@ -257,11 +257,21 @@ class E2eKeysHandler(object):
"status": 503, "message": e.message
}
- yield preserve_context_over_deferred(defer.gatherResults([
+ yield make_deferred_yieldable(defer.gatherResults([
preserve_fn(claim_client_keys)(destination)
for destination in remote_queries
]))
+ logger.info(
+ "Claimed one-time-keys: %s",
+ ",".join((
+ "%s for %s:%s" % (key_id, user_id, device_id)
+ for user_id, user_keys in json_result.iteritems()
+ for device_id, device_keys in user_keys.iteritems()
+ for key_id, _ in device_keys.iteritems()
+ )),
+ )
+
defer.returnValue({
"one_time_keys": json_result,
"failures": failures
@@ -288,19 +298,8 @@ class E2eKeysHandler(object):
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
- logger.info(
- "Adding %d one_time_keys for device %r for user %r at %d",
- len(one_time_keys), device_id, user_id, time_now
- )
- key_list = []
- for key_id, key_json in one_time_keys.items():
- algorithm, key_id = key_id.split(":")
- key_list.append((
- algorithm, key_id, encode_canonical_json(key_json)
- ))
-
- yield self.store.add_e2e_one_time_keys(
- user_id, device_id, time_now, key_list
+ yield self._upload_one_time_keys_for_user(
+ user_id, device_id, time_now, one_time_keys,
)
# the device should have been registered already, but it may have been
@@ -313,3 +312,58 @@ class E2eKeysHandler(object):
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue({"one_time_key_counts": result})
+
+ @defer.inlineCallbacks
+ def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
+ one_time_keys):
+ logger.info(
+ "Adding one_time_keys %r for device %r for user %r at %d",
+ one_time_keys.keys(), device_id, user_id, time_now,
+ )
+
+ # make a list of (alg, id, key) tuples
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append((
+ algorithm, key_id, key_obj
+ ))
+
+ # First we check if we have already persisted any of the keys.
+ existing_key_map = yield self.store.get_e2e_one_time_keys(
+ user_id, device_id, [k_id for _, k_id, _ in key_list]
+ )
+
+ new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
+ for algorithm, key_id, key in key_list:
+ ex_json = existing_key_map.get((algorithm, key_id), None)
+ if ex_json:
+ if not _one_time_keys_match(ex_json, key):
+ raise SynapseError(
+ 400,
+ ("One time key %s:%s already exists. "
+ "Old key: %s; new key: %r") %
+ (algorithm, key_id, ex_json, key)
+ )
+ else:
+ new_keys.append((algorithm, key_id, encode_canonical_json(key)))
+
+ yield self.store.add_e2e_one_time_keys(
+ user_id, device_id, time_now, new_keys
+ )
+
+
+def _one_time_keys_match(old_key_json, new_key):
+ old_key = json.loads(old_key_json)
+
+ # if either is a string rather than an object, they must match exactly
+ if not isinstance(old_key, dict) or not isinstance(new_key, dict):
+ return old_key == new_key
+
+ # otherwise, we strip off the 'signatures' if any, because it's legitimate
+ # for different upload attempts to have different signatures.
+ old_key.pop("signatures", None)
+ new_key_copy = dict(new_key)
+ new_key_copy.pop("signatures", None)
+
+ return old_key == new_key_copy
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index c96dae352d..e00f31da2b 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,7 +14,6 @@
# limitations under the License.
from twisted.internet import defer
-from synapse.api.errors import SynapseError
from synapse.util.caches.descriptors import cached
from canonicaljson import encode_canonical_json
@@ -124,18 +123,24 @@ class EndToEndKeyStore(SQLBaseStore):
return result
@defer.inlineCallbacks
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
- """Insert some new one time keys for a device.
+ def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+ """Retrieve a number of one-time keys for a user
- Checks if any of the keys are already inserted, if they are then check
- if they match. If they don't then we raise an error.
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ key_ids(list[str]): list of key ids (excluding algorithm) to
+ retrieve
+
+ Returns:
+ deferred resolving to Dict[(str, str), str]: map from (algorithm,
+ key_id) to json string for key
"""
- # First we check if we have already persisted any of the keys.
rows = yield self._simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
- iterable=[key_id for _, key_id, _ in key_list],
+ iterable=key_ids,
retcols=("algorithm", "key_id", "key_json",),
keyvalues={
"user_id": user_id,
@@ -144,20 +149,22 @@ class EndToEndKeyStore(SQLBaseStore):
desc="add_e2e_one_time_keys_check",
)
- existing_key_map = {
+ defer.returnValue({
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
- }
-
- new_keys = [] # Keys that we need to insert
- for algorithm, key_id, json_bytes in key_list:
- ex_bytes = existing_key_map.get((algorithm, key_id), None)
- if ex_bytes:
- if json_bytes != ex_bytes:
- raise SynapseError(
- 400, "One time key with key_id %r already exists" % (key_id,)
- )
- else:
- new_keys.append((algorithm, key_id, json_bytes))
+ })
+
+ @defer.inlineCallbacks
+ def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+ """Insert some new one time keys for a device. Errors if any of the
+ keys already exist.
+
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ time_now(long): insertion time to record (ms since epoch)
+ new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+ (algorithm, key_id, key json)
+ """
def _add_e2e_one_time_keys(txn):
# We are protected from race between lookup and insertion due to
|