summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/register.py18
-rw-r--r--synapse/handlers/sso.py2
-rw-r--r--synapse/rest/client/login.py29
3 files changed, 43 insertions, 6 deletions
diff --git a/synapse/handlers/register.py b/synapse/handlers/register.py
index 8cf614136e..0ed59d757b 100644
--- a/synapse/handlers/register.py
+++ b/synapse/handlers/register.py
@@ -56,6 +56,22 @@ login_counter = Counter(
 )
 
 
+def init_counters_for_auth_provider(auth_provider_id: str) -> None:
+    """Ensure the prometheus counters for the given auth provider are initialised
+
+    This fixes a problem where the counters are not reported for a given auth provider
+    until the user first logs in/registers.
+    """
+    for is_guest in (True, False):
+        login_counter.labels(guest=is_guest, auth_provider=auth_provider_id)
+        for shadow_banned in (True, False):
+            registration_counter.labels(
+                guest=is_guest,
+                shadow_banned=shadow_banned,
+                auth_provider=auth_provider_id,
+            )
+
+
 class LoginDict(TypedDict):
     device_id: str
     access_token: str
@@ -96,6 +112,8 @@ class RegistrationHandler(BaseHandler):
         self.session_lifetime = hs.config.session_lifetime
         self.access_token_lifetime = hs.config.access_token_lifetime
 
+        init_counters_for_auth_provider("")
+
     async def check_username(
         self,
         localpart: str,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index 1b855a685c..0e6ebb574e 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -37,6 +37,7 @@ from twisted.web.server import Request
 from synapse.api.constants import LoginType
 from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
 from synapse.config.sso import SsoAttributeRequirement
+from synapse.handlers.register import init_counters_for_auth_provider
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html, respond_with_redirect
@@ -213,6 +214,7 @@ class SsoHandler:
         p_id = p.idp_id
         assert p_id not in self._identity_providers
         self._identity_providers[p_id] = p
+        init_counters_for_auth_provider(p_id)
 
     def get_identity_providers(self) -> Mapping[str, SsoIdentityProvider]:
         """Get the configured identity providers"""
diff --git a/synapse/rest/client/login.py b/synapse/rest/client/login.py
index 0c8d8967b7..11d07776b2 100644
--- a/synapse/rest/client/login.py
+++ b/synapse/rest/client/login.py
@@ -104,6 +104,12 @@ class LoginRestServlet(RestServlet):
             burst_count=self.hs.config.rc_login_account.burst_count,
         )
 
+        # ensure the CAS/SAML/OIDC handlers are loaded on this worker instance.
+        # The reason for this is to ensure that the auth_provider_ids are registered
+        # with SsoHandler, which in turn ensures that the login/registration prometheus
+        # counters are initialised for the auth_provider_ids.
+        _load_sso_handlers(hs)
+
     def on_GET(self, request: SynapseRequest):
         flows = []
         if self.jwt_enabled:
@@ -499,12 +505,7 @@ class SsoRedirectServlet(RestServlet):
     def __init__(self, hs: "HomeServer"):
         # make sure that the relevant handlers are instantiated, so that they
         # register themselves with the main SSOHandler.
-        if hs.config.cas_enabled:
-            hs.get_cas_handler()
-        if hs.config.saml2_enabled:
-            hs.get_saml_handler()
-        if hs.config.oidc_enabled:
-            hs.get_oidc_handler()
+        _load_sso_handlers(hs)
         self._sso_handler = hs.get_sso_handler()
         self._msc2858_enabled = hs.config.experimental.msc2858_enabled
         self._public_baseurl = hs.config.public_baseurl
@@ -598,3 +599,19 @@ def register_servlets(hs, http_server):
     SsoRedirectServlet(hs).register(http_server)
     if hs.config.cas_enabled:
         CasTicketServlet(hs).register(http_server)
+
+
+def _load_sso_handlers(hs: "HomeServer"):
+    """Ensure that the SSO handlers are loaded, if they are enabled by configuration.
+
+    This is mostly useful to ensure that the CAS/SAML/OIDC handlers register themselves
+    with the main SsoHandler.
+
+    It's safe to call this multiple times.
+    """
+    if hs.config.cas.cas_enabled:
+        hs.get_cas_handler()
+    if hs.config.saml2.saml2_enabled:
+        hs.get_saml_handler()
+    if hs.config.oidc.oidc_enabled:
+        hs.get_oidc_handler()