diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index b63a660c06..e40495d1ab 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -73,10 +73,9 @@ class E2eKeysHandler(object):
if self.is_mine_id(user_id):
local_query[user_id] = device_ids
else:
- domain = get_domain_from_id(user_id)
- remote_queries.setdefault(domain, {})[user_id] = device_ids
+ remote_queries[user_id] = device_ids
- # do the queries
+ # Firt get local devices.
failures = {}
results = {}
if local_query:
@@ -85,9 +84,42 @@ class E2eKeysHandler(object):
if user_id in local_query:
results[user_id] = keys
+ # Now attempt to get any remote devices from our local cache.
+ remote_queries_not_in_cache = {}
+ if remote_queries:
+ query_list = []
+ for user_id, device_ids in remote_queries.iteritems():
+ if device_ids:
+ query_list.extend((user_id, device_id) for device_id in device_ids)
+ else:
+ query_list.append((user_id, None))
+
+ user_ids_not_in_cache, remote_results = (
+ yield self.store.get_user_devices_from_cache(
+ query_list
+ )
+ )
+ for user_id, devices in remote_results.iteritems():
+ user_devices = results.setdefault(user_id, {})
+ for device_id, device in devices.iteritems():
+ keys = device.get("keys", None)
+ device_display_name = device.get("device_display_name", None)
+ if keys:
+ result = dict(keys)
+ unsigned = result.setdefault("unsigned", {})
+ if device_display_name:
+ unsigned["device_display_name"] = device_display_name
+ user_devices[device_id] = result
+
+ for user_id in user_ids_not_in_cache:
+ domain = get_domain_from_id(user_id)
+ r = remote_queries_not_in_cache.setdefault(domain, {})
+ r[user_id] = remote_queries[user_id]
+
+ # Now fetch any devices that we don't have in our cache
@defer.inlineCallbacks
def do_remote_query(destination):
- destination_query = remote_queries[destination]
+ destination_query = remote_queries_not_in_cache[destination]
try:
limiter = yield get_retry_limiter(
destination, self.clock, self.store
@@ -119,7 +151,7 @@ class E2eKeysHandler(object):
yield preserve_context_over_deferred(defer.gatherResults([
preserve_fn(do_remote_query)(destination)
- for destination in remote_queries
+ for destination in remote_queries_not_in_cache
]))
defer.returnValue({
@@ -162,7 +194,7 @@ class E2eKeysHandler(object):
# "unsigned" section
for user_id, device_keys in results.items():
for device_id, device_info in device_keys.items():
- r = json.loads(device_info["key_json"])
+ r = dict(device_info["keys"])
r["unsigned"] = {}
display_name = device_info["device_display_name"]
if display_name is not None:
@@ -255,10 +287,12 @@ class E2eKeysHandler(object):
device_id, user_id, time_now
)
# TODO: Sign the JSON with the server key
- yield self.store.set_e2e_device_keys(
- user_id, device_id, time_now,
- encode_canonical_json(device_keys)
+ changed = yield self.store.set_e2e_device_keys(
+ user_id, device_id, time_now, device_keys,
)
+ if changed:
+ # Only notify about device updates *if* the keys actually changed
+ yield self.device_handler.notify_device_update(user_id, [device_id])
one_time_keys = keys.get("one_time_keys", None)
if one_time_keys:
|