summary refs log tree commit diff
path: root/synapse/rest/client/v1/login.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/rest/client/v1/login.py')
-rw-r--r--synapse/rest/client/v1/login.py105
1 files changed, 76 insertions, 29 deletions
diff --git a/synapse/rest/client/v1/login.py b/synapse/rest/client/v1/login.py
index 011e84e8b1..b7c5b58b01 100644
--- a/synapse/rest/client/v1/login.py
+++ b/synapse/rest/client/v1/login.py
@@ -23,8 +23,12 @@ from twisted.web.client import PartialDownloadError
 
 from synapse.api.errors import Codes, LoginError, SynapseError
 from synapse.http.server import finish_request
-from synapse.http.servlet import RestServlet, parse_json_object_from_request
-from synapse.types import UserID
+from synapse.http.servlet import (
+    RestServlet,
+    parse_json_object_from_request,
+    parse_string,
+)
+from synapse.types import UserID, map_username_to_mxid_localpart
 from synapse.util.msisdn import phone_number_to_msisdn
 
 from .base import ClientV1RestServlet, client_path_patterns
@@ -358,17 +362,15 @@ class CasTicketServlet(ClientV1RestServlet):
         self.cas_server_url = hs.config.cas_server_url
         self.cas_service_url = hs.config.cas_service_url
         self.cas_required_attributes = hs.config.cas_required_attributes
-        self.auth_handler = hs.get_auth_handler()
-        self.handlers = hs.get_handlers()
-        self.macaroon_gen = hs.get_macaroon_generator()
+        self._sso_auth_handler = SSOAuthHandler(hs)
 
     @defer.inlineCallbacks
     def on_GET(self, request):
-        client_redirect_url = request.args[b"redirectUrl"][0]
+        client_redirect_url = parse_string(request, "redirectUrl", required=True)
         http_client = self.hs.get_simple_http_client()
         uri = self.cas_server_url + "/proxyValidate"
         args = {
-            "ticket": request.args[b"ticket"][0].decode('ascii'),
+            "ticket": parse_string(request, "ticket", required=True),
             "service": self.cas_service_url
         }
         try:
@@ -380,7 +382,6 @@ class CasTicketServlet(ClientV1RestServlet):
         result = yield self.handle_cas_response(request, body, client_redirect_url)
         defer.returnValue(result)
 
-    @defer.inlineCallbacks
     def handle_cas_response(self, request, cas_response_body, client_redirect_url):
         user, attributes = self.parse_cas_response(cas_response_body)
 
@@ -396,28 +397,9 @@ class CasTicketServlet(ClientV1RestServlet):
                 if required_value != actual_value:
                     raise LoginError(401, "Unauthorized", errcode=Codes.UNAUTHORIZED)
 
-        user_id = UserID(user, self.hs.hostname).to_string()
-        auth_handler = self.auth_handler
-        registered_user_id = yield auth_handler.check_user_exists(user_id)
-        if not registered_user_id:
-            registered_user_id, _ = (
-                yield self.handlers.registration_handler.register(localpart=user)
-            )
-
-        login_token = self.macaroon_gen.generate_short_term_login_token(
-            registered_user_id
+        return self._sso_auth_handler.on_successful_auth(
+            user, request, client_redirect_url,
         )
-        redirect_url = self.add_login_token_to_redirect_url(client_redirect_url,
-                                                            login_token)
-        request.redirect(redirect_url)
-        finish_request(request)
-
-    def add_login_token_to_redirect_url(self, url, token):
-        url_parts = list(urllib.parse.urlparse(url))
-        query = dict(urllib.parse.parse_qsl(url_parts[4]))
-        query.update({"loginToken": token})
-        url_parts[4] = urllib.parse.urlencode(query).encode('ascii')
-        return urllib.parse.urlunparse(url_parts)
 
     def parse_cas_response(self, cas_response_body):
         user = None
@@ -452,6 +434,71 @@ class CasTicketServlet(ClientV1RestServlet):
         return user, attributes
 
 
+class SSOAuthHandler(object):
+    """
+    Utility class for Resources and Servlets which handle the response from a SSO
+    service
+
+    Args:
+        hs (synapse.server.HomeServer)
+    """
+    def __init__(self, hs):
+        self._hostname = hs.hostname
+        self._auth_handler = hs.get_auth_handler()
+        self._registration_handler = hs.get_handlers().registration_handler
+        self._macaroon_gen = hs.get_macaroon_generator()
+
+    @defer.inlineCallbacks
+    def on_successful_auth(
+        self, username, request, client_redirect_url,
+    ):
+        """Called once the user has successfully authenticated with the SSO.
+
+        Registers the user if necessary, and then returns a redirect (with
+        a login token) to the client.
+
+        Args:
+            username (unicode|bytes): the remote user id. We'll map this onto
+                something sane for a MXID localpath.
+
+            request (SynapseRequest): the incoming request from the browser. We'll
+                respond to it with a redirect.
+
+            client_redirect_url (unicode): the redirect_url the client gave us when
+                it first started the process.
+
+        Returns:
+            Deferred[none]: Completes once we have handled the request.
+        """
+        localpart = map_username_to_mxid_localpart(username)
+        user_id = UserID(localpart, self._hostname).to_string()
+        registered_user_id = yield self._auth_handler.check_user_exists(user_id)
+        if not registered_user_id:
+            registered_user_id, _ = (
+                yield self._registration_handler.register(
+                    localpart=localpart,
+                    generate_token=False,
+                )
+            )
+
+        login_token = self._macaroon_gen.generate_short_term_login_token(
+            registered_user_id
+        )
+        redirect_url = self._add_login_token_to_redirect_url(
+            client_redirect_url, login_token
+        )
+        request.redirect(redirect_url)
+        finish_request(request)
+
+    @staticmethod
+    def _add_login_token_to_redirect_url(url, token):
+        url_parts = list(urllib.parse.urlparse(url))
+        query = dict(urllib.parse.parse_qsl(url_parts[4]))
+        query.update({"loginToken": token})
+        url_parts[4] = urllib.parse.urlencode(query)
+        return urllib.parse.urlunparse(url_parts)
+
+
 def register_servlets(hs, http_server):
     LoginRestServlet(hs).register(http_server)
     if hs.config.cas_enabled: