diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index ec3ab968e9..953df4d9cd 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
from prometheus_client import Counter
@@ -829,3 +838,66 @@ class ApplicationServicesHandler:
if unknown_user:
return await self.query_user_exists(user_id)
return True
+
+ async def claim_e2e_one_time_keys(
+ self, query: Iterable[Tuple[str, str, str]]
+ ) -> Tuple[
+ Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
+ ]:
+ """Claim one time keys from application services.
+
+ Args:
+ query: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A tuple of:
+ An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+
+ A copy of the input which has not been fulfilled (either because
+ they are not appservice users or the appservice does not support
+ providing OTKs).
+ """
+ services = self.store.get_app_services()
+
+ # Partition the users by appservice.
+ query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
+ missing = []
+ for user_id, device, algorithm in query:
+ if not self.store.get_if_app_services_interested_in_user(user_id):
+ missing.append((user_id, device, algorithm))
+ continue
+
+ # Find the associated appservice.
+ for service in services:
+ if service.is_exclusive_user(user_id):
+ query_by_appservice.setdefault(service.id, []).append(
+ (user_id, device, algorithm)
+ )
+ continue
+
+ # Query each service in parallel.
+ results = await make_deferred_yieldable(
+ defer.DeferredList(
+ [
+ run_in_background(
+ self.appservice_api.claim_client_keys,
+ # We know this must be an app service.
+ self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
+ service_query,
+ )
+ for service_id, service_query in query_by_appservice.items()
+ ],
+ consumeErrors=True,
+ )
+ )
+
+ # Patch together the results -- they are all independent (since they
+ # require exclusive control over the users). They get returned as a list
+ # and the caller combines them.
+ claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
+ for success, result in results:
+ if success:
+ claimed_keys.append(result[0])
+ missing.extend(result[1])
+
+ return claimed_keys, missing
|