diff options
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r-- | synapse/handlers/e2e_keys.py | 43 |
1 files changed, 38 insertions, 5 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index b63a660c06..a16b9def8d 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -73,10 +73,9 @@ class E2eKeysHandler(object): if self.is_mine_id(user_id): local_query[user_id] = device_ids else: - domain = get_domain_from_id(user_id) - remote_queries.setdefault(domain, {})[user_id] = device_ids + remote_queries[user_id] = device_ids - # do the queries + # Firt get local devices. failures = {} results = {} if local_query: @@ -85,9 +84,42 @@ class E2eKeysHandler(object): if user_id in local_query: results[user_id] = keys + # Now attempt to get any remote devices from our local cache. + remote_queries_not_in_cache = {} + if remote_queries: + query_list = [] + for user_id, device_ids in remote_queries.iteritems(): + if device_ids: + query_list.extend((user_id, device_id) for device_id in device_ids) + else: + query_list.append((user_id, None)) + + user_ids_not_in_cache, remote_results = ( + yield self.store.get_user_devices_from_cache( + query_list + ) + ) + for user_id, devices in remote_results.iteritems(): + user_devices = results.setdefault(user_id, {}) + for device_id, device in devices.iteritems(): + keys = device.get("keys", None) + device_display_name = device.get("device_display_name", None) + if keys: + result = dict(keys) + unsigned = result.setdefault("unsigned", {}) + if device_display_name: + unsigned["device_display_name"] = device_display_name + user_devices[device_id] = result + + for user_id in user_ids_not_in_cache: + domain = get_domain_from_id(user_id) + r = remote_queries_not_in_cache.setdefault(domain, {}) + r[user_id] = remote_queries[user_id] + + # Now fetch any devices that we don't have in our cache @defer.inlineCallbacks def do_remote_query(destination): - destination_query = remote_queries[destination] + destination_query = remote_queries_not_in_cache[destination] try: limiter = yield get_retry_limiter( destination, self.clock, self.store @@ -119,7 +151,7 @@ class E2eKeysHandler(object): yield preserve_context_over_deferred(defer.gatherResults([ preserve_fn(do_remote_query)(destination) - for destination in remote_queries + for destination in remote_queries_not_in_cache ])) defer.returnValue({ @@ -259,6 +291,7 @@ class E2eKeysHandler(object): user_id, device_id, time_now, encode_canonical_json(device_keys) ) + yield self.device_handler.notify_device_update(user_id, [device_id]) one_time_keys = keys.get("one_time_keys", None) if one_time_keys: |