diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index da887647d4..4ca2bc0420 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -842,9 +842,7 @@ class ApplicationServicesHandler:
async def claim_e2e_one_time_keys(
self, query: Iterable[Tuple[str, str, str]]
- ) -> Tuple[
- Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
- ]:
+ ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Claim one time keys from application services.
Users which are exclusively owned by an application service are sent a
@@ -856,7 +854,7 @@ class ApplicationServicesHandler:
Returns:
A tuple of:
- An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+ A map of user ID -> a map device ID -> a map of key ID -> JSON.
A copy of the input which has not been fulfilled (either because
they are not appservice users or the appservice does not support
@@ -897,12 +895,11 @@ class ApplicationServicesHandler:
)
# Patch together the results -- they are all independent (since they
- # require exclusive control over the users). They get returned as a list
- # and the caller combines them.
- claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
+ # require exclusive control over the users, which is the outermost key).
+ claimed_keys: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
for success, result in results:
if success:
- claimed_keys.append(result[0])
+ claimed_keys.update(result[0])
missing.extend(result[1])
return claimed_keys, missing
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 0073667470..d1ab95126c 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -563,7 +563,9 @@ class E2eKeysHandler:
return ret
async def claim_local_one_time_keys(
- self, local_query: List[Tuple[str, str, str]]
+ self,
+ local_query: List[Tuple[str, str, str]],
+ always_include_fallback_keys: bool,
) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
"""Claim one time keys for local users.
@@ -573,6 +575,7 @@ class E2eKeysHandler:
Args:
local_query: An iterable of tuples of (user ID, device ID, algorithm).
+ always_include_fallback_keys: True to always include fallback keys.
Returns:
An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
@@ -583,24 +586,73 @@ class E2eKeysHandler:
# If the application services have not provided any keys via the C-S
# API, query it directly for one-time keys.
if self._query_appservices_for_otks:
+ # TODO Should this query for fallback keys of uploaded OTKs if
+ # always_include_fallback_keys is True? The MSC is ambiguous.
(
appservice_results,
not_found,
) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
else:
- appservice_results = []
+ appservice_results = {}
+
+ # Calculate which user ID / device ID / algorithm tuples to get fallback
+ # keys for. This can be either only missing results *or* all results
+ # (which don't already have a fallback key).
+ if always_include_fallback_keys:
+ # Build the fallback query as any part of the original query where
+ # the appservice didn't respond with a fallback key.
+ fallback_query = []
+
+ # Iterate each item in the original query and search the results
+ # from the appservice for that user ID / device ID. If it is found,
+ # check if any of the keys match the requested algorithm & are a
+ # fallback key.
+ for user_id, device_id, algorithm in local_query:
+ # Check if the appservice responded for this query.
+ as_result = appservice_results.get(user_id, {}).get(device_id, {})
+ found_otk = False
+ for key_id, key_json in as_result.items():
+ if key_id.startswith(f"{algorithm}:"):
+ # A OTK or fallback key was found for this query.
+ found_otk = True
+ # A fallback key was found for this query, no need to
+ # query further.
+ if key_json.get("fallback", False):
+ break
+
+ else:
+ # No fallback key was found from appservices, query for it.
+ # Only mark the fallback key as used if no OTK was found
+ # (from either the database or appservices).
+ mark_as_used = not found_otk and not any(
+ key_id.startswith(f"{algorithm}:")
+ for key_id in otk_results.get(user_id, {})
+ .get(device_id, {})
+ .keys()
+ )
+ fallback_query.append((user_id, device_id, algorithm, mark_as_used))
+
+ else:
+ # All fallback keys get marked as used.
+ fallback_query = [
+ (user_id, device_id, algorithm, True)
+ for user_id, device_id, algorithm in not_found
+ ]
# For each user that does not have a one-time keys available, see if
# there is a fallback key.
- fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
+ fallback_results = await self.store.claim_e2e_fallback_keys(fallback_query)
# Return the results in order, each item from the input query should
# only appear once in the combined list.
- return (otk_results, *appservice_results, fallback_results)
+ return (otk_results, appservice_results, fallback_results)
@trace
async def claim_one_time_keys(
- self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
+ self,
+ query: Dict[str, Dict[str, Dict[str, str]]],
+ timeout: Optional[int],
+ always_include_fallback_keys: bool,
) -> JsonDict:
local_query: List[Tuple[str, str, str]] = []
remote_queries: Dict[str, Dict[str, Dict[str, str]]] = {}
@@ -617,7 +669,9 @@ class E2eKeysHandler:
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))
- results = await self.claim_local_one_time_keys(local_query)
+ results = await self.claim_local_one_time_keys(
+ local_query, always_include_fallback_keys
+ )
# A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
@@ -625,7 +679,9 @@ class E2eKeysHandler:
for user_id, device_keys in result.items():
for device_id, keys in device_keys.items():
for key_id, key in keys.items():
- json_result.setdefault(user_id, {})[device_id] = {key_id: key}
+ json_result.setdefault(user_id, {}).setdefault(
+ device_id, {}
+ ).update({key_id: key})
# Remote failures.
failures: Dict[str, JsonDict] = {}
|