diff --git a/synapse/api/auth/msc3861_delegated.py b/synapse/api/auth/msc3861_delegated.py
index 7361666c77..e500a06afe 100644
--- a/synapse/api/auth/msc3861_delegated.py
+++ b/synapse/api/auth/msc3861_delegated.py
@@ -19,7 +19,8 @@
#
#
import logging
-from typing import TYPE_CHECKING, Any, Dict, List, Optional
+from dataclasses import dataclass
+from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
from urllib.parse import urlencode
from authlib.oauth2 import ClientAuth
@@ -38,15 +39,16 @@ from synapse.api.errors import (
HttpResponseException,
InvalidClientTokenError,
OAuthInsufficientScopeError,
- StoreError,
SynapseError,
UnrecognizedRequestError,
)
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
+from synapse.logging.opentracing import active_span, force_tracing, start_active_span
from synapse.types import Requester, UserID, create_requester
from synapse.util import json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
+from synapse.util.caches.response_cache import ResponseCache, ResponseCacheContext
if TYPE_CHECKING:
from synapse.rest.admin.experimental_features import ExperimentalFeature
@@ -76,6 +78,61 @@ def scope_to_list(scope: str) -> List[str]:
return scope.strip().split(" ")
+@dataclass
+class IntrospectionResult:
+ _inner: IntrospectionToken
+
+ # when we retrieved this token,
+ # in milliseconds since the Unix epoch
+ retrieved_at_ms: int
+
+ def is_active(self, now_ms: int) -> bool:
+ if not self._inner.get("active"):
+ return False
+
+ expires_in = self._inner.get("expires_in")
+ if expires_in is None:
+ return True
+ if not isinstance(expires_in, int):
+ raise InvalidClientTokenError("token `expires_in` is not an int")
+
+ absolute_expiry_ms = expires_in * 1000 + self.retrieved_at_ms
+ return now_ms < absolute_expiry_ms
+
+ def get_scope_list(self) -> List[str]:
+ value = self._inner.get("scope")
+ if not isinstance(value, str):
+ return []
+ return scope_to_list(value)
+
+ def get_sub(self) -> Optional[str]:
+ value = self._inner.get("sub")
+ if not isinstance(value, str):
+ return None
+ return value
+
+ def get_username(self) -> Optional[str]:
+ value = self._inner.get("username")
+ if not isinstance(value, str):
+ return None
+ return value
+
+ def get_name(self) -> Optional[str]:
+ value = self._inner.get("name")
+ if not isinstance(value, str):
+ return None
+ return value
+
+ def get_device_id(self) -> Optional[str]:
+ value = self._inner.get("device_id")
+ if value is not None and not isinstance(value, str):
+ raise AuthError(
+ 500,
+ "Invalid device ID in introspection result",
+ )
+ return value
+
+
class PrivateKeyJWTWithKid(PrivateKeyJWT): # type: ignore[misc]
"""An implementation of the private_key_jwt client auth method that includes a kid header.
@@ -119,9 +176,39 @@ class MSC3861DelegatedAuth(BaseAuth):
self._clock = hs.get_clock()
self._http_client = hs.get_proxied_http_client()
self._hostname = hs.hostname
- self._admin_token = self._config.admin_token
+ self._admin_token: Callable[[], Optional[str]] = self._config.admin_token
+ self._force_tracing_for_users = hs.config.tracing.force_tracing_for_users
+
+ # # Token Introspection Cache
+ # This remembers what users/devices are represented by which access tokens,
+ # in order to reduce overall system load:
+ # - on Synapse (as requests are relatively expensive)
+ # - on the network
+ # - on MAS
+ #
+ # Since there is no invalidation mechanism currently,
+ # the entries expire after 2 minutes.
+ # This does mean tokens can be treated as valid by Synapse
+ # for longer than reality.
+ #
+ # Ideally, tokens should logically be invalidated in the following circumstances:
+ # - If a session logout happens.
+ # In this case, MAS will delete the device within Synapse
+ # anyway and this is good enough as an invalidation.
+ # - If the client refreshes their token in MAS.
+ # In this case, the device still exists and it's not the end of the world for
+ # the old access token to continue working for a short time.
+ self._introspection_cache: ResponseCache[str] = ResponseCache(
+ self._clock,
+ "token_introspection",
+ timeout_ms=120_000,
+ # don't log because the keys are access tokens
+ enable_logging=False,
+ )
- self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
+ self._issuer_metadata = RetryOnExceptionCachedCall[OpenIDProviderMetadata](
+ self._load_metadata
+ )
if isinstance(auth_method, PrivateKeyJWTWithKid):
# Use the JWK as the client secret when using the private_key_jwt method
@@ -131,9 +218,10 @@ class MSC3861DelegatedAuth(BaseAuth):
)
else:
# Else use the client secret
- assert self._config.client_secret, "No client_secret provided"
+ client_secret = self._config.client_secret()
+ assert client_secret, "No client_secret provided"
self._client_auth = ClientAuth(
- self._config.client_id, self._config.client_secret, auth_method
+ self._config.client_id, client_secret, auth_method
)
async def _load_metadata(self) -> OpenIDProviderMetadata:
@@ -145,6 +233,39 @@ class MSC3861DelegatedAuth(BaseAuth):
# metadata.validate_introspection_endpoint()
return metadata
+ async def issuer(self) -> str:
+ """
+ Get the configured issuer
+
+ This will use the issuer value set in the metadata,
+ falling back to the one set in the config if not set in the metadata
+ """
+ metadata = await self._issuer_metadata.get()
+ return metadata.issuer or self._config.issuer
+
+ async def account_management_url(self) -> Optional[str]:
+ """
+ Get the configured account management URL
+
+ This will discover the account management URL from the issuer if it's not set in the config
+ """
+ if self._config.account_management_url is not None:
+ return self._config.account_management_url
+
+ try:
+ metadata = await self._issuer_metadata.get()
+ return metadata.get("account_management_uri", None)
+ # We don't want to raise here if we can't load the metadata
+ except Exception:
+ logger.warning("Failed to load metadata:", exc_info=True)
+ return None
+
+ async def auth_metadata(self) -> Dict[str, Any]:
+ """
+ Returns the auth metadata dict
+ """
+ return await self._issuer_metadata.get()
+
async def _introspection_endpoint(self) -> str:
"""
Returns the introspection endpoint of the issuer
@@ -154,10 +275,12 @@ class MSC3861DelegatedAuth(BaseAuth):
if self._config.introspection_endpoint is not None:
return self._config.introspection_endpoint
- metadata = await self._load_metadata()
+ metadata = await self._issuer_metadata.get()
return metadata.get("introspection_endpoint")
- async def _introspect_token(self, token: str) -> IntrospectionToken:
+ async def _introspect_token(
+ self, token: str, cache_context: ResponseCacheContext[str]
+ ) -> IntrospectionResult:
"""
Send a token to the introspection endpoint and returns the introspection response
@@ -173,11 +296,16 @@ class MSC3861DelegatedAuth(BaseAuth):
Returns:
The introspection response
"""
+ # By default, we shouldn't cache the result unless we know it's valid
+ cache_context.should_cache = False
introspection_endpoint = await self._introspection_endpoint()
raw_headers: Dict[str, str] = {
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": str(self._http_client.user_agent, "utf-8"),
"Accept": "application/json",
+ # Tell MAS that we support reading the device ID as an explicit
+ # value, not encoded in the scope. This is supported by MAS 0.15+
+ "X-MAS-Supports-Device-Id": "1",
}
args = {"token": token, "token_type_hint": "access_token"}
@@ -227,7 +355,11 @@ class MSC3861DelegatedAuth(BaseAuth):
"The introspection endpoint returned an invalid JSON response."
)
- return IntrospectionToken(**resp)
+ # We had a valid response, so we can cache it
+ cache_context.should_cache = True
+ return IntrospectionResult(
+ IntrospectionToken(**resp), retrieved_at_ms=self._clock.time_msec()
+ )
async def is_server_admin(self, requester: Requester) -> bool:
return "urn:synapse:admin:*" in requester.scope
@@ -239,6 +371,55 @@ class MSC3861DelegatedAuth(BaseAuth):
allow_expired: bool = False,
allow_locked: bool = False,
) -> Requester:
+ """Get a registered user's ID.
+
+ Args:
+ request: An HTTP request with an access_token query parameter.
+ allow_guest: If False, will raise an AuthError if the user making the
+ request is a guest.
+ allow_expired: If True, allow the request through even if the account
+ is expired, or session token lifetime has ended. Note that
+ /login will deliver access tokens regardless of expiration.
+
+ Returns:
+ Resolves to the requester
+ Raises:
+ InvalidClientCredentialsError if no user by that token exists or the token
+ is invalid.
+ AuthError if access is denied for the user in the access token
+ """
+ parent_span = active_span()
+ with start_active_span("get_user_by_req"):
+ requester = await self._wrapped_get_user_by_req(
+ request, allow_guest, allow_expired, allow_locked
+ )
+
+ if parent_span:
+ if requester.authenticated_entity in self._force_tracing_for_users:
+ # request tracing is enabled for this user, so we need to force it
+ # tracing on for the parent span (which will be the servlet span).
+ #
+ # It's too late for the get_user_by_req span to inherit the setting,
+ # so we also force it on for that.
+ force_tracing()
+ force_tracing(parent_span)
+ parent_span.set_tag(
+ "authenticated_entity", requester.authenticated_entity
+ )
+ parent_span.set_tag("user_id", requester.user.to_string())
+ if requester.device_id is not None:
+ parent_span.set_tag("device_id", requester.device_id)
+ if requester.app_service is not None:
+ parent_span.set_tag("appservice_id", requester.app_service.id)
+ return requester
+
+ async def _wrapped_get_user_by_req(
+ self,
+ request: SynapseRequest,
+ allow_guest: bool = False,
+ allow_expired: bool = False,
+ allow_locked: bool = False,
+ ) -> Requester:
access_token = self.get_access_token_from_request(request)
requester = await self.get_appservice_user(request, access_token)
@@ -248,7 +429,7 @@ class MSC3861DelegatedAuth(BaseAuth):
requester = await self.get_user_by_access_token(access_token, allow_expired)
# Do not record requests from MAS using the virtual `__oidc_admin` user.
- if access_token != self._admin_token:
+ if access_token != self._admin_token():
await self._record_request(request, requester)
if not allow_guest and requester.is_guest:
@@ -289,7 +470,8 @@ class MSC3861DelegatedAuth(BaseAuth):
token: str,
allow_expired: bool = False,
) -> Requester:
- if self._admin_token is not None and token == self._admin_token:
+ admin_token = self._admin_token()
+ if admin_token is not None and token == admin_token:
# XXX: This is a temporary solution so that the admin API can be called by
# the OIDC provider. This will be removed once we have OIDC client
# credentials grant support in matrix-authentication-service.
@@ -304,20 +486,22 @@ class MSC3861DelegatedAuth(BaseAuth):
)
try:
- introspection_result = await self._introspect_token(token)
+ introspection_result = await self._introspection_cache.wrap(
+ token, self._introspect_token, token, cache_context=True
+ )
except Exception:
logger.exception("Failed to introspect token")
raise SynapseError(503, "Unable to introspect the access token")
- logger.info(f"Introspection result: {introspection_result!r}")
+ logger.debug("Introspection result: %r", introspection_result)
# TODO: introspection verification should be more extensive, especially:
# - verify the audience
- if not introspection_result.get("active"):
+ if not introspection_result.is_active(self._clock.time_msec()):
raise InvalidClientTokenError("Token is not active")
# Let's look at the scope
- scope: List[str] = scope_to_list(introspection_result.get("scope", ""))
+ scope: List[str] = introspection_result.get_scope_list()
# Determine type of user based on presence of particular scopes
has_user_scope = SCOPE_MATRIX_API in scope
@@ -327,7 +511,7 @@ class MSC3861DelegatedAuth(BaseAuth):
raise InvalidClientTokenError("No scope in token granting user rights")
# Match via the sub claim
- sub: Optional[str] = introspection_result.get("sub")
+ sub = introspection_result.get_sub()
if sub is None:
raise InvalidClientTokenError(
"Invalid sub claim in the introspection result"
@@ -340,29 +524,20 @@ class MSC3861DelegatedAuth(BaseAuth):
# If we could not find a user via the external_id, it either does not exist,
# or the external_id was never recorded
- # TODO: claim mapping should be configurable
- username: Optional[str] = introspection_result.get("username")
- if username is None or not isinstance(username, str):
+ username = introspection_result.get_username()
+ if username is None:
raise AuthError(
500,
"Invalid username claim in the introspection result",
)
user_id = UserID(username, self._hostname)
- # First try to find a user from the username claim
+ # Try to find a user from the username claim
user_info = await self.store.get_user_by_id(user_id=user_id.to_string())
if user_info is None:
- # If the user does not exist, we should create it on the fly
- # TODO: we could use SCIM to provision users ahead of time and listen
- # for SCIM SET events if those ever become standard:
- # https://datatracker.ietf.org/doc/html/draft-hunt-scim-notify-00
-
- # TODO: claim mapping should be configurable
- # If present, use the name claim as the displayname
- name: Optional[str] = introspection_result.get("name")
-
- await self.store.register_user(
- user_id=user_id.to_string(), create_profile_with_displayname=name
+ raise AuthError(
+ 500,
+ "User not found",
)
# And record the sub as external_id
@@ -372,42 +547,40 @@ class MSC3861DelegatedAuth(BaseAuth):
else:
user_id = UserID.from_string(user_id_str)
- # Find device_ids in scope
- # We only allow a single device_id in the scope, so we find them all in the
- # scope list, and raise if there are more than one. The OIDC server should be
- # the one enforcing valid scopes, so we raise a 500 if we find an invalid scope.
- device_ids = [
- tok[len(SCOPE_MATRIX_DEVICE_PREFIX) :]
- for tok in scope
- if tok.startswith(SCOPE_MATRIX_DEVICE_PREFIX)
- ]
-
- if len(device_ids) > 1:
- raise AuthError(
- 500,
- "Multiple device IDs in scope",
- )
+ # MAS 0.15+ will give us the device ID as an explicit value for compatibility sessions
+ # If present, we get it from here, if not we get it in thee scope
+ device_id = introspection_result.get_device_id()
+ if device_id is None:
+ # Find device_ids in scope
+ # We only allow a single device_id in the scope, so we find them all in the
+ # scope list, and raise if there are more than one. The OIDC server should be
+ # the one enforcing valid scopes, so we raise a 500 if we find an invalid scope.
+ device_ids = [
+ tok[len(SCOPE_MATRIX_DEVICE_PREFIX) :]
+ for tok in scope
+ if tok.startswith(SCOPE_MATRIX_DEVICE_PREFIX)
+ ]
+
+ if len(device_ids) > 1:
+ raise AuthError(
+ 500,
+ "Multiple device IDs in scope",
+ )
+
+ device_id = device_ids[0] if device_ids else None
- device_id = device_ids[0] if device_ids else None
if device_id is not None:
# Sanity check the device_id
if len(device_id) > 255 or len(device_id) < 1:
raise AuthError(
500,
- "Invalid device ID in scope",
+ "Invalid device ID in introspection result",
)
- # Create the device on the fly if it does not exist
- try:
- await self.store.get_device(
- user_id=user_id.to_string(), device_id=device_id
- )
- except StoreError:
- await self.store.store_device(
- user_id=user_id.to_string(),
- device_id=device_id,
- initial_device_display_name="OIDC-native client",
- )
+ # Make sure the device exists
+ await self.store.get_device(
+ user_id=user_id.to_string(), device_id=device_id
+ )
# TODO: there is a few things missing in the requester here, which still need
# to be figured out, like:
|