summary refs log tree commit diff
path: root/synapse/rest/client
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client')
-rw-r--r--synapse/rest/client/v1/login.py45
-rw-r--r--synapse/rest/client/v1/logout.py6
-rw-r--r--synapse/rest/client/v2_alpha/auth.py19
3 files changed, 55 insertions, 15 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 4de2f97d06..d89b2e5532 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -83,6 +83,7 @@ class LoginRestServlet(RestServlet):
         self.jwt_algorithm = hs.config.jwt_algorithm
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
+        self.oidc_enabled = hs.config.oidc_enabled
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
         self.handlers = hs.get_handlers()
@@ -96,9 +97,7 @@ class LoginRestServlet(RestServlet):
         flows = []
         if self.jwt_enabled:
             flows.append({"type": LoginRestServlet.JWT_TYPE})
-        if self.saml2_enabled:
-            flows.append({"type": LoginRestServlet.SSO_TYPE})
-            flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+
         if self.cas_enabled:
             flows.append({"type": LoginRestServlet.SSO_TYPE})
 
@@ -114,6 +113,11 @@ class LoginRestServlet(RestServlet):
             # fall back to the fallback API if they don't understand one of the
             # login flow types returned.
             flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+        elif self.saml2_enabled:
+            flows.append({"type": LoginRestServlet.SSO_TYPE})
+            flows.append({"type": LoginRestServlet.TOKEN_TYPE})
+        elif self.oidc_enabled:
+            flows.append({"type": LoginRestServlet.SSO_TYPE})
 
         flows.extend(
             ({"type": t} for t in self.auth_handler.get_supported_login_types())
@@ -397,19 +401,22 @@ class BaseSSORedirectServlet(RestServlet):
 
     PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
 
-    def on_GET(self, request: SynapseRequest):
+    async def on_GET(self, request: SynapseRequest):
         args = request.args
         if b"redirectUrl" not in args:
             return 400, "Redirect URL not specified for SSO auth"
         client_redirect_url = args[b"redirectUrl"][0]
-        sso_url = self.get_sso_url(client_redirect_url)
+        sso_url = await self.get_sso_url(request, client_redirect_url)
         request.redirect(sso_url)
         finish_request(request)
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         """Get the URL to redirect to, to perform SSO auth
 
         Args:
+            request: The client request to redirect.
             client_redirect_url: the URL that we should redirect the
                 client to when everything is done
 
@@ -424,7 +431,9 @@ class CasRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
         self._cas_handler = hs.get_cas_handler()
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         return self._cas_handler.get_redirect_url(
             {"redirectUrl": client_redirect_url}
         ).encode("ascii")
@@ -461,10 +470,28 @@ class SAMLRedirectServlet(BaseSSORedirectServlet):
     def __init__(self, hs):
         self._saml_handler = hs.get_saml_handler()
 
-    def get_sso_url(self, client_redirect_url: bytes) -> bytes:
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
         return self._saml_handler.handle_redirect_request(client_redirect_url)
 
 
+class OIDCRedirectServlet(BaseSSORedirectServlet):
+    """Implementation for /login/sso/redirect for the OIDC login flow."""
+
+    PATTERNS = client_patterns("/login/sso/redirect", v1=True)
+
+    def __init__(self, hs):
+        self._oidc_handler = hs.get_oidc_handler()
+
+    async def get_sso_url(
+        self, request: SynapseRequest, client_redirect_url: bytes
+    ) -> bytes:
+        return await self._oidc_handler.handle_redirect_request(
+            request, client_redirect_url
+        )
+
+
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
     if hs.config.cas_enabled:
@@ -472,3 +499,5 @@ def register_servlets(hs, http_server):
         CasTicketServlet(hs).register(http_server)
     elif hs.config.saml2_enabled:
         SAMLRedirectServlet(hs).register(http_server)
+    elif hs.config.oidc_enabled:
+        OIDCRedirectServlet(hs).register(http_server)
diff --git a/synapse/rest/client/v1/logout.py b/synapse/rest/client/v1/logout.py
index 1cf3caf832..b0c30b65be 100644
--- a/synapse/rest/client/v1/logout.py
+++ b/synapse/rest/client/v1/logout.py
@@ -34,10 +34,10 @@ class LogoutRestServlet(RestServlet):
         return 200, {}
 
     async def on_POST(self, request):
-        requester = await self.auth.get_user_by_req(request)
+        requester = await self.auth.get_user_by_req(request, allow_expired=True)
 
         if requester.device_id is None:
-            # the acccess token wasn't associated with a device.
+            # The access token wasn't associated with a device.
             # Just delete the access token
             access_token = self.auth.get_access_token_from_request(request)
             await self._auth_handler.delete_access_token(access_token)
@@ -62,7 +62,7 @@ class LogoutAllRestServlet(RestServlet):
         return 200, {}
 
     async def on_POST(self, request):
-        requester = await self.auth.get_user_by_req(request)
+        requester = await self.auth.get_user_by_req(request, allow_expired=True)
         user_id = requester.user.to_string()
 
         # first delete all of the user's devices
diff --git a/synapse/rest/client/v2_alpha/auth.py b/synapse/rest/client/v2_alpha/auth.py
index 24dd3d3e96..7bca1326d5 100644
--- a/synapse/rest/client/v2_alpha/auth.py
+++ b/synapse/rest/client/v2_alpha/auth.py
@@ -131,14 +131,19 @@ class AuthRestServlet(RestServlet):
         self.registration_handler = hs.get_registration_handler()
 
         # SSO configuration.
-        self._saml_enabled = hs.config.saml2_enabled
-        if self._saml_enabled:
-            self._saml_handler = hs.get_saml_handler()
         self._cas_enabled = hs.config.cas_enabled
         if self._cas_enabled:
             self._cas_handler = hs.get_cas_handler()
             self._cas_server_url = hs.config.cas_server_url
             self._cas_service_url = hs.config.cas_service_url
+        self._saml_enabled = hs.config.saml2_enabled
+        if self._saml_enabled:
+            self._saml_handler = hs.get_saml_handler()
+        self._oidc_enabled = hs.config.oidc_enabled
+        if self._oidc_enabled:
+            self._oidc_handler = hs.get_oidc_handler()
+            self._cas_server_url = hs.config.cas_server_url
+            self._cas_service_url = hs.config.cas_service_url
 
     async def on_GET(self, request, stagetype):
         session = parse_string(request, "session")
@@ -172,11 +177,17 @@ class AuthRestServlet(RestServlet):
                 )
 
             elif self._saml_enabled:
-                client_redirect_url = ""
+                client_redirect_url = b""
                 sso_redirect_url = self._saml_handler.handle_redirect_request(
                     client_redirect_url, session
                 )
 
+            elif self._oidc_enabled:
+                client_redirect_url = b""
+                sso_redirect_url = await self._oidc_handler.handle_redirect_request(
+                    request, client_redirect_url, session
+                )
+
             else:
                 raise SynapseError(400, "Homeserver not configured for SSO.")