summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 48a88d3c2a..7ca90f91c4 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -17,6 +17,8 @@
 import logging
 import time
 import unicodedata
+import urllib.parse
+from typing import Any
 
 import attr
 import bcrypt
@@ -38,8 +40,11 @@ from synapse.api.errors import (
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http.server import finish_request
+from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.module_api import ModuleApi
+from synapse.push.mailer import load_jinja2_templates
 from synapse.types import UserID
 from synapse.util.caches.expiringcache import ExpiringCache
 
@@ -108,6 +113,16 @@ class AuthHandler(BaseHandler):
 
         self._clock = self.hs.get_clock()
 
+        # Load the SSO redirect confirmation page HTML template
+        self._sso_redirect_confirm_template = load_jinja2_templates(
+            hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+        )[0]
+
+        self._server_name = hs.config.server_name
+
+        # cast to tuple for use with str.startswith
+        self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+
     @defer.inlineCallbacks
     def validate_user_via_ui_auth(self, requester, request_body, clientip):
         """
@@ -927,6 +942,65 @@ class AuthHandler(BaseHandler):
         else:
             return defer.succeed(False)
 
+    def complete_sso_login(
+        self,
+        registered_user_id: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
+    ):
+        """Having figured out a mxid for this user, complete the HTTP request
+
+        Args:
+            registered_user_id: The registered user ID to complete SSO login for.
+            request: The request to complete.
+            client_redirect_url: The URL to which to redirect the user at the end of the
+                process.
+        """
+        # Create a login token
+        login_token = self.macaroon_gen.generate_short_term_login_token(
+            registered_user_id
+        )
+
+        # Append the login token to the original redirect URL (i.e. with its query
+        # parameters kept intact) to build the URL to which the template needs to
+        # redirect the users once they have clicked on the confirmation link.
+        redirect_url = self.add_query_param_to_url(
+            client_redirect_url, "loginToken", login_token
+        )
+
+        # if the client is whitelisted, we can redirect straight to it
+        if client_redirect_url.startswith(self._whitelisted_sso_clients):
+            request.redirect(redirect_url)
+            finish_request(request)
+            return
+
+        # Otherwise, serve the redirect confirmation page.
+
+        # Remove the query parameters from the redirect URL to get a shorter version of
+        # it. This is only to display a human-readable URL in the template, but not the
+        # URL we redirect users to.
+        redirect_url_no_params = client_redirect_url.split("?")[0]
+
+        html = self._sso_redirect_confirm_template.render(
+            display_url=redirect_url_no_params,
+            redirect_url=redirect_url,
+            server_name=self._server_name,
+        ).encode("utf-8")
+
+        request.setResponseCode(200)
+        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+        request.setHeader(b"Content-Length", b"%d" % (len(html),))
+        request.write(html)
+        finish_request(request)
+
+    @staticmethod
+    def add_query_param_to_url(url: str, param_name: str, param: Any):
+        url_parts = list(urllib.parse.urlparse(url))
+        query = dict(urllib.parse.parse_qsl(url_parts[4]))
+        query.update({param_name: param})
+        url_parts[4] = urllib.parse.urlencode(query)
+        return urllib.parse.urlunparse(url_parts)
+
 
 @attr.s
 class MacaroonGenerator(object):