diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py
index 4812fb4496..51ee0e79df 100644
--- a/synapse/appservice/api.py
+++ b/synapse/appservice/api.py
@@ -388,6 +388,62 @@ class ApplicationServiceApi(SimpleHttpClient):
failed_transactions_counter.labels(service.id).inc()
return False
+ async def claim_client_keys(
+ self, service: "ApplicationService", query: List[Tuple[str, str, str]]
+ ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
+ """Claim one time keys from an application service.
+
+ Args:
+ query: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A tuple of:
+ A map of user ID -> a map device ID -> a map of key ID -> JSON dict.
+
+ A copy of the input which has not been fulfilled because the
+ appservice doesn't support this endpoint or has not returned
+ data for that tuple.
+ """
+ if service.url is None:
+ return {}, query
+
+ # This is required by the configuration.
+ assert service.hs_token is not None
+
+ # Create the expected payload shape.
+ body: Dict[str, Dict[str, List[str]]] = {}
+ for user_id, device, algorithm in query:
+ body.setdefault(user_id, {}).setdefault(device, []).append(algorithm)
+
+ uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim"
+ try:
+ response = await self.post_json_get_json(
+ uri,
+ body,
+ headers={"Authorization": [f"Bearer {service.hs_token}"]},
+ )
+ except CodeMessageException as e:
+ # The appservice doesn't support this endpoint.
+ if e.code == 404 or e.code == 405:
+ return {}, query
+ logger.warning("claim_keys to %s received %s", uri, e.code)
+ return {}, query
+ except Exception as ex:
+ logger.warning("claim_keys to %s threw exception %s", uri, ex)
+ return {}, query
+
+ # Check if the appservice fulfilled all of the queried user/device/algorithms
+ # or if some are still missing.
+ #
+ # TODO This places a lot of faith in the response shape being correct.
+ missing = [
+ (user_id, device, algorithm)
+ for user_id, device, algorithm in query
+ if algorithm not in response.get(user_id, {}).get(device, [])
+ ]
+
+ return response, missing
+
def _serialize(
self, service: "ApplicationService", events: Iterable[EventBase]
) -> List[JsonDict]:
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
index 99dcd27c74..53e6fc2b54 100644
--- a/synapse/config/experimental.py
+++ b/synapse/config/experimental.py
@@ -74,6 +74,11 @@ class ExperimentalConfig(Config):
"msc3202_transaction_extensions", False
)
+ # MSC3983: Proxying OTK claim requests to exclusive ASes.
+ self.msc3983_appservice_otk_claims: bool = experimental.get(
+ "msc3983_appservice_otk_claims", False
+ )
+
# MSC3706 (server-side support for partial state in /send_join responses)
# Synapse will always serve partial state responses to requests using the stable
# query parameter `omit_members`. If this flag is set, Synapse will also serve
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 6d99845de5..64e99292ec 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -86,7 +86,7 @@ from synapse.storage.databases.main.lock import Lock
from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary
from synapse.storage.roommember import MemberSummary
from synapse.types import JsonDict, StateMap, get_domain_from_id
-from synapse.util import json_decoder, unwrapFirstError
+from synapse.util import unwrapFirstError
from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results
from synapse.util.caches.response_cache import ResponseCache
from synapse.util.stringutils import parse_server_name
@@ -135,6 +135,7 @@ class FederationServer(FederationBase):
self.state = hs.get_state_handler()
self._event_auth_handler = hs.get_event_auth_handler()
self._room_member_handler = hs.get_room_member_handler()
+ self._e2e_keys_handler = hs.get_e2e_keys_handler()
self._state_storage_controller = hs.get_storage_controllers().state
@@ -1012,15 +1013,14 @@ class FederationServer(FederationBase):
query.append((user_id, device_id, algorithm))
log_kv({"message": "Claiming one time keys.", "user, device pairs": query})
- results = await self.store.claim_e2e_one_time_keys(query)
-
- json_result: Dict[str, Dict[str, dict]] = {}
- for user_id, device_keys in results.items():
- for device_id, keys in device_keys.items():
- for key_id, json_str in keys.items():
- json_result.setdefault(user_id, {})[device_id] = {
- key_id: json_decoder.decode(json_str)
- }
+ results = await self._e2e_keys_handler.claim_local_one_time_keys(query)
+
+ json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+ for result in results:
+ for user_id, device_keys in result.items():
+ for device_id, keys in device_keys.items():
+ for key_id, key in keys.items():
+ json_result.setdefault(user_id, {})[device_id] = {key_id: key}
logger.info(
"Claimed one-time-keys: %s",
diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py
index ec3ab968e9..953df4d9cd 100644
--- a/synapse/handlers/appservice.py
+++ b/synapse/handlers/appservice.py
@@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
-from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union
+from typing import (
+ TYPE_CHECKING,
+ Collection,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Tuple,
+ Union,
+)
from prometheus_client import Counter
@@ -829,3 +838,66 @@ class ApplicationServicesHandler:
if unknown_user:
return await self.query_user_exists(user_id)
return True
+
+ async def claim_e2e_one_time_keys(
+ self, query: Iterable[Tuple[str, str, str]]
+ ) -> Tuple[
+ Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]]
+ ]:
+ """Claim one time keys from application services.
+
+ Args:
+ query: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A tuple of:
+ An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+
+ A copy of the input which has not been fulfilled (either because
+ they are not appservice users or the appservice does not support
+ providing OTKs).
+ """
+ services = self.store.get_app_services()
+
+ # Partition the users by appservice.
+ query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {}
+ missing = []
+ for user_id, device, algorithm in query:
+ if not self.store.get_if_app_services_interested_in_user(user_id):
+ missing.append((user_id, device, algorithm))
+ continue
+
+ # Find the associated appservice.
+ for service in services:
+ if service.is_exclusive_user(user_id):
+ query_by_appservice.setdefault(service.id, []).append(
+ (user_id, device, algorithm)
+ )
+ continue
+
+ # Query each service in parallel.
+ results = await make_deferred_yieldable(
+ defer.DeferredList(
+ [
+ run_in_background(
+ self.appservice_api.claim_client_keys,
+ # We know this must be an app service.
+ self.store.get_app_service_by_id(service_id), # type: ignore[arg-type]
+ service_query,
+ )
+ for service_id, service_query in query_by_appservice.items()
+ ],
+ consumeErrors=True,
+ )
+ )
+
+ # Patch together the results -- they are all independent (since they
+ # require exclusive control over the users). They get returned as a list
+ # and the caller combines them.
+ claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = []
+ for success, result in results:
+ if success:
+ claimed_keys.append(result[0])
+ missing.extend(result[1])
+
+ return claimed_keys, missing
diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py
index 4e9c8d8db0..9e7c2c45b5 100644
--- a/synapse/handlers/e2e_keys.py
+++ b/synapse/handlers/e2e_keys.py
@@ -13,7 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import logging
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple
@@ -53,6 +52,7 @@ class E2eKeysHandler:
self.store = hs.get_datastores().main
self.federation = hs.get_federation_client()
self.device_handler = hs.get_device_handler()
+ self._appservice_handler = hs.get_application_service_handler()
self.is_mine = hs.is_mine
self.clock = hs.get_clock()
@@ -88,6 +88,10 @@ class E2eKeysHandler:
max_count=10,
)
+ self._query_appservices_for_otks = (
+ hs.config.experimental.msc3983_appservice_otk_claims
+ )
+
@trace
@cancellable
async def query_devices(
@@ -542,6 +546,42 @@ class E2eKeysHandler:
return ret
+ async def claim_local_one_time_keys(
+ self, local_query: List[Tuple[str, str, str]]
+ ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]:
+ """Claim one time keys for local users.
+
+ 1. Attempt to claim OTKs from the database.
+ 2. Ask application services if they provide OTKs.
+ 3. Attempt to fetch fallback keys from the database.
+
+ Args:
+ local_query: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+ """
+
+ otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query)
+
+ # If the application services have not provided any keys via the C-S
+ # API, query it directly for one-time keys.
+ if self._query_appservices_for_otks:
+ (
+ appservice_results,
+ not_found,
+ ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found)
+ else:
+ appservice_results = []
+
+ # For each user that does not have a one-time keys available, see if
+ # there is a fallback key.
+ fallback_results = await self.store.claim_e2e_fallback_keys(not_found)
+
+ # Return the results in order, each item from the input query should
+ # only appear once in the combined list.
+ return (otk_results, *appservice_results, fallback_results)
+
@trace
async def claim_one_time_keys(
self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int]
@@ -561,17 +601,18 @@ class E2eKeysHandler:
set_tag("local_key_query", str(local_query))
set_tag("remote_key_query", str(remote_queries))
- results = await self.store.claim_e2e_one_time_keys(local_query)
+ results = await self.claim_local_one_time_keys(local_query)
# A map of user ID -> device ID -> key ID -> key.
json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+ for result in results:
+ for user_id, device_keys in result.items():
+ for device_id, keys in device_keys.items():
+ for key_id, key in keys.items():
+ json_result.setdefault(user_id, {})[device_id] = {key_id: key}
+
+ # Remote failures.
failures: Dict[str, JsonDict] = {}
- for user_id, device_keys in results.items():
- for device_id, keys in device_keys.items():
- for key_id, json_str in keys.items():
- json_result.setdefault(user_id, {})[device_id] = {
- key_id: json_decoder.decode(json_str)
- }
@trace
async def claim_client_keys(destination: str) -> None:
diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py
index a3b6c8ae8e..dc7768c50c 100644
--- a/synapse/storage/databases/main/end_to_end_keys.py
+++ b/synapse/storage/databases/main/end_to_end_keys.py
@@ -51,7 +51,7 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore
from synapse.storage.engines import PostgresEngine
from synapse.storage.util.id_generators import StreamIdGenerator
from synapse.types import JsonDict
-from synapse.util import json_encoder
+from synapse.util import json_decoder, json_encoder
from synapse.util.caches.descriptors import cached, cachedList
from synapse.util.cancellation import cancellable
from synapse.util.iterutils import batch_iter
@@ -1028,14 +1028,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
async def claim_e2e_one_time_keys(
self, query_list: Iterable[Tuple[str, str, str]]
- ) -> Dict[str, Dict[str, Dict[str, str]]]:
+ ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]:
"""Take a list of one time keys out of the database.
Args:
query_list: An iterable of tuples of (user ID, device ID, algorithm).
Returns:
- A map of user ID -> a map device ID -> a map of key ID -> JSON bytes.
+ A tuple pf:
+ A map of user ID -> a map device ID -> a map of key ID -> JSON.
+
+ A copy of the input which has not been fulfilled.
"""
@trace
@@ -1115,7 +1118,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
key_id, key_json = otk_row
return f"{algorithm}:{key_id}", key_json
- results: Dict[str, Dict[str, Dict[str, str]]] = {}
+ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+ missing: List[Tuple[str, str, str]] = []
for user_id, device_id, algorithm in query_list:
if self.database_engine.supports_returning:
# If we support RETURNING clause we can use a single query that
@@ -1138,11 +1142,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
device_results = results.setdefault(user_id, {}).setdefault(
device_id, {}
)
- device_results[claim_row[0]] = claim_row[1]
- continue
+ device_results[claim_row[0]] = json_decoder.decode(claim_row[1])
+ else:
+ missing.append((user_id, device_id, algorithm))
+
+ return results, missing
+
+ async def claim_e2e_fallback_keys(
+ self, query_list: Iterable[Tuple[str, str, str]]
+ ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]:
+ """Take a list of fallback keys out of the database.
- # No one-time key available, so see if there's a fallback
- # key
+ Args:
+ query_list: An iterable of tuples of (user ID, device ID, algorithm).
+
+ Returns:
+ A map of user ID -> a map device ID -> a map of key ID -> JSON.
+ """
+ results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {}
+ for user_id, device_id, algorithm in query_list:
row = await self.db_pool.simple_select_one(
table="e2e_fallback_keys_json",
keyvalues={
@@ -1179,7 +1197,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker
)
device_results = results.setdefault(user_id, {}).setdefault(device_id, {})
- device_results[f"{algorithm}:{key_id}"] = key_json
+ device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json)
return results
|