summary refs log tree commit diff
path: root/synapse/api
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/api')
-rw-r--r--synapse/api/auth/msc3861_delegated.py (renamed from synapse/api/auth/oauth_delegated.py)53
1 files changed, 28 insertions, 25 deletions
diff --git a/synapse/api/auth/oauth_delegated.py b/synapse/api/auth/msc3861_delegated.py
index 9cb6eb7f79..4ca3280bd3 100644
--- a/synapse/api/auth/oauth_delegated.py
+++ b/synapse/api/auth/msc3861_delegated.py
@@ -65,7 +65,7 @@ class PrivateKeyJWTWithKid(PrivateKeyJWT):
         )
 
 
-class OAuthDelegatedAuth(BaseAuth):
+class MSC3861DelegatedAuth(BaseAuth):
     AUTH_METHODS = {
         "client_secret_post": encode_client_secret_post,
         "client_secret_basic": encode_client_secret_basic,
@@ -78,35 +78,38 @@ class OAuthDelegatedAuth(BaseAuth):
     def __init__(self, hs: "HomeServer"):
         super().__init__(hs)
 
-        self._config = hs.config.auth
-        assert self._config.oauth_delegation_enabled, "OAuth delegation is not enabled"
-        assert self._config.oauth_delegation_issuer, "No issuer provided"
-        assert self._config.oauth_delegation_client_id, "No client_id provided"
-        assert self._config.oauth_delegation_client_secret, "No client_secret provided"
-        assert (
-            self._config.oauth_delegation_client_auth_method
-            in OAuthDelegatedAuth.AUTH_METHODS
-        ), "Invalid client_auth_method"
+        self._config = hs.config.experimental.msc3861
+        auth_method = MSC3861DelegatedAuth.AUTH_METHODS.get(
+            self._config.client_auth_method.value, None
+        )
+        # Those assertions are already checked when parsing the config
+        assert self._config.enabled, "OAuth delegation is not enabled"
+        assert self._config.issuer, "No issuer provided"
+        assert self._config.client_id, "No client_id provided"
+        assert auth_method is not None, "Invalid client_auth_method provided"
 
         self._http_client = hs.get_proxied_http_client()
         self._hostname = hs.hostname
 
         self._issuer_metadata = RetryOnExceptionCachedCall(self._load_metadata)
-        secret = self._config.oauth_delegation_client_secret
-        self._client_auth = ClientAuth(
-            self._config.oauth_delegation_client_id,
-            secret,
-            OAuthDelegatedAuth.AUTH_METHODS[
-                self._config.oauth_delegation_client_auth_method
-            ],
-        )
 
-    async def _load_metadata(self) -> OpenIDProviderMetadata:
-        if self._config.oauth_delegation_issuer_metadata is not None:
-            return OpenIDProviderMetadata(
-                **self._config.oauth_delegation_issuer_metadata
+        if isinstance(auth_method, PrivateKeyJWTWithKid):
+            # Use the JWK as the client secret when using the private_key_jwt method
+            assert self._config.jwk, "No JWK provided"
+            self._client_auth = ClientAuth(
+                self._config.client_id, self._config.jwk, auth_method
             )
-        url = get_well_known_url(self._config.oauth_delegation_issuer, external=True)
+        else:
+            # Else use the client secret
+            assert self._config.client_secret, "No client_secret provided"
+            self._client_auth = ClientAuth(
+                self._config.client_id, self._config.client_secret, auth_method
+            )
+
+    async def _load_metadata(self) -> OpenIDProviderMetadata:
+        if self._config.issuer_metadata is not None:
+            return OpenIDProviderMetadata(**self._config.issuer_metadata)
+        url = get_well_known_url(self._config.issuer, external=True)
         response = await self._http_client.get_json(url)
         metadata = OpenIDProviderMetadata(**response)
         # metadata.validate_introspection_endpoint()
@@ -203,7 +206,7 @@ class OAuthDelegatedAuth(BaseAuth):
             )
 
         user_id_str = await self.store.get_user_by_external_id(
-            OAuthDelegatedAuth.EXTERNAL_ID_PROVIDER, sub
+            MSC3861DelegatedAuth.EXTERNAL_ID_PROVIDER, sub
         )
         if user_id_str is None:
             # If we could not find a user via the external_id, it either does not exist,
@@ -236,7 +239,7 @@ class OAuthDelegatedAuth(BaseAuth):
 
             # And record the sub as external_id
             await self.store.record_user_external_id(
-                OAuthDelegatedAuth.EXTERNAL_ID_PROVIDER, sub, user_id.to_string()
+                MSC3861DelegatedAuth.EXTERNAL_ID_PROVIDER, sub, user_id.to_string()
             )
         else:
             user_id = UserID.from_string(user_id_str)