summary refs log tree commit diff
path: root/synapse/rest/client/v1/login.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r--synapse/rest/client/v1/login.py55
1 files changed, 49 insertions, 6 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index be938df962..0a561eea60 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
     parse_json_object_from_request,
@@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet):
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
 
         self.auth = hs.get_auth()
 
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
+        self._sso_handler = hs.get_sso_handler()
+
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
             clock=hs.get_clock(),
@@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet):
             flows.append({"type": LoginRestServlet.CAS_TYPE})
 
         if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            flows.append({"type": LoginRestServlet.SSO_TYPE})
-            # While its valid for us to advertise this login type generally,
+            sso_flow = {"type": LoginRestServlet.SSO_TYPE}  # type: JsonDict
+
+            if self._msc2858_enabled:
+                sso_flow["org.matrix.msc2858.identity_providers"] = [
+                    _get_auth_flow_dict_for_idp(idp)
+                    for idp in self._sso_handler.get_identity_providers().values()
+                ]
+
+            flows.append(sso_flow)
+
+            # While it's valid for us to advertise this login type generally,
             # synapse currently only gives out these tokens as part of the
             # SSO login flow.
             # Generally we don't want to advertise login flows that clients
@@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet):
         return result
 
 
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+    """Return an entry for the login flow dict
+
+    Returns an entry suitable for inclusion in "identity_providers" in the
+    response to GET /_matrix/client/r0/login
+    """
+    e = {"id": idp.idp_id, "name": idp.idp_name}  # type: JsonDict
+    if idp.idp_icon:
+        e["icon"] = idp.idp_icon
+    return e
+
+
 class SsoRedirectServlet(RestServlet):
-    PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+    PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
 
     def __init__(self, hs: "HomeServer"):
         # make sure that the relevant handlers are instantiated, so that they
@@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet):
         if hs.config.oidc_enabled:
             hs.get_oidc_handler()
         self._sso_handler = hs.get_sso_handler()
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+    def register(self, http_server: HttpServer) -> None:
+        super().register(http_server)
+        if self._msc2858_enabled:
+            # expose additional endpoint for MSC2858 support
+            http_server.register_paths(
+                "GET",
+                client_patterns(
+                    "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+                    releases=(),
+                    unstable=True,
+                ),
+                self.on_GET,
+                self.__class__.__name__,
+            )
 
-    async def on_GET(self, request: SynapseRequest):
+    async def on_GET(
+        self, request: SynapseRequest, idp_id: Optional[str] = None
+    ) -> None:
         client_redirect_url = parse_string(
             request, "redirectUrl", required=True, encoding=None
         )
         sso_url = await self._sso_handler.handle_redirect_request(
-            request, client_redirect_url
+            request, client_redirect_url, idp_id,
         )
         logger.info("Redirecting to %s", sso_url)
         request.redirect(sso_url)