summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r--synapse/handlers/e2e_keys.py71
1 files changed, 63 insertions, 8 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 6187f879ef..39f4ec8e60 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -31,6 +31,7 @@ from synapse.types import (
     get_domain_from_id,
     get_verify_key_from_cross_signing_key,
 )
+from synapse.util import unwrapFirstError
 from synapse.util.retryutils import NotRetryingDestination
 
 logger = logging.getLogger(__name__)
@@ -76,6 +77,7 @@ class E2eKeysHandler(object):
                 adding cross-signing signatures to limit what signatures users
                 can see.
         """
+
         device_keys_query = query_body.get("device_keys", {})
 
         # separate users by domain.
@@ -137,7 +139,56 @@ class E2eKeysHandler(object):
         # Now fetch any devices that we don't have in our cache
         @defer.inlineCallbacks
         def do_remote_query(destination):
+            """This is called when we are querying the device list of a user on
+            a remote homeserver and their device list is not in the device list
+            cache. If we share a room with this user and we're not querying for
+            specific user we will update the cache
+            with their device list."""
+
             destination_query = remote_queries_not_in_cache[destination]
+
+            # We first consider whether we wish to update the device list cache with
+            # the users device list. We want to track a user's devices when the
+            # authenticated user shares a room with the queried user and the query
+            # has not specified a particular device.
+            # If we update the cache for the queried user we remove them from further
+            # queries. We use the more efficient batched query_client_keys for all
+            # remaining users
+            user_ids_updated = []
+            for (user_id, device_list) in destination_query.items():
+                if user_id in user_ids_updated:
+                    continue
+
+                if device_list:
+                    continue
+
+                room_ids = yield self.store.get_rooms_for_user(user_id)
+                if not room_ids:
+                    continue
+
+                # We've decided we're sharing a room with this user and should
+                # probably be tracking their device lists. However, we haven't
+                # done an initial sync on the device list so we do it now.
+                try:
+                    user_devices = yield self.device_handler.device_list_updater.user_device_resync(
+                        user_id
+                    )
+                    user_devices = user_devices["devices"]
+                    for device in user_devices:
+                        results[user_id] = {device["device_id"]: device["keys"]}
+                    user_ids_updated.append(user_id)
+                except Exception as e:
+                    failures[destination] = _exception_to_failure(e)
+
+            if len(destination_query) == len(user_ids_updated):
+                # We've updated all the users in the query and we do not need to
+                # make any further remote calls.
+                return
+
+            # Remove all the users from the query which we have updated
+            for user_id in user_ids_updated:
+                destination_query.pop(user_id)
+
             try:
                 remote_result = yield self.federation.query_client_keys(
                     destination, {"device_keys": destination_query}, timeout=timeout
@@ -156,7 +207,8 @@ class E2eKeysHandler(object):
                         cross_signing_keys["self_signing"][user_id] = key
 
             except Exception as e:
-                failures[destination] = _exception_to_failure(e)
+                failure = _exception_to_failure(e)
+                failures[destination] = failure
 
         yield make_deferred_yieldable(
             defer.gatherResults(
@@ -165,7 +217,7 @@ class E2eKeysHandler(object):
                     for destination in remote_queries_not_in_cache
                 ],
                 consumeErrors=True,
-            )
+            ).addErrback(unwrapFirstError)
         )
 
         ret = {"device_keys": results, "failures": failures}
@@ -173,7 +225,7 @@ class E2eKeysHandler(object):
         for key, value in iteritems(cross_signing_keys):
             ret[key + "_keys"] = value
 
-        defer.returnValue(ret)
+        return ret
 
     @defer.inlineCallbacks
     def query_cross_signing_keys(self, query, from_user_id):
@@ -279,7 +331,7 @@ class E2eKeysHandler(object):
                     r["unsigned"]["device_display_name"] = display_name
                 result_dict[user_id][device_id] = r
 
-        defer.returnValue(result_dict)
+        return result_dict
 
     @defer.inlineCallbacks
     def on_federation_query_client_keys(self, query_body):
@@ -287,7 +339,7 @@ class E2eKeysHandler(object):
         """
         device_keys_query = query_body.get("device_keys", {})
         res = yield self.query_local_devices(device_keys_query)
-        defer.returnValue({"device_keys": res})
+        return {"device_keys": res}
 
     @defer.inlineCallbacks
     def claim_one_time_keys(self, query, timeout):
@@ -324,8 +376,10 @@ class E2eKeysHandler(object):
                 for user_id, keys in remote_result["one_time_keys"].items():
                     if user_id in device_keys:
                         json_result[user_id] = keys
+
             except Exception as e:
-                failures[destination] = _exception_to_failure(e)
+                failure = _exception_to_failure(e)
+                failures[destination] = failure
 
         yield make_deferred_yieldable(
             defer.gatherResults(
@@ -349,10 +403,11 @@ class E2eKeysHandler(object):
             ),
         )
 
-        defer.returnValue({"one_time_keys": json_result, "failures": failures})
+        return {"one_time_keys": json_result, "failures": failures}
 
     @defer.inlineCallbacks
     def upload_keys_for_user(self, user_id, device_id, keys):
+
         time_now = self.clock.time_msec()
 
         # TODO: Validate the JSON to make sure it has the right keys.
@@ -387,7 +442,7 @@ class E2eKeysHandler(object):
 
         result = yield self.store.count_e2e_one_time_keys(user_id, device_id)
 
-        defer.returnValue({"one_time_key_counts": result})
+        return {"one_time_key_counts": result}
 
     @defer.inlineCallbacks
     def _upload_one_time_keys_for_user(