summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/e2e_keys.py76
1 files changed, 34 insertions, 42 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 9081c3f64c..53ca8330ad 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -17,6 +17,8 @@
 
 import logging
 
+import time
+
 from six import iteritems
 
 from canonicaljson import encode_canonical_json, json
@@ -132,7 +134,7 @@ class E2eKeysHandler(object):
                 r[user_id] = remote_queries[user_id]
 
         # Get cached cross-signing keys
-        cross_signing_keys = yield self.query_cross_signing_keys(
+        cross_signing_keys = yield self.get_cross_signing_keys_from_cache(
             device_keys_query, from_user_id
         )
 
@@ -200,11 +202,11 @@ class E2eKeysHandler(object):
 
                 for user_id, key in remote_result["master_keys"].items():
                     if user_id in destination_query:
-                        cross_signing_keys["master"][user_id] = key
+                        cross_signing_keys["master_keys"][user_id] = key
 
                 for user_id, key in remote_result["self_signing_keys"].items():
                     if user_id in destination_query:
-                        cross_signing_keys["self_signing"][user_id] = key
+                        cross_signing_keys["self_signing_keys"][user_id] = key
 
             except Exception as e:
                 failure = _exception_to_failure(e)
@@ -222,14 +224,13 @@ class E2eKeysHandler(object):
 
         ret = {"device_keys": results, "failures": failures}
 
-        for key, value in iteritems(cross_signing_keys):
-            ret[key + "_keys"] = value
+        ret.update(cross_signing_keys)
 
         return ret
 
     @defer.inlineCallbacks
-    def query_cross_signing_keys(self, query, from_user_id):
-        """Get cross-signing keys for users
+    def get_cross_signing_keys_from_cache(self, query, from_user_id):
+        """Get cross-signing keys for users from the database
 
         Args:
             query (Iterable[string]) an iterable of user IDs.  A dict whose keys
@@ -250,43 +251,32 @@ class E2eKeysHandler(object):
         for user_id in query:
             # XXX: consider changing the store functions to allow querying
             # multiple users simultaneously.
-            try:
-                key = yield self.store.get_e2e_cross_signing_key(
-                    user_id, "master", from_user_id
-                )
-                if key:
-                    master_keys[user_id] = key
-            except Exception as e:
-                logger.info("Error getting master key: %s", e)
+            key = yield self.store.get_e2e_cross_signing_key(
+                user_id, "master", from_user_id
+            )
+            if key:
+                master_keys[user_id] = key
 
-            try:
-                key = yield self.store.get_e2e_cross_signing_key(
-                    user_id, "self_signing", from_user_id
-                )
-                if key:
-                    self_signing_keys[user_id] = key
-            except Exception as e:
-                logger.info("Error getting self-signing key: %s", e)
+            key = yield self.store.get_e2e_cross_signing_key(
+                user_id, "self_signing", from_user_id
+            )
+            if key:
+                self_signing_keys[user_id] = key
 
             # users can see other users' master and self-signing keys, but can
             # only see their own user-signing keys
             if from_user_id == user_id:
-                try:
-                    key = yield self.store.get_e2e_cross_signing_key(
-                        user_id, "user_signing", from_user_id
-                    )
-                    if key:
-                        user_signing_keys[user_id] = key
-                except Exception as e:
-                    logger.info("Error getting user-signing key: %s", e)
+                key = yield self.store.get_e2e_cross_signing_key(
+                    user_id, "user_signing", from_user_id
+                )
+                if key:
+                    user_signing_keys[user_id] = key
 
-        defer.returnValue(
-            {
-                "master": master_keys,
-                "self_signing": self_signing_keys,
-                "user_signing": user_signing_keys,
-            }
-        )
+        return {
+            "master_keys": master_keys,
+            "self_signing_keys": self_signing_keys,
+            "user_signing_keys": user_signing_keys,
+        }
 
     @defer.inlineCallbacks
     def query_local_devices(self, query):
@@ -542,11 +532,13 @@ class E2eKeysHandler(object):
         # if everything checks out, then store the keys and send notifications
         deviceids = []
         if "master_key" in keys:
-            yield self.store.set_e2e_cross_signing_key(user_id, "master", master_key)
+            yield self.store.set_e2e_cross_signing_key(
+                user_id, "master", master_key, time.time() * 1000
+            )
             deviceids.append(master_verify_key.version)
         if "self_signing_key" in keys:
             yield self.store.set_e2e_cross_signing_key(
-                user_id, "self_signing", self_signing_key
+                user_id, "self_signing", self_signing_key, time.time() * 1000
             )
             try:
                 deviceids.append(
@@ -556,7 +548,7 @@ class E2eKeysHandler(object):
                 raise SynapseError(400, "Invalid self-signing key", Codes.INVALID_PARAM)
         if "user_signing_key" in keys:
             yield self.store.set_e2e_cross_signing_key(
-                user_id, "user_signing", user_signing_key
+                user_id, "user_signing", user_signing_key, time.time() * 1000
             )
             # the signature stream matches the semantics that we want for
             # user-signing key updates: only the user themselves is notified of
@@ -568,7 +560,7 @@ class E2eKeysHandler(object):
         if len(deviceids):
             yield self.device_handler.notify_device_update(user_id, deviceids)
 
-        defer.returnValue({})
+        return {}
 
 
 def _check_cross_signing_key(key, user_id, key_type, signing_key=None):