summary refs log tree commit diff
path: root/synapse/handlers
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers')
-rw-r--r--synapse/handlers/auth.py58
-rw-r--r--synapse/handlers/federation.py2
-rw-r--r--synapse/handlers/identity.py6
3 files changed, 47 insertions, 19 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index c7dc07008a..2e72298e05 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -193,9 +193,7 @@ class AuthHandler(BaseHandler):
         self.hs = hs  # FIXME better possibility to access registrationHandler later?
         self.macaroon_gen = hs.get_macaroon_generator()
         self._password_enabled = hs.config.password_enabled
-        self._sso_enabled = (
-            hs.config.cas_enabled or hs.config.saml2_enabled or hs.config.oidc_enabled
-        )
+        self._password_localdb_enabled = hs.config.password_localdb_enabled
 
         # we keep this as a list despite the O(N^2) implication so that we can
         # keep PASSWORD first and avoid confusing clients which pick the first
@@ -205,7 +203,7 @@ class AuthHandler(BaseHandler):
 
         # start out by assuming PASSWORD is enabled; we will remove it later if not.
         login_types = []
-        if hs.config.password_localdb_enabled:
+        if self._password_localdb_enabled:
             login_types.append(LoginType.PASSWORD)
 
         for provider in self.password_providers:
@@ -219,14 +217,6 @@ class AuthHandler(BaseHandler):
 
         self._supported_login_types = login_types
 
-        # Login types and UI Auth types have a heavy overlap, but are not
-        # necessarily identical. Login types have SSO (and other login types)
-        # added in the rest layer, see synapse.rest.client.v1.login.LoginRestServerlet.on_GET.
-        ui_auth_types = login_types.copy()
-        if self._sso_enabled:
-            ui_auth_types.append(LoginType.SSO)
-        self._supported_ui_auth_types = ui_auth_types
-
         # Ratelimiter for failed auth during UIA. Uses same ratelimit config
         # as per `rc_login.failed_attempts`.
         self._failed_uia_attempts_ratelimiter = Ratelimiter(
@@ -339,7 +329,10 @@ class AuthHandler(BaseHandler):
         self._failed_uia_attempts_ratelimiter.ratelimit(user_id, update=False)
 
         # build a list of supported flows
-        flows = [[login_type] for login_type in self._supported_ui_auth_types]
+        supported_ui_auth_types = await self._get_available_ui_auth_types(
+            requester.user
+        )
+        flows = [[login_type] for login_type in supported_ui_auth_types]
 
         try:
             result, params, session_id = await self.check_ui_auth(
@@ -351,7 +344,7 @@ class AuthHandler(BaseHandler):
             raise
 
         # find the completed login type
-        for login_type in self._supported_ui_auth_types:
+        for login_type in supported_ui_auth_types:
             if login_type not in result:
                 continue
 
@@ -367,6 +360,41 @@ class AuthHandler(BaseHandler):
 
         return params, session_id
 
+    async def _get_available_ui_auth_types(self, user: UserID) -> Iterable[str]:
+        """Get a list of the authentication types this user can use
+        """
+
+        ui_auth_types = set()
+
+        # if the HS supports password auth, and the user has a non-null password, we
+        # support password auth
+        if self._password_localdb_enabled and self._password_enabled:
+            lookupres = await self._find_user_id_and_pwd_hash(user.to_string())
+            if lookupres:
+                _, password_hash = lookupres
+                if password_hash:
+                    ui_auth_types.add(LoginType.PASSWORD)
+
+        # also allow auth from password providers
+        for provider in self.password_providers:
+            for t in provider.get_supported_login_types().keys():
+                if t == LoginType.PASSWORD and not self._password_enabled:
+                    continue
+                ui_auth_types.add(t)
+
+        # if sso is enabled, allow the user to log in via SSO iff they have a mapping
+        # from sso to mxid.
+        if self.hs.config.saml2.saml2_enabled or self.hs.config.oidc.oidc_enabled:
+            if await self.store.get_external_ids_by_user(user.to_string()):
+                ui_auth_types.add(LoginType.SSO)
+
+        # Our CAS impl does not (yet) correctly register users in user_external_ids,
+        # so always offer that if it's available.
+        if self.hs.config.cas.cas_enabled:
+            ui_auth_types.add(LoginType.SSO)
+
+        return ui_auth_types
+
     def get_enabled_auth_types(self):
         """Return the enabled user-interactive authentication types
 
@@ -1029,7 +1057,7 @@ class AuthHandler(BaseHandler):
             if result:
                 return result
 
-        if login_type == LoginType.PASSWORD and self.hs.config.password_localdb_enabled:
+        if login_type == LoginType.PASSWORD and self._password_localdb_enabled:
             known_login_type = True
 
             # we've already checked that there is a (valid) password field
diff --git a/synapse/handlers/federation.py b/synapse/handlers/federation.py
index b9799090f7..df82e60b33 100644
--- a/synapse/handlers/federation.py
+++ b/synapse/handlers/federation.py
@@ -140,7 +140,7 @@ class FederationHandler(BaseHandler):
         self._message_handler = hs.get_message_handler()
         self._server_notices_mxid = hs.config.server_notices_mxid
         self.config = hs.config
-        self.http_client = hs.get_simple_http_client()
+        self.http_client = hs.get_proxied_blacklisted_http_client()
         self._instance_name = hs.get_instance_name()
         self._replication = hs.get_replication_data_handler()
 
diff --git a/synapse/handlers/identity.py b/synapse/handlers/identity.py
index 9b3c6b4551..7301c24710 100644
--- a/synapse/handlers/identity.py
+++ b/synapse/handlers/identity.py
@@ -46,13 +46,13 @@ class IdentityHandler(BaseHandler):
     def __init__(self, hs):
         super().__init__(hs)
 
+        # An HTTP client for contacting trusted URLs.
         self.http_client = SimpleHttpClient(hs)
-        # We create a blacklisting instance of SimpleHttpClient for contacting identity
-        # servers specified by clients
+        # An HTTP client for contacting identity servers specified by clients.
         self.blacklisting_http_client = SimpleHttpClient(
             hs, ip_blacklist=hs.config.federation_ip_range_blacklist
         )
-        self.federation_http_client = hs.get_http_client()
+        self.federation_http_client = hs.get_federation_http_client()
         self.hs = hs
 
     async def threepid_from_creds(