summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorMathieu Velten <matmaul@gmail.com>2024-03-19 18:20:10 +0100
committerGitHub <noreply@github.com>2024-03-19 17:20:10 +0000
commit74ab329eaa50348d3ff65fc97d7fbc9cd9773311 (patch)
treef54f66631c011296b5ff2d745d02058a25bca72a /synapse
parentSpecify IP subnet literals in canonical form (#16953) (diff)
downloadsynapse-74ab329eaa50348d3ff65fc97d7fbc9cd9773311.tar.xz
Pass module API to OIDC mapping provider (#16974)
As done for SAML mapping provider, let's pass the module API to the OIDC
one so the mapper can do more logic in its code.
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/oidc.py17
1 files changed, 14 insertions, 3 deletions
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index fe13d82b66..ba67cc4768 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -65,6 +65,7 @@ from synapse.http.server import finish_request
 from synapse.http.servlet import parse_string
 from synapse.http.site import SynapseRequest
 from synapse.logging.context import make_deferred_yieldable
+from synapse.module_api import ModuleApi
 from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
 from synapse.util import Clock, json_decoder
 from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
@@ -421,9 +422,19 @@ class OidcProvider:
         # from the IdP's jwks_uri, if required.
         self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
 
-        self._user_mapping_provider = provider.user_mapping_provider_class(
-            provider.user_mapping_provider_config
+        user_mapping_provider_init_method = (
+            provider.user_mapping_provider_class.__init__
         )
+        if len(inspect.signature(user_mapping_provider_init_method).parameters) == 3:
+            self._user_mapping_provider = provider.user_mapping_provider_class(
+                provider.user_mapping_provider_config,
+                ModuleApi(hs, hs.get_auth_handler()),
+            )
+        else:
+            self._user_mapping_provider = provider.user_mapping_provider_class(
+                provider.user_mapping_provider_config,
+            )
+
         self._skip_verification = provider.skip_verification
         self._allow_existing_users = provider.allow_existing_users
 
@@ -1583,7 +1594,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
     This is the default mapping provider.
     """
 
-    def __init__(self, config: JinjaOidcMappingConfig):
+    def __init__(self, config: JinjaOidcMappingConfig, module_api: ModuleApi):
         self._config = config
 
     @staticmethod