summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9109.feature1
-rw-r--r--synapse/config/oidc_config.py26
-rw-r--r--synapse/handlers/oidc_handler.py22
-rw-r--r--tests/handlers/test_oidc.py3
4 files changed, 42 insertions, 10 deletions
diff --git a/changelog.d/9109.feature b/changelog.d/9109.feature
new file mode 100644
index 0000000000..01a24dcf49
--- /dev/null
+++ b/changelog.d/9109.feature
@@ -0,0 +1 @@
+Add support for multiple SSO Identity Providers.
diff --git a/synapse/config/oidc_config.py b/synapse/config/oidc_config.py
index c705de5694..fddca19223 100644
--- a/synapse/config/oidc_config.py
+++ b/synapse/config/oidc_config.py
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 # Copyright 2020 Quentin Gliech
-# Copyright 2020 The Matrix.org Foundation C.I.C.
+# Copyright 2020-2021 The Matrix.org Foundation C.I.C.
 #
 # Licensed under the Apache License, Version 2.0 (the "License");
 # you may not use this file except in compliance with the License.
@@ -14,6 +14,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
+import string
 from typing import Optional, Type
 
 import attr
@@ -38,7 +39,7 @@ class OIDCConfig(Config):
 
         oidc_config = config.get("oidc_config")
         if oidc_config and oidc_config.get("enabled", False):
-            validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, "oidc_config")
+            validate_config(OIDC_PROVIDER_CONFIG_SCHEMA, oidc_config, ("oidc_config",))
             self.oidc_provider = _parse_oidc_config_dict(oidc_config)
 
         if not self.oidc_provider:
@@ -205,6 +206,8 @@ OIDC_PROVIDER_CONFIG_SCHEMA = {
     "type": "object",
     "required": ["issuer", "client_id", "client_secret"],
     "properties": {
+        "idp_id": {"type": "string", "minLength": 1, "maxLength": 128},
+        "idp_name": {"type": "string"},
         "discover": {"type": "boolean"},
         "issuer": {"type": "string"},
         "client_id": {"type": "string"},
@@ -277,7 +280,17 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
             "methods: %s" % (", ".join(missing_methods),)
         )
 
+    # MSC2858 will appy certain limits in what can be used as an IdP id, so let's
+    # enforce those limits now.
+    idp_id = oidc_config.get("idp_id", "oidc")
+    valid_idp_chars = set(string.ascii_letters + string.digits + "-._~")
+
+    if any(c not in valid_idp_chars for c in idp_id):
+        raise ConfigError('idp_id may only contain A-Z, a-z, 0-9, "-", ".", "_", "~"')
+
     return OidcProviderConfig(
+        idp_id=idp_id,
+        idp_name=oidc_config.get("idp_name", "OIDC"),
         discover=oidc_config.get("discover", True),
         issuer=oidc_config["issuer"],
         client_id=oidc_config["client_id"],
@@ -296,8 +309,15 @@ def _parse_oidc_config_dict(oidc_config: JsonDict) -> "OidcProviderConfig":
     )
 
 
-@attr.s
+@attr.s(slots=True, frozen=True)
 class OidcProviderConfig:
+    # a unique identifier for this identity provider. Used in the 'user_external_ids'
+    # table, as well as the query/path parameter used in the login protocol.
+    idp_id = attr.ib(type=str)
+
+    # user-facing name for this identity provider.
+    idp_name = attr.ib(type=str)
+
     # whether the OIDC discovery mechanism is used to discover endpoints
     discover = attr.ib(type=bool)
 
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index d6347bb1b8..f63a90ec5c 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -175,7 +175,7 @@ class OidcHandler:
             session_data = self._token_generator.verify_oidc_session_token(
                 session, state
             )
-        except MacaroonDeserializationException as e:
+        except (MacaroonDeserializationException, ValueError) as e:
             logger.exception("Invalid session")
             self._sso_handler.render_error(request, "invalid_session", str(e))
             return
@@ -253,10 +253,10 @@ class OidcProvider:
         self._server_name = hs.config.server_name  # type: str
 
         # identifier for the external_ids table
-        self.idp_id = "oidc"
+        self.idp_id = provider.idp_id
 
         # user-facing name of this auth provider
-        self.idp_name = "OIDC"
+        self.idp_name = provider.idp_name
 
         self._sso_handler = hs.get_sso_handler()
 
@@ -656,6 +656,7 @@ class OidcProvider:
         cookie = self._token_generator.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
+                idp_id=self.idp_id,
                 nonce=nonce,
                 client_redirect_url=client_redirect_url.decode(),
                 ui_auth_session_id=ui_auth_session_id,
@@ -924,6 +925,7 @@ class OidcSessionTokenGenerator:
         macaroon.add_first_party_caveat("gen = 1")
         macaroon.add_first_party_caveat("type = session")
         macaroon.add_first_party_caveat("state = %s" % (state,))
+        macaroon.add_first_party_caveat("idp_id = %s" % (session_data.idp_id,))
         macaroon.add_first_party_caveat("nonce = %s" % (session_data.nonce,))
         macaroon.add_first_party_caveat(
             "client_redirect_url = %s" % (session_data.client_redirect_url,)
@@ -952,6 +954,9 @@ class OidcSessionTokenGenerator:
 
         Returns:
             The data extracted from the session cookie
+
+        Raises:
+            ValueError if an expected caveat is missing from the macaroon.
         """
         macaroon = pymacaroons.Macaroon.deserialize(session)
 
@@ -960,6 +965,7 @@ class OidcSessionTokenGenerator:
         v.satisfy_exact("type = session")
         v.satisfy_exact("state = %s" % (state,))
         v.satisfy_general(lambda c: c.startswith("nonce = "))
+        v.satisfy_general(lambda c: c.startswith("idp_id = "))
         v.satisfy_general(lambda c: c.startswith("client_redirect_url = "))
         # Sometimes there's a UI auth session ID, it seems to be OK to attempt
         # to always satisfy this.
@@ -968,9 +974,9 @@ class OidcSessionTokenGenerator:
 
         v.verify(macaroon, self._macaroon_secret_key)
 
-        # Extract the `nonce`, `client_redirect_url`, and maybe the
-        # `ui_auth_session_id` from the token.
+        # Extract the session data from the token.
         nonce = self._get_value_from_macaroon(macaroon, "nonce")
+        idp_id = self._get_value_from_macaroon(macaroon, "idp_id")
         client_redirect_url = self._get_value_from_macaroon(
             macaroon, "client_redirect_url"
         )
@@ -983,6 +989,7 @@ class OidcSessionTokenGenerator:
 
         return OidcSessionData(
             nonce=nonce,
+            idp_id=idp_id,
             client_redirect_url=client_redirect_url,
             ui_auth_session_id=ui_auth_session_id,
         )
@@ -998,7 +1005,7 @@ class OidcSessionTokenGenerator:
             The extracted value
 
         Raises:
-            Exception: if the caveat was not in the macaroon
+            ValueError: if the caveat was not in the macaroon
         """
         prefix = key + " = "
         for caveat in macaroon.caveats:
@@ -1019,6 +1026,9 @@ class OidcSessionTokenGenerator:
 class OidcSessionData:
     """The attributes which are stored in a OIDC session cookie"""
 
+    # the Identity Provider being used
+    idp_id = attr.ib(type=str)
+
     # The `nonce` parameter passed to the OIDC provider.
     nonce = attr.ib(type=str)
 
diff --git a/tests/handlers/test_oidc.py b/tests/handlers/test_oidc.py
index 5d338bea87..38ae8ca19e 100644
--- a/tests/handlers/test_oidc.py
+++ b/tests/handlers/test_oidc.py
@@ -848,6 +848,7 @@ class OidcHandlerTestCase(HomeserverTestCase):
         return self.handler._token_generator.generate_oidc_session_token(
             state=state,
             session_data=OidcSessionData(
+                idp_id="oidc",
                 nonce=nonce,
                 client_redirect_url=client_redirect_url,
                 ui_auth_session_id=ui_auth_session_id,
@@ -990,7 +991,7 @@ async def _make_callback_with_userinfo(
     session = handler._token_generator.generate_oidc_session_token(
         state=state,
         session_data=OidcSessionData(
-            nonce="nonce", client_redirect_url=client_redirect_url,
+            idp_id="oidc", nonce="nonce", client_redirect_url=client_redirect_url,
         ),
     )
     request = _build_callback_request("code", state, session)