diff options
Diffstat (limited to '')
-rw-r--r-- | synapse/rest/client/login.py | 29 |
1 files changed, 23 insertions, 6 deletions
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py index 0c8d8967b7..11d07776b2 100644 --- a/synapse/rest/client/login.py +++ b/synapse/rest/client/login.py @@ -104,6 +104,12 @@ class LoginRestServlet(RestServlet): burst_count=self.hs.config.rc_login_account.burst_count, ) + # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance. + # The reason for this is to ensure that the auth_provider_ids are registered + # with SsoHandler, which in turn ensures that the login/registration prometheus + # counters are initialised for the auth_provider_ids. + _load_sso_handlers(hs) + def on_GET(self, request: SynapseRequest): flows = [] if self.jwt_enabled: @@ -499,12 +505,7 @@ class SsoRedirectServlet(RestServlet): def __init__(self, hs: "HomeServer"): # make sure that the relevant handlers are instantiated, so that they # register themselves with the main SSOHandler. - if hs.config.cas_enabled: - hs.get_cas_handler() - if hs.config.saml2_enabled: - hs.get_saml_handler() - if hs.config.oidc_enabled: - hs.get_oidc_handler() + _load_sso_handlers(hs) self._sso_handler = hs.get_sso_handler() self._msc2858_enabled = hs.config.experimental.msc2858_enabled self._public_baseurl = hs.config.public_baseurl @@ -598,3 +599,19 @@ def register_servlets(hs, http_server): SsoRedirectServlet(hs).register(http_server) if hs.config.cas_enabled: CasTicketServlet(hs).register(http_server) + + +def _load_sso_handlers(hs: "HomeServer"): + """Ensure that the SSO handlers are loaded, if they are enabled by configuration. + + This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves + with the main SsoHandler. + + It's safe to call this multiple times. + """ + if hs.config.cas.cas_enabled: + hs.get_cas_handler() + if hs.config.saml2.saml2_enabled: + hs.get_saml_handler() + if hs.config.oidc.oidc_enabled: + hs.get_oidc_handler() |