From 5282ba1e2bbff2635dc09aec45fd42a56c1a4545 Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Tue, 28 Mar 2023 14:26:27 -0400 Subject: Implement MSC3983 to proxy /keys/claim queries to appservices. (#15314) Experimental support for MSC3983 is behind a configuration flag. If enabled, for users which are exclusively owned by an application service then the appservice will be queried for one-time keys *if* there are none uploaded to Synapse. --- synapse/appservice/api.py | 56 +++++++++++++++++ synapse/config/experimental.py | 5 ++ synapse/federation/federation_server.py | 20 +++--- synapse/handlers/appservice.py | 74 ++++++++++++++++++++++- synapse/handlers/e2e_keys.py | 57 ++++++++++++++--- synapse/storage/databases/main/end_to_end_keys.py | 36 ++++++++--- 6 files changed, 220 insertions(+), 28 deletions(-) (limited to 'synapse') diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 4812fb4496..51ee0e79df 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -388,6 +388,62 @@ class ApplicationServiceApi(SimpleHttpClient): failed_transactions_counter.labels(service.id).inc() return False + async def claim_client_keys( + self, service: "ApplicationService", query: List[Tuple[str, str, str]] + ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: + """Claim one time keys from an application service. + + Args: + query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A tuple of: + A map of user ID -> a map device ID -> a map of key ID -> JSON dict. + + A copy of the input which has not been fulfilled because the + appservice doesn't support this endpoint or has not returned + data for that tuple. + """ + if service.url is None: + return {}, query + + # This is required by the configuration. + assert service.hs_token is not None + + # Create the expected payload shape. + body: Dict[str, Dict[str, List[str]]] = {} + for user_id, device, algorithm in query: + body.setdefault(user_id, {}).setdefault(device, []).append(algorithm) + + uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3983/keys/claim" + try: + response = await self.post_json_get_json( + uri, + body, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, + ) + except CodeMessageException as e: + # The appservice doesn't support this endpoint. + if e.code == 404 or e.code == 405: + return {}, query + logger.warning("claim_keys to %s received %s", uri, e.code) + return {}, query + except Exception as ex: + logger.warning("claim_keys to %s threw exception %s", uri, ex) + return {}, query + + # Check if the appservice fulfilled all of the queried user/device/algorithms + # or if some are still missing. + # + # TODO This places a lot of faith in the response shape being correct. + missing = [ + (user_id, device, algorithm) + for user_id, device, algorithm in query + if algorithm not in response.get(user_id, {}).get(device, []) + ] + + return response, missing + def _serialize( self, service: "ApplicationService", events: Iterable[EventBase] ) -> List[JsonDict]: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 99dcd27c74..53e6fc2b54 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -74,6 +74,11 @@ class ExperimentalConfig(Config): "msc3202_transaction_extensions", False ) + # MSC3983: Proxying OTK claim requests to exclusive ASes. + self.msc3983_appservice_otk_claims: bool = experimental.get( + "msc3983_appservice_otk_claims", False + ) + # MSC3706 (server-side support for partial state in /send_join responses) # Synapse will always serve partial state responses to requests using the stable # query parameter `omit_members`. If this flag is set, Synapse will also serve diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py index 6d99845de5..64e99292ec 100644 --- a/synapse/federation/federation_server.py +++ b/synapse/federation/federation_server.py @@ -86,7 +86,7 @@ from synapse.storage.databases.main.lock import Lock from synapse.storage.databases.main.roommember import extract_heroes_from_room_summary from synapse.storage.roommember import MemberSummary from synapse.types import JsonDict, StateMap, get_domain_from_id -from synapse.util import json_decoder, unwrapFirstError +from synapse.util import unwrapFirstError from synapse.util.async_helpers import Linearizer, concurrently_execute, gather_results from synapse.util.caches.response_cache import ResponseCache from synapse.util.stringutils import parse_server_name @@ -135,6 +135,7 @@ class FederationServer(FederationBase): self.state = hs.get_state_handler() self._event_auth_handler = hs.get_event_auth_handler() self._room_member_handler = hs.get_room_member_handler() + self._e2e_keys_handler = hs.get_e2e_keys_handler() self._state_storage_controller = hs.get_storage_controllers().state @@ -1012,15 +1013,14 @@ class FederationServer(FederationBase): query.append((user_id, device_id, algorithm)) log_kv({"message": "Claiming one time keys.", "user, device pairs": query}) - results = await self.store.claim_e2e_one_time_keys(query) - - json_result: Dict[str, Dict[str, dict]] = {} - for user_id, device_keys in results.items(): - for device_id, keys in device_keys.items(): - for key_id, json_str in keys.items(): - json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_str) - } + results = await self._e2e_keys_handler.claim_local_one_time_keys(query) + + json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + for result in results: + for user_id, device_keys in result.items(): + for device_id, keys in device_keys.items(): + for key_id, key in keys.items(): + json_result.setdefault(user_id, {})[device_id] = {key_id: key} logger.info( "Claimed one-time-keys: %s", diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index ec3ab968e9..953df4d9cd 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -12,7 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging -from typing import TYPE_CHECKING, Collection, Dict, Iterable, List, Optional, Union +from typing import ( + TYPE_CHECKING, + Collection, + Dict, + Iterable, + List, + Optional, + Tuple, + Union, +) from prometheus_client import Counter @@ -829,3 +838,66 @@ class ApplicationServicesHandler: if unknown_user: return await self.query_user_exists(user_id) return True + + async def claim_e2e_one_time_keys( + self, query: Iterable[Tuple[str, str, str]] + ) -> Tuple[ + Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]], List[Tuple[str, str, str]] + ]: + """Claim one time keys from application services. + + Args: + query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A tuple of: + An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. + + A copy of the input which has not been fulfilled (either because + they are not appservice users or the appservice does not support + providing OTKs). + """ + services = self.store.get_app_services() + + # Partition the users by appservice. + query_by_appservice: Dict[str, List[Tuple[str, str, str]]] = {} + missing = [] + for user_id, device, algorithm in query: + if not self.store.get_if_app_services_interested_in_user(user_id): + missing.append((user_id, device, algorithm)) + continue + + # Find the associated appservice. + for service in services: + if service.is_exclusive_user(user_id): + query_by_appservice.setdefault(service.id, []).append( + (user_id, device, algorithm) + ) + continue + + # Query each service in parallel. + results = await make_deferred_yieldable( + defer.DeferredList( + [ + run_in_background( + self.appservice_api.claim_client_keys, + # We know this must be an app service. + self.store.get_app_service_by_id(service_id), # type: ignore[arg-type] + service_query, + ) + for service_id, service_query in query_by_appservice.items() + ], + consumeErrors=True, + ) + ) + + # Patch together the results -- they are all independent (since they + # require exclusive control over the users). They get returned as a list + # and the caller combines them. + claimed_keys: List[Dict[str, Dict[str, Dict[str, JsonDict]]]] = [] + for success, result in results: + if success: + claimed_keys.append(result[0]) + missing.extend(result[1]) + + return claimed_keys, missing diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 4e9c8d8db0..9e7c2c45b5 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import logging from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple @@ -53,6 +52,7 @@ class E2eKeysHandler: self.store = hs.get_datastores().main self.federation = hs.get_federation_client() self.device_handler = hs.get_device_handler() + self._appservice_handler = hs.get_application_service_handler() self.is_mine = hs.is_mine self.clock = hs.get_clock() @@ -88,6 +88,10 @@ class E2eKeysHandler: max_count=10, ) + self._query_appservices_for_otks = ( + hs.config.experimental.msc3983_appservice_otk_claims + ) + @trace @cancellable async def query_devices( @@ -542,6 +546,42 @@ class E2eKeysHandler: return ret + async def claim_local_one_time_keys( + self, local_query: List[Tuple[str, str, str]] + ) -> Iterable[Dict[str, Dict[str, Dict[str, JsonDict]]]]: + """Claim one time keys for local users. + + 1. Attempt to claim OTKs from the database. + 2. Ask application services if they provide OTKs. + 3. Attempt to fetch fallback keys from the database. + + Args: + local_query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + An iterable of maps of user ID -> a map device ID -> a map of key ID -> JSON bytes. + """ + + otk_results, not_found = await self.store.claim_e2e_one_time_keys(local_query) + + # If the application services have not provided any keys via the C-S + # API, query it directly for one-time keys. + if self._query_appservices_for_otks: + ( + appservice_results, + not_found, + ) = await self._appservice_handler.claim_e2e_one_time_keys(not_found) + else: + appservice_results = [] + + # For each user that does not have a one-time keys available, see if + # there is a fallback key. + fallback_results = await self.store.claim_e2e_fallback_keys(not_found) + + # Return the results in order, each item from the input query should + # only appear once in the combined list. + return (otk_results, *appservice_results, fallback_results) + @trace async def claim_one_time_keys( self, query: Dict[str, Dict[str, Dict[str, str]]], timeout: Optional[int] @@ -561,17 +601,18 @@ class E2eKeysHandler: set_tag("local_key_query", str(local_query)) set_tag("remote_key_query", str(remote_queries)) - results = await self.store.claim_e2e_one_time_keys(local_query) + results = await self.claim_local_one_time_keys(local_query) # A map of user ID -> device ID -> key ID -> key. json_result: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + for result in results: + for user_id, device_keys in result.items(): + for device_id, keys in device_keys.items(): + for key_id, key in keys.items(): + json_result.setdefault(user_id, {})[device_id] = {key_id: key} + + # Remote failures. failures: Dict[str, JsonDict] = {} - for user_id, device_keys in results.items(): - for device_id, keys in device_keys.items(): - for key_id, json_str in keys.items(): - json_result.setdefault(user_id, {})[device_id] = { - key_id: json_decoder.decode(json_str) - } @trace async def claim_client_keys(destination: str) -> None: diff --git a/synapse/storage/databases/main/end_to_end_keys.py b/synapse/storage/databases/main/end_to_end_keys.py index a3b6c8ae8e..dc7768c50c 100644 --- a/synapse/storage/databases/main/end_to_end_keys.py +++ b/synapse/storage/databases/main/end_to_end_keys.py @@ -51,7 +51,7 @@ from synapse.storage.databases.main.cache import CacheInvalidationWorkerStore from synapse.storage.engines import PostgresEngine from synapse.storage.util.id_generators import StreamIdGenerator from synapse.types import JsonDict -from synapse.util import json_encoder +from synapse.util import json_decoder, json_encoder from synapse.util.caches.descriptors import cached, cachedList from synapse.util.cancellation import cancellable from synapse.util.iterutils import batch_iter @@ -1028,14 +1028,17 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker async def claim_e2e_one_time_keys( self, query_list: Iterable[Tuple[str, str, str]] - ) -> Dict[str, Dict[str, Dict[str, str]]]: + ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: """Take a list of one time keys out of the database. Args: query_list: An iterable of tuples of (user ID, device ID, algorithm). Returns: - A map of user ID -> a map device ID -> a map of key ID -> JSON bytes. + A tuple pf: + A map of user ID -> a map device ID -> a map of key ID -> JSON. + + A copy of the input which has not been fulfilled. """ @trace @@ -1115,7 +1118,8 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker key_id, key_json = otk_row return f"{algorithm}:{key_id}", key_json - results: Dict[str, Dict[str, Dict[str, str]]] = {} + results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + missing: List[Tuple[str, str, str]] = [] for user_id, device_id, algorithm in query_list: if self.database_engine.supports_returning: # If we support RETURNING clause we can use a single query that @@ -1138,11 +1142,25 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker device_results = results.setdefault(user_id, {}).setdefault( device_id, {} ) - device_results[claim_row[0]] = claim_row[1] - continue + device_results[claim_row[0]] = json_decoder.decode(claim_row[1]) + else: + missing.append((user_id, device_id, algorithm)) + + return results, missing + + async def claim_e2e_fallback_keys( + self, query_list: Iterable[Tuple[str, str, str]] + ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: + """Take a list of fallback keys out of the database. - # No one-time key available, so see if there's a fallback - # key + Args: + query_list: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A map of user ID -> a map device ID -> a map of key ID -> JSON. + """ + results: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + for user_id, device_id, algorithm in query_list: row = await self.db_pool.simple_select_one( table="e2e_fallback_keys_json", keyvalues={ @@ -1179,7 +1197,7 @@ class EndToEndKeyWorkerStore(EndToEndKeyBackgroundStore, CacheInvalidationWorker ) device_results = results.setdefault(user_id, {}).setdefault(device_id, {}) - device_results[f"{algorithm}:{key_id}"] = key_json + device_results[f"{algorithm}:{key_id}"] = json_decoder.decode(key_json) return results -- cgit 1.4.1