summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2017-01-26 16:06:54 +0000
committerErik Johnston <erik@matrix.org>2017-01-26 16:07:24 +0000
commitc974116f197d211ba9b42159fe61cfd5957411b5 (patch)
treeb4d6e71850ba0d371e0463a5cc99439ce8072c40 /synapse/handlers/e2e_keys.py
parentFix up sending of m.device_list_update edus (diff)
downloadsynapse-c974116f197d211ba9b42159fe61cfd5957411b5.tar.xz
Implement device key caching over federation
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r--synapse/handlers/e2e_keys.py40
1 files changed, 35 insertions, 5 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 38c2a2d39e..832998a6d3 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -73,8 +73,7 @@ class E2eKeysHandler(object):
             if self.is_mine_id(user_id):
                 local_query[user_id] = device_ids
             else:
-                domain = get_domain_from_id(user_id)
-                remote_queries.setdefault(domain, {})[user_id] = device_ids
+                remote_queries[user_id] = device_ids
 
         # do the queries
         failures = {}
@@ -85,9 +84,40 @@ class E2eKeysHandler(object):
                 if user_id in local_query:
                     results[user_id] = keys
 
+        remote_queries_not_in_cache = {}
+        if remote_queries:
+            query_list = []
+            for user_id, device_ids in remote_queries.iteritems():
+                if device_ids:
+                    query_list.extend((user_id, device_id) for device_id in device_ids)
+                else:
+                    query_list.append((user_id, None))
+
+            user_ids_not_in_cache, remote_results = (
+                yield self.store.get_user_devices_from_cache(
+                    query_list
+                )
+            )
+            for user_id, devices in remote_results.iteritems():
+                user_devices = results.setdefault(user_id, {})
+                for device_id, device in devices.iteritems():
+                    keys = device.get("keys", None)
+                    device_display_name = device.get("device_display_name", None)
+                    if keys:
+                        result = dict(keys)
+                        unsigned = result.setdefault("unsigned", {})
+                        if device_display_name:
+                            unsigned["device_display_name"] = device_display_name
+                        user_devices[device_id] = result
+
+            for user_id in user_ids_not_in_cache:
+                domain = get_domain_from_id(user_id)
+                r = remote_queries_not_in_cache.setdefault(domain, {})
+                r[user_id] = remote_queries[user_id]
+
         @defer.inlineCallbacks
         def do_remote_query(destination):
-            destination_query = remote_queries[destination]
+            destination_query = remote_queries_not_in_cache[destination]
             try:
                 limiter = yield get_retry_limiter(
                     destination, self.clock, self.store
@@ -119,7 +149,7 @@ class E2eKeysHandler(object):
 
         yield preserve_context_over_deferred(defer.gatherResults([
             preserve_fn(do_remote_query)(destination)
-            for destination in remote_queries
+            for destination in remote_queries_not_in_cache
         ]))
 
         defer.returnValue({
@@ -259,7 +289,7 @@ class E2eKeysHandler(object):
                 user_id, device_id, time_now,
                 encode_canonical_json(device_keys)
             )
-            yield self.device_handler.notify_device_update(user_id, device_id)
+            yield self.device_handler.notify_device_update(user_id, [device_id])
 
         one_time_keys = keys.get("one_time_keys", None)
         if one_time_keys: