summary refs log tree commit diff
path: root/synapse/rest/client/v2_alpha/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v2_alpha/auth.py')
-rw-r--r--synapse/rest/client/v2_alpha/auth.py53
1 files changed, 8 insertions, 45 deletions
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index fab077747f..75ece1c911 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -14,6 +14,7 @@
 # limitations under the License.
 
 import logging
+from typing import TYPE_CHECKING
 
 from synapse.api.constants import LoginType
 from synapse.api.errors import SynapseError
@@ -23,6 +24,9 @@ from synapse.http.servlet import RestServlet, parse_string
 
 from ._base import client_patterns
 
+if TYPE_CHECKING:
+    from synapse.server import HomeServer
+
 logger = logging.getLogger(__name__)
 
 
@@ -35,28 +39,12 @@ class AuthRestServlet(RestServlet):
 
     PATTERNS = client_patterns(r"/auth/(?P<stagetype>[\w\.]*)/fallback/web")
 
-    def __init__(self, hs):
+    def __init__(self, hs: "HomeServer"):
         super().__init__()
         self.hs = hs
         self.auth = hs.get_auth()
         self.auth_handler = hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
-
-        # SSO configuration.
-        self._cas_enabled = hs.config.cas_enabled
-        if self._cas_enabled:
-            self._cas_handler = hs.get_cas_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-        self._saml_enabled = hs.config.saml2_enabled
-        if self._saml_enabled:
-            self._saml_handler = hs.get_saml_handler()
-        self._oidc_enabled = hs.config.oidc_enabled
-        if self._oidc_enabled:
-            self._oidc_handler = hs.get_oidc_handler()
-            self._cas_server_url = hs.config.cas_server_url
-            self._cas_service_url = hs.config.cas_service_url
-
         self.recaptcha_template = hs.config.recaptcha_template
         self.terms_template = hs.config.terms_template
         self.success_template = hs.config.fallback_success_template
@@ -85,32 +73,7 @@ class AuthRestServlet(RestServlet):
         elif stagetype == LoginType.SSO:
             # Display a confirmation page which prompts the user to
             # re-authenticate with their SSO provider.
-            if self._cas_enabled:
-                # Generate a request to CAS that redirects back to an endpoint
-                # to verify the successful authentication.
-                sso_redirect_url = self._cas_handler.get_redirect_url(
-                    {"session": session},
-                )
-
-            elif self._saml_enabled:
-                # Some SAML identity providers (e.g. Google) require a
-                # RelayState parameter on requests. It is not necessary here, so
-                # pass in a dummy redirect URL (which will never get used).
-                client_redirect_url = b"unused"
-                sso_redirect_url = self._saml_handler.handle_redirect_request(
-                    client_redirect_url, session
-                )
-
-            elif self._oidc_enabled:
-                client_redirect_url = b""
-                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
-                    request, client_redirect_url, session
-                )
-
-            else:
-                raise SynapseError(400, "Homeserver not configured for SSO.")
-
-            html = await self.auth_handler.start_sso_ui_auth(sso_redirect_url, session)
+            html = await self.auth_handler.start_sso_ui_auth(request, session)
 
         else:
             raise SynapseError(404, "Unknown auth stage type")
@@ -134,7 +97,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"response": response, "session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.RECAPTCHA, authdict, self.hs.get_ip_from_request(request)
+                LoginType.RECAPTCHA, authdict, request.getClientIP()
             )
 
             if success:
@@ -150,7 +113,7 @@ class AuthRestServlet(RestServlet):
             authdict = {"session": session}
 
             success = await self.auth_handler.add_oob_auth(
-                LoginType.TERMS, authdict, self.hs.get_ip_from_request(request)
+                LoginType.TERMS, authdict, request.getClientIP()
             )
 
             if success: