summary refs log tree commit diff
path: root/synapse/rest/client/v1/login.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r--synapse/rest/client/v1/login.py31
1 files changed, 19 insertions, 12 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index de7eca21f8..d89b2e5532 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -401,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet):
 
     PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
 
-    def on_GET(self, request: SynapseRequest):
+    async def on_GET(self, request: SynapseRequest):
         args = request.args
         if b"redirectUrl" not in args:
             return 400, "Redirect URL not specified for SSO auth"
         client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = self.get_sso_url(client_redirect_url)
+        sso_url = await self.get_sso_url(request, client_redirect_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         """Get the URL to redirect to, to perform SSO auth
 
         Args:
+            request: The client request to redirect.
             client_redirect_url: the URL that we should redirect the
                 client to when everything is done
 
@@ -428,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
         self._cas_handler = hs.get_cas_handler()
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         return self._cas_handler.get_redirect_url(
             {"redirectUrl": client_redirect_url}
         ).encode("ascii")
@@ -465,11 +470,13 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
         self._saml_handler = hs.get_saml_handler()
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         return self._saml_handler.handle_redirect_request(client_redirect_url)
 
 
-class OIDCRedirectServlet(RestServlet):
+class OIDCRedirectServlet(BaseSSORedirectServlet):
     """Implementation for /login/sso/redirect for the OIDC login flow."""
 
     PATTERNS = client_patterns("/login/sso/redirect", v1=True)
@@ -477,12 +484,12 @@ class OIDCRedirectServlet(RestServlet):
     def __init__(self, hs):
         self._oidc_handler = hs.get_oidc_handler()
 
-    async def on_GET(self, request):
-        args = request.args
-        if b"redirectUrl" not in args:
-            return 400, "Redirect URL not specified for SSO auth"
-        client_redirect_url = args[b"redirectUrl"][0]
-        await self._oidc_handler.handle_redirect_request(request, client_redirect_url)
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
+        return await self._oidc_handler.handle_redirect_request(
+            request, client_redirect_url
+        )
 
 
 def register_servlets(hs, http_server):