summary refs log tree commit diff
path: root/synapse/handlers/oidc_handler.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/oidc_handler.py')
-rw-r--r--synapse/handlers/oidc_handler.py27
1 files changed, 20 insertions, 7 deletions
diff --git a/synapse/handlers/oidc_handler.py b/synapse/handlers/oidc_handler.py
index f63a90ec5c..5e5fda7b2f 100644
--- a/synapse/handlers/oidc_handler.py
+++ b/synapse/handlers/oidc_handler.py
@@ -78,21 +78,28 @@ class OidcHandler:
     def __init__(self, hs: "HomeServer"):
         self._sso_handler = hs.get_sso_handler()
 
-        provider_conf = hs.config.oidc.oidc_provider
+        provider_confs = hs.config.oidc.oidc_providers
         # we should not have been instantiated if there is no configured provider.
-        assert provider_conf is not None
+        assert provider_confs
 
         self._token_generator = OidcSessionTokenGenerator(hs)
-
-        self._provider = OidcProvider(hs, self._token_generator, provider_conf)
+        self._providers = {
+            p.idp_id: OidcProvider(hs, self._token_generator, p) for p in provider_confs
+        }
 
     async def load_metadata(self) -> None:
         """Validate the config and load the metadata from the remote endpoint.
 
         Called at startup to ensure we have everything we need.
         """
-        await self._provider.load_metadata()
-        await self._provider.load_jwks()
+        for idp_id, p in self._providers.items():
+            try:
+                await p.load_metadata()
+                await p.load_jwks()
+            except Exception as e:
+                raise Exception(
+                    "Error while initialising OIDC provider %r" % (idp_id,)
+                ) from e
 
     async def handle_oidc_callback(self, request: SynapseRequest) -> None:
         """Handle an incoming request to /_synapse/oidc/callback
@@ -184,6 +191,12 @@ class OidcHandler:
             self._sso_handler.render_error(request, "mismatching_session", str(e))
             return
 
+        oidc_provider = self._providers.get(session_data.idp_id)
+        if not oidc_provider:
+            logger.error("OIDC session uses unknown IdP %r", oidc_provider)
+            self._sso_handler.render_error(request, "unknown_idp", "Unknown IdP")
+            return
+
         if b"code" not in request.args:
             logger.info("Code parameter is missing")
             self._sso_handler.render_error(
@@ -193,7 +206,7 @@ class OidcHandler:
 
         code = request.args[b"code"][0].decode()
 
-        await self._provider.handle_oidc_callback(request, session_data, code)
+        await oidc_provider.handle_oidc_callback(request, session_data, code)
 
 
 class OidcError(Exception):