summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/config/_base.pyi2
-rw-r--r--synapse/config/experimental.py29
-rw-r--r--synapse/config/homeserver.py2
-rw-r--r--synapse/handlers/sso.py23
-rw-r--r--synapse/http/server.py44
-rw-r--r--synapse/rest/client/v1/login.py55
6 files changed, 136 insertions, 19 deletions
diff --git a/synapse/config/_base.pyi b/synapse/config/_base.pyi
index 29aa064e57..3ccea4b02d 100644
--- a/synapse/config/_base.pyi
+++ b/synapse/config/_base.pyi
@@ -9,6 +9,7 @@ from synapse.config import (
     consent_config,
     database,
     emailconfig,
+    experimental,
     groups,
     jwt_config,
     key,
@@ -48,6 +49,7 @@ def path_exists(file_path: str): ...
 
 class RootConfig:
     server: server.ServerConfig
+    experimental: experimental.ExperimentalConfig
     tls: tls.TlsConfig
     database: database.DatabaseConfig
     logging: logger.LoggingConfig
diff --git a/synapse/config/experimental.py b/synapse/config/experimental.py
new file mode 100644
index 0000000000..b1c1c51e4d
--- /dev/null
+++ b/synapse/config/experimental.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# Copyright 2021 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from synapse.config._base import Config
+from synapse.types import JsonDict
+
+
+class ExperimentalConfig(Config):
+    """Config section for enabling experimental features"""
+
+    section = "experimental"
+
+    def read_config(self, config: JsonDict, **kwargs):
+        experimental = config.get("experimental_features") or {}
+
+        # MSC2858 (multiple SSO identity providers)
+        self.msc2858_enabled = experimental.get("msc2858_enabled", False)  # type: bool
diff --git a/synapse/config/homeserver.py b/synapse/config/homeserver.py
index 4bd2b3587b..64a2429f77 100644
--- a/synapse/config/homeserver.py
+++ b/synapse/config/homeserver.py
@@ -24,6 +24,7 @@ from .cas import CasConfig
 from .consent_config import ConsentConfig
 from .database import DatabaseConfig
 from .emailconfig import EmailConfig
+from .experimental import ExperimentalConfig
 from .federation import FederationConfig
 from .groups import GroupsConfig
 from .jwt_config import JWTConfig
@@ -57,6 +58,7 @@ class HomeServerConfig(RootConfig):
 
     config_classes = [
         ServerConfig,
+        ExperimentalConfig,
         TlsConfig,
         FederationConfig,
         CacheConfig,
diff --git a/synapse/handlers/sso.py b/synapse/handlers/sso.py
index d493327a10..afc1341d09 100644
--- a/synapse/handlers/sso.py
+++ b/synapse/handlers/sso.py
@@ -23,7 +23,7 @@ from typing_extensions import NoReturn, Protocol
 from twisted.web.http import Request
 
 from synapse.api.constants import LoginType
-from synapse.api.errors import Codes, RedirectException, SynapseError
+from synapse.api.errors import Codes, NotFoundError, RedirectException, SynapseError
 from synapse.handlers.ui_auth import UIAuthSessionDataConstants
 from synapse.http import get_request_user_agent
 from synapse.http.server import respond_with_html
@@ -235,7 +235,10 @@ class SsoHandler:
         respond_with_html(request, code, html)
 
     async def handle_redirect_request(
-        self, request: SynapseRequest, client_redirect_url: bytes,
+        self,
+        request: SynapseRequest,
+        client_redirect_url: bytes,
+        idp_id: Optional[str],
     ) -> str:
         """Handle a request to /login/sso/redirect
 
@@ -243,6 +246,7 @@ class SsoHandler:
             request: incoming HTTP request
             client_redirect_url: the URL that we should redirect the
                 client to after login.
+            idp_id: optional identity provider chosen by the client
 
         Returns:
              the URI to redirect to
@@ -252,10 +256,19 @@ class SsoHandler:
                 400, "Homeserver not configured for SSO.", errcode=Codes.UNRECOGNIZED
             )
 
+        # if the client chose an IdP, use that
+        idp = None  # type: Optional[SsoIdentityProvider]
+        if idp_id:
+            idp = self._identity_providers.get(idp_id)
+            if not idp:
+                raise NotFoundError("Unknown identity provider")
+
         # if we only have one auth provider, redirect to it directly
-        if len(self._identity_providers) == 1:
-            ap = next(iter(self._identity_providers.values()))
-            return await ap.handle_redirect_request(request, client_redirect_url)
+        elif len(self._identity_providers) == 1:
+            idp = next(iter(self._identity_providers.values()))
+
+        if idp:
+            return await idp.handle_redirect_request(request, client_redirect_url)
 
         # otherwise, redirect to the IDP picker
         return "/_synapse/client/pick_idp?" + urlencode(
diff --git a/synapse/http/server.py b/synapse/http/server.py
index e464bfe6c7..d69d579b3a 100644
--- a/synapse/http/server.py
+++ b/synapse/http/server.py
@@ -22,10 +22,22 @@ import types
 import urllib
 from http import HTTPStatus
 from io import BytesIO
-from typing import Any, Callable, Dict, Iterator, List, Tuple, Union
+from typing import (
+    Any,
+    Awaitable,
+    Callable,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Pattern,
+    Tuple,
+    Union,
+)
 
 import jinja2
 from canonicaljson import iterencode_canonical_json
+from typing_extensions import Protocol
 from zope.interface import implementer
 
 from twisted.internet import defer, interfaces
@@ -168,11 +180,25 @@ def wrap_async_request_handler(h):
     return preserve_fn(wrapped_async_request_handler)
 
 
-class HttpServer:
+# Type of a callback method for processing requests
+# it is actually called with a SynapseRequest and a kwargs dict for the params,
+# but I can't figure out how to represent that.
+ServletCallback = Callable[
+    ..., Union[None, Awaitable[None], Tuple[int, Any], Awaitable[Tuple[int, Any]]]
+]
+
+
+class HttpServer(Protocol):
     """ Interface for registering callbacks on a HTTP server
     """
 
-    def register_paths(self, method, path_patterns, callback):
+    def register_paths(
+        self,
+        method: str,
+        path_patterns: Iterable[Pattern],
+        callback: ServletCallback,
+        servlet_classname: str,
+    ) -> None:
         """ Register a callback that gets fired if we receive a http request
         with the given method for a path that matches the given regex.
 
@@ -180,12 +206,14 @@ class HttpServer:
         an unpacked tuple.
 
         Args:
-            method (str): The method to listen to.
-            path_patterns (list<SRE_Pattern>): The regex used to match requests.
-            callback (function): The function to fire if we receive a matched
+            method: The HTTP method to listen to.
+            path_patterns: The regex used to match requests.
+            callback: The function to fire if we receive a matched
                 request. The first argument will be the request object and
                 subsequent arguments will be any matched groups from the regex.
-                This should return a tuple of (code, response).
+                This should return either tuple of (code, response), or None.
+            servlet_classname (str): The name of the handler to be used in prometheus
+                and opentracing logs.
         """
         pass
 
@@ -354,7 +382,7 @@ class JsonResource(DirectServeJsonResource):
 
     def _get_handler_for_request(
         self, request: SynapseRequest
-    ) -> Tuple[Callable, str, Dict[str, str]]:
+    ) -> Tuple[ServletCallback, str, Dict[str, str]]:
         """Finds a callback method to handle the given request.
 
         Returns:
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index be938df962..0a561eea60 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, Awaitable, Callable, Dict, Optional
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.appservice import ApplicationService
-from synapse.http.server import finish_request
+from synapse.handlers.sso import SsoIdentityProvider
+from synapse.http.server import HttpServer, finish_request
 from synapse.http.servlet import (
     RestServlet,
     parse_json_object_from_request,
@@ -60,11 +61,14 @@ class LoginRestServlet(RestServlet):
         self.saml2_enabled = hs.config.saml2_enabled
         self.cas_enabled = hs.config.cas_enabled
         self.oidc_enabled = hs.config.oidc_enabled
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
 
         self.auth = hs.get_auth()
 
         self.auth_handler = self.hs.get_auth_handler()
         self.registration_handler = hs.get_registration_handler()
+        self._sso_handler = hs.get_sso_handler()
+
         self._well_known_builder = WellKnownBuilder(hs)
         self._address_ratelimiter = Ratelimiter(
             clock=hs.get_clock(),
@@ -89,8 +93,17 @@ class LoginRestServlet(RestServlet):
             flows.append({"type": LoginRestServlet.CAS_TYPE})
 
         if self.cas_enabled or self.saml2_enabled or self.oidc_enabled:
-            flows.append({"type": LoginRestServlet.SSO_TYPE})
-            # While its valid for us to advertise this login type generally,
+            sso_flow = {"type": LoginRestServlet.SSO_TYPE}  # type: JsonDict
+
+            if self._msc2858_enabled:
+                sso_flow["org.matrix.msc2858.identity_providers"] = [
+                    _get_auth_flow_dict_for_idp(idp)
+                    for idp in self._sso_handler.get_identity_providers().values()
+                ]
+
+            flows.append(sso_flow)
+
+            # While it's valid for us to advertise this login type generally,
             # synapse currently only gives out these tokens as part of the
             # SSO login flow.
             # Generally we don't want to advertise login flows that clients
@@ -311,8 +324,20 @@ class LoginRestServlet(RestServlet):
         return result
 
 
+def _get_auth_flow_dict_for_idp(idp: SsoIdentityProvider) -> JsonDict:
+    """Return an entry for the login flow dict
+
+    Returns an entry suitable for inclusion in "identity_providers" in the
+    response to GET /_matrix/client/r0/login
+    """
+    e = {"id": idp.idp_id, "name": idp.idp_name}  # type: JsonDict
+    if idp.idp_icon:
+        e["icon"] = idp.idp_icon
+    return e
+
+
 class SsoRedirectServlet(RestServlet):
-    PATTERNS = client_patterns("/login/(cas|sso)/redirect", v1=True)
+    PATTERNS = client_patterns("/login/(cas|sso)/redirect$", v1=True)
 
     def __init__(self, hs: "HomeServer"):
         # make sure that the relevant handlers are instantiated, so that they
@@ -324,13 +349,31 @@ class SsoRedirectServlet(RestServlet):
         if hs.config.oidc_enabled:
             hs.get_oidc_handler()
         self._sso_handler = hs.get_sso_handler()
+        self._msc2858_enabled = hs.config.experimental.msc2858_enabled
+
+    def register(self, http_server: HttpServer) -> None:
+        super().register(http_server)
+        if self._msc2858_enabled:
+            # expose additional endpoint for MSC2858 support
+            http_server.register_paths(
+                "GET",
+                client_patterns(
+                    "/org.matrix.msc2858/login/sso/redirect/(?P<idp_id>[A-Za-z0-9_.~-]+)$",
+                    releases=(),
+                    unstable=True,
+                ),
+                self.on_GET,
+                self.__class__.__name__,
+            )
 
-    async def on_GET(self, request: SynapseRequest):
+    async def on_GET(
+        self, request: SynapseRequest, idp_id: Optional[str] = None
+    ) -> None:
         client_redirect_url = parse_string(
             request, "redirectUrl", required=True, encoding=None
         )
         sso_url = await self._sso_handler.handle_redirect_request(
-            request, client_redirect_url
+            request, client_redirect_url, idp_id,
         )
         logger.info("Redirecting to %s", sso_url)
         request.redirect(sso_url)