diff --git a/changelog.d/16974.misc b/changelog.d/16974.misc
new file mode 100644
index 0000000000..bf0a13786c
--- /dev/null
+++ b/changelog.d/16974.misc
@@ -0,0 +1 @@
+As done for SAML mapping provider, let's pass the module API to the OIDC one so the mapper can do more logic in its code.
diff --git a/docs/sso_mapping_providers.md b/docs/sso_mapping_providers.md
index 77cc02c541..10c695029f 100644
--- a/docs/sso_mapping_providers.md
+++ b/docs/sso_mapping_providers.md
@@ -50,11 +50,13 @@ comment these options out and use those specified by the module instead.
A custom mapping provider must specify the following methods:
-* `def __init__(self, parsed_config)`
+* `def __init__(self, parsed_config, module_api)`
- Arguments:
- `parsed_config` - A configuration object that is the return value of the
`parse_config` method. You should set any configuration options needed by
the module here.
+ - `module_api` - a `synapse.module_api.ModuleApi` object which provides the
+ stable API available for extension modules.
* `def parse_config(config)`
- This method should have the `@staticmethod` decoration.
- Arguments:
diff --git a/synapse/handlers/oidc.py b/synapse/handlers/oidc.py
index fe13d82b66..ba67cc4768 100644
--- a/synapse/handlers/oidc.py
+++ b/synapse/handlers/oidc.py
@@ -65,6 +65,7 @@ from synapse.http.server import finish_request
from synapse.http.servlet import parse_string
from synapse.http.site import SynapseRequest
from synapse.logging.context import make_deferred_yieldable
+from synapse.module_api import ModuleApi
from synapse.types import JsonDict, UserID, map_username_to_mxid_localpart
from synapse.util import Clock, json_decoder
from synapse.util.caches.cached_call import RetryOnExceptionCachedCall
@@ -421,9 +422,19 @@ class OidcProvider:
# from the IdP's jwks_uri, if required.
self._jwks = RetryOnExceptionCachedCall(self._load_jwks)
- self._user_mapping_provider = provider.user_mapping_provider_class(
- provider.user_mapping_provider_config
+ user_mapping_provider_init_method = (
+ provider.user_mapping_provider_class.__init__
)
+ if len(inspect.signature(user_mapping_provider_init_method).parameters) == 3:
+ self._user_mapping_provider = provider.user_mapping_provider_class(
+ provider.user_mapping_provider_config,
+ ModuleApi(hs, hs.get_auth_handler()),
+ )
+ else:
+ self._user_mapping_provider = provider.user_mapping_provider_class(
+ provider.user_mapping_provider_config,
+ )
+
self._skip_verification = provider.skip_verification
self._allow_existing_users = provider.allow_existing_users
@@ -1583,7 +1594,7 @@ class JinjaOidcMappingProvider(OidcMappingProvider[JinjaOidcMappingConfig]):
This is the default mapping provider.
"""
- def __init__(self, config: JinjaOidcMappingConfig):
+ def __init__(self, config: JinjaOidcMappingConfig, module_api: ModuleApi):
self._config = config
@staticmethod
|