From a3bad89d57645b2ea304d2900adab71a786b0172 Mon Sep 17 00:00:00 2001 From: Warren Bailey Date: Thu, 30 Mar 2023 12:09:41 +0100 Subject: Add the ability to enable/disable registrations when in the OIDC flow (#14978) Signed-off-by: Warren Bailey --- synapse/handlers/oidc.py | 1 + synapse/handlers/sso.py | 17 +++++++++++++++-- 2 files changed, 16 insertions(+), 2 deletions(-) (limited to 'synapse/handlers') diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py index 0fc829acf7..e7e0b5e049 100644 --- a/synapse/handlers/oidc.py +++ b/synapse/handlers/oidc.py @@ -1239,6 +1239,7 @@ class OidcProvider: grandfather_existing_users, extra_attributes, auth_provider_session_id=sid, + registration_enabled=self._config.enable_registration, ) def _remote_id_from_userinfo(self, userinfo: UserInfo) -> str: diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py index 4a27c0f051..c28325323c 100644 --- a/synapse/handlers/sso.py +++ b/synapse/handlers/sso.py @@ -383,6 +383,7 @@ class SsoHandler: grandfather_existing_users: Callable[[], Awaitable[Optional[str]]], extra_login_attributes: Optional[JsonDict] = None, auth_provider_session_id: Optional[str] = None, + registration_enabled: bool = True, ) -> None: """ Given an SSO ID, retrieve the user ID for it and possibly register the user. @@ -435,6 +436,10 @@ class SsoHandler: auth_provider_session_id: An optional session ID from the IdP. + registration_enabled: An optional boolean to enable/disable automatic + registrations of new users. If false and the user does not exist then the + flow is aborted. Defaults to true. + Raises: MappingException if there was a problem mapping the response to a user. RedirectException: if the mapping provider needs to redirect the user @@ -462,8 +467,16 @@ class SsoHandler: auth_provider_id, remote_user_id, user_id ) - # Otherwise, generate a new user. - if not user_id: + if not user_id and not registration_enabled: + logger.info( + "User does not exist and registration are disabled for IdP '%s' and remote_user_id '%s'", + auth_provider_id, + remote_user_id, + ) + raise MappingException( + "User does not exist and registrations are disabled" + ) + elif not user_id: # Otherwise, generate a new user. attributes = await self._call_attribute_mapper(sso_to_matrix_id_mapper) next_step_url = self._get_url_for_next_new_user_step( -- cgit 1.5.1 From d9f694932c64d68e791ecb4c860e911e21a0baeb Mon Sep 17 00:00:00 2001 From: Sean Quah <8349537+squahtx@users.noreply.github.com> Date: Thu, 30 Mar 2023 13:36:41 +0100 Subject: Fix spinloop during partial state sync when a prev event is in backoff (#15351) Previously, we would spin in a tight loop until `update_state_for_partial_state_event` stopped raising `FederationPullAttemptBackoffError`s. Replace the spinloop with a wait until the backoff period has expired. Signed-off-by: Sean Quah --- changelog.d/15351.bugfix | 1 + synapse/api/errors.py | 17 +++++++--- synapse/handlers/federation.py | 36 +++++++++++----------- synapse/handlers/federation_event.py | 24 ++++++++++----- synapse/storage/databases/main/event_federation.py | 35 ++++++++++++--------- tests/storage/test_event_federation.py | 13 +++++--- 6 files changed, 79 insertions(+), 47 deletions(-) create mode 100644 changelog.d/15351.bugfix (limited to 'synapse/handlers') diff --git a/changelog.d/15351.bugfix b/changelog.d/15351.bugfix new file mode 100644 index 0000000000..e68023c671 --- /dev/null +++ b/changelog.d/15351.bugfix @@ -0,0 +1 @@ +Fix a bug introduced in Synapse 1.70.0 where the background sync from a faster join could spin for hours when one of the events involved had been marked for backoff. diff --git a/synapse/api/errors.py b/synapse/api/errors.py index 8c6822f3c6..f2d6f9ab2d 100644 --- a/synapse/api/errors.py +++ b/synapse/api/errors.py @@ -27,7 +27,7 @@ from synapse.util import json_decoder if typing.TYPE_CHECKING: from synapse.config.homeserver import HomeServerConfig - from synapse.types import JsonDict + from synapse.types import JsonDict, StrCollection logger = logging.getLogger(__name__) @@ -682,18 +682,27 @@ class FederationPullAttemptBackoffError(RuntimeError): Attributes: event_id: The event_id which we are refusing to pull message: A custom error message that gives more context + retry_after_ms: The remaining backoff interval, in milliseconds """ - def __init__(self, event_ids: List[str], message: Optional[str]): - self.event_ids = event_ids + def __init__( + self, event_ids: "StrCollection", message: Optional[str], retry_after_ms: int + ): + event_ids = list(event_ids) if message: error_message = message else: - error_message = f"Not attempting to pull event_ids={self.event_ids} because we already tried to pull them recently (backing off)." + error_message = ( + f"Not attempting to pull event_ids={event_ids} because we already " + "tried to pull them recently (backing off)." + ) super().__init__(error_message) + self.event_ids = event_ids + self.retry_after_ms = retry_after_ms + class HttpResponseException(CodeMessageException): """ diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py index 80156ef343..65461a0787 100644 --- a/synapse/handlers/federation.py +++ b/synapse/handlers/federation.py @@ -1949,27 +1949,25 @@ class FederationHandler: ) for event in events: for attempt in itertools.count(): + # We try a new destination on every iteration. try: - await self._federation_event_handler.update_state_for_partial_state_event( - destination, event - ) + while True: + try: + await self._federation_event_handler.update_state_for_partial_state_event( + destination, event + ) + break + except FederationPullAttemptBackoffError as e: + # We are in the backoff period for one of the event's + # prev_events. Wait it out and try again after. + logger.warning( + "%s; waiting for %d ms...", e, e.retry_after_ms + ) + await self.clock.sleep(e.retry_after_ms / 1000) + + # Success, no need to try the rest of the destinations. break - except FederationPullAttemptBackoffError as exc: - # Log a warning about why we failed to process the event (the error message - # for `FederationPullAttemptBackoffError` is pretty good) - logger.warning("_sync_partial_state_room: %s", exc) - # We do not record a failed pull attempt when we backoff fetching a missing - # `prev_event` because not being able to fetch the `prev_events` just means - # we won't be able to de-outlier the pulled event. But we can still use an - # `outlier` in the state/auth chain for another event. So we shouldn't stop - # a downstream event from trying to pull it. - # - # This avoids a cascade of backoff for all events in the DAG downstream from - # one event backoff upstream. except FederationError as e: - # TODO: We should `record_event_failed_pull_attempt` here, - # see https://github.com/matrix-org/synapse/issues/13700 - if attempt == len(destinations) - 1: # We have tried every remote server for this event. Give up. # TODO(faster_joins) giving up isn't the right thing to do @@ -1986,6 +1984,8 @@ class FederationHandler: destination, e, ) + # TODO: We should `record_event_failed_pull_attempt` here, + # see https://github.com/matrix-org/synapse/issues/13700 raise # Try the next remote server. diff --git a/synapse/handlers/federation_event.py b/synapse/handlers/federation_event.py index 648843cdbe..982c8d3b2f 100644 --- a/synapse/handlers/federation_event.py +++ b/synapse/handlers/federation_event.py @@ -140,6 +140,7 @@ class FederationEventHandler: """ def __init__(self, hs: "HomeServer"): + self._clock = hs.get_clock() self._store = hs.get_datastores().main self._storage_controllers = hs.get_storage_controllers() self._state_storage_controller = self._storage_controllers.state @@ -1038,8 +1039,8 @@ class FederationEventHandler: Raises: FederationPullAttemptBackoffError if we are are deliberately not attempting - to pull the given event over federation because we've already done so - recently and are backing off. + to pull one of the given event's `prev_event`s over federation because + we've already done so recently and are backing off. FederationError if we fail to get the state from the remote server after any missing `prev_event`s. """ @@ -1053,13 +1054,22 @@ class FederationEventHandler: # If we've already recently attempted to pull this missing event, don't # try it again so soon. Since we have to fetch all of the prev_events, we can # bail early here if we find any to ignore. - prevs_to_ignore = await self._store.get_event_ids_to_not_pull_from_backoff( - room_id, missing_prevs + prevs_with_pull_backoff = ( + await self._store.get_event_ids_to_not_pull_from_backoff( + room_id, missing_prevs + ) ) - if len(prevs_to_ignore) > 0: + if len(prevs_with_pull_backoff) > 0: raise FederationPullAttemptBackoffError( - event_ids=prevs_to_ignore, - message=f"While computing context for event={event_id}, not attempting to pull missing prev_event={prevs_to_ignore[0]} because we already tried to pull recently (backing off).", + event_ids=prevs_with_pull_backoff.keys(), + message=( + f"While computing context for event={event_id}, not attempting to " + f"pull missing prev_events={list(prevs_with_pull_backoff.keys())} " + "because we already tried to pull recently (backing off)." + ), + retry_after_ms=( + max(prevs_with_pull_backoff.values()) - self._clock.time_msec() + ), ) if not missing_prevs: diff --git a/synapse/storage/databases/main/event_federation.py b/synapse/storage/databases/main/event_federation.py index ff3edeb716..a19ba88bf8 100644 --- a/synapse/storage/databases/main/event_federation.py +++ b/synapse/storage/databases/main/event_federation.py @@ -1544,7 +1544,7 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas self, room_id: str, event_ids: Collection[str], - ) -> List[str]: + ) -> Dict[str, int]: """ Filter down the events to ones that we've failed to pull before recently. Uses exponential backoff. @@ -1554,7 +1554,8 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas event_ids: A list of events to filter down Returns: - List of event_ids that should not be attempted to be pulled + A dictionary of event_ids that should not be attempted to be pulled and the + next timestamp at which we may try pulling them again. """ event_failed_pull_attempts = await self.db_pool.simple_select_many_batch( table="event_failed_pull_attempts", @@ -1570,22 +1571,28 @@ class EventFederationWorkerStore(SignatureWorkerStore, EventsWorkerStore, SQLBas ) current_time = self._clock.time_msec() - return [ - event_failed_pull_attempt["event_id"] - for event_failed_pull_attempt in event_failed_pull_attempts + + event_ids_with_backoff = {} + for event_failed_pull_attempt in event_failed_pull_attempts: + event_id = event_failed_pull_attempt["event_id"] # Exponential back-off (up to the upper bound) so we don't try to # pull the same event over and over. ex. 2hr, 4hr, 8hr, 16hr, etc. - if current_time - < event_failed_pull_attempt["last_attempt_ts"] - + ( - 2 - ** min( - event_failed_pull_attempt["num_attempts"], - BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + backoff_end_time = ( + event_failed_pull_attempt["last_attempt_ts"] + + ( + 2 + ** min( + event_failed_pull_attempt["num_attempts"], + BACKFILL_EVENT_EXPONENTIAL_BACKOFF_MAXIMUM_DOUBLING_STEPS, + ) ) + * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS ) - * BACKFILL_EVENT_EXPONENTIAL_BACKOFF_STEP_MILLISECONDS - ] + + if current_time < backoff_end_time: # `backoff_end_time` is exclusive + event_ids_with_backoff[event_id] = backoff_end_time + + return event_ids_with_backoff async def get_missing_events( self, diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py index 3e1984c15c..81e50bdd55 100644 --- a/tests/storage/test_event_federation.py +++ b/tests/storage/test_event_federation.py @@ -1143,19 +1143,24 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): tok = self.login("alice", "test") room_id = self.helper.create_room_as(room_creator=user_id, tok=tok) + failure_time = self.clock.time_msec() self.get_success( self.store.record_event_failed_pull_attempt( room_id, "$failed_event_id", "fake cause" ) ) - event_ids_to_backoff = self.get_success( + event_ids_with_backoff = self.get_success( self.store.get_event_ids_to_not_pull_from_backoff( room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"] ) ) - self.assertEqual(event_ids_to_backoff, ["$failed_event_id"]) + self.assertEqual( + event_ids_with_backoff, + # We expect a 2^1 hour backoff after a single failed attempt. + {"$failed_event_id": failure_time + 2 * 60 * 60 * 1000}, + ) def test_get_event_ids_to_not_pull_from_backoff_retry_after_backoff_duration( self, @@ -1179,14 +1184,14 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase): # attempt (2^1 hours). self.reactor.advance(datetime.timedelta(hours=2).total_seconds()) - event_ids_to_backoff = self.get_success( + event_ids_with_backoff = self.get_success( self.store.get_event_ids_to_not_pull_from_backoff( room_id=room_id, event_ids=["$failed_event_id", "$normal_event_id"] ) ) # Since this function only returns events we should backoff from, time has # elapsed past the backoff range so there is no events to backoff from. - self.assertEqual(event_ids_to_backoff, []) + self.assertEqual(event_ids_with_backoff, {}) @attr.s(auto_attribs=True) -- cgit 1.5.1 From ae4acda1bb903f504e946442bfc66dd0e5757dad Mon Sep 17 00:00:00 2001 From: Patrick Cloke Date: Thu, 30 Mar 2023 08:39:38 -0400 Subject: Implement MSC3984 to proxy /keys/query requests to appservices. (#15321) If enabled, for users which are exclusively owned by an application service then the appservice will be queried for devices in addition to any information stored in the Synapse database. --- changelog.d/15314.feature | 2 +- changelog.d/15321.feature | 1 + synapse/appservice/api.py | 54 ++++++++++++-- synapse/config/experimental.py | 5 ++ synapse/federation/federation_client.py | 48 ++----------- synapse/handlers/appservice.py | 61 ++++++++++++++++ synapse/handlers/e2e_keys.py | 16 +++++ synapse/http/client.py | 38 ++++++++++ tests/handlers/test_e2e_keys.py | 121 +++++++++++++++++++++++++++++++- 9 files changed, 298 insertions(+), 48 deletions(-) create mode 100644 changelog.d/15321.feature (limited to 'synapse/handlers') diff --git a/changelog.d/15314.feature b/changelog.d/15314.feature index 68b289b0cc..5ce0c029ce 100644 --- a/changelog.d/15314.feature +++ b/changelog.d/15314.feature @@ -1 +1 @@ -Experimental support for passing One Time Key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983)). +Experimental support for passing One Time Key and device key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) and [MSC3984](https://github.com/matrix-org/matrix-spec-proposals/pull/3984)). diff --git a/changelog.d/15321.feature b/changelog.d/15321.feature new file mode 100644 index 0000000000..5ce0c029ce --- /dev/null +++ b/changelog.d/15321.feature @@ -0,0 +1 @@ +Experimental support for passing One Time Key and device key requests to application services ([MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983) and [MSC3984](https://github.com/matrix-org/matrix-spec-proposals/pull/3984)). diff --git a/synapse/appservice/api.py b/synapse/appservice/api.py index 51ee0e79df..b27eedef99 100644 --- a/synapse/appservice/api.py +++ b/synapse/appservice/api.py @@ -30,7 +30,7 @@ from prometheus_client import Counter from typing_extensions import TypeGuard from synapse.api.constants import EventTypes, Membership, ThirdPartyEntityKind -from synapse.api.errors import CodeMessageException +from synapse.api.errors import CodeMessageException, HttpResponseException from synapse.appservice import ( ApplicationService, TransactionOneTimeKeysCount, @@ -38,7 +38,7 @@ from synapse.appservice import ( ) from synapse.events import EventBase from synapse.events.utils import SerializeEventConfig, serialize_event -from synapse.http.client import SimpleHttpClient +from synapse.http.client import SimpleHttpClient, is_unknown_endpoint from synapse.types import DeviceListUpdates, JsonDict, ThirdPartyInstanceID from synapse.util.caches.response_cache import ResponseCache @@ -393,7 +393,11 @@ class ApplicationServiceApi(SimpleHttpClient): ) -> Tuple[Dict[str, Dict[str, Dict[str, JsonDict]]], List[Tuple[str, str, str]]]: """Claim one time keys from an application service. + Note that any error (including a timeout) is treated as the application + service having no information. + Args: + service: The application service to query. query: An iterable of tuples of (user ID, device ID, algorithm). Returns: @@ -422,9 +426,9 @@ class ApplicationServiceApi(SimpleHttpClient): body, headers={"Authorization": [f"Bearer {service.hs_token}"]}, ) - except CodeMessageException as e: + except HttpResponseException as e: # The appservice doesn't support this endpoint. - if e.code == 404 or e.code == 405: + if is_unknown_endpoint(e): return {}, query logger.warning("claim_keys to %s received %s", uri, e.code) return {}, query @@ -444,6 +448,48 @@ class ApplicationServiceApi(SimpleHttpClient): return response, missing + async def query_keys( + self, service: "ApplicationService", query: Dict[str, List[str]] + ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: + """Query the application service for keys. + + Note that any error (including a timeout) is treated as the application + service having no information. + + Args: + service: The application service to query. + query: An iterable of tuples of (user ID, device ID, algorithm). + + Returns: + A map of device_keys/master_keys/self_signing_keys/user_signing_keys: + + device_keys is a map of user ID -> a map device ID -> device info. + """ + if service.url is None: + return {} + + # This is required by the configuration. + assert service.hs_token is not None + + uri = f"{service.url}/_matrix/app/unstable/org.matrix.msc3984/keys/query" + try: + response = await self.post_json_get_json( + uri, + query, + headers={"Authorization": [f"Bearer {service.hs_token}"]}, + ) + except HttpResponseException as e: + # The appservice doesn't support this endpoint. + if is_unknown_endpoint(e): + return {} + logger.warning("query_keys to %s received %s", uri, e.code) + return {} + except Exception as ex: + logger.warning("query_keys to %s threw exception %s", uri, ex) + return {} + + return response + def _serialize( self, service: "ApplicationService", events: Iterable[EventBase] ) -> List[JsonDict]: diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py index 53e6fc2b54..7687c80ea0 100644 --- a/synapse/config/experimental.py +++ b/synapse/config/experimental.py @@ -79,6 +79,11 @@ class ExperimentalConfig(Config): "msc3983_appservice_otk_claims", False ) + # MSC3984: Proxying key queries to exclusive ASes. + self.msc3984_appservice_key_query: bool = experimental.get( + "msc3984_appservice_key_query", False + ) + # MSC3706 (server-side support for partial state in /send_join responses) # Synapse will always serve partial state responses to requests using the stable # query parameter `omit_members`. If this flag is set, Synapse will also serve diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py index 7d04560dca..4cf4957a42 100644 --- a/synapse/federation/federation_client.py +++ b/synapse/federation/federation_client.py @@ -61,6 +61,7 @@ from synapse.federation.federation_base import ( event_from_pdu_json, ) from synapse.federation.transport.client import SendJoinResponse +from synapse.http.client import is_unknown_endpoint from synapse.http.types import QueryParams from synapse.logging.opentracing import SynapseTags, log_kv, set_tag, tag_args, trace from synapse.types import JsonDict, UserID, get_domain_from_id @@ -759,43 +760,6 @@ class FederationClient(FederationBase): return signed_auth - def _is_unknown_endpoint( - self, e: HttpResponseException, synapse_error: Optional[SynapseError] = None - ) -> bool: - """ - Returns true if the response was due to an endpoint being unimplemented. - - Args: - e: The error response received from the remote server. - synapse_error: The above error converted to a SynapseError. This is - automatically generated if not provided. - - """ - if synapse_error is None: - synapse_error = e.to_synapse_error() - # MSC3743 specifies that servers should return a 404 or 405 with an errcode - # of M_UNRECOGNIZED when they receive a request to an unknown endpoint or - # to an unknown method, respectively. - # - # Older versions of servers don't properly handle this. This needs to be - # rather specific as some endpoints truly do return 404 errors. - return ( - # 404 is an unknown endpoint, 405 is a known endpoint, but unknown method. - (e.code == 404 or e.code == 405) - and ( - # Older Dendrites returned a text or empty body. - # Older Conduit returned an empty body. - not e.response - or e.response == b"404 page not found" - # The proper response JSON with M_UNRECOGNIZED errcode. - or synapse_error.errcode == Codes.UNRECOGNIZED - ) - ) or ( - # Older Synapses returned a 400 error. - e.code == 400 - and synapse_error.errcode == Codes.UNRECOGNIZED - ) - async def _try_destination_list( self, description: str, @@ -887,7 +851,7 @@ class FederationClient(FederationBase): elif 400 <= e.code < 500 and synapse_error.errcode in failover_errcodes: failover = True - elif failover_on_unknown_endpoint and self._is_unknown_endpoint( + elif failover_on_unknown_endpoint and is_unknown_endpoint( e, synapse_error ): failover = True @@ -1223,7 +1187,7 @@ class FederationClient(FederationBase): # If an error is received that is due to an unrecognised endpoint, # fallback to the v1 endpoint. Otherwise, consider it a legitimate error # and raise. - if not self._is_unknown_endpoint(e): + if not is_unknown_endpoint(e): raise logger.debug("Couldn't send_join with the v2 API, falling back to the v1 API") @@ -1297,7 +1261,7 @@ class FederationClient(FederationBase): # fallback to the v1 endpoint if the room uses old-style event IDs. # Otherwise, consider it a legitimate error and raise. err = e.to_synapse_error() - if self._is_unknown_endpoint(e, err): + if is_unknown_endpoint(e, err): if room_version.event_format != EventFormatVersions.ROOM_V1_V2: raise SynapseError( 400, @@ -1358,7 +1322,7 @@ class FederationClient(FederationBase): # If an error is received that is due to an unrecognised endpoint, # fallback to the v1 endpoint. Otherwise, consider it a legitimate error # and raise. - if not self._is_unknown_endpoint(e): + if not is_unknown_endpoint(e): raise logger.debug("Couldn't send_leave with the v2 API, falling back to the v1 API") @@ -1629,7 +1593,7 @@ class FederationClient(FederationBase): # If an error is received that is due to an unrecognised endpoint, # fallback to the unstable endpoint. Otherwise, consider it a # legitimate error and raise. - if not self._is_unknown_endpoint(e): + if not is_unknown_endpoint(e): raise logger.debug( diff --git a/synapse/handlers/appservice.py b/synapse/handlers/appservice.py index 953df4d9cd..da887647d4 100644 --- a/synapse/handlers/appservice.py +++ b/synapse/handlers/appservice.py @@ -18,6 +18,7 @@ from typing import ( Dict, Iterable, List, + Mapping, Optional, Tuple, Union, @@ -846,6 +847,10 @@ class ApplicationServicesHandler: ]: """Claim one time keys from application services. + Users which are exclusively owned by an application service are sent a + key claim request to check if the application service provides keys + directly. + Args: query: An iterable of tuples of (user ID, device ID, algorithm). @@ -901,3 +906,59 @@ class ApplicationServicesHandler: missing.extend(result[1]) return claimed_keys, missing + + async def query_keys( + self, query: Mapping[str, Optional[List[str]]] + ) -> Dict[str, Dict[str, Dict[str, JsonDict]]]: + """Query application services for device keys. + + Users which are exclusively owned by an application service are queried + for keys to check if the application service provides keys directly. + + Args: + query: map from user_id to a list of devices to query + + Returns: + A map from user_id -> device_id -> device details + """ + services = self.store.get_app_services() + + # Partition the users by appservice. + query_by_appservice: Dict[str, Dict[str, List[str]]] = {} + for user_id, device_ids in query.items(): + if not self.store.get_if_app_services_interested_in_user(user_id): + continue + + # Find the associated appservice. + for service in services: + if service.is_exclusive_user(user_id): + query_by_appservice.setdefault(service.id, {})[user_id] = ( + device_ids or [] + ) + continue + + # Query each service in parallel. + results = await make_deferred_yieldable( + defer.DeferredList( + [ + run_in_background( + self.appservice_api.query_keys, + # We know this must be an app service. + self.store.get_app_service_by_id(service_id), # type: ignore[arg-type] + service_query, + ) + for service_id, service_query in query_by_appservice.items() + ], + consumeErrors=True, + ) + ) + + # Patch together the results -- they are all independent (since they + # require exclusive control over the users). They get returned as a single + # dictionary. + key_queries: Dict[str, Dict[str, Dict[str, JsonDict]]] = {} + for success, result in results: + if success: + key_queries.update(result) + + return key_queries diff --git a/synapse/handlers/e2e_keys.py b/synapse/handlers/e2e_keys.py index 9e7c2c45b5..0073667470 100644 --- a/synapse/handlers/e2e_keys.py +++ b/synapse/handlers/e2e_keys.py @@ -91,6 +91,9 @@ class E2eKeysHandler: self._query_appservices_for_otks = ( hs.config.experimental.msc3983_appservice_otk_claims ) + self._query_appservices_for_keys = ( + hs.config.experimental.msc3984_appservice_key_query + ) @trace @cancellable @@ -497,6 +500,19 @@ class E2eKeysHandler: local_query, include_displaynames ) + # Check if the application services have any additional results. + if self._query_appservices_for_keys: + # Query the appservices for any keys. + appservice_results = await self._appservice_handler.query_keys(query) + + # Merge results, overriding with what the appservice returned. + for user_id, devices in appservice_results.get("device_keys", {}).items(): + # Copy the appservice device info over the homeserver device info, but + # don't completely overwrite it. + results.setdefault(user_id, {}).update(devices) + + # TODO Handle cross-signing keys. + # Build the result structure for user_id, device_keys in results.items(): for device_id, device_info in device_keys.items(): diff --git a/synapse/http/client.py b/synapse/http/client.py index d777d59ccf..5ee55981d9 100644 --- a/synapse/http/client.py +++ b/synapse/http/client.py @@ -966,3 +966,41 @@ class InsecureInterceptableContextFactory(ssl.ContextFactory): def creatorForNetloc(self, hostname: bytes, port: int) -> IOpenSSLContextFactory: return self + + +def is_unknown_endpoint( + e: HttpResponseException, synapse_error: Optional[SynapseError] = None +) -> bool: + """ + Returns true if the response was due to an endpoint being unimplemented. + + Args: + e: The error response received from the remote server. + synapse_error: The above error converted to a SynapseError. This is + automatically generated if not provided. + + """ + if synapse_error is None: + synapse_error = e.to_synapse_error() + # MSC3743 specifies that servers should return a 404 or 405 with an errcode + # of M_UNRECOGNIZED when they receive a request to an unknown endpoint or + # to an unknown method, respectively. + # + # Older versions of servers don't properly handle this. This needs to be + # rather specific as some endpoints truly do return 404 errors. + return ( + # 404 is an unknown endpoint, 405 is a known endpoint, but unknown method. + (e.code == 404 or e.code == 405) + and ( + # Older Dendrites returned a text body or empty body. + # Older Conduit returned an empty body. + not e.response + or e.response == b"404 page not found" + # The proper response JSON with M_UNRECOGNIZED errcode. + or synapse_error.errcode == Codes.UNRECOGNIZED + ) + ) or ( + # Older Synapses returned a 400 error. + e.code == 400 + and synapse_error.errcode == Codes.UNRECOGNIZED + ) diff --git a/tests/handlers/test_e2e_keys.py b/tests/handlers/test_e2e_keys.py index 4ff04fc66b..013b9ee550 100644 --- a/tests/handlers/test_e2e_keys.py +++ b/tests/handlers/test_e2e_keys.py @@ -960,7 +960,7 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): appservice = ApplicationService( token="i_am_an_app_service", id="1234", - namespaces={"users": [{"regex": r"@boris:*", "exclusive": True}]}, + namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]}, # Note: this user does not have to match the regex above sender="@as_main:test", ) @@ -1015,3 +1015,122 @@ class E2eKeysHandlerTestCase(unittest.HomeserverTestCase): }, }, ) + + @override_config({"experimental_features": {"msc3984_appservice_key_query": True}}) + def test_query_local_devices_appservice(self) -> None: + """Test that querying of appservices for keys overrides responses from the database.""" + local_user = "@boris:" + self.hs.hostname + device_1 = "abc" + device_2 = "def" + device_3 = "ghi" + + # There are 3 devices: + # + # 1. One which is uploaded to the homeserver. + # 2. One which is uploaded to the homeserver, but a newer copy is returned + # by the appservice. + # 3. One which is only returned by the appservice. + device_key_1: JsonDict = { + "user_id": local_user, + "device_id": device_1, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], + "keys": { + "ed25519:abc": "base64+ed25519+key", + "curve25519:abc": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:abc": "base64+signature"}}, + } + device_key_2a: JsonDict = { + "user_id": local_user, + "device_id": device_2, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], + "keys": { + "ed25519:def": "base64+ed25519+key", + "curve25519:def": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:def": "base64+signature"}}, + } + + device_key_2b: JsonDict = { + "user_id": local_user, + "device_id": device_2, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], + # The device ID is the same (above), but the keys are different. + "keys": { + "ed25519:xyz": "base64+ed25519+key", + "curve25519:xyz": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:xyz": "base64+signature"}}, + } + device_key_3: JsonDict = { + "user_id": local_user, + "device_id": device_3, + "algorithms": [ + "m.olm.curve25519-aes-sha2", + RoomEncryptionAlgorithms.MEGOLM_V1_AES_SHA2, + ], + "keys": { + "ed25519:jkl": "base64+ed25519+key", + "curve25519:jkl": "base64+curve25519+key", + }, + "signatures": {local_user: {"ed25519:jkl": "base64+signature"}}, + } + + # Upload keys for devices 1 & 2a. + self.get_success( + self.handler.upload_keys_for_user( + local_user, device_1, {"device_keys": device_key_1} + ) + ) + self.get_success( + self.handler.upload_keys_for_user( + local_user, device_2, {"device_keys": device_key_2a} + ) + ) + + # Inject an appservice interested in this user. + appservice = ApplicationService( + token="i_am_an_app_service", + id="1234", + namespaces={"users": [{"regex": r"@boris:.+", "exclusive": True}]}, + # Note: this user does not have to match the regex above + sender="@as_main:test", + ) + self.hs.get_datastores().main.services_cache = [appservice] + self.hs.get_datastores().main.exclusive_user_regex = _make_exclusive_regex( + [appservice] + ) + + # Setup a response. + self.appservice_api.query_keys.return_value = make_awaitable( + { + "device_keys": { + local_user: {device_2: device_key_2b, device_3: device_key_3} + } + } + ) + + # Request all devices. + res = self.get_success(self.handler.query_local_devices({local_user: None})) + self.assertIn(local_user, res) + for res_key in res[local_user].values(): + res_key.pop("unsigned", None) + self.assertDictEqual( + res, + { + local_user: { + device_1: device_key_1, + device_2: device_key_2b, + device_3: device_key_3, + } + }, + ) -- cgit 1.5.1