diff options
Diffstat (limited to 'synapse/handlers/auth.py')
-rw-r--r-- | synapse/handlers/auth.py | 265 |
1 files changed, 177 insertions, 88 deletions
diff --git a/synapse/handlers/auth.py b/synapse/handlers/auth.py index 48a88d3c2a..7860f9625e 100644 --- a/synapse/handlers/auth.py +++ b/synapse/handlers/auth.py @@ -17,9 +17,11 @@ import logging import time import unicodedata +import urllib.parse +from typing import Any, Dict, Iterable, List, Optional import attr -import bcrypt +import bcrypt # type: ignore[import] import pymacaroons from twisted.internet import defer @@ -38,9 +40,12 @@ from synapse.api.errors import ( from synapse.api.ratelimiting import Ratelimiter from synapse.handlers.ui_auth import INTERACTIVE_AUTH_CHECKERS from synapse.handlers.ui_auth.checkers import UserInteractiveAuthChecker +from synapse.http.server import finish_request +from synapse.http.site import SynapseRequest from synapse.logging.context import defer_to_thread from synapse.module_api import ModuleApi -from synapse.types import UserID +from synapse.push.mailer import load_jinja2_templates +from synapse.types import Requester, UserID from synapse.util.caches.expiringcache import ExpiringCache from ._base import BaseHandler @@ -58,11 +63,11 @@ class AuthHandler(BaseHandler): """ super(AuthHandler, self).__init__(hs) - self.checkers = {} # type: dict[str, UserInteractiveAuthChecker] + self.checkers = {} # type: Dict[str, UserInteractiveAuthChecker] for auth_checker_class in INTERACTIVE_AUTH_CHECKERS: inst = auth_checker_class(hs) if inst.is_enabled(): - self.checkers[inst.AUTH_TYPE] = inst + self.checkers[inst.AUTH_TYPE] = inst # type: ignore self.bcrypt_rounds = hs.config.bcrypt_rounds @@ -108,8 +113,20 @@ class AuthHandler(BaseHandler): self._clock = self.hs.get_clock() + # Load the SSO redirect confirmation page HTML template + self._sso_redirect_confirm_template = load_jinja2_templates( + hs.config.sso_redirect_confirm_template_dir, ["sso_redirect_confirm.html"], + )[0] + + self._server_name = hs.config.server_name + + # cast to tuple for use with str.startswith + self._whitelisted_sso_clients = tuple(hs.config.sso_client_whitelist) + @defer.inlineCallbacks - def validate_user_via_ui_auth(self, requester, request_body, clientip): + def validate_user_via_ui_auth( + self, requester: Requester, request_body: Dict[str, Any], clientip: str + ): """ Checks that the user is who they claim to be, via a UI auth. @@ -118,11 +135,11 @@ class AuthHandler(BaseHandler): that it isn't stolen by re-authenticating them. Args: - requester (Requester): The user, as given by the access token + requester: The user, as given by the access token - request_body (dict): The body of the request sent by the client + request_body: The body of the request sent by the client - clientip (str): The IP address of the client. + clientip: The IP address of the client. Returns: defer.Deferred[dict]: the parameters for this request (which may @@ -193,7 +210,9 @@ class AuthHandler(BaseHandler): return self.checkers.keys() @defer.inlineCallbacks - def check_auth(self, flows, clientdict, clientip): + def check_auth( + self, flows: List[List[str]], clientdict: Dict[str, Any], clientip: str + ): """ Takes a dictionary sent by the client in the login / registration protocol and handles the User-Interactive Auth flow. @@ -208,14 +227,14 @@ class AuthHandler(BaseHandler): decorator. Args: - flows (list): A list of login flows. Each flow is an ordered list of - strings representing auth-types. At least one full - flow must be completed in order for auth to be successful. + flows: A list of login flows. Each flow is an ordered list of + strings representing auth-types. At least one full + flow must be completed in order for auth to be successful. clientdict: The dictionary from the client root level, not the 'auth' key: this method prompts for auth if none is sent. - clientip (str): The IP address of the client. + clientip: The IP address of the client. Returns: defer.Deferred[dict, dict, str]: a deferred tuple of @@ -235,7 +254,7 @@ class AuthHandler(BaseHandler): """ authdict = None - sid = None + sid = None # type: Optional[str] if clientdict and "auth" in clientdict: authdict = clientdict["auth"] del clientdict["auth"] @@ -268,9 +287,9 @@ class AuthHandler(BaseHandler): creds = session["creds"] # check auth type currently being presented - errordict = {} + errordict = {} # type: Dict[str, Any] if "type" in authdict: - login_type = authdict["type"] + login_type = authdict["type"] # type: str try: result = yield self._check_auth_dict(authdict, clientip) if result: @@ -311,7 +330,7 @@ class AuthHandler(BaseHandler): raise InteractiveAuthIncompleteError(ret) @defer.inlineCallbacks - def add_oob_auth(self, stagetype, authdict, clientip): + def add_oob_auth(self, stagetype: str, authdict: Dict[str, Any], clientip: str): """ Adds the result of out-of-band authentication into an existing auth session. Currently used for adding the result of fallback auth. @@ -333,7 +352,7 @@ class AuthHandler(BaseHandler): return True return False - def get_session_id(self, clientdict): + def get_session_id(self, clientdict: Dict[str, Any]) -> Optional[str]: """ Gets the session ID for a client given the client dictionary @@ -341,7 +360,7 @@ class AuthHandler(BaseHandler): clientdict: The dictionary sent by the client in the request Returns: - str|None: The string session ID the client sent. If the client did + The string session ID the client sent. If the client did not send a session ID, returns None. """ sid = None @@ -351,40 +370,42 @@ class AuthHandler(BaseHandler): sid = authdict["session"] return sid - def set_session_data(self, session_id, key, value): + def set_session_data(self, session_id: str, key: str, value: Any) -> None: """ Store a key-value pair into the sessions data associated with this request. This data is stored server-side and cannot be modified by the client. Args: - session_id (string): The ID of this session as returned from check_auth - key (string): The key to store the data under - value (any): The data to store + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + value: The data to store """ sess = self._get_session_info(session_id) sess.setdefault("serverdict", {})[key] = value self._save_session(sess) - def get_session_data(self, session_id, key, default=None): + def get_session_data( + self, session_id: str, key: str, default: Optional[Any] = None + ) -> Any: """ Retrieve data stored with set_session_data Args: - session_id (string): The ID of this session as returned from check_auth - key (string): The key to store the data under - default (any): Value to return if the key has not been set + session_id: The ID of this session as returned from check_auth + key: The key to store the data under + default: Value to return if the key has not been set """ sess = self._get_session_info(session_id) return sess.setdefault("serverdict", {}).get(key, default) @defer.inlineCallbacks - def _check_auth_dict(self, authdict, clientip): + def _check_auth_dict(self, authdict: Dict[str, Any], clientip: str): """Attempt to validate the auth dict provided by a client Args: - authdict (object): auth dict provided by the client - clientip (str): IP address of the client + authdict: auth dict provided by the client + clientip: IP address of the client Returns: Deferred: result of the stage verification. @@ -410,10 +431,10 @@ class AuthHandler(BaseHandler): (canonical_id, callback) = yield self.validate_login(user_id, authdict) return canonical_id - def _get_params_recaptcha(self): + def _get_params_recaptcha(self) -> dict: return {"public_key": self.hs.config.recaptcha_public_key} - def _get_params_terms(self): + def _get_params_terms(self) -> dict: return { "policies": { "privacy_policy": { @@ -430,7 +451,9 @@ class AuthHandler(BaseHandler): } } - def _auth_dict_for_flows(self, flows, session): + def _auth_dict_for_flows( + self, flows: List[List[str]], session: Dict[str, Any] + ) -> Dict[str, Any]: public_flows = [] for f in flows: public_flows.append(f) @@ -440,7 +463,7 @@ class AuthHandler(BaseHandler): LoginType.TERMS: self._get_params_terms, } - params = {} + params = {} # type: Dict[str, Any] for f in public_flows: for stage in f: @@ -453,7 +476,13 @@ class AuthHandler(BaseHandler): "params": params, } - def _get_session_info(self, session_id): + def _get_session_info(self, session_id: Optional[str]) -> dict: + """ + Gets or creates a session given a session ID. + + The session can be used to track data across multiple requests, e.g. for + interactive authentication. + """ if session_id not in self.sessions: session_id = None @@ -466,7 +495,9 @@ class AuthHandler(BaseHandler): return self.sessions[session_id] @defer.inlineCallbacks - def get_access_token_for_user_id(self, user_id, device_id, valid_until_ms): + def get_access_token_for_user_id( + self, user_id: str, device_id: Optional[str], valid_until_ms: Optional[int] + ): """ Creates a new access token for the user with the given user ID. @@ -476,11 +507,11 @@ class AuthHandler(BaseHandler): The device will be recorded in the table if it is not there already. Args: - user_id (str): canonical User ID - device_id (str|None): the device ID to associate with the tokens. + user_id: canonical User ID + device_id: the device ID to associate with the tokens. None to leave the tokens unassociated with a device (deprecated: we should always have a device ID) - valid_until_ms (int|None): when the token is valid until. None for + valid_until_ms: when the token is valid until. None for no expiry. Returns: The access token for the user's session. @@ -515,13 +546,13 @@ class AuthHandler(BaseHandler): return access_token @defer.inlineCallbacks - def check_user_exists(self, user_id): + def check_user_exists(self, user_id: str): """ Checks to see if a user with the given id exists. Will check case insensitively, but return None if there are multiple inexact matches. Args: - (unicode|bytes) user_id: complete @user:id + user_id: complete @user:id Returns: defer.Deferred: (unicode) canonical_user_id, or None if zero or @@ -536,7 +567,7 @@ class AuthHandler(BaseHandler): return None @defer.inlineCallbacks - def _find_user_id_and_pwd_hash(self, user_id): + def _find_user_id_and_pwd_hash(self, user_id: str): """Checks to see if a user with the given id exists. Will check case insensitively, but will return None if there are multiple inexact matches. @@ -566,7 +597,7 @@ class AuthHandler(BaseHandler): ) return result - def get_supported_login_types(self): + def get_supported_login_types(self) -> Iterable[str]: """Get a the login types supported for the /login API By default this is just 'm.login.password' (unless password_enabled is @@ -574,20 +605,20 @@ class AuthHandler(BaseHandler): other login types. Returns: - Iterable[str]: login types + login types """ return self._supported_login_types @defer.inlineCallbacks - def validate_login(self, username, login_submission): + def validate_login(self, username: str, login_submission: Dict[str, Any]): """Authenticates the user for the /login API Also used by the user-interactive auth flow to validate m.login.password auth types. Args: - username (str): username supplied by the user - login_submission (dict): the whole of the login submission + username: username supplied by the user + login_submission: the whole of the login submission (including 'type' and other relevant fields) Returns: Deferred[str, func]: canonical user id, and optional callback @@ -675,13 +706,13 @@ class AuthHandler(BaseHandler): raise LoginError(403, "Invalid password", errcode=Codes.FORBIDDEN) @defer.inlineCallbacks - def check_password_provider_3pid(self, medium, address, password): + def check_password_provider_3pid(self, medium: str, address: str, password: str): """Check if a password provider is able to validate a thirdparty login Args: - medium (str): The medium of the 3pid (ex. email). - address (str): The address of the 3pid (ex. jdoe@example.com). - password (str): The password of the user. + medium: The medium of the 3pid (ex. email). + address: The address of the 3pid (ex. jdoe@example.com). + password: The password of the user. Returns: Deferred[(str|None, func|None)]: A tuple of `(user_id, @@ -709,15 +740,15 @@ class AuthHandler(BaseHandler): return None, None @defer.inlineCallbacks - def _check_local_password(self, user_id, password): + def _check_local_password(self, user_id: str, password: str): """Authenticate a user against the local password database. user_id is checked case insensitively, but will return None if there are multiple inexact matches. Args: - user_id (unicode): complete @user:id - password (unicode): the provided password + user_id: complete @user:id + password: the provided password Returns: Deferred[unicode] the canonical_user_id, or Deferred[None] if unknown user/bad password @@ -740,7 +771,7 @@ class AuthHandler(BaseHandler): return user_id @defer.inlineCallbacks - def validate_short_term_login_token_and_get_user_id(self, login_token): + def validate_short_term_login_token_and_get_user_id(self, login_token: str): auth_api = self.hs.get_auth() user_id = None try: @@ -754,11 +785,11 @@ class AuthHandler(BaseHandler): return user_id @defer.inlineCallbacks - def delete_access_token(self, access_token): + def delete_access_token(self, access_token: str): """Invalidate a single access token Args: - access_token (str): access token to be deleted + access_token: access token to be deleted Returns: Deferred @@ -783,15 +814,17 @@ class AuthHandler(BaseHandler): @defer.inlineCallbacks def delete_access_tokens_for_user( - self, user_id, except_token_id=None, device_id=None + self, + user_id: str, + except_token_id: Optional[str] = None, + device_id: Optional[str] = None, ): """Invalidate access tokens belonging to a user Args: - user_id (str): ID of user the tokens belong to - except_token_id (str|None): access_token ID which should *not* be - deleted - device_id (str|None): ID of device the tokens are associated with. + user_id: ID of user the tokens belong to + except_token_id: access_token ID which should *not* be deleted + device_id: ID of device the tokens are associated with. If None, tokens associated with any device (or no device) will be deleted Returns: @@ -815,7 +848,7 @@ class AuthHandler(BaseHandler): ) @defer.inlineCallbacks - def add_threepid(self, user_id, medium, address, validated_at): + def add_threepid(self, user_id: str, medium: str, address: str, validated_at: int): # check if medium has a valid value if medium not in ["email", "msisdn"]: raise SynapseError( @@ -841,19 +874,20 @@ class AuthHandler(BaseHandler): ) @defer.inlineCallbacks - def delete_threepid(self, user_id, medium, address, id_server=None): + def delete_threepid( + self, user_id: str, medium: str, address: str, id_server: Optional[str] = None + ): """Attempts to unbind the 3pid on the identity servers and deletes it from the local database. Args: - user_id (str) - medium (str) - address (str) - id_server (str|None): Use the given identity server when unbinding + user_id: ID of user to remove the 3pid from. + medium: The medium of the 3pid being removed: "email" or "msisdn". + address: The 3pid address to remove. + id_server: Use the given identity server when unbinding any threepids. If None then will attempt to unbind using the identity server specified when binding (if known). - Returns: Deferred[bool]: Returns True if successfully unbound the 3pid on the identity server, False if identity server doesn't support the @@ -872,17 +906,18 @@ class AuthHandler(BaseHandler): yield self.store.user_delete_threepid(user_id, medium, address) return result - def _save_session(self, session): + def _save_session(self, session: Dict[str, Any]) -> None: + """Update the last used time on the session to now and add it back to the session store.""" # TODO: Persistent storage logger.debug("Saving session %s", session) session["last_used"] = self.hs.get_clock().time_msec() self.sessions[session["id"]] = session - def hash(self, password): + def hash(self, password: str): """Computes a secure hash of password. Args: - password (unicode): Password to hash. + password: Password to hash. Returns: Deferred(unicode): Hashed password. @@ -899,12 +934,12 @@ class AuthHandler(BaseHandler): return defer_to_thread(self.hs.get_reactor(), _do_hash) - def validate_hash(self, password, stored_hash): + def validate_hash(self, password: str, stored_hash: bytes): """Validates that self.hash(password) == stored_hash. Args: - password (unicode): Password to hash. - stored_hash (bytes): Expected hash value. + password: Password to hash. + stored_hash: Expected hash value. Returns: Deferred(bool): Whether self.hash(password) == stored_hash. @@ -927,13 +962,74 @@ class AuthHandler(BaseHandler): else: return defer.succeed(False) + def complete_sso_login( + self, + registered_user_id: str, + request: SynapseRequest, + client_redirect_url: str, + ): + """Having figured out a mxid for this user, complete the HTTP request + + Args: + registered_user_id: The registered user ID to complete SSO login for. + request: The request to complete. + client_redirect_url: The URL to which to redirect the user at the end of the + process. + """ + # Create a login token + login_token = self.macaroon_gen.generate_short_term_login_token( + registered_user_id + ) + + # Append the login token to the original redirect URL (i.e. with its query + # parameters kept intact) to build the URL to which the template needs to + # redirect the users once they have clicked on the confirmation link. + redirect_url = self.add_query_param_to_url( + client_redirect_url, "loginToken", login_token + ) + + # if the client is whitelisted, we can redirect straight to it + if client_redirect_url.startswith(self._whitelisted_sso_clients): + request.redirect(redirect_url) + finish_request(request) + return + + # Otherwise, serve the redirect confirmation page. + + # Remove the query parameters from the redirect URL to get a shorter version of + # it. This is only to display a human-readable URL in the template, but not the + # URL we redirect users to. + redirect_url_no_params = client_redirect_url.split("?")[0] + + html = self._sso_redirect_confirm_template.render( + display_url=redirect_url_no_params, + redirect_url=redirect_url, + server_name=self._server_name, + ).encode("utf-8") + + request.setResponseCode(200) + request.setHeader(b"Content-Type", b"text/html; charset=utf-8") + request.setHeader(b"Content-Length", b"%d" % (len(html),)) + request.write(html) + finish_request(request) + + @staticmethod + def add_query_param_to_url(url: str, param_name: str, param: Any): + url_parts = list(urllib.parse.urlparse(url)) + query = dict(urllib.parse.parse_qsl(url_parts[4])) + query.update({param_name: param}) + url_parts[4] = urllib.parse.urlencode(query) + return urllib.parse.urlunparse(url_parts) + @attr.s class MacaroonGenerator(object): hs = attr.ib() - def generate_access_token(self, user_id, extra_caveats=None): + def generate_access_token( + self, user_id: str, extra_caveats: Optional[List[str]] = None + ) -> str: extra_caveats = extra_caveats or [] macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = access") @@ -946,16 +1042,9 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat(caveat) return macaroon.serialize() - def generate_short_term_login_token(self, user_id, duration_in_ms=(2 * 60 * 1000)): - """ - - Args: - user_id (unicode): - duration_in_ms (int): - - Returns: - unicode - """ + def generate_short_term_login_token( + self, user_id: str, duration_in_ms: int = (2 * 60 * 1000) + ) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = login") now = self.hs.get_clock().time_msec() @@ -963,12 +1052,12 @@ class MacaroonGenerator(object): macaroon.add_first_party_caveat("time < %d" % (expiry,)) return macaroon.serialize() - def generate_delete_pusher_token(self, user_id): + def generate_delete_pusher_token(self, user_id: str) -> str: macaroon = self._generate_base_macaroon(user_id) macaroon.add_first_party_caveat("type = delete_pusher") return macaroon.serialize() - def _generate_base_macaroon(self, user_id): + def _generate_base_macaroon(self, user_id: str) -> pymacaroons.Macaroon: macaroon = pymacaroons.Macaroon( location=self.hs.config.server_name, identifier="key", |