summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r--synapse/handlers/e2e_keys.py57
1 files changed, 49 insertions, 8 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 4e9c8d8db0..9e7c2c45b5 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,7 +13,6 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
 import logging
 from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
 
@@ -53,6 +52,7 @@ class E2eKeysHandler:
         self.store = hs.get_datastores().main
         self.federation = hs.get_federation_client()
         self.device_handler = hs.get_device_handler()
+        self._appservice_handler = hs.get_application_service_handler()
         self.is_mine = hs.is_mine
         self.clock = hs.get_clock()
 
@@ -88,6 +88,10 @@ class E2eKeysHandler:
             max_count=10,
         )
 
+        self._query_appservices_for_otks = (
+            hs.config.experimental.msc3983_appservice_otk_claims
+        )
+
     @trace
     @cancellable
     async def query_devices(
@@ -542,6 +546,42 @@ class E2eKeysHandler:
 
         return ret
 
+    async def claim_local_one_time_keys(
+        self, local_query: List[Tuple[str, str, str]]
+    ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
+        """Claim one time keys for local users.
+
+        1. Attempt to claim OTKs from the database.
+        2. Ask application services if they provide OTKs.
+        3. Attempt to fetch fallback keys from the database.
+
+        Args:
+            local_query: An iterable of tuples of (user ID, device ID, algorithm).
+
+        Returns:
+            An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+        """
+
+        otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
+
+        # If the application services have not provided any keys via the C-S
+        # API, query it directly for one-time keys.
+        if self._query_appservices_for_otks:
+            (
+                appservice_results,
+                not_found,
+            ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
+        else:
+            appservice_results = []
+
+        # For each user that does not have a one-time keys available, see if
+        # there is a fallback key.
+        fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
+
+        # Return the results in order, each item from the input query should
+        # only appear once in the combined list.
+        return (otk_results, *appservice_results, fallback_results)
+
     @trace
     async def claim_one_time_keys(
         self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
@@ -561,17 +601,18 @@ class E2eKeysHandler:
         set_tag("local_key_query", str(local_query))
         set_tag("remote_key_query", str(remote_queries))
 
-        results = await self.store.claim_e2e_one_time_keys(local_query)
+        results = await self.claim_local_one_time_keys(local_query)
 
         # A map of user ID -> device ID -> key ID -> key.
         json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+        for result in results:
+            for user_id, device_keys in result.items():
+                for device_id, keys in device_keys.items():
+                    for key_id, key in keys.items():
+                        json_result.setdefault(user_id, {})[device_id] = {key_id: key}
+
+        # Remote failures.
         failures: Dict[str, JsonDict] = {}
-        for user_id, device_keys in results.items():
-            for device_id, keys in device_keys.items():
-                for key_id, json_str in keys.items():
-                    json_result.setdefault(user_id, {})[device_id] = {
-                        key_id: json_decoder.decode(json_str)
-                    }
 
         @trace
         async def claim_client_keys(destination: str) -> None: