summary refs log tree commit diff
path: root/synapse/handlers/auth.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r--synapse/handlers/auth.py265
1 files changed, 177 insertions, 88 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py
index 48a88d3c2a..7860f9625e 100644
--- a/synapse/handlers/auth.py
+++ b/synapse/handlers/auth.py
@@ -17,9 +17,11 @@
 import logging
 import time
 import unicodedata
+import urllib.parse
+from typing import Any, Dict, Iterable, List, Optional
 
 import attr
-import bcrypt
+import bcrypt  # type: ignore[import]
 import pymacaroons
 
 from twisted.internet import defer
@@ -38,9 +40,12 @@ from synapse.api.errors import (
 from synapse.api.ratelimiting import Ratelimiter
 from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS
 from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker
+from synapse.http.server import finish_request
+from synapse.http.site import SynapseRequest
 from synapse.logging.context import defer_to_thread
 from synapse.module_api import ModuleApi
-from synapse.types import UserID
+from synapse.push.mailer import load_jinja2_templates
+from synapse.types import Requester, UserID
 from synapse.util.caches.expiringcache import ExpiringCache
 
 from ._base import BaseHandler
@@ -58,11 +63,11 @@ class AuthHandler(BaseHandler):
         """
         super(AuthHandler, self).__init__(hs)
 
-        self.checkers = {}  # type: dict[str, UserInteractiveAuthChecker]
+        self.checkers = {}  # type: Dict[str, UserInteractiveAuthChecker]
         for auth_checker_class in INTERACTIVE_AUTH_CHECKERS:
             inst = auth_checker_class(hs)
             if inst.is_enabled():
-                self.checkers[inst.AUTH_TYPE] = inst
+                self.checkers[inst.AUTH_TYPE] = inst  # type: ignore
 
         self.bcrypt_rounds = hs.config.bcrypt_rounds
 
@@ -108,8 +113,20 @@ class AuthHandler(BaseHandler):
 
         self._clock = self.hs.get_clock()
 
+        # Load the SSO redirect confirmation page HTML template
+        self._sso_redirect_confirm_template = load_jinja2_templates(
+            hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"],
+        )[0]
+
+        self._server_name = hs.config.server_name
+
+        # cast to tuple for use with str.startswith
+        self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist)
+
     @defer.inlineCallbacks
-    def validate_user_via_ui_auth(self, requester, request_body, clientip):
+    def validate_user_via_ui_auth(
+        self, requester: Requester, request_body: Dict[str, Any], clientip: str
+    ):
         """
         Checks that the user is who they claim to be, via a UI auth.
 
@@ -118,11 +135,11 @@ class AuthHandler(BaseHandler):
         that it isn't stolen by re-authenticating them.
 
         Args:
-            requester (Requester): The user, as given by the access token
+            requester: The user, as given by the access token
 
-            request_body (dict): The body of the request sent by the client
+            request_body: The body of the request sent by the client
 
-            clientip (str): The IP address of the client.
+            clientip: The IP address of the client.
 
         Returns:
             defer.Deferred[dict]: the parameters for this request (which may
@@ -193,7 +210,9 @@ class AuthHandler(BaseHandler):
         return self.checkers.keys()
 
     @defer.inlineCallbacks
-    def check_auth(self, flows, clientdict, clientip):
+    def check_auth(
+        self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str
+    ):
         """
         Takes a dictionary sent by the client in the login / registration
         protocol and handles the User-Interactive Auth flow.
@@ -208,14 +227,14 @@ class AuthHandler(BaseHandler):
         decorator.
 
         Args:
-            flows (list): A list of login flows. Each flow is an ordered list of
-                          strings representing auth-types. At least one full
-                          flow must be completed in order for auth to be successful.
+            flows: A list of login flows. Each flow is an ordered list of
+                   strings representing auth-types. At least one full
+                   flow must be completed in order for auth to be successful.
 
             clientdict: The dictionary from the client root level, not the
                         'auth' key: this method prompts for auth if none is sent.
 
-            clientip (str): The IP address of the client.
+            clientip: The IP address of the client.
 
         Returns:
             defer.Deferred[dict, dict, str]: a deferred tuple of
@@ -235,7 +254,7 @@ class AuthHandler(BaseHandler):
         """
 
         authdict = None
-        sid = None
+        sid = None  # type: Optional[str]
         if clientdict and "auth" in clientdict:
             authdict = clientdict["auth"]
             del clientdict["auth"]
@@ -268,9 +287,9 @@ class AuthHandler(BaseHandler):
         creds = session["creds"]
 
         # check auth type currently being presented
-        errordict = {}
+        errordict = {}  # type: Dict[str, Any]
         if "type" in authdict:
-            login_type = authdict["type"]
+            login_type = authdict["type"]  # type: str
             try:
                 result = yield self._check_auth_dict(authdict, clientip)
                 if result:
@@ -311,7 +330,7 @@ class AuthHandler(BaseHandler):
         raise InteractiveAuthIncompleteError(ret)
 
     @defer.inlineCallbacks
-    def add_oob_auth(self, stagetype, authdict, clientip):
+    def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str):
         """
         Adds the result of out-of-band authentication into an existing auth
         session. Currently used for adding the result of fallback auth.
@@ -333,7 +352,7 @@ class AuthHandler(BaseHandler):
             return True
         return False
 
-    def get_session_id(self, clientdict):
+    def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]:
         """
         Gets the session ID for a client given the client dictionary
 
@@ -341,7 +360,7 @@ class AuthHandler(BaseHandler):
             clientdict: The dictionary sent by the client in the request
 
         Returns:
-            str|None: The string session ID the client sent. If the client did
+            The string session ID the client sent. If the client did
                 not send a session ID, returns None.
         """
         sid = None
@@ -351,40 +370,42 @@ class AuthHandler(BaseHandler):
                 sid = authdict["session"]
         return sid
 
-    def set_session_data(self, session_id, key, value):
+    def set_session_data(self, session_id: str, key: str, value: Any) -> None:
         """
         Store a key-value pair into the sessions data associated with this
         request. This data is stored server-side and cannot be modified by
         the client.
 
         Args:
-            session_id (string): The ID of this session as returned from check_auth
-            key (string): The key to store the data under
-            value (any): The data to store
+            session_id: The ID of this session as returned from check_auth
+            key: The key to store the data under
+            value: The data to store
         """
         sess = self._get_session_info(session_id)
         sess.setdefault("serverdict", {})[key] = value
         self._save_session(sess)
 
-    def get_session_data(self, session_id, key, default=None):
+    def get_session_data(
+        self, session_id: str, key: str, default: Optional[Any] = None
+    ) -> Any:
         """
         Retrieve data stored with set_session_data
 
         Args:
-            session_id (string): The ID of this session as returned from check_auth
-            key (string): The key to store the data under
-            default (any): Value to return if the key has not been set
+            session_id: The ID of this session as returned from check_auth
+            key: The key to store the data under
+            default: Value to return if the key has not been set
         """
         sess = self._get_session_info(session_id)
         return sess.setdefault("serverdict", {}).get(key, default)
 
     @defer.inlineCallbacks
-    def _check_auth_dict(self, authdict, clientip):
+    def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str):
         """Attempt to validate the auth dict provided by a client
 
         Args:
-            authdict (object): auth dict provided by the client
-            clientip (str): IP address of the client
+            authdict: auth dict provided by the client
+            clientip: IP address of the client
 
         Returns:
             Deferred: result of the stage verification.
@@ -410,10 +431,10 @@ class AuthHandler(BaseHandler):
         (canonical_id, callback) = yield self.validate_login(user_id, authdict)
         return canonical_id
 
-    def _get_params_recaptcha(self):
+    def _get_params_recaptcha(self) -> dict:
         return {"public_key": self.hs.config.recaptcha_public_key}
 
-    def _get_params_terms(self):
+    def _get_params_terms(self) -> dict:
         return {
             "policies": {
                 "privacy_policy": {
@@ -430,7 +451,9 @@ class AuthHandler(BaseHandler):
             }
         }
 
-    def _auth_dict_for_flows(self, flows, session):
+    def _auth_dict_for_flows(
+        self, flows: List[List[str]], session: Dict[str, Any]
+    ) -> Dict[str, Any]:
         public_flows = []
         for f in flows:
             public_flows.append(f)
@@ -440,7 +463,7 @@ class AuthHandler(BaseHandler):
             LoginType.TERMS: self._get_params_terms,
         }
 
-        params = {}
+        params = {}  # type: Dict[str, Any]
 
         for f in public_flows:
             for stage in f:
@@ -453,7 +476,13 @@ class AuthHandler(BaseHandler):
             "params": params,
         }
 
-    def _get_session_info(self, session_id):
+    def _get_session_info(self, session_id: Optional[str]) -> dict:
+        """
+        Gets or creates a session given a session ID.
+
+        The session can be used to track data across multiple requests, e.g. for
+        interactive authentication.
+        """
         if session_id not in self.sessions:
             session_id = None
 
@@ -466,7 +495,9 @@ class AuthHandler(BaseHandler):
         return self.sessions[session_id]
 
     @defer.inlineCallbacks
-    def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms):
+    def get_access_token_for_user_id(
+        self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int]
+    ):
         """
         Creates a new access token for the user with the given user ID.
 
@@ -476,11 +507,11 @@ class AuthHandler(BaseHandler):
         The device will be recorded in the table if it is not there already.
 
         Args:
-            user_id (str): canonical User ID
-            device_id (str|None): the device ID to associate with the tokens.
+            user_id: canonical User ID
+            device_id: the device ID to associate with the tokens.
                None to leave the tokens unassociated with a device (deprecated:
                we should always have a device ID)
-            valid_until_ms (int|None): when the token is valid until. None for
+            valid_until_ms: when the token is valid until. None for
                 no expiry.
         Returns:
               The access token for the user's session.
@@ -515,13 +546,13 @@ class AuthHandler(BaseHandler):
         return access_token
 
     @defer.inlineCallbacks
-    def check_user_exists(self, user_id):
+    def check_user_exists(self, user_id: str):
         """
         Checks to see if a user with the given id exists. Will check case
         insensitively, but return None if there are multiple inexact matches.
 
         Args:
-            (unicode|bytes) user_id: complete @user:id
+            user_id: complete @user:id
 
         Returns:
             defer.Deferred: (unicode) canonical_user_id, or None if zero or
@@ -536,7 +567,7 @@ class AuthHandler(BaseHandler):
         return None
 
     @defer.inlineCallbacks
-    def _find_user_id_and_pwd_hash(self, user_id):
+    def _find_user_id_and_pwd_hash(self, user_id: str):
         """Checks to see if a user with the given id exists. Will check case
         insensitively, but will return None if there are multiple inexact
         matches.
@@ -566,7 +597,7 @@ class AuthHandler(BaseHandler):
             )
         return result
 
-    def get_supported_login_types(self):
+    def get_supported_login_types(self) -> Iterable[str]:
         """Get a the login types supported for the /login API
 
         By default this is just 'm.login.password' (unless password_enabled is
@@ -574,20 +605,20 @@ class AuthHandler(BaseHandler):
         other login types.
 
         Returns:
-            Iterable[str]: login types
+            login types
         """
         return self._supported_login_types
 
     @defer.inlineCallbacks
-    def validate_login(self, username, login_submission):
+    def validate_login(self, username: str, login_submission: Dict[str, Any]):
         """Authenticates the user for the /login API
 
         Also used by the user-interactive auth flow to validate
         m.login.password auth types.
 
         Args:
-            username (str): username supplied by the user
-            login_submission (dict): the whole of the login submission
+            username: username supplied by the user
+            login_submission: the whole of the login submission
                 (including 'type' and other relevant fields)
         Returns:
             Deferred[str, func]: canonical user id, and optional callback
@@ -675,13 +706,13 @@ class AuthHandler(BaseHandler):
         raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN)
 
     @defer.inlineCallbacks
-    def check_password_provider_3pid(self, medium, address, password):
+    def check_password_provider_3pid(self, medium: str, address: str, password: str):
         """Check if a password provider is able to validate a thirdparty login
 
         Args:
-            medium (str): The medium of the 3pid (ex. email).
-            address (str): The address of the 3pid (ex. jdoe@example.com).
-            password (str): The password of the user.
+            medium: The medium of the 3pid (ex. email).
+            address: The address of the 3pid (ex. jdoe@example.com).
+            password: The password of the user.
 
         Returns:
             Deferred[(str|None, func|None)]: A tuple of `(user_id,
@@ -709,15 +740,15 @@ class AuthHandler(BaseHandler):
         return None, None
 
     @defer.inlineCallbacks
-    def _check_local_password(self, user_id, password):
+    def _check_local_password(self, user_id: str, password: str):
         """Authenticate a user against the local password database.
 
         user_id is checked case insensitively, but will return None if there are
         multiple inexact matches.
 
         Args:
-            user_id (unicode): complete @user:id
-            password (unicode): the provided password
+            user_id: complete @user:id
+            password: the provided password
         Returns:
             Deferred[unicode] the canonical_user_id, or Deferred[None] if
                 unknown user/bad password
@@ -740,7 +771,7 @@ class AuthHandler(BaseHandler):
         return user_id
 
     @defer.inlineCallbacks
-    def validate_short_term_login_token_and_get_user_id(self, login_token):
+    def validate_short_term_login_token_and_get_user_id(self, login_token: str):
         auth_api = self.hs.get_auth()
         user_id = None
         try:
@@ -754,11 +785,11 @@ class AuthHandler(BaseHandler):
         return user_id
 
     @defer.inlineCallbacks
-    def delete_access_token(self, access_token):
+    def delete_access_token(self, access_token: str):
         """Invalidate a single access token
 
         Args:
-            access_token (str): access token to be deleted
+            access_token: access token to be deleted
 
         Returns:
             Deferred
@@ -783,15 +814,17 @@ class AuthHandler(BaseHandler):
 
     @defer.inlineCallbacks
     def delete_access_tokens_for_user(
-        self, user_id, except_token_id=None, device_id=None
+        self,
+        user_id: str,
+        except_token_id: Optional[str] = None,
+        device_id: Optional[str] = None,
     ):
         """Invalidate access tokens belonging to a user
 
         Args:
-            user_id (str):  ID of user the tokens belong to
-            except_token_id (str|None): access_token ID which should *not* be
-                deleted
-            device_id (str|None):  ID of device the tokens are associated with.
+            user_id:  ID of user the tokens belong to
+            except_token_id: access_token ID which should *not* be deleted
+            device_id:  ID of device the tokens are associated with.
                 If None, tokens associated with any device (or no device) will
                 be deleted
         Returns:
@@ -815,7 +848,7 @@ class AuthHandler(BaseHandler):
         )
 
     @defer.inlineCallbacks
-    def add_threepid(self, user_id, medium, address, validated_at):
+    def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int):
         # check if medium has a valid value
         if medium not in ["email", "msisdn"]:
             raise SynapseError(
@@ -841,19 +874,20 @@ class AuthHandler(BaseHandler):
         )
 
     @defer.inlineCallbacks
-    def delete_threepid(self, user_id, medium, address, id_server=None):
+    def delete_threepid(
+        self, user_id: str, medium: str, address: str, id_server: Optional[str] = None
+    ):
         """Attempts to unbind the 3pid on the identity servers and deletes it
         from the local database.
 
         Args:
-            user_id (str)
-            medium (str)
-            address (str)
-            id_server (str|None): Use the given identity server when unbinding
+            user_id: ID of user to remove the 3pid from.
+            medium: The medium of the 3pid being removed: "email" or "msisdn".
+            address: The 3pid address to remove.
+            id_server: Use the given identity server when unbinding
                 any threepids. If None then will attempt to unbind using the
                 identity server specified when binding (if known).
 
-
         Returns:
             Deferred[bool]: Returns True if successfully unbound the 3pid on
             the identity server, False if identity server doesn't support the
@@ -872,17 +906,18 @@ class AuthHandler(BaseHandler):
         yield self.store.user_delete_threepid(user_id, medium, address)
         return result
 
-    def _save_session(self, session):
+    def _save_session(self, session: Dict[str, Any]) -> None:
+        """Update the last used time on the session to now and add it back to the session store."""
         # TODO: Persistent storage
         logger.debug("Saving session %s", session)
         session["last_used"] = self.hs.get_clock().time_msec()
         self.sessions[session["id"]] = session
 
-    def hash(self, password):
+    def hash(self, password: str):
         """Computes a secure hash of password.
 
         Args:
-            password (unicode): Password to hash.
+            password: Password to hash.
 
         Returns:
             Deferred(unicode): Hashed password.
@@ -899,12 +934,12 @@ class AuthHandler(BaseHandler):
 
         return defer_to_thread(self.hs.get_reactor(), _do_hash)
 
-    def validate_hash(self, password, stored_hash):
+    def validate_hash(self, password: str, stored_hash: bytes):
         """Validates that self.hash(password) == stored_hash.
 
         Args:
-            password (unicode): Password to hash.
-            stored_hash (bytes): Expected hash value.
+            password: Password to hash.
+            stored_hash: Expected hash value.
 
         Returns:
             Deferred(bool): Whether self.hash(password) == stored_hash.
@@ -927,13 +962,74 @@ class AuthHandler(BaseHandler):
         else:
             return defer.succeed(False)
 
+    def complete_sso_login(
+        self,
+        registered_user_id: str,
+        request: SynapseRequest,
+        client_redirect_url: str,
+    ):
+        """Having figured out a mxid for this user, complete the HTTP request
+
+        Args:
+            registered_user_id: The registered user ID to complete SSO login for.
+            request: The request to complete.
+            client_redirect_url: The URL to which to redirect the user at the end of the
+                process.
+        """
+        # Create a login token
+        login_token = self.macaroon_gen.generate_short_term_login_token(
+            registered_user_id
+        )
+
+        # Append the login token to the original redirect URL (i.e. with its query
+        # parameters kept intact) to build the URL to which the template needs to
+        # redirect the users once they have clicked on the confirmation link.
+        redirect_url = self.add_query_param_to_url(
+            client_redirect_url, "loginToken", login_token
+        )
+
+        # if the client is whitelisted, we can redirect straight to it
+        if client_redirect_url.startswith(self._whitelisted_sso_clients):
+            request.redirect(redirect_url)
+            finish_request(request)
+            return
+
+        # Otherwise, serve the redirect confirmation page.
+
+        # Remove the query parameters from the redirect URL to get a shorter version of
+        # it. This is only to display a human-readable URL in the template, but not the
+        # URL we redirect users to.
+        redirect_url_no_params = client_redirect_url.split("?")[0]
+
+        html = self._sso_redirect_confirm_template.render(
+            display_url=redirect_url_no_params,
+            redirect_url=redirect_url,
+            server_name=self._server_name,
+        ).encode("utf-8")
+
+        request.setResponseCode(200)
+        request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
+        request.setHeader(b"Content-Length", b"%d" % (len(html),))
+        request.write(html)
+        finish_request(request)
+
+    @staticmethod
+    def add_query_param_to_url(url: str, param_name: str, param: Any):
+        url_parts = list(urllib.parse.urlparse(url))
+        query = dict(urllib.parse.parse_qsl(url_parts[4]))
+        query.update({param_name: param})
+        url_parts[4] = urllib.parse.urlencode(query)
+        return urllib.parse.urlunparse(url_parts)
+
 
 @attr.s
 class MacaroonGenerator(object):
 
     hs = attr.ib()
 
-    def generate_access_token(self, user_id, extra_caveats=None):
+    def generate_access_token(
+        self, user_id: str, extra_caveats: Optional[List[str]] = None
+    ) -> str:
         extra_caveats = extra_caveats or []
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = access")
@@ -946,16 +1042,9 @@ class MacaroonGenerator(object):
             macaroon.add_first_party_caveat(caveat)
         return macaroon.serialize()
 
-    def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)):
-        """
-
-        Args:
-            user_id (unicode):
-            duration_in_ms (int):
-
-        Returns:
-            unicode
-        """
+    def generate_short_term_login_token(
+        self, user_id: str, duration_in_ms: int = (2 * 60 * 1000)
+    ) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = login")
         now = self.hs.get_clock().time_msec()
@@ -963,12 +1052,12 @@ class MacaroonGenerator(object):
         macaroon.add_first_party_caveat("time < %d" % (expiry,))
         return macaroon.serialize()
 
-    def generate_delete_pusher_token(self, user_id):
+    def generate_delete_pusher_token(self, user_id: str) -> str:
         macaroon = self._generate_base_macaroon(user_id)
         macaroon.add_first_party_caveat("type = delete_pusher")
         return macaroon.serialize()
 
-    def _generate_base_macaroon(self, user_id):
+    def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon:
         macaroon = pymacaroons.Macaroon(
             location=self.hs.config.server_name,
             identifier="key",