summary refs log tree commit diff
path: root/synapse/handlers/e2e_keys.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/e2e_keys.py')
-rw-r--r--synapse/handlers/e2e_keys.py70
1 files changed, 63 insertions, 7 deletions
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 0073667470..d1ab95126c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -563,7 +563,9 @@ class E2eKeysHandler:
         return ret
 
     async def claim_local_one_time_keys(
-        self, local_query: List[Tuple[str, str, str]]
+        self,
+        local_query: List[Tuple[str, str, str]],
+        always_include_fallback_keys: bool,
     ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
         """Claim one time keys for local users.
 
@@ -573,6 +575,7 @@ class E2eKeysHandler:
 
         Args:
             local_query: An iterable of tuples of (user ID, device ID, algorithm).
+            always_include_fallback_keys: True to always include fallback keys.
 
         Returns:
             An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
@@ -583,24 +586,73 @@ class E2eKeysHandler:
         # If the application services have not provided any keys via the C-S
         # API, query it directly for one-time keys.
         if self._query_appservices_for_otks:
+            # TODO Should this query for fallback keys of uploaded OTKs if
+            #      always_include_fallback_keys is True? The MSC is ambiguous.
             (
                 appservice_results,
                 not_found,
             ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
         else:
-            appservice_results = []
+            appservice_results = {}
+
+        # Calculate which user ID / device ID / algorithm tuples to get fallback
+        # keys for. This can be either only missing results *or* all results
+        # (which don't already have a fallback key).
+        if always_include_fallback_keys:
+            # Build the fallback query as any part of the original query where
+            # the appservice didn't respond with a fallback key.
+            fallback_query = []
+
+            # Iterate each item in the original query and search the results
+            # from the appservice for that user ID / device ID. If it is found,
+            # check if any of the keys match the requested algorithm & are a
+            # fallback key.
+            for user_id, device_id, algorithm in local_query:
+                # Check if the appservice responded for this query.
+                as_result = appservice_results.get(user_id, {}).get(device_id, {})
+                found_otk = False
+                for key_id, key_json in as_result.items():
+                    if key_id.startswith(f"{algorithm}:"):
+                        # A OTK or fallback key was found for this query.
+                        found_otk = True
+                        # A fallback key was found for this query, no need to
+                        # query further.
+                        if key_json.get("fallback", False):
+                            break
+
+                else:
+                    # No fallback key was found from appservices, query for it.
+                    # Only mark the fallback key as used if no OTK was found
+                    # (from either the database or appservices).
+                    mark_as_used = not found_otk and not any(
+                        key_id.startswith(f"{algorithm}:")
+                        for key_id in otk_results.get(user_id, {})
+                        .get(device_id, {})
+                        .keys()
+                    )
+                    fallback_query.append((user_id, device_id, algorithm, mark_as_used))
+
+        else:
+            # All fallback keys get marked as used.
+            fallback_query = [
+                (user_id, device_id, algorithm, True)
+                for user_id, device_id, algorithm in not_found
+            ]
 
         # For each user that does not have a one-time keys available, see if
         # there is a fallback key.
-        fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
+        fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
 
         # Return the results in order, each item from the input query should
         # only appear once in the combined list.
-        return (otk_results, *appservice_results, fallback_results)
+        return (otk_results, appservice_results, fallback_results)
 
     @trace
     async def claim_one_time_keys(
-        self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
+        self,
+        query: Dict[str, Dict[str, Dict[str, str]]],
+        timeout: Optional[int],
+        always_include_fallback_keys: bool,
     ) -> JsonDict:
         local_query: List[Tuple[str, str, str]] = []
         remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@@ -617,7 +669,9 @@ class E2eKeysHandler:
         set_tag("local_key_query", str(local_query))
         set_tag("remote_key_query", str(remote_queries))
 
-        results = await self.claim_local_one_time_keys(local_query)
+        results = await self.claim_local_one_time_keys(
+            local_query, always_include_fallback_keys
+        )
 
         # A map of user ID -> device ID -> key ID -> key.
         json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
@@ -625,7 +679,9 @@ class E2eKeysHandler:
             for user_id, device_keys in result.items():
                 for device_id, keys in device_keys.items():
                     for key_id, key in keys.items():
-                        json_result.setdefault(user_id, {})[device_id] = {key_id: key}
+                        json_result.setdefault(user_id, {}).setdefault(
+                            device_id, {}
+                        ).update({key_id: key})
 
         # Remote failures.
         failures: Dict[str, JsonDict] = {}