summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/federation/federation_server.py10
-rw-r--r--synapse/handlers/e2e_keys.py86
-rw-r--r--synapse/rest/client/v2_alpha/register.py2
-rw-r--r--synapse/storage/end_to_end_keys.py47
-rw-r--r--tests/handlers/test_e2e_keys.py132
5 files changed, 240 insertions, 37 deletions
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index bc20b9c201..51e3fdea06 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -440,6 +440,16 @@ class FederationServer(FederationBase):
                         key_id: json.loads(json_bytes)
                     }
 
+        logger.info(
+            "Claimed one-time-keys: %s",
+            ",".join((
+                "%s for %s:%s" % (key_id, user_id, device_id)
+                for user_id, user_keys in json_result.iteritems()
+                for device_id, device_keys in user_keys.iteritems()
+                for key_id, _ in device_keys.iteritems()
+            )),
+        )
+
         defer.returnValue({"one_time_keys": json_result})
 
     @defer.inlineCallbacks
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index c2b38d72a9..668a90e495 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -21,7 +21,7 @@ from twisted.internet import defer
 
 from synapse.api.errors import SynapseError, CodeMessageException
 from synapse.types import get_domain_from_id
-from synapse.util.logcontext import preserve_fn, preserve_context_over_deferred
+from synapse.util.logcontext import preserve_fn, make_deferred_yieldable
 from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
@@ -145,7 +145,7 @@ class E2eKeysHandler(object):
                     "status": 503, "message": e.message
                 }
 
-        yield preserve_context_over_deferred(defer.gatherResults([
+        yield make_deferred_yieldable(defer.gatherResults([
             preserve_fn(do_remote_query)(destination)
             for destination in remote_queries_not_in_cache
         ]))
@@ -257,11 +257,21 @@ class E2eKeysHandler(object):
                     "status": 503, "message": e.message
                 }
 
-        yield preserve_context_over_deferred(defer.gatherResults([
+        yield make_deferred_yieldable(defer.gatherResults([
             preserve_fn(claim_client_keys)(destination)
             for destination in remote_queries
         ]))
 
+        logger.info(
+            "Claimed one-time-keys: %s",
+            ",".join((
+                "%s for %s:%s" % (key_id, user_id, device_id)
+                for user_id, user_keys in json_result.iteritems()
+                for device_id, device_keys in user_keys.iteritems()
+                for key_id, _ in device_keys.iteritems()
+            )),
+        )
+
         defer.returnValue({
             "one_time_keys": json_result,
             "failures": failures
@@ -288,19 +298,8 @@ class E2eKeysHandler(object):
 
         one_time_keys = keys.get("one_time_keys", None)
         if one_time_keys:
-            logger.info(
-                "Adding %d one_time_keys for device %r for user %r at %d",
-                len(one_time_keys), device_id, user_id, time_now
-            )
-            key_list = []
-            for key_id, key_json in one_time_keys.items():
-                algorithm, key_id = key_id.split(":")
-                key_list.append((
-                    algorithm, key_id, encode_canonical_json(key_json)
-                ))
-
-            yield self.store.add_e2e_one_time_keys(
-                user_id, device_id, time_now, key_list
+            yield self._upload_one_time_keys_for_user(
+                user_id, device_id, time_now, one_time_keys,
             )
 
         # the device should have been registered already, but it may have been
@@ -313,3 +312,58 @@ class E2eKeysHandler(object):
         result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
 
         defer.returnValue({"one_time_key_counts": result})
+
+    @defer.inlineCallbacks
+    def _upload_one_time_keys_for_user(self, user_id, device_id, time_now,
+                                       one_time_keys):
+        logger.info(
+            "Adding one_time_keys %r for device %r for user %r at %d",
+            one_time_keys.keys(), device_id, user_id, time_now,
+        )
+
+        # make a list of (alg, id, key) tuples
+        key_list = []
+        for key_id, key_obj in one_time_keys.items():
+            algorithm, key_id = key_id.split(":")
+            key_list.append((
+                algorithm, key_id, key_obj
+            ))
+
+        # First we check if we have already persisted any of the keys.
+        existing_key_map = yield self.store.get_e2e_one_time_keys(
+            user_id, device_id, [k_id for _, k_id, _ in key_list]
+        )
+
+        new_keys = []  # Keys that we need to insert. (alg, id, json) tuples.
+        for algorithm, key_id, key in key_list:
+            ex_json = existing_key_map.get((algorithm, key_id), None)
+            if ex_json:
+                if not _one_time_keys_match(ex_json, key):
+                    raise SynapseError(
+                        400,
+                        ("One time key %s:%s already exists. "
+                         "Old key: %s; new key: %r") %
+                        (algorithm, key_id, ex_json, key)
+                    )
+            else:
+                new_keys.append((algorithm, key_id, encode_canonical_json(key)))
+
+        yield self.store.add_e2e_one_time_keys(
+            user_id, device_id, time_now, new_keys
+        )
+
+
+def _one_time_keys_match(old_key_json, new_key):
+    old_key = json.loads(old_key_json)
+
+    # if either is a string rather than an object, they must match exactly
+    if not isinstance(old_key, dict) or not isinstance(new_key, dict):
+        return old_key == new_key
+
+    # otherwise, we strip off the 'signatures' if any, because it's legitimate
+    # for different upload attempts to have different signatures.
+    old_key.pop("signatures", None)
+    new_key_copy = dict(new_key)
+    new_key_copy.pop("signatures", None)
+
+    return old_key == new_key_copy
diff --git a/synapse/rest/client/v2_alpha/register.py b/synapse/rest/client/v2_alpha/register.py
index 38a739f2f8..6a7cd96ea5 100644
--- a/synapse/rest/client/v2_alpha/register.py
+++ b/synapse/rest/client/v2_alpha/register.py
@@ -142,7 +142,7 @@ class UsernameAvailabilityRestServlet(RestServlet):
         )
 
     @defer.inlineCallbacks
-    def on_GET(self, request):
+    def on_POST(self, request):
         ip = self.hs.get_ip_from_request(request)
         with self.ratelimiter.ratelimit(ip) as wait_deferred:
             yield wait_deferred
diff --git a/synapse/storage/end_to_end_keys.py b/synapse/storage/end_to_end_keys.py
index c96dae352d..e00f31da2b 100644
--- a/synapse/storage/end_to_end_keys.py
+++ b/synapse/storage/end_to_end_keys.py
@@ -14,7 +14,6 @@
 # limitations under the License.
 from twisted.internet import defer
 
-from synapse.api.errors import SynapseError
 from synapse.util.caches.descriptors import cached
 
 from canonicaljson import encode_canonical_json
@@ -124,18 +123,24 @@ class EndToEndKeyStore(SQLBaseStore):
         return result
 
     @defer.inlineCallbacks
-    def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list):
-        """Insert some new one time keys for a device.
+    def get_e2e_one_time_keys(self, user_id, device_id, key_ids):
+        """Retrieve a number of one-time keys for a user
 
-        Checks if any of the keys are already inserted, if they are then check
-        if they match. If they don't then we raise an error.
+        Args:
+            user_id(str): id of user to get keys for
+            device_id(str): id of device to get keys for
+            key_ids(list[str]): list of key ids (excluding algorithm) to
+                retrieve
+
+        Returns:
+            deferred resolving to Dict[(str, str), str]: map from (algorithm,
+            key_id) to json string for key
         """
 
-        # First we check if we have already persisted any of the keys.
         rows = yield self._simple_select_many_batch(
             table="e2e_one_time_keys_json",
             column="key_id",
-            iterable=[key_id for _, key_id, _ in key_list],
+            iterable=key_ids,
             retcols=("algorithm", "key_id", "key_json",),
             keyvalues={
                 "user_id": user_id,
@@ -144,20 +149,22 @@ class EndToEndKeyStore(SQLBaseStore):
             desc="add_e2e_one_time_keys_check",
         )
 
-        existing_key_map = {
+        defer.returnValue({
             (row["algorithm"], row["key_id"]): row["key_json"] for row in rows
-        }
-
-        new_keys = []  # Keys that we need to insert
-        for algorithm, key_id, json_bytes in key_list:
-            ex_bytes = existing_key_map.get((algorithm, key_id), None)
-            if ex_bytes:
-                if json_bytes != ex_bytes:
-                    raise SynapseError(
-                        400, "One time key with key_id %r already exists" % (key_id,)
-                    )
-            else:
-                new_keys.append((algorithm, key_id, json_bytes))
+        })
+
+    @defer.inlineCallbacks
+    def add_e2e_one_time_keys(self, user_id, device_id, time_now, new_keys):
+        """Insert some new one time keys for a device. Errors if any of the
+        keys already exist.
+
+        Args:
+            user_id(str): id of user to get keys for
+            device_id(str): id of device to get keys for
+            time_now(long): insertion time to record (ms since epoch)
+            new_keys(iterable[(str, str, str)]: keys to add - each a tuple of
+                (algorithm, key_id, key json)
+        """
 
         def _add_e2e_one_time_keys(txn):
             # We are protected from race between lookup and insertion due to
diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py
index 878a54dc34..19f5ed6bce 100644
--- a/tests/handlers/test_e2e_keys.py
+++ b/tests/handlers/test_e2e_keys.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import mock
+from synapse.api import errors
 from twisted.internet import defer
 
 import synapse.api.errors
@@ -44,3 +45,134 @@ class E2eKeysHandlerTestCase(unittest.TestCase):
         local_user = "@boris:" + self.hs.hostname
         res = yield self.handler.query_local_devices({local_user: None})
         self.assertDictEqual(res, {local_user: {}})
+
+    @defer.inlineCallbacks
+    def test_reupload_one_time_keys(self):
+        """we should be able to re-upload the same keys"""
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        keys = {
+            "alg1:k1": "key1",
+            "alg2:k2": {
+                "key": "key2",
+                "signatures": {"k1": "sig1"}
+            },
+            "alg2:k3": {
+                "key": "key3",
+            },
+        }
+
+        res = yield self.handler.upload_keys_for_user(
+            local_user, device_id, {"one_time_keys": keys},
+        )
+        self.assertDictEqual(res, {
+            "one_time_key_counts": {"alg1": 1, "alg2": 2}
+        })
+
+        # we should be able to change the signature without a problem
+        keys["alg2:k2"]["signatures"]["k1"] = "sig2"
+        res = yield self.handler.upload_keys_for_user(
+            local_user, device_id, {"one_time_keys": keys},
+        )
+        self.assertDictEqual(res, {
+            "one_time_key_counts": {"alg1": 1, "alg2": 2}
+        })
+
+    @defer.inlineCallbacks
+    def test_change_one_time_keys(self):
+        """attempts to change one-time-keys should be rejected"""
+
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        keys = {
+            "alg1:k1": "key1",
+            "alg2:k2": {
+                "key": "key2",
+                "signatures": {"k1": "sig1"}
+            },
+            "alg2:k3": {
+                "key": "key3",
+            },
+        }
+
+        res = yield self.handler.upload_keys_for_user(
+            local_user, device_id, {"one_time_keys": keys},
+        )
+        self.assertDictEqual(res, {
+            "one_time_key_counts": {"alg1": 1, "alg2": 2}
+        })
+
+        try:
+            yield self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg1:k1": "key2"}},
+            )
+            self.fail("No error when changing string key")
+        except errors.SynapseError:
+            pass
+
+        try:
+            yield self.handler.upload_keys_for_user(
+                local_user, device_id, {"one_time_keys": {"alg2:k3": "key2"}},
+            )
+            self.fail("No error when replacing dict key with string")
+        except errors.SynapseError:
+            pass
+
+        try:
+            yield self.handler.upload_keys_for_user(
+                local_user, device_id, {
+                    "one_time_keys": {"alg1:k1": {"key": "key"}}
+                },
+            )
+            self.fail("No error when replacing string key with dict")
+        except errors.SynapseError:
+            pass
+
+        try:
+            yield self.handler.upload_keys_for_user(
+                local_user, device_id, {
+                    "one_time_keys": {
+                        "alg2:k2": {
+                            "key": "key3",
+                            "signatures": {"k1": "sig1"},
+                        }
+                    },
+                },
+            )
+            self.fail("No error when replacing dict key")
+        except errors.SynapseError:
+            pass
+
+    @unittest.DEBUG
+    @defer.inlineCallbacks
+    def test_claim_one_time_key(self):
+        local_user = "@boris:" + self.hs.hostname
+        device_id = "xyz"
+        keys = {
+            "alg1:k1": "key1",
+        }
+
+        res = yield self.handler.upload_keys_for_user(
+            local_user, device_id, {"one_time_keys": keys},
+        )
+        self.assertDictEqual(res, {
+            "one_time_key_counts": {"alg1": 1}
+        })
+
+        res2 = yield self.handler.claim_one_time_keys({
+            "one_time_keys": {
+                local_user: {
+                    device_id: "alg1"
+                }
+            }
+        }, timeout=None)
+        self.assertEqual(res2, {
+            "failures": {},
+            "one_time_keys": {
+                local_user: {
+                    device_id: {
+                        "alg1:k1": "key1"
+                    }
+                }
+            }
+        })