diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index c2b38d72a9..9d994a8f71 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -288,19 +288,8 @@ class E2eKeysHandler(object):
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
- logger.info(
- "Adding %d one_time_keys for device %r for user %r at %d",
- len(one_time_keys), device_id, user_id, time_now
- )
- key_list = []
- for key_id, key_json in one_time_keys.items():
- algorithm, key_id = key_id.split(":")
- key_list.append((
- algorithm, key_id, encode_canonical_json(key_json)
- ))
-
- yield self.store.add_e2e_one_time_keys(
- user_id, device_id, time_now, key_list
+ yield self._upload_one_time_keys_for_user(
+ user_id, device_id, time_now, one_time_keys,
)
# the device should have been registered already, but it may have been
@@ -313,3 +302,58 @@ class E2eKeysHandler(object):
result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
defer.returnValue({"one_time_key_counts": result})
+
+ @defer.inlineCallbacks
+ def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
+ one_time_keys):
+ logger.info(
+ "Adding one_time_keys %r for device %r for user %r at %d",
+ one_time_keys.keys(), device_id, user_id, time_now,
+ )
+
+ # make a list of (alg, id, key) tuples
+ key_list = []
+ for key_id, key_obj in one_time_keys.items():
+ algorithm, key_id = key_id.split(":")
+ key_list.append((
+ algorithm, key_id, key_obj
+ ))
+
+ # First we check if we have already persisted any of the keys.
+ existing_key_map = yield self.store.get_e2e_one_time_keys(
+ user_id, device_id, [k_id for _, k_id, _ in key_list]
+ )
+
+ new_keys = [] # Keys that we need to insert. (alg, id, json) tuples.
+ for algorithm, key_id, key in key_list:
+ ex_json = existing_key_map.get((algorithm, key_id), None)
+ if ex_json:
+ if not _one_time_keys_match(ex_json, key):
+ raise SynapseError(
+ 400,
+ ("One time key %s:%s already exists. "
+ "Old key: %s; new key: %r") %
+ (algorithm, key_id, ex_json, key)
+ )
+ else:
+ new_keys.append((algorithm, key_id, encode_canonical_json(key)))
+
+ yield self.store.add_e2e_one_time_keys(
+ user_id, device_id, time_now, new_keys
+ )
+
+
+def _one_time_keys_match(old_key_json, new_key):
+ old_key = json.loads(old_key_json)
+
+ # if either is a string rather than an object, they must match exactly
+ if not isinstance(old_key, dict) or not isinstance(new_key, dict):
+ return old_key == new_key
+
+ # otherwise, we strip off the 'signatures' if any, because it's legitimate
+ # for different upload attempts to have different signatures.
+ old_key.pop("signatures", None)
+ new_key_copy = dict(new_key)
+ new_key_copy.pop("signatures", None)
+
+ return old_key == new_key_copy
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index c96dae352d..e00f31da2b 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,7 +14,6 @@
# limitations under the License.
from twisted.internet import defer
-from synapse.api.errors import SynapseError
from synapse.util.caches.descriptors import cached
from canonicaljson import encode_canonical_json
@@ -124,18 +123,24 @@ class EndToEndKeyStore(SQLBaseStore):
return result
@defer.inlineCallbacks
- def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
- """Insert some new one time keys for a device.
+ def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+ """Retrieve a number of one-time keys for a user
- Checks if any of the keys are already inserted, if they are then check
- if they match. If they don't then we raise an error.
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ key_ids(list[str]): list of key ids (excluding algorithm) to
+ retrieve
+
+ Returns:
+ deferred resolving to Dict[(str, str), str]: map from (algorithm,
+ key_id) to json string for key
"""
- # First we check if we have already persisted any of the keys.
rows = yield self._simple_select_many_batch(
table="e2e_one_time_keys_json",
column="key_id",
- iterable=[key_id for _, key_id, _ in key_list],
+ iterable=key_ids,
retcols=("algorithm", "key_id", "key_json",),
keyvalues={
"user_id": user_id,
@@ -144,20 +149,22 @@ class EndToEndKeyStore(SQLBaseStore):
desc="add_e2e_one_time_keys_check",
)
- existing_key_map = {
+ defer.returnValue({
(row["algorithm"], row["key_id"]): row["key_json"] for row in rows
- }
-
- new_keys = [] # Keys that we need to insert
- for algorithm, key_id, json_bytes in key_list:
- ex_bytes = existing_key_map.get((algorithm, key_id), None)
- if ex_bytes:
- if json_bytes != ex_bytes:
- raise SynapseError(
- 400, "One time key with key_id %r already exists" % (key_id,)
- )
- else:
- new_keys.append((algorithm, key_id, json_bytes))
+ })
+
+ @defer.inlineCallbacks
+ def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+ """Insert some new one time keys for a device. Errors if any of the
+ keys already exist.
+
+ Args:
+ user_id(str): id of user to get keys for
+ device_id(str): id of device to get keys for
+ time_now(long): insertion time to record (ms since epoch)
+ new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+ (algorithm, key_id, key json)
+ """
def _add_e2e_one_time_keys(txn):
# We are protected from race between lookup and insertion due to
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 878a54dc34..f10a80a8e1 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -14,6 +14,7 @@
# limitations under the License.
import mock
+from synapse.api import errors
from twisted.internet import defer
import synapse.api.errors
@@ -44,3 +45,100 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
local_user = "@boris:" + self.hs.hostname
res = yield self.handler.query_local_devices({local_user: None})
self.assertDictEqual(res, {local_user: {}})
+
+ @defer.inlineCallbacks
+ def test_reupload_one_time_keys(self):
+ """we should be able to re-upload the same keys"""
+ local_user = "@boris:" + self.hs.hostname
+ device_id = "xyz"
+ keys = {
+ "alg1:k1": "key1",
+ "alg2:k2": {
+ "key": "key2",
+ "signatures": {"k1": "sig1"}
+ },
+ "alg2:k3": {
+ "key": "key3",
+ },
+ }
+
+ res = yield self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys},
+ )
+ self.assertDictEqual(res, {
+ "one_time_key_counts": {"alg1": 1, "alg2": 2}
+ })
+
+ # we should be able to change the signature without a problem
+ keys["alg2:k2"]["signatures"]["k1"] = "sig2"
+ res = yield self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys},
+ )
+ self.assertDictEqual(res, {
+ "one_time_key_counts": {"alg1": 1, "alg2": 2}
+ })
+
+ @defer.inlineCallbacks
+ def test_change_one_time_keys(self):
+ """attempts to change one-time-keys should be rejected"""
+
+ local_user = "@boris:" + self.hs.hostname
+ device_id = "xyz"
+ keys = {
+ "alg1:k1": "key1",
+ "alg2:k2": {
+ "key": "key2",
+ "signatures": {"k1": "sig1"}
+ },
+ "alg2:k3": {
+ "key": "key3",
+ },
+ }
+
+ res = yield self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": keys},
+ )
+ self.assertDictEqual(res, {
+ "one_time_key_counts": {"alg1": 1, "alg2": 2}
+ })
+
+ try:
+ yield self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}},
+ )
+ self.fail("No error when changing string key")
+ except errors.SynapseError:
+ pass
+
+ try:
+ yield self.handler.upload_keys_for_user(
+ local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}},
+ )
+ self.fail("No error when replacing dict key with string")
+ except errors.SynapseError:
+ pass
+
+ try:
+ yield self.handler.upload_keys_for_user(
+ local_user, device_id, {
+ "one_time_keys": {"alg1:k1": {"key": "key"}}
+ },
+ )
+ self.fail("No error when replacing string key with dict")
+ except errors.SynapseError:
+ pass
+
+ try:
+ yield self.handler.upload_keys_for_user(
+ local_user, device_id, {
+ "one_time_keys": {
+ "alg2:k2": {
+ "key": "key3",
+ "signatures": {"k1": "sig1"},
+ }
+ },
+ },
+ )
+ self.fail("No error when replacing dict key")
+ except errors.SynapseError:
+ pass
|